From db9fe5d95dfcf292268f24d56e5e05db20b902e5 Mon Sep 17 00:00:00 2001 From: TheLeoP Date: Thu, 29 May 2025 13:31:50 -0500 Subject: [PATCH] refactor: mostly remove nvim-treesitter dependency --- .gitignore | 1 + after/queries/c/matchup.scm | 3 +- after/queries/ecma/matchup.scm | 3 +- after/queries/elm/matchup.scm | 5 +- after/queries/go/matchup.scm | 3 +- after/queries/rust/matchup.scm | 3 +- after/queries/zig/matchup.scm | 3 +- autoload/matchup/loader.vim | 15 - autoload/matchup/ts_engine.vim | 7 - lua/match-up.lua | 6 +- lua/treesitter-matchup.lua | 9 +- lua/treesitter-matchup/internal.lua | 402 ++++++++++++------ lua/treesitter-matchup/syntax.lua | 56 +-- lua/treesitter-matchup/third-party/query.lua | 394 ----------------- .../third-party/ts-utils.lua | 28 ++ 15 files changed, 330 insertions(+), 608 deletions(-) delete mode 100644 lua/treesitter-matchup/third-party/query.lua create mode 100644 lua/treesitter-matchup/third-party/ts-utils.lua diff --git a/.gitignore b/.gitignore index 926ccaa..12604df 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ doc/tags +.nvim.lua diff --git a/after/queries/c/matchup.scm b/after/queries/c/matchup.scm index 48caa18..74668cd 100644 --- a/after/queries/c/matchup.scm +++ b/after/queries/c/matchup.scm @@ -19,8 +19,7 @@ ; 'else' and 'else if' (else_clause - "else" @_start (if_statement "if" @_end)? - (#make-range! "mid.if.1" @_start @_end)) + "else" @mid.if.1 (if_statement "if" @mid.if.1)?) ; if ((if_statement diff --git a/after/queries/ecma/matchup.scm b/after/queries/ecma/matchup.scm index 9b9f1da..3259186 100644 --- a/after/queries/ecma/matchup.scm +++ b/after/queries/ecma/matchup.scm @@ -17,8 +17,7 @@ ; 'else' and 'else if' (else_clause - "else" @_start (if_statement "if" @_end)? - (#make-range! "mid.if.1" @_start @_end)) + "else" @mid.if.1 (if_statement "if" @mid.if.1)?) ; if ((if_statement diff --git a/after/queries/elm/matchup.scm b/after/queries/elm/matchup.scm index 0c98830..1411f9e 100644 --- a/after/queries/elm/matchup.scm +++ b/after/queries/elm/matchup.scm @@ -2,9 +2,8 @@ . "if" @open.if) @scope.if (if_else_expr - "else" @_else - "if"? @_if - (#make-range! "mid.if.1" @_else @_if)) + "else" @mid.if.1 + "if"? @mid.if.1) (let_in_expr "let" @open.let diff --git a/after/queries/go/matchup.scm b/after/queries/go/matchup.scm index ed8415f..9058c18 100644 --- a/after/queries/go/matchup.scm +++ b/after/queries/go/matchup.scm @@ -6,8 +6,7 @@ ; 'else' and 'else if' (if_statement - "else" @_start (if_statement "if" @_end)? - (#make-range! "mid.if.1" @_start @_end)) + "else" @mid.if.1 (if_statement "if" @mid.if.1)?) ; if (block (if_statement "if" @open.if) @scope.if) diff --git a/after/queries/rust/matchup.scm b/after/queries/rust/matchup.scm index ebd01b4..a77ad56 100644 --- a/after/queries/rust/matchup.scm +++ b/after/queries/rust/matchup.scm @@ -16,8 +16,7 @@ (else_clause "else" @mid.if_.1 (block)) (else_clause - "else" @_start (if_expression "if" @_end) - (#make-range! "mid.if_.2" @_start @_end)) + "else" @mid.if_.2 (if_expression "if" @mid.if_.2)) ; --------------- async/await --------------- (function_item (function_modifiers "async" @open.async)) @scope.async diff --git a/after/queries/zig/matchup.scm b/after/queries/zig/matchup.scm index 5ec4bfb..5e1edd7 100644 --- a/after/queries/zig/matchup.scm +++ b/after/queries/zig/matchup.scm @@ -7,8 +7,7 @@ ; 'else' and 'else if' (else_clause - "else" @_start (if_statement "if" @_end)? - (#make-range! "mid.if.1" @_start @_end)) + "else" @mid.if.1 (if_statement "if" @mid.if.1)?) ; if ((if_statement diff --git a/autoload/matchup/loader.vim b/autoload/matchup/loader.vim index e78243b..61f99c7 100644 --- a/autoload/matchup/loader.vim +++ b/autoload/matchup/loader.vim @@ -33,21 +33,6 @@ function! matchup#loader#init_buffer() abort " {{{1 endif let l:has_ts_hl = 0 - if s:ts_may_be_supported && matchup#ts_engine#is_hl_enabled(bufnr('%')) - let l:has_ts_hl = 1 - - if matchup#ts_engine#get_option( - \ bufnr('%'), 'additional_vim_regex_highlighting') - if empty(&syntax) - set syntax=ON - else - augroup matchup_syntax - au! - autocmd VimEnter * if empty(&syntax) | set syntax=ON | endif - augroup END - endif - endif - endif " initialize lists of delimiter pairs and regular expressions " this is the data obtained from parsing b:match_words diff --git a/autoload/matchup/ts_engine.vim b/autoload/matchup/ts_engine.vim index 96f788f..4bea3b6 100644 --- a/autoload/matchup/ts_engine.vim +++ b/autoload/matchup/ts_engine.vim @@ -21,13 +21,6 @@ function! matchup#ts_engine#is_enabled(bufnr) abort return +s:forward('is_enabled', a:bufnr) endfunction -function! matchup#ts_engine#is_hl_enabled(bufnr) abort - if !has('nvim-0.5.0') - return 0 - endif - return +s:forward('is_hl_enabled', a:bufnr) -endfunction - function! matchup#ts_engine#get_option(bufnr, opt_name) abort return s:forward('get_option', a:bufnr, a:opt_name) endfunction diff --git a/lua/match-up.lua b/lua/match-up.lua index 4378b04..e394990 100644 --- a/lua/match-up.lua +++ b/lua/match-up.lua @@ -66,11 +66,15 @@ local M = {} ---@class matchup.TransmuteConfig ---@field enabled 0|1 +-- TODO: remove vim syntax related g: vars +-- TODO: modify vimscript to work without nvim-treesitter -- TODO: add documentation for g: vars --- TODO: add defautls for this g: vars? ---@class matchup.TreesitterConfig ---@field enabled boolean ---@field disabled string[] +---@field include_match_words boolean +---@field disable_virtual_text boolean +---@field enable_quotes boolean ---@class matchup.Config ---@field delim matchup.DelimConfig diff --git a/lua/treesitter-matchup.lua b/lua/treesitter-matchup.lua index cec414f..27dfcb2 100644 --- a/lua/treesitter-matchup.lua +++ b/lua/treesitter-matchup.lua @@ -1,9 +1,4 @@ -if not pcall(require, 'nvim-treesitter') then - return {init = function() end} -end - -local treesitter = require 'nvim-treesitter' -local queries = require 'nvim-treesitter.query' +-- TODO: remove module structure local M = {} @@ -12,7 +7,7 @@ function M.init() matchup = { module_path = 'treesitter-matchup.internal', is_supported = function(lang) - return queries.has_query_files(lang, 'matchup') + return vim.treesitter.query.get(lang, 'matchup') ~= nil end } } diff --git a/lua/treesitter-matchup/internal.lua b/lua/treesitter-matchup/internal.lua index 1e12e56..6305ba5 100644 --- a/lua/treesitter-matchup/internal.lua +++ b/lua/treesitter-matchup/internal.lua @@ -1,18 +1,16 @@ -if not pcall(require, 'nvim-treesitter') then - return {is_enabled = function(bufnr) return 0 end, - is_hl_enabled = function(bufnr) return 0 end} -end - local vim = vim local api = vim.api -local ts = require'treesitter-matchup.compat' -local configs = require'nvim-treesitter.configs' -local parsers = require'nvim-treesitter.parsers' -local queries = require'treesitter-matchup.third-party.query' -local ts_utils = require'nvim-treesitter.ts_utils' +local ts = vim.treesitter +local memoize = require'treesitter-matchup.third-party.ts-utils'.memoize + +vim.g.matchup_treesitter_enabled = false +vim.g.matchup_treesitter_disabled = {} +vim.g.matchup_treesitter_include_match_words = false +vim.g.matchup_treesitter_enable_quotes = true + +-- TODO: update this dependencies local lru = require'treesitter-matchup.third-party.lru' local util = require'treesitter-matchup.util' -local utils2 = require'treesitter-matchup.third-party.utils' local unpack = unpack or table.unpack @@ -20,69 +18,179 @@ local M = {} local cache = lru.new(150) + +---@param lang string +---@param bufnr integer +local function is_enabled(lang, bufnr) + local enabled = vim.g.matchup_treesitter_enabled == 1 + local buf_enabled = vim.b[bufnr].matchup_treesitter_enabled == 1 + local lang_disabled = vim.list_contains(vim.g.matchup_treesitter_disabled, lang) + + return enabled and buf_enabled and not lang_disabled +end +-- TODO: this is following the old module structure of nvim-treesitter. Change it + +---@param bufnr integer? +---@return boolean function M.is_enabled(bufnr) bufnr = bufnr or api.nvim_get_current_buf() - local lang = parsers.get_buf_lang(bufnr) - return configs.is_enabled('matchup', lang, bufnr) + local lang = ts.language.get_lang(vim.bo[bufnr].filetype) + if not lang then + return false + end + assert(lang) + return is_enabled(lang, bufnr) end -function M.is_hl_enabled(bufnr) - bufnr = bufnr or api.nvim_get_current_buf() - local lang = parsers.get_buf_lang(bufnr) - return configs.is_enabled('highlight', lang, bufnr) +-- TODO: I had to remove the `is_hl_enabled` function and the related logic. On +-- the `main` branch of nvim-treesitter it's not possible to tell wether or not +-- the hl is enabled for a given buffer and there is no +-- `additional_vim_regex_highlighting` option anymore. Now, users will have to +-- enable syntax themselves after doing `vim.treesitter.start()`. Mention this +-- as a possible workaround and possible regression in the PR. +-- +-- Technically, the undocumented `vim.treesitter.highlighter` table can be +-- accessed. But, should we rely in undocumented features? + +---@param bufnr integer +---@param root TSNode +---@param lang string +---@return string +local function buf_root_lang_hash(bufnr, root, lang) + return tostring(bufnr) .. root:id() .. '_' .. lang end -M.get_matches = ts_utils.memoize_by_buf_tick(function(bufnr) - local parser = parsers.get_parser(bufnr) - local matches = {} +---@class matchup.treesitter.MatchInfo +---@field range Range4 +---@field length integer +---@field last_node TSNode +---@field text string + +---@class matchup.treesitter.MatchInfoWrapper +---@field info matchup.treesitter.MatchInfo + +---@class matchup.treesitter.Match +---@field scope? table +---@field open? table +---@field mid? table> +---@field close? table +---@field skip? matchup.treesitter.MatchInfoWrapper + +---@param bufnr integer +---@param root TSNode +---@param lang string +---@return matchup.treesitter.Match[] +local get_memoized_matches = memoize(function(bufnr, root, lang) + local query_name = 'matchup' + local query = ts.query.get(lang, query_name) + + if not query then + return {} + end + + local out = {} ---@type matchup.treesitter.Match[] + for _, match, metadata in query:iter_matches(root, bufnr) do + local match_info = {} + for id, nodes in pairs(match) do + local first = nodes[1] + local last = nodes[#nodes] + + ---@type integer, integer, integer + local start_row, start_col , start_byte = unpack(ts.get_range(first, bufnr, metadata)) + ---@type integer, integer, integer, integer, integer, integer + local _, _, _, end_row, end_col , end_byte = unpack(ts.get_range(last, bufnr, metadata)) + local range = { start_row, start_col, end_row, end_col } + local length = end_byte - start_byte + + if end_col == 0 then + if start_row == end_row then + start_col = -1 + start_row = start_row - 1 + end + end_col = -1 + end_row = end_row - 1 + end + local lines = api.nvim_buf_get_text(bufnr, start_row, end_row, start_col, end_col, {}) + local text = table.concat(lines, '\n') + + local name = query.captures[id] + local path = vim.split(name, '.', { plain = true }) + + local current = match_info ---@type table> + for _, segment in ipairs(path) do + current[segment] = current[segment] or {} + current = current[segment] + end + current.info = { + range = range, + length = length, + last_node = last, + text = text, + } + end + table.insert(out, match_info) + end + + return out +end, buf_root_lang_hash) + +---@param bufnr integer +---@return matchup.treesitter.Match[] +M.get_matches = function(bufnr) + local parser = ts.get_parser(bufnr) + local matches = {} ---@type matchup.treesitter.Match[] if parser then + -- TODO: g:matchup_delim_stopline could be used, but this functions needs to + -- know on which window it should look for in order to get the current + -- cursor position of that window + parser:parse(nil) parser:for_each_tree(function(tree, lang_tree) if not tree or lang_tree:lang() == 'comment' then return end local lang = lang_tree:lang() - local group_results = queries.collect_group_results( - bufnr, 'matchup', tree:root(), lang) or {} + local group_results = get_memoized_matches(bufnr, tree:root(), lang) vim.list_extend(matches, group_results) end) end return matches -end) +end local function _time() - local s, u = vim.loop.gettimeofday() + local s, u = vim.uv.gettimeofday() return s * 1000 + u * 1e-3 end ---- Returns a (mostly) unique id for this node --- Also supports nvim-treesitter's range object -local function _node_id(node) - if not node then - return nil - end - if node:type() == 'nvim-treesitter-range' then - return string.format('range_%d_%d_%d_%d', node:range()) - end - return node:id() +--- Returns a (mostly) unique id for this range +---@param range Range4 +---@return string +function M.range_id(range) + return ('range_%d_%d_%d_%d'):format(unpack(range)) end +-- TODO: mention this in the PR. this is not memoized because: +-- - get_matches is already memoized +-- - this function does not have access to the treesitter root and memoizing by +-- buf_tick is unreliable (buf_tick may be out-of-sync with treesitter changes +-- because of undo, for example) +-- --- Get all nodes belonging to defined scopes (organized by key) -M.get_scopes = ts_utils.memoize_by_buf_tick(function(bufnr) +---@param bufnr integer +---@return table> +M.get_scopes = function(bufnr) local matches = M.get_matches(bufnr) - local scopes = {} + local scopes = {} ---@type table> for _, match in ipairs(matches) do if match.scope then for key, scope in pairs(match.scope) do - local id = _node_id(scope.node) - if scope.node then - if not scopes[key] then - scopes[key] = {} - end + if scope.info then + local id = M.range_id(scope.info.range) + scopes[key] = scopes[key] or {} scopes[key][id] = true end end @@ -90,49 +198,52 @@ M.get_scopes = ts_utils.memoize_by_buf_tick(function(bufnr) end return scopes -end) +end -M.get_active_nodes = ts_utils.memoize_by_buf_tick(function(bufnr) - -- TODO: why do we need to force a parse? - if not pcall(function() parsers.get_parser():parse() end) then - -- TODO workaround a crash due to tree-sitter parsing - return {{ open={}, mid={}, close={} }, {}} - end +---@class matchup.treesitter.Matches +---@field open matchup.treesitter.MatchInfo[] +---@field mid matchup.treesitter.MatchInfo[] +---@field close matchup.treesitter.MatchInfo[] +---@param bufnr integer +---@return [matchup.treesitter.Matches, table] +M.get_active_matches = function(bufnr) local matches = M.get_matches(bufnr) - local nodes = { open = {}, mid = {}, close = {} } + ---@type matchup.treesitter.Matches + local info = { open = {}, mid = {}, close = {} } + ---@type table local symbols = {} + local enable_quotes = vim.g.matchup_treesitter_enable_quotes for _, match in ipairs(matches) do if match.open then for key, open in pairs(match.open) do - local reject = key:find('quote') - and not M.get_option(bufnr, 'enable_quotes') - local id = _node_id(open.node) - if not reject and open.node and symbols[id] == nil then - table.insert(nodes.open, open.node) + local reject = key:find('quote') and not enable_quotes + local id = M.range_id(open.info.range) + if not reject and open.info and symbols[id] == nil then + table.insert(info.open, open.info) symbols[id] = key end end end if match.close then for key, close in pairs(match.close) do - local reject = key:find('quote') - and not M.get_option(bufnr, 'enable_quotes') - local id = _node_id(close.node) - if not reject and close.node and symbols[id] == nil then - table.insert(nodes.close, close.node) + local reject = key:find('quote') and not enable_quotes + local id = M.range_id(close.info.range) + if not reject and close.info and symbols[id] == nil then + table.insert(info.close, close.info) symbols[id] = key end end end if match.mid then for key, mid_group in pairs(match.mid) do + -- TODO: mid type is wrong, fix everywhere for _, mid in pairs(mid_group) do - local id = _node_id(mid.node) - if mid.node and symbols[id] == nil then - table.insert(nodes.mid, mid.node) + local id = M.range_id(mid.info.range) + if mid.info and symbols[id] == nil then + table.insert(info.mid, mid.info) symbols[id] = key end end @@ -140,19 +251,25 @@ M.get_active_nodes = ts_utils.memoize_by_buf_tick(function(bufnr) end end - return {nodes, symbols} -end) + return {info, symbols} +end -function M.containing_scope(node, bufnr, key) +---@param info matchup.treesitter.MatchInfo? +---@param bufnr integer? +---@param key string +---@return TSNode|nil +function M.containing_scope(info, bufnr, key) bufnr = bufnr or api.nvim_get_current_buf() local scopes = M.get_scopes(bufnr) - if not node or not scopes or not scopes[key] then return end + if not info or not scopes or not scopes[key] then return end - local iter_node = node + ---@type TSNode|nil + local iter_node = info.last_node while iter_node ~= nil do - if scopes[key][_node_id(iter_node)] then + ---@diagnostic disable-next-line: missing-fields LuaLS bug + if scopes[key][M.range_id({iter_node:range()})] then return iter_node end iter_node = iter_node:parent() @@ -161,27 +278,35 @@ function M.containing_scope(node, bufnr, key) return nil end -local function _node_text(node, bufnr) - local text = ts.get_node_text(node, bufnr) +---@param info matchup.treesitter.MatchInfo +---@return string +local function text_until_newline(info) + local text = info.text return text:match("([^\n]+).*") end --- Fill in a match result based on a seed node -function M.do_node_result(initial_node, bufnr, opts, side, key) +---@param info matchup.treesitter.MatchInfo +---@param bufnr integer +---@param opts table +---@param side matchup.Side? +---@param key string? +function M.do_match_result(info, bufnr, opts, side, key) if not side or not key then return nil end - local scope = M.containing_scope(initial_node, bufnr, key) + local scope = M.containing_scope(info, bufnr, key) if not scope then return nil end - local row, col, _ = initial_node:start() + ---@type integer, integer + local row, col = unpack(info.range) local result = { type = 'delim_py', - match = _node_text(initial_node, bufnr), + match = text_until_newline(info), side = side, lnum = row + 1, cnum = col + 1, @@ -191,9 +316,9 @@ function M.do_node_result(initial_node, bufnr, opts, side, key) _id = util.uuid4(), } - local info = { + local cached_info = { bufnr = bufnr, - initial_node = initial_node, + info = info, row = row, col = col, key = key, @@ -201,11 +326,36 @@ function M.do_node_result(initial_node, bufnr, opts, side, key) search_range = {scope:range()}, } - cache:set(result._id, info) + cache:set(result._id, cached_info) return result end +---@param info matchup.treesitter.MatchInfo +---@param line integer +---@param col integer +---@return boolean +local function is_in_range(info, line, col) + ---@type integer, integer, integer, integer + local r_start_row, r_start_col, r_end_row, r_end_col = unpack(info.range) + local p_start_row, p_start_col, p_end_row, p_end_col = line, col, line, col + 1 + + if p_start_row < r_start_row then + return false + elseif p_start_row == r_start_row and p_start_col < r_start_col then + return false + end + + if p_end_row > r_end_row then + return false + elseif p_end_row == r_end_row and p_end_col > r_end_col then + return false + end + + return true +end + +---@type table local side_table = { open = {'open'}, mid = {'mid'}, @@ -215,25 +365,32 @@ local side_table = { open_mid = {'mid', 'open'}, } +---@alias matchup.Side 'open'|'mid'|'close'|'both'|'both_all'|'open_mid' +---@alias matchup.Direction 'current'|'next'|'prev' +---@alias matchup.Type 'delim_text'|'delim_all'|'all' + +---@param bufnr integer +---@param opts {direction: matchup.Direction, side: matchup.Side, type: matchup.Type} function M.get_delim(bufnr, opts) if opts.direction == 'current' then -- get current by query - local active_nodes, symbols = unpack(M.get_active_nodes(bufnr)) + local active_matches, symbols = unpack(M.get_active_matches(bufnr)) local cursor = api.nvim_win_get_cursor(0) local smallest_len = 1e31 + ---@type {info: matchup.treesitter.MatchInfo, side: matchup.Side, key: string}|nil local result_info = nil for _, side in ipairs(side_table[opts.side]) do if not(side == 'mid' and vim.g.matchup_delim_nomids > 0) then - for _, node in ipairs(active_nodes[side]) do - if utils2.is_in_node_range(node, cursor[1]-1, cursor[2]) then - local len = ts_utils.node_length(node) + for _, info in ipairs(active_matches[side] --[=[@as matchup.treesitter.MatchInfo[]]=]) do + if is_in_range(info, cursor[1] - 1, cursor[2]) then + local len = info.length if len < smallest_len then smallest_len = len result_info = { - node = node, + info = info, side = side, - key = symbols[_node_id(node)] + key = symbols[M.range_id(info.range)] } end end @@ -242,7 +399,7 @@ function M.get_delim(bufnr, opts) end if result_info then - return M.do_node_result(result_info.node, bufnr, opts, + return M.do_match_result(result_info.info, bufnr, opts, result_info.side, result_info.key) end @@ -253,16 +410,17 @@ function M.get_delim(bufnr, opts) -- look forwards or backwards for an active node local max_col = 1e5 - local active_nodes, symbols = unpack(M.get_active_nodes(bufnr)) + local active_matches, symbols = unpack(M.get_active_matches(bufnr)) local cursor = api.nvim_win_get_cursor(0) local cur_pos = max_col * (cursor[1]-1) + cursor[2] - local closest_node, closest_dist = nil, 1e31 + local closest_match, closest_dist = nil, 1e31 local result_info = {} for _, side in ipairs(side_table[opts.side]) do - for _, node in ipairs(active_nodes[side]) do - local row, col, _ = node:start() + for _, info in ipairs(active_matches[side]--[=[@as matchup.treesitter.MatchInfo[]]=]) do + ---@type integer, integer + local row, col = unpack(info.range) local pos = max_col * row + col if opts.direction == 'next' and pos >= cur_pos @@ -271,61 +429,62 @@ function M.get_delim(bufnr, opts) local dist = math.abs(pos - cur_pos) if dist < closest_dist then closest_dist = dist - closest_node = node - result_info = { side=side, key=symbols[_node_id(node)] } + closest_match = info + result_info = { side=side, key=symbols[M.range_id(info.range)] } end end end end - if closest_node == nil then + if closest_match == nil then return nil end - return M.do_node_result(closest_node, bufnr, opts, + return M.do_match_result(closest_match, bufnr, opts, result_info.side, result_info.key) end function M.get_matching(delim, down, bufnr) down = down > 0 - local info = cache:get(delim._id) or {} - if info.bufnr ~= bufnr then + local cached_info = cache:get(delim._id) or {} + if cached_info.bufnr ~= bufnr then return {} end - local matches = {} + local matches = {} ---@type [string, integer, integer][] - local sides + local sides ---@type ('open'|'mid'|'close')[] if vim.g.matchup_delim_nomids > 0 then sides = down and {'close'} or {'open'} else sides = down and {'mid', 'close'} or {'mid', 'open'} end - local active_nodes, symbols = unpack(M.get_active_nodes(bufnr)) + local active_matches, symbols = unpack(M.get_active_matches(bufnr)) local got_close = false - local stop_time = _time() + vim.fn['matchup#perf#timeout']() + local stop_time = _time() + vim.fn['matchup#perf#timeout']() ---@type number for _, side in ipairs(sides) do - for _, node in ipairs(active_nodes[side]) do - local row, col, _ = node:start() + for _, info in ipairs(active_matches[side]--[=[@as matchup.treesitter.MatchInfo[]]=]) do + ---@type integer, integer + local row, col = unpack(info.range) if _time() > stop_time then return {} end - if info.initial_node ~= node and symbols[_node_id(node)] == info.key - and (down and (row > info.row or row == info.row and col > info.col) - or not down and (row < info.row or row == info.row and col < info.col)) - and (row >= info.search_range[1] - and row <= info.search_range[3]) then + if cached_info.info ~= info and symbols[M.range_id(info.range)] == cached_info.key + and (down and (row > cached_info.row or row == cached_info.row and col > cached_info.col) + or not down and (row < cached_info.row or row == cached_info.row and col < cached_info.col)) + and (row >= cached_info.search_range[1] + and row <= cached_info.search_range[3]) then - local target_scope = M.containing_scope(node, bufnr, info.key) - if info.scope == target_scope then - local text = _node_text(node, bufnr) or '' + local target_scope = M.containing_scope(info, bufnr, cached_info.key) + if cached_info.scope == target_scope then + local text = text_until_newline(info) or '' table.insert(matches, {text, row + 1, col + 1}) if side == 'close' then @@ -343,39 +502,14 @@ function M.get_matching(delim, down, bufnr) -- no stop marker is found, use enclosing scope if down and not got_close then - local row, col, _ = info.scope:end_() + local row, col, _ = cached_info.scope:end_() table.insert(matches, {'', row + 1, col + 1}) end return matches end -local function opt_tbl_for_lang(opt, lang) - local is_table = type(opt) == "table" - if opt and (not is_table or vim.tbl_contains(opt, lang)) then - return true - end - return false -end - -function M.get_option(bufnr, opt_name) - local config = configs.get_module('matchup') or {} - local lang = parsers.get_buf_lang(bufnr) - if (opt_name == 'include_match_words' - or opt_name == 'additional_vim_regex_highlighting' - or opt_name == 'disable_virtual_text' - or opt_name == 'enable_quotes') then - return opt_tbl_for_lang(config[opt_name], lang) - end - error('invalid option ' .. opt_name) -end - function M.attach(bufnr, lang) - if M.get_option(bufnr, 'additional_vim_regex_highlighting') - and api.nvim_buf_get_option(bufnr, 'syntax') == '' then - api.nvim_buf_set_option(bufnr, 'syntax', 'ON') - end - api.nvim_call_function('matchup#ts_engine#attach', {bufnr, lang}) end diff --git a/lua/treesitter-matchup/syntax.lua b/lua/treesitter-matchup/syntax.lua index 1f2f4a7..23b343f 100644 --- a/lua/treesitter-matchup/syntax.lua +++ b/lua/treesitter-matchup/syntax.lua @@ -1,60 +1,37 @@ -if not pcall(require, 'nvim-treesitter') then - return { - is_active = function() return false end, - synID = function(lnum, col, transparent) - return vim.fn.synID(lnum, col, transparent) - end - } -end - local api = vim.api +local vts = vim.treesitter local hl_info = require'treesitter-matchup.third-party.hl-info' -local queries = require'treesitter-matchup.third-party.query' -local ts_utils = require'nvim-treesitter.ts_utils' -local parsers = require'nvim-treesitter.parsers' +local internal = require'treesitter-matchup.internal' local M = {} +---@param bufnr integer? +---@return boolean function M.is_active(bufnr) bufnr = bufnr or api.nvim_get_current_buf() return (hl_info.active() - and api.nvim_buf_get_option(bufnr, 'syntax') == '') + and vim.bo[bufnr].syntax == '') end --- Get all nodes that are marked as skip +---@param bufnr integer function M.get_skips(bufnr) - local matches = queries.get_matches(bufnr, 'matchup') + local matches = internal.get_matches(bufnr) - local skips = {} + local skips = {} ---@type table for _, match in ipairs(matches) do if match.skip then - skips[match.skip.node:id()] = 1 + skips[internal.range_id(match.skip.info.range)] = 1 end end return skips end -local function get_node_at_pos(cursor) - local cursor_range = { cursor[1] - 1, cursor[2] } - - local buf = vim.api.nvim_win_get_buf(0) - local root_lang_tree = parsers.get_parser(buf) - if not root_lang_tree then - return - end - local root = ts_utils.get_root_for_position( - cursor_range[1], cursor_range[2], root_lang_tree) - - if not root then - return - end - - return root:named_descendant_for_range( - cursor_range[1], cursor_range[2], cursor_range[1], cursor_range[2]) -end - +---@param lnum integer +---@param col integer +---@return boolean function M.lang_skip(lnum, col) local bufnr = api.nvim_get_current_buf() local skips = M.get_skips(bufnr) @@ -63,17 +40,22 @@ function M.lang_skip(lnum, col) return false end - local node = get_node_at_pos({lnum, col - 1}) + -- TODO: is lnum - 1 ok? + local node = vts.get_node({pos = {lnum - 1, col - 1}}) if not node then return false end - if skips[node:id()] then + ---@diagnostic disable-next-line: missing-fields LuaLS bug + if skips[internal.range_id({node:range()})] then return true end return false end +---@param lnum integer +---@param col integer +---@param transparent 1|0 function M.synID(lnum, col, transparent) if not M.is_active() then return vim.fn.synID(lnum, col, transparent) diff --git a/lua/treesitter-matchup/third-party/query.lua b/lua/treesitter-matchup/third-party/query.lua deleted file mode 100644 index 5c8960a..0000000 --- a/lua/treesitter-matchup/third-party/query.lua +++ /dev/null @@ -1,394 +0,0 @@ --- From https://github.com/nvim-treesitter/nvim-treesitter --- Copyright 2021 --- licensed under the Apache License 2.0 --- See nvim-treesitter.LICENSE-APACHE-2.0 - -local api = vim.api -local ts = require 'treesitter-matchup.compat' -local tsrange = require "nvim-treesitter.tsrange" -local utils = require "nvim-treesitter.utils" -local parsers = require "nvim-treesitter.parsers" -local caching = require "nvim-treesitter.caching" - -local M = {} - -local EMPTY_ITER = function() end - -do - local query_cache = caching.create_buffer_cache() - - local function update_cached_matches(bufnr, changed_tick, query_group) - query_cache.set(query_group, bufnr, { - tick = changed_tick, - cache = M.collect_group_results(bufnr, query_group) or {}, - }) - end - - function M.get_matches(bufnr, query_group) - bufnr = bufnr or api.nvim_get_current_buf() - local cached_local = query_cache.get(query_group, bufnr) - if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then - update_cached_matches(bufnr, api.nvim_buf_get_changedtick(bufnr), query_group) - end - - return query_cache.get(query_group, bufnr).cache - end -end - -do - local mt = {} - mt.__index = function(tbl, key) - if rawget(tbl, key) == nil then - rawset(tbl, key, {}) - end - return rawget(tbl, key) - end - - -- cache will auto set the table for each lang if it is nil - local cache = setmetatable({}, mt) - - --- Same as `vim.treesitter.query` except will return cached values - ---@param lang string - ---@param query_name string - function M.get_query(lang, query_name) - if cache[lang][query_name] == nil then - cache[lang][query_name] = ts.get_query(lang, query_name) - end - - return cache[lang][query_name] - end - - --- Invalidates the query file cache. - --- If lang and query_name is both present, will reload for only the lang and query_name. - --- If only lang is present, will reload all query_names for that lang - --- If none are present, will reload everything - ---@param lang string - ---@param query_name string - function M.invalidate_query_cache(lang, query_name) - if lang and query_name then - cache[lang][query_name] = nil - elseif lang and not query_name then - for query_name0, _ in pairs(cache[lang]) do - M.invalidate_query_cache(lang, query_name0) - end - elseif not lang and not query_name then - for lang0, _ in pairs(cache) do - for query_name0, _ in pairs(cache[lang0]) do - M.invalidate_query_cache(lang0, query_name0) - end - end - else - error "Cannot have query_name by itself!" - end - end -end - ---- This function is meant for an autocommand and not to be used. Only use if file is a query file. ----@param fname string -function M.invalidate_query_file(fname) - local fnamemodify = vim.fn.fnamemodify - M.invalidate_query_cache(fnamemodify(fname, ":p:h:t"), fnamemodify(fname, ":t:r")) -end - ----@class QueryInfo ----@field root LanguageTree ----@field source integer ----@field start integer ----@field stop integer - ----@param bufnr integer ----@param query_name string ----@param root LanguageTree ----@param root_lang string|nil ----@return Query|nil, QueryInfo|nil -local function prepare_query(bufnr, query_name, root, root_lang) - local buf_lang = parsers.get_buf_lang(bufnr) - - if not buf_lang then - return - end - - local parser = parsers.get_parser(bufnr, buf_lang) - if not parser then - return - end - - if not root then - local first_tree = parser:trees()[1] - - if first_tree then - root = first_tree:root() - end - end - - if not root then - return - end - - local range = { root:range() } - - if not root_lang then - local lang_tree = parser:language_for_range(range) - - if lang_tree then - root_lang = lang_tree:lang() - end - end - - if not root_lang then - return - end - - local query = M.get_query(root_lang, query_name) - if not query then - return - end - - return query, - { - root = root, - source = bufnr, - start = range[1], - -- The end row is exclusive so we need to add 1 to it. - stop = range[3] + 1, - } -end - -local function get_byte_offset(buf, row, col) - local lines = api.nvim_buf_get_lines(buf, row, row + 1, false) - if #lines < 1 then - return - end - return api.nvim_buf_get_offset(buf, row) + vim.fn.byteidx(lines[1], col) -end - -local function TSRange_from_table(buf, range) - return setmetatable( - { - start_pos = {range[1], range[2], get_byte_offset(buf, range[1], range[2])}, - end_pos = {range[3], range[4], get_byte_offset(buf, range[3], range[4])}, - buf = buf, - [1] = range[1], - [2] = range[2], - [3] = range[3], - [4] = range[4], - }, - tsrange.TSRange) -end - ----@param query Query ----@param bufnr integer ----@param start_row integer ----@param end_row integer -function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) - -- A function that splits a string on '.' - local function split(string) - local t = {} - for str in string.gmatch(string, "([^.]+)") do - table.insert(t, str) - end - - return t - end - -- Given a path (i.e. a List(String)) this functions inserts value at path - local function insert_to_path(object, path, value) - local curr_obj = object - - for index = 1, (#path - 1) do - if curr_obj[path[index]] == nil then - curr_obj[path[index]] = {} - end - - curr_obj = curr_obj[path[index]] - end - - curr_obj[path[#path]] = value - end - - local matches = query:iter_matches(qnode, bufnr, start_row, end_row, { all = false }) - - local function iterator() - local pattern, match, metadata = matches() - if pattern ~= nil then - local prepared_match = {} - - -- Extract capture names from each match - for id, node in pairs(match) do - local name = query.captures[id] -- name of the capture in the query - if name ~= nil then - local path = split(name .. ".node") - insert_to_path(prepared_match, path, node) - local metadata_path = split(name .. ".metadata") - insert_to_path(prepared_match, metadata_path, metadata[id]) - end - end - - -- Add some predicates for testing - local preds = query.info.patterns[pattern] - if preds then - for _, pred in pairs(preds) do - -- functions - if pred[1] == "set!" and type(pred[2]) == "string" then - insert_to_path(prepared_match, split(pred[2]), pred[3]) - end - if pred[1] == "make-range!" and #pred == 4 then - assert(type(pred[2]) == "string") - local path = pred[2] - insert_to_path( - prepared_match, - split(path .. ".node"), - tsrange.TSRange.from_nodes(bufnr, match[pred[3]], match[pred[4]]) - ) - end - if pred[1] == "offset!" then - local path = type(pred[2]) == "string" and pred[2] or query.captures[pred[2]] - - local offset_node = match[pred[2]] - local range = {offset_node:range()} - local start_row_offset = pred[3] or 0 - local start_col_offset = pred[4] or 0 - local end_row_offset = pred[5] or 0 - local end_col_offset = pred[6] or 0 - - range[1] = range[1] + start_row_offset - range[2] = range[2] + start_col_offset - range[3] = range[3] + end_row_offset - range[4] = range[4] + end_col_offset - - insert_to_path(prepared_match, split(path..'.node'), - TSRange_from_table(bufnr, range)) - end - end - end - - return prepared_match - end - end - return iterator -end - ---- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type) ----Works like M.get_references or M.get_scopes except you can choose the capture ----Can also be a nested capture like @definition.function to get all nodes defining a function. ---- ----@param bufnr integer the buffer ----@param captures string|string[] ----@param query_group string the name of query group (highlights or injections for example) ----@param root LanguageTree|nil node from where to start the search ----@param lang string|nil the language from where to get the captures. ---- Root nodes can have several languages. ----@return table|nil -function M.get_capture_matches(bufnr, captures, query_group, root, lang) - if type(captures) == "string" then - captures = { captures } - end - local strip_captures = {} - for i, capture in ipairs(captures) do - if capture:sub(1, 1) ~= "@" then - error 'Captures must start with "@"' - return - end - -- Remove leading "@". - strip_captures[i] = capture:sub(2) - end - - local matches = {} - for match in M.iter_group_results(bufnr, query_group, root, lang) do - for _, capture in ipairs(strip_captures) do - local insert = utils.get_at_path(match, capture) - if insert then - table.insert(matches, insert) - end - end - end - return matches -end - -function M.iter_captures(bufnr, query_name, root, lang) - local query, params = prepare_query(bufnr, query_name, root, lang) - if not query then - return EMPTY_ITER - end - assert(params) - - local iter = query:iter_captures(params.root, params.source, params.start, params.stop) - - local function wrapped_iter() - local id, node, metadata = iter() - if not id then - return - end - - local name = query.captures[id] - if string.sub(name, 1, 1) == "_" then - return wrapped_iter() - end - - return name, node, metadata - end - - return wrapped_iter -end - ----Iterates matches from a query file. ----@param bufnr integer the buffer ----@param query_group string the query file to use ----@param root LanguageTree the root node ----@param root_lang string|nil the root node lang, if known -function M.iter_group_results(bufnr, query_group, root, root_lang) - local query, params = prepare_query(bufnr, query_group, root, root_lang) - if not query then - return EMPTY_ITER - end - assert(params) - - return M.iter_prepared_matches(query, params.root, params.source, params.start, params.stop) -end - -function M.collect_group_results(bufnr, query_group, root, lang) - local matches = {} - - for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do - table.insert(matches, prepared_match) - end - - return matches -end - ----@alias CaptureResFn function(string, LanguageTree, LanguageTree): string, string - ---- Same as get_capture_matches except this will recursively get matches for every language in the tree. ----@param bufnr integer The bufnr ----@param capture_or_fn string|CaptureResFn The capture to get. If a function is provided then that ---- function will be used to resolve both the capture and query argument. ---- The function can return `nil` to ignore that tree. ----@param query_type string The query to get the capture from. This is ignore if a function is provided ---- for the captuer argument. -function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type) - ---@type CaptureResFn - local type_fn - if type(capture_or_fn) == "function" then - type_fn = capture_or_fn - else - type_fn = function(_, _, _) - return capture_or_fn, query_type - end - end - local parser = parsers.get_parser(bufnr) - local matches = {} - - if parser then - parser:for_each_tree(function(tree, lang_tree) - local lang = lang_tree:lang() - local capture, type_ = type_fn(lang, tree, lang_tree) - - if capture then - vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang)) - end - end) - end - - return matches -end - -return M diff --git a/lua/treesitter-matchup/third-party/ts-utils.lua b/lua/treesitter-matchup/third-party/ts-utils.lua new file mode 100644 index 0000000..41db334 --- /dev/null +++ b/lua/treesitter-matchup/third-party/ts-utils.lua @@ -0,0 +1,28 @@ +-- From https://github.com/nvim-treesitter/nvim-treesitter +-- Copyright 2021 +-- licensed under the Apache License 2.0 +-- See nvim-treesitter.LICENSE-APACHE-2.0 + +local M = {} + +---Memoize a function using hash_fn to hash the arguments. +---@generic F: function +---@param fn F +---@param hash_fn fun(...): any +---@return F +function M.memoize(fn, hash_fn) + local cache = setmetatable({}, { __mode = 'kv' }) ---@type table + + return function(...) + local key = hash_fn(...) + if cache[key] == nil then + local v = fn(...) ---@type any + cache[key] = v ~= nil and v or vim.NIL + end + + local v = cache[key] + return v ~= vim.NIL and v or nil + end +end + +return M