mirror of
https://github.com/chenasraf/nvim-treesitter.git
synced 2026-05-18 01:39:00 +00:00
refactor(all): language tree adaption (#1105)
This commit is contained in:
@@ -333,12 +333,16 @@ Swaps the nodes or ranges.
|
||||
set `cursor_to_second` to true to move the cursor to the second node
|
||||
|
||||
*ts_utils.memoize_by_buf_tick*
|
||||
memoize_by_buf_tick(fn)~
|
||||
memoize_by_buf_tick(fn, options)~
|
||||
|
||||
Cache values by bufnr tick change
|
||||
Caches the return value for a function and returns the cache value if the tick
|
||||
of the buffer has not changed from the previous.
|
||||
|
||||
`fn`: a function that takes a bufnr as argument
|
||||
`fn`: a function that takes any arguments
|
||||
and returns a value to store.
|
||||
`options?`: <table>
|
||||
- `bufnr`: a function/value that extracts the bufnr from the given arguments.
|
||||
- `key`: a function/value that extracts the cache key from the given arguments.
|
||||
`returns`: a function to call with bufnr as argument to
|
||||
retrieve the value from the cache
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
local api = vim.api
|
||||
local utils = require'nvim-treesitter.ts_utils'
|
||||
local tsutils = require'nvim-treesitter.ts_utils'
|
||||
local query = require'nvim-treesitter.query'
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
|
||||
@@ -7,18 +7,19 @@ local M = {}
|
||||
|
||||
-- This is cached on buf tick to avoid computing that multiple times
|
||||
-- Especially not for every line in the file when `zx` is hit
|
||||
local folds_levels = utils.memoize_by_buf_tick(function(bufnr)
|
||||
local lang = parsers.get_buf_lang(bufnr)
|
||||
local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr)
|
||||
local max_fold_level = api.nvim_win_get_option(0, 'foldnestmax')
|
||||
local parser = parsers.get_parser(bufnr)
|
||||
|
||||
local matches
|
||||
if query.has_folds(lang) then
|
||||
matches = query.get_capture_matches(bufnr, "@fold", "folds")
|
||||
elseif query.has_locals(lang) then
|
||||
matches = query.get_capture_matches(bufnr, "@scope", "locals")
|
||||
else
|
||||
return {}
|
||||
end
|
||||
if not parser then return {} end
|
||||
|
||||
local matches = query.get_capture_matches_recursively(bufnr, function(lang)
|
||||
if query.has_folds(lang) then
|
||||
return "@fold", "folds"
|
||||
elseif query.has_locals(lang) then
|
||||
return "@scope", "locals"
|
||||
end
|
||||
end)
|
||||
|
||||
local levels_tmp = {}
|
||||
|
||||
@@ -35,7 +36,6 @@ local folds_levels = utils.memoize_by_buf_tick(function(bufnr)
|
||||
levels_tmp[start] = (levels_tmp[start] or 0) + 1
|
||||
levels_tmp[stop] = (levels_tmp[stop] or 0) - 1
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
local levels = {}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
local queries = require'nvim-treesitter.query'
|
||||
local utils = require'nvim-treesitter.ts_utils'
|
||||
local tsutils = require'nvim-treesitter.ts_utils'
|
||||
|
||||
local M = {}
|
||||
|
||||
@@ -21,9 +21,9 @@ local function node_fmt(node)
|
||||
return tostring(node)
|
||||
end
|
||||
|
||||
local get_indents = utils.memoize_by_buf_tick(function(bufnr)
|
||||
local get_indents = tsutils.memoize_by_buf_tick(function(bufnr, root, lang)
|
||||
local get_map = function(capture)
|
||||
local matches = queries.get_capture_matches(bufnr, capture, 'indents') or {}
|
||||
local matches = queries.get_capture_matches(bufnr, capture, 'indents', root, lang) or {}
|
||||
local map = {}
|
||||
for _, node in ipairs(matches) do
|
||||
map[tostring(node)] = true
|
||||
@@ -37,14 +37,23 @@ local get_indents = utils.memoize_by_buf_tick(function(bufnr)
|
||||
returns = get_map('@return.node'),
|
||||
ignores = get_map('@ignore.node'),
|
||||
}
|
||||
end)
|
||||
end, {
|
||||
-- Memoize by bufnr and lang together.
|
||||
key = function(bufnr, _, lang)
|
||||
return tostring(bufnr) .. '_' .. lang
|
||||
end
|
||||
})
|
||||
|
||||
function M.get_indent(lnum)
|
||||
local parser = parsers.get_parser()
|
||||
if not parser or not lnum then return -1 end
|
||||
|
||||
local q = get_indents(vim.api.nvim_get_current_buf())
|
||||
local root = parser:parse()[1]:root()
|
||||
local root, _, lang_tree = tsutils.get_root_for_position(lnum, 0, parser)
|
||||
|
||||
-- Not likely, but just in case...
|
||||
if not root then return 0 end
|
||||
|
||||
local q = get_indents(vim.api.nvim_get_current_buf(), root, lang_tree:lang())
|
||||
local node = get_node_at_line(root, lnum-1)
|
||||
|
||||
local indent = 0
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
-- its the way nvim-treesitter uses to "understand" the code
|
||||
|
||||
local queries = require'nvim-treesitter.query'
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
local ts_utils = require'nvim-treesitter.ts_utils'
|
||||
local api = vim.api
|
||||
|
||||
@@ -91,13 +90,12 @@ end
|
||||
--- Iterates over a nodes scopes moving from the bottom up
|
||||
function M.iter_scope_tree(node, bufnr)
|
||||
local last_node = node
|
||||
|
||||
return function()
|
||||
if not last_node then
|
||||
return
|
||||
end
|
||||
|
||||
local scope = M.containing_scope(last_node, bufnr, false) or parsers.get_tree_root(bufnr)
|
||||
local scope = M.containing_scope(last_node, bufnr, false) or ts_utils.get_root_for_node(node)
|
||||
|
||||
last_node = scope:parent()
|
||||
|
||||
@@ -222,7 +220,7 @@ function M.find_definition(node, bufnr)
|
||||
end
|
||||
end
|
||||
|
||||
return node, parsers.get_tree_root(bufnr), nil
|
||||
return node, ts_utils.get_root_for_node(node), nil
|
||||
end
|
||||
|
||||
-- Finds usages of a node in a given scope.
|
||||
@@ -235,7 +233,7 @@ function M.find_usages(node, scope_node, bufnr)
|
||||
|
||||
if not node_text or #node_text < 1 then return {} end
|
||||
|
||||
local scope_node = scope_node or parsers.get_parser(bufnr):parse()[1]:root()
|
||||
local scope_node = scope_node or ts_utils.get_root_for_node(node)
|
||||
local usages = {}
|
||||
|
||||
for match in M.iter_locals(bufnr, scope_node) do
|
||||
|
||||
@@ -584,6 +584,8 @@ function M.get_parser(bufnr, lang)
|
||||
end
|
||||
end
|
||||
|
||||
-- @deprecated This is only kept for legacy purposes.
|
||||
-- All root nodes should be accounted for.
|
||||
function M.get_tree_root(bufnr)
|
||||
local bufnr = bufnr or api.nvim_get_current_buf()
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ local caching = require'nvim-treesitter.caching'
|
||||
|
||||
local M = {}
|
||||
|
||||
local EMPTY_ITER = function() end
|
||||
|
||||
M.built_in_query_groups = {'highlights', 'locals', 'folds', 'indents'}
|
||||
|
||||
-- Creates a function that checks whether a given query exists
|
||||
@@ -166,7 +168,7 @@ end
|
||||
--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
|
||||
-- Works like M.get_references or M.get_scopes except you can choose the capture
|
||||
-- Can also be a nested capture like @definition.function to get all nodes defining a function
|
||||
function M.get_capture_matches(bufnr, capture_string, query_group)
|
||||
function M.get_capture_matches(bufnr, capture_string, query_group, root, lang)
|
||||
if not string.sub(capture_string, 1,2) == '@' then
|
||||
print('capture_string must start with "@"')
|
||||
return
|
||||
@@ -176,7 +178,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
|
||||
capture_string = string.sub(capture_string, 2)
|
||||
|
||||
local matches = {}
|
||||
for match in M.iter_group_results(bufnr, query_group) do
|
||||
for match in M.iter_group_results(bufnr, query_group, root, lang) do
|
||||
local insert = utils.get_at_path(match, capture_string)
|
||||
|
||||
if insert then
|
||||
@@ -186,7 +188,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group)
|
||||
return matches
|
||||
end
|
||||
|
||||
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function)
|
||||
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root)
|
||||
if not string.sub(capture_string, 1,2) == '@' then
|
||||
api.nvim_err_writeln('capture_string must start with "@"')
|
||||
return
|
||||
@@ -198,7 +200,7 @@ function M.find_best_match(bufnr, capture_string, query_group, filter_predicate,
|
||||
local best
|
||||
local best_score
|
||||
|
||||
for maybe_match in M.iter_group_results(bufnr, query_group) do
|
||||
for maybe_match in M.iter_group_results(bufnr, query_group, root) do
|
||||
local match = utils.get_at_path(maybe_match, capture_string)
|
||||
|
||||
if match and filter_predicate(match) then
|
||||
@@ -220,31 +222,82 @@ end
|
||||
-- @param bufnr the buffer
|
||||
-- @param query_group the query file to use
|
||||
-- @param root the root node
|
||||
function M.iter_group_results(bufnr, query_group, root)
|
||||
local lang = parsers.get_buf_lang(bufnr)
|
||||
if not lang then return function() end end
|
||||
-- @param root the root node lang, if known
|
||||
function M.iter_group_results(bufnr, query_group, root, root_lang)
|
||||
local buf_lang = parsers.get_buf_lang(bufnr)
|
||||
|
||||
local query = M.get_query(lang, query_group)
|
||||
if not query then return function() end end
|
||||
if not buf_lang then return EMPTY_ITER end
|
||||
|
||||
local parser = parsers.get_parser(bufnr, lang)
|
||||
if not parser then return function() end end
|
||||
local parser = parsers.get_parser(bufnr, buf_lang)
|
||||
if not parser then return EMPTY_ITER end
|
||||
|
||||
local root = root or parser:parse()[1]:root()
|
||||
local start_row, _, end_row, _ = root:range()
|
||||
if not root then
|
||||
local first_tree = parser:trees()[1]
|
||||
|
||||
if first_tree then
|
||||
root = first_tree:root()
|
||||
end
|
||||
end
|
||||
|
||||
if not root then return EMPTY_ITER end
|
||||
|
||||
local range = {root:range()}
|
||||
|
||||
if not root_lang then
|
||||
local lang_tree = parser:language_for_range(range)
|
||||
|
||||
if lang_tree then
|
||||
root_lang = lang_tree:lang()
|
||||
end
|
||||
end
|
||||
|
||||
if not root_lang then return EMPTY_ITER end
|
||||
|
||||
local query = M.get_query(root_lang, query_group)
|
||||
if not query then return EMPTY_ITER end
|
||||
|
||||
-- The end row is exclusive so we need to add 1 to it.
|
||||
return M.iter_prepared_matches(query, root, bufnr, start_row, end_row + 1)
|
||||
return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1)
|
||||
end
|
||||
|
||||
function M.collect_group_results(bufnr, query_group, root)
|
||||
function M.collect_group_results(bufnr, query_group, root, lang)
|
||||
local matches = {}
|
||||
|
||||
for prepared_match in M.iter_group_results(bufnr, query_group, root) do
|
||||
for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do
|
||||
table.insert(matches, prepared_match)
|
||||
end
|
||||
|
||||
return matches
|
||||
end
|
||||
|
||||
--- Same as get_capture_matches except this will recursively get matches for every language in the tree.
|
||||
-- @param bufnr The bufnr
|
||||
-- @param capture_or_fn The capture to get. If a function is provided then that
|
||||
-- function will be used to resolve both the capture and query argument.
|
||||
-- The function can return `nil` to ignore that tree.
|
||||
-- @param query_type The query to get the capture from. This is ignore if a function is provided
|
||||
-- for the captuer argument.
|
||||
function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type)
|
||||
local type_fn = type(capture_or_fn) == 'function'
|
||||
and capture_or_fn
|
||||
or function()
|
||||
return capture_or_fn, query_type
|
||||
end
|
||||
local parser = parsers.get_parser(bufnr)
|
||||
local matches = {}
|
||||
|
||||
if parser then
|
||||
parser:for_each_tree(function(tree, lang_tree)
|
||||
local lang = lang_tree:lang()
|
||||
local capture, type_ = type_fn(lang, tree, lang_tree)
|
||||
|
||||
if capture then
|
||||
vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang))
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
return matches
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -114,10 +114,46 @@ function M.get_named_children(node)
|
||||
end
|
||||
|
||||
function M.get_node_at_cursor(winnr)
|
||||
if not parsers.has_parser() then return end
|
||||
local cursor = api.nvim_win_get_cursor(winnr or 0)
|
||||
local root = parsers.get_parser():parse()[1]:root()
|
||||
return root:named_descendant_for_range(cursor[1]-1,cursor[2],cursor[1]-1,cursor[2])
|
||||
local cursor_range = { cursor[1] - 1, cursor[2] }
|
||||
local root = M.get_root_for_position(unpack(cursor_range))
|
||||
|
||||
if not root then return end
|
||||
|
||||
return root:named_descendant_for_range(cursor_range[1], cursor_range[2], cursor_range[1], cursor_range[2])
|
||||
end
|
||||
|
||||
function M.get_root_for_position(line, col, root_lang_tree)
|
||||
if not root_lang_tree then
|
||||
if not parsers.has_parser() then return end
|
||||
|
||||
root_lang_tree = parsers.get_parser()
|
||||
end
|
||||
|
||||
local lang_tree = root_lang_tree:language_for_range({ line, col, line, col })
|
||||
|
||||
for _, tree in ipairs(lang_tree:trees()) do
|
||||
local root = tree:root()
|
||||
|
||||
if root and M.is_in_node_range(root, line, col) then
|
||||
return root, tree, lang_tree
|
||||
end
|
||||
end
|
||||
|
||||
-- This isn't a likely scenario, since the position must belong to a tree somewhere.
|
||||
return nil, nil, lang_tree
|
||||
end
|
||||
|
||||
function M.get_root_for_node(node)
|
||||
local parent = node
|
||||
local result = node
|
||||
|
||||
while parent ~= nil do
|
||||
result = parent
|
||||
parent = result:parent()
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
function M.highlight_node(node, buf, hl_namespace, hl_group)
|
||||
@@ -213,25 +249,44 @@ end
|
||||
--- Memoizes a function based on the buffer tick of the provided bufnr.
|
||||
-- The cache entry is cleared when the buffer is detached to avoid memory leaks.
|
||||
-- @param fn: the fn to memoize, taking the bufnr as first argument
|
||||
-- @param options:
|
||||
-- - bufnr: extracts a bufnr from the given arguments.
|
||||
-- - key: extracts the cache key from the given arguments.
|
||||
-- @returns a memoized function
|
||||
function M.memoize_by_buf_tick(fn)
|
||||
local cache = {}
|
||||
function M.memoize_by_buf_tick(fn, options)
|
||||
options = options or {}
|
||||
|
||||
return function(bufnr)
|
||||
if cache[bufnr] then
|
||||
return cache[bufnr]
|
||||
local cache = {}
|
||||
local bufnr_fn = utils.to_func(options.bufnr or utils.identity)
|
||||
local key_fn = utils.to_func(options.key or utils.identity)
|
||||
|
||||
return function(...)
|
||||
local bufnr = bufnr_fn(...)
|
||||
local key = key_fn(...)
|
||||
local tick = api.nvim_buf_get_changedtick(bufnr)
|
||||
|
||||
if cache[key] then
|
||||
if cache[key].last_tick == tick then
|
||||
return cache[key].result
|
||||
end
|
||||
else
|
||||
cache[bufnr] = {}
|
||||
api.nvim_buf_attach(bufnr, false,
|
||||
{
|
||||
on_changedtick = function() cache[bufnr] = fn(bufnr) end,
|
||||
on_detach = function() cache[bufnr] = nil end
|
||||
}
|
||||
)
|
||||
local function detach_handler()
|
||||
cache[key] = nil
|
||||
end
|
||||
|
||||
-- Clean up logic only!
|
||||
api.nvim_buf_attach(bufnr, false, {
|
||||
on_detach = detach_handler,
|
||||
on_reload = detach_handler
|
||||
})
|
||||
end
|
||||
|
||||
cache[bufnr] = fn(bufnr)
|
||||
return cache[bufnr]
|
||||
cache[key] = {
|
||||
result = fn(...),
|
||||
last_tick = tick
|
||||
}
|
||||
|
||||
return cache[key].result
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ TSRange.__index = TSRange
|
||||
|
||||
local api = vim.api
|
||||
local parsers = require'nvim-treesitter.parsers'
|
||||
local ts_utils = require'nvim-treesitter.ts_utils'
|
||||
|
||||
local function get_byte_offset(buf, row, col)
|
||||
return api.nvim_buf_get_offset(buf, row)
|
||||
@@ -57,8 +58,11 @@ end
|
||||
|
||||
function TSRange:parent(range)
|
||||
local parser = parsers.get_parser(self.buf, parsers.get_buf_lang(range))
|
||||
local root = parser:parse()[1]:root()
|
||||
return root:named_descendant_for_range(self.start_pos[1], self.start_pos[2], self.end_pos[1], self.end_pos[2])
|
||||
local root = ts_utils.get_root_for_position(range[1], range[2], parser)
|
||||
|
||||
return root
|
||||
and root:named_descendant_for_range(self.start_pos[1], self.start_pos[2], self.end_pos[1], self.end_pos[2])
|
||||
or nil
|
||||
end
|
||||
|
||||
function TSRange:field()
|
||||
|
||||
@@ -166,4 +166,16 @@ function M.difference(tbl1, tbl2)
|
||||
end)
|
||||
end
|
||||
|
||||
function M.identity(a)
|
||||
return a
|
||||
end
|
||||
|
||||
function M.constant(a)
|
||||
return function() return a end
|
||||
end
|
||||
|
||||
function M.to_func(a)
|
||||
return type(a) == 'function' and a or M.constant(a)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
Reference in New Issue
Block a user