diff --git a/lua/treesitter-matchup/compat.lua b/lua/treesitter-matchup/compat.lua new file mode 100644 index 0000000..6635722 --- /dev/null +++ b/lua/treesitter-matchup/compat.lua @@ -0,0 +1,14 @@ +local M = {} + +local ts = vim.treesitter +local tsq = vim.treesitter.query + +M.get_node_text = function(node, source, opts) + return (ts.get_node_text or tsq.get_node_text)(node, source, opts) +end + +M.get_query = function(lang, query_name) + return (tsq.get or tsq.get_query)(lang, query_name) +end + +return M diff --git a/lua/treesitter-matchup/internal.lua b/lua/treesitter-matchup/internal.lua index a396001..d808217 100644 --- a/lua/treesitter-matchup/internal.lua +++ b/lua/treesitter-matchup/internal.lua @@ -5,7 +5,7 @@ end local vim = vim local api = vim.api -local ts_compat = require'nvim-treesitter.compat' +local ts = require'treesitter-matchup.compat' local configs = require'nvim-treesitter.configs' local parsers = require'nvim-treesitter.parsers' local queries = require'treesitter-matchup.third-party.query' @@ -158,7 +158,7 @@ function M.containing_scope(node, bufnr, key) end local function _node_text(node, bufnr) - local text = ts_compat.get_node_text(node, bufnr) + local text = ts.get_node_text(node, bufnr) return text:match("(%S+).*") end diff --git a/lua/treesitter-matchup/third-party/query.lua b/lua/treesitter-matchup/third-party/query.lua index 58618e9..add114c 100644 --- a/lua/treesitter-matchup/third-party/query.lua +++ b/lua/treesitter-matchup/third-party/query.lua @@ -4,7 +4,7 @@ -- See nvim-treesitter.LICENSE-APACHE-2.0 local api = vim.api -local ts_compat = require 'nvim-treesitter.compat' +local ts = require 'treesitter-matchup.compat' local tsrange = require "nvim-treesitter.tsrange" local utils = require "nvim-treesitter.utils" local parsers = require "nvim-treesitter.parsers" @@ -52,7 +52,7 @@ do ---@param query_name string function M.get_query(lang, query_name) if cache[lang][query_name] == nil then - cache[lang][query_name] = ts_compat.get_query(lang, query_name) + cache[lang][query_name] = ts.get_query(lang, query_name) end return cache[lang][query_name]