mirror of
https://github.com/chenasraf/nvim-treesitter.git
synced 2026-05-17 17:38:02 +00:00
feat(install)!: migrate to latest async.nvim impl (#7856)
Provides significantly simpler blocking installation and update.
This commit is contained in:
committed by
Christian Clason
parent
7a4a35de3e
commit
69371f0148
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -1,3 +1,4 @@
|
||||
runtime/queries/**/*.scm linguist-language=Tree-sitter-Query
|
||||
runtime/queries/**/*.scm linguist-language=tsq
|
||||
doc/*.txt linguist-documentation
|
||||
SUPPORTED_LANGUAGES.md linguist-generated
|
||||
lua/nvim-treesitter/async.lua linguist-vendored
|
||||
|
||||
@@ -68,13 +68,9 @@ Parsers and queries can then be installed with
|
||||
require'nvim-treesitter'.install { 'rust', 'javascript', 'zig' }
|
||||
```
|
||||
|
||||
(This is a no-op if the parsers are already installed.) Note that this function runs asynchronously; for synchronous installation in a script context ("bootstrapping"), use something like
|
||||
(This is a no-op if the parsers are already installed.) Note that this function runs asynchronously; for synchronous installation in a script context ("bootstrapping"), you need to `wait()` for it to finish:
|
||||
```lua
|
||||
local done = nil
|
||||
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' }, function(success)
|
||||
done = success
|
||||
end)
|
||||
vim.wait(3000000, function() return done ~= nil end)
|
||||
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' }):wait(300000) -- wait max. 5 minutes
|
||||
```
|
||||
|
||||
Check [`:h nvim-treesitter-commands`](doc/nvim-treesitter.txt) for a list of all available commands.
|
||||
|
||||
2
TODO.md
2
TODO.md
@@ -4,7 +4,6 @@ This document lists the planned and finished changes in this rewrite towards [Nv
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] **`install.lua`:** migrate to async v2
|
||||
- [ ] **tests:** remove custom crate, plenary dependency
|
||||
- [ ] **indents:** rewrite (Helix or Zed compatible)
|
||||
- [ ] **textobjects:** include simple(!) `node`, `scope` (using `locals`) objects
|
||||
@@ -26,3 +25,4 @@ This document lists the planned and finished changes in this rewrite towards [Nv
|
||||
- [X] drop ensure_install (replace with install)
|
||||
- [X] **CI:** switch to ts_query_ls, add update readme as check (remove update job)
|
||||
- [X] **CI:** track versioned releases for tier 1
|
||||
- [X] **`install.lua`:** migrate to async v2
|
||||
|
||||
@@ -102,7 +102,7 @@ setup({opts}) *nvim-treesitter.setup()*
|
||||
directory to install parsers and queries to. Note: will be
|
||||
appended to |runtimepath|.
|
||||
|
||||
install({languages}, {opts}, {callback}) *nvim-treesitter.install()*
|
||||
install({languages} [, {opts}]) *nvim-treesitter.install()*
|
||||
|
||||
Download, compile, and install the specified treesitter parsers and copy
|
||||
the corresponding queries to a directory on |runtimepath|, enabling their
|
||||
@@ -110,13 +110,9 @@ install({languages}, {opts}, {callback}) *nvim-treesitter.install()*
|
||||
|
||||
Note: This operation is performed asynchronously by default. For
|
||||
synchronous operation (e.g., in a bootstrapping script), you need to
|
||||
provide a suitable {callback}: >lua
|
||||
local done = nil
|
||||
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' },
|
||||
function(success)
|
||||
done = success
|
||||
end)
|
||||
vim.wait(3000000, function() return done ~= nil end)
|
||||
`wait()` for it: >lua
|
||||
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' })
|
||||
:wait(300000) -- max. 5 minutes
|
||||
<
|
||||
Parameters: ~
|
||||
• {languages} `(string[]|string)` (List of) languages or tiers (`stable`,
|
||||
@@ -129,7 +125,6 @@ install({languages}, {opts}, {callback}) *nvim-treesitter.install()*
|
||||
compiling.
|
||||
• {max_jobs} (`integer?`) limit parallel tasks (useful in
|
||||
combination with {generate} on memory-limited systems).
|
||||
• {callback} `(function?`) Callback for synchronous execution.
|
||||
|
||||
uninstall({languages}) *nvim-treesitter.uninstall()*
|
||||
|
||||
@@ -139,25 +134,19 @@ uninstall({languages}) *nvim-treesitter.uninstall()
|
||||
• {languages} `(string[]|string)` (List of) languages or tiers (`stable`,
|
||||
`unstable`) to update.
|
||||
|
||||
update({languages}, {callback}) *nvim-treesitter.update()*
|
||||
update([{languages}]) *nvim-treesitter.update()*
|
||||
|
||||
Update the parsers and queries if older than the revision specified in the
|
||||
manifest.
|
||||
|
||||
Note: This operation is performed asynchronously by default. For
|
||||
synchronous operation (e.g., in a bootstrapping script), you need to
|
||||
provide a suitable {callback}: >lua
|
||||
local done = nil
|
||||
require('nvim-treesitter').update(),
|
||||
function(success)
|
||||
done = success
|
||||
end)
|
||||
vim.wait(3000000, function() return done ~= nil end)
|
||||
`wait()` for it: >lua
|
||||
require('nvim-treesitter').update():wait(300000) -- max. 5 minutes
|
||||
<
|
||||
Parameters: ~
|
||||
• {languages} `(string[]|string)` (List of) languages or tiers to
|
||||
uninstall.
|
||||
• {callback} `(function?`) Callback for synchronous execution.
|
||||
• {languages} `(string[]|string)?` (List of) languages or tiers to update
|
||||
(default: all installed).
|
||||
|
||||
indentexpr() *nvim-treesitter.indentexpr()*
|
||||
|
||||
|
||||
763
lua/nvim-treesitter/async.lua
vendored
763
lua/nvim-treesitter/async.lua
vendored
@@ -1,112 +1,725 @@
|
||||
local co = coroutine
|
||||
local pcall = copcall or pcall
|
||||
|
||||
--- @param ... any
|
||||
--- @return {[integer]: any, n: integer}
|
||||
local function pack_len(...)
|
||||
return { n = select('#', ...), ... }
|
||||
end
|
||||
|
||||
--- like unpack() but use the length set by F.pack_len if present
|
||||
--- @param t? { [integer]: any, n?: integer }
|
||||
--- @param first? integer
|
||||
--- @return ...any
|
||||
local function unpack_len(t, first)
|
||||
if t then
|
||||
return unpack(t, first or 1, t.n or table.maxn(t))
|
||||
end
|
||||
end
|
||||
|
||||
--- @class async
|
||||
local M = {}
|
||||
|
||||
---Executes a future with a callback when it is done
|
||||
---@param func function
|
||||
---@param callback function
|
||||
---@param ... unknown
|
||||
local function execute(func, callback, ...)
|
||||
local thread = co.create(func)
|
||||
--- Weak table to keep track of running tasks
|
||||
--- @type table<thread,async.Task?>
|
||||
local threads = setmetatable({}, { __mode = 'k' })
|
||||
|
||||
local function step(...)
|
||||
local ret = { co.resume(thread, ...) }
|
||||
---@type boolean, any
|
||||
local stat, nargs_or_err = unpack(ret)
|
||||
--- @return async.Task?
|
||||
local function running()
|
||||
local task = threads[coroutine.running()]
|
||||
if task and not (task:_completed() or task._closing) then
|
||||
return task
|
||||
end
|
||||
end
|
||||
|
||||
--- Base class for async tasks. Async functions should return a subclass of
|
||||
--- this. This is designed specifically to be a base class of uv_handle_t
|
||||
--- @class async.Handle
|
||||
--- @field close fun(self: async.Handle, callback?: fun())
|
||||
--- @field is_closing? fun(self: async.Handle): boolean
|
||||
|
||||
--- @alias async.CallbackFn fun(...: any): async.Handle?
|
||||
|
||||
--- @class async.Task : async.Handle
|
||||
--- @field package _callbacks table<integer,fun(err?: any, ...: any)>
|
||||
--- @field package _callback_pos integer
|
||||
--- @field private _thread thread
|
||||
---
|
||||
--- Tasks can call other async functions (task of callback functions)
|
||||
--- when we are waiting on a child, we store the handle to it here so we can
|
||||
--- cancel it.
|
||||
--- @field private _current_child? async.Handle
|
||||
---
|
||||
--- Error result of the task is an error occurs.
|
||||
--- Must use `await` to get the result.
|
||||
--- @field private _err? any
|
||||
---
|
||||
--- Result of the task.
|
||||
--- Must use `await` to get the result.
|
||||
--- @field private _result? any[]
|
||||
local Task = {}
|
||||
Task.__index = Task
|
||||
|
||||
--- @private
|
||||
--- @param func function
|
||||
--- @return async.Task
|
||||
function Task._new(func)
|
||||
local thread = coroutine.create(func)
|
||||
|
||||
local self = setmetatable({
|
||||
_closing = false,
|
||||
_thread = thread,
|
||||
_callbacks = {},
|
||||
_callback_pos = 1,
|
||||
}, Task)
|
||||
|
||||
threads[thread] = self
|
||||
|
||||
return self
|
||||
end
|
||||
|
||||
--- @param callback fun(err?: any, ...: any)
|
||||
function Task:await(callback)
|
||||
if self._closing then
|
||||
callback('closing')
|
||||
elseif self:_completed() then -- TODO(lewis6991): test
|
||||
-- Already finished or closed
|
||||
callback(self._err, unpack_len(self._result))
|
||||
else
|
||||
self._callbacks[self._callback_pos] = callback
|
||||
self._callback_pos = self._callback_pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
--- @package
|
||||
function Task:_completed()
|
||||
return (self._err or self._result) ~= nil
|
||||
end
|
||||
|
||||
-- Use max 32-bit signed int value to avoid overflow on 32-bit systems.
|
||||
-- Do not use `math.huge` as it is not interpreted as a positive integer on all
|
||||
-- platforms.
|
||||
local MAX_TIMEOUT = 2 ^ 31 - 1
|
||||
|
||||
--- Synchronously wait (protected) for a task to finish (blocking)
|
||||
---
|
||||
--- If an error is returned, `Task:traceback()` can be used to get the
|
||||
--- stack trace of the error.
|
||||
---
|
||||
--- Example:
|
||||
--- ```lua
|
||||
---
|
||||
--- local ok, err_or_result = task:pwait(10)
|
||||
---
|
||||
--- if not ok then
|
||||
--- error(task:traceback(err_or_result))
|
||||
--- end
|
||||
---
|
||||
--- local _, result = assert(task:pwait(10))
|
||||
--- ```
|
||||
---
|
||||
--- Can be called if a task is closing.
|
||||
--- @param timeout? integer
|
||||
--- @return boolean status
|
||||
--- @return any ... result or error
|
||||
function Task:pwait(timeout)
|
||||
local done = vim.wait(timeout or MAX_TIMEOUT, function()
|
||||
-- Note we use self:_completed() instead of self:await() to avoid creating a
|
||||
-- callback. This avoids having to cleanup/unregister any callback in the
|
||||
-- case of a timeout.
|
||||
return self:_completed()
|
||||
end)
|
||||
|
||||
if not done then
|
||||
return false, 'timeout'
|
||||
elseif self._err then
|
||||
return false, self._err
|
||||
else
|
||||
return true, unpack_len(self._result)
|
||||
end
|
||||
end
|
||||
|
||||
--- Synchronously wait for a task to finish (blocking)
|
||||
---
|
||||
--- Example:
|
||||
--- ```lua
|
||||
--- local result = task:wait(10) -- wait for 10ms or else error
|
||||
---
|
||||
--- local result = task:wait() -- wait indefinitely
|
||||
--- ```
|
||||
--- @param timeout? integer Timeout in milliseconds
|
||||
--- @return any ... result
|
||||
function Task:wait(timeout)
|
||||
local res = pack_len(self:pwait(timeout))
|
||||
local stat = res[1]
|
||||
|
||||
if not stat then
|
||||
error(
|
||||
string.format(
|
||||
'The coroutine failed with this message: %s\n%s',
|
||||
nargs_or_err,
|
||||
debug.traceback(thread)
|
||||
)
|
||||
)
|
||||
error(self:traceback(res[2]))
|
||||
end
|
||||
|
||||
if co.status(thread) == 'dead' then
|
||||
return unpack_len(res, 2)
|
||||
end
|
||||
|
||||
--- @private
|
||||
--- @param msg? string
|
||||
--- @param _lvl? integer
|
||||
--- @return string
|
||||
function Task:_traceback(msg, _lvl)
|
||||
_lvl = _lvl or 0
|
||||
|
||||
local thread = ('[%s] '):format(self._thread)
|
||||
|
||||
local child = self._current_child
|
||||
if getmetatable(child) == Task then
|
||||
--- @cast child async.Task
|
||||
msg = child:_traceback(msg, _lvl + 1)
|
||||
end
|
||||
|
||||
local tblvl = getmetatable(child) == Task and 2 or nil
|
||||
msg = (msg or '') .. debug.traceback(self._thread, '', tblvl):gsub('\n\t', '\n\t' .. thread)
|
||||
|
||||
if _lvl == 0 then
|
||||
--- @type string
|
||||
msg = msg
|
||||
:gsub('\nstack traceback:\n', '\nSTACK TRACEBACK:\n', 1)
|
||||
:gsub('\nstack traceback:\n', '\n')
|
||||
:gsub('\nSTACK TRACEBACK:\n', '\nstack traceback:\n', 1)
|
||||
end
|
||||
|
||||
return msg
|
||||
end
|
||||
|
||||
--- Get the traceback of a task when it is not active.
|
||||
--- Will also get the traceback of nested tasks.
|
||||
---
|
||||
--- @param msg? string
|
||||
--- @return string
|
||||
function Task:traceback(msg)
|
||||
return self:_traceback(msg)
|
||||
end
|
||||
|
||||
--- If a task completes with an error, raise the error
|
||||
function Task:raise_on_error()
|
||||
self:await(function(err)
|
||||
if err then
|
||||
error(self:_traceback(err), 0)
|
||||
end
|
||||
end)
|
||||
return self
|
||||
end
|
||||
|
||||
--- @private
|
||||
--- @param err? any
|
||||
--- @param result? {[integer]: any, n: integer}
|
||||
function Task:_finish(err, result)
|
||||
self._current_child = nil
|
||||
self._err = err
|
||||
self._result = result
|
||||
threads[self._thread] = nil
|
||||
|
||||
local errs = {} --- @type string[]
|
||||
for _, cb in pairs(self._callbacks) do
|
||||
--- @type boolean, string
|
||||
local ok, cb_err = pcall(cb, err, unpack_len(result))
|
||||
if not ok then
|
||||
errs[#errs + 1] = cb_err
|
||||
end
|
||||
end
|
||||
|
||||
if #errs > 0 then
|
||||
error(table.concat(errs, '\n'), 0)
|
||||
end
|
||||
end
|
||||
|
||||
--- @return boolean
|
||||
function Task:is_closing()
|
||||
return self._closing
|
||||
end
|
||||
|
||||
--- Close the task and all its children.
|
||||
--- If callback is provided it will run asynchronously,
|
||||
--- else it will run synchronously.
|
||||
---
|
||||
--- @param callback? fun()
|
||||
function Task:close(callback)
|
||||
if self:_completed() then
|
||||
if callback then
|
||||
callback(unpack(ret, 3, table.maxn(ret)))
|
||||
callback()
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
---@type function, any[]
|
||||
local fn, args = ret[3], { unpack(ret, 4, table.maxn(ret)) }
|
||||
args[nargs_or_err] = step
|
||||
fn(unpack(args, 1, nargs_or_err))
|
||||
if self._closing then
|
||||
return
|
||||
end
|
||||
|
||||
step(...)
|
||||
self._closing = true
|
||||
|
||||
if callback then -- async
|
||||
if self._current_child then
|
||||
self._current_child:close(function()
|
||||
self:_finish('closed')
|
||||
callback()
|
||||
end)
|
||||
else
|
||||
self:_finish('closed')
|
||||
callback()
|
||||
end
|
||||
else -- sync
|
||||
if self._current_child then
|
||||
self._current_child:close(function()
|
||||
self:_finish('closed')
|
||||
end)
|
||||
else
|
||||
self:_finish('closed')
|
||||
end
|
||||
vim.wait(0, function()
|
||||
return self:_completed()
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
--- @param obj any
|
||||
--- @return boolean
|
||||
local function is_async_handle(obj)
|
||||
local ty = type(obj)
|
||||
return (ty == 'table' or ty == 'userdata') and vim.is_callable(obj.close)
|
||||
end
|
||||
|
||||
--- @param ... any
|
||||
function Task:_resume(...)
|
||||
--- @type [boolean, string|async.CallbackFn]
|
||||
local ret = pack_len(coroutine.resume(self._thread, ...))
|
||||
local stat = ret[1]
|
||||
|
||||
if not stat then
|
||||
-- Coroutine had error
|
||||
self:_finish(ret[2])
|
||||
elseif coroutine.status(self._thread) == 'dead' then
|
||||
-- Coroutine finished
|
||||
local result = pack_len(unpack_len(ret, 2))
|
||||
self:_finish(nil, result)
|
||||
else
|
||||
local fn = ret[2]
|
||||
--- @cast fn -string
|
||||
|
||||
-- TODO(lewis6991): refine error handler to be more specific
|
||||
local ok, r
|
||||
ok, r = pcall(fn, function(...)
|
||||
if is_async_handle(r) then
|
||||
--- @cast r async.Handle
|
||||
-- We must close children before we resume to ensure
|
||||
-- all resources are collected.
|
||||
local args = pack_len(...)
|
||||
r:close(function()
|
||||
self:_resume(unpack_len(args))
|
||||
end)
|
||||
else
|
||||
self:_resume(...)
|
||||
end
|
||||
end)
|
||||
|
||||
if not ok then
|
||||
self:_finish(r)
|
||||
elseif is_async_handle(r) then
|
||||
self._current_child = r
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- @return 'running'|'suspended'|'normal'|'dead'?
|
||||
function Task:status()
|
||||
return coroutine.status(self._thread)
|
||||
end
|
||||
|
||||
--- Run a function in an async context, asynchronously.
|
||||
---
|
||||
--- Examples:
|
||||
--- ```lua
|
||||
--- -- The two below blocks are equivalent:
|
||||
---
|
||||
--- -- Run a uv function and wait for it
|
||||
--- local stat = async.arun(function()
|
||||
--- return async.await(2, vim.uv.fs_stat, 'foo.txt')
|
||||
--- end):wait()
|
||||
---
|
||||
--- -- Since uv functions have sync versions. You can just do:
|
||||
--- local stat = vim.fs_stat('foo.txt')
|
||||
--- ```
|
||||
--- @param func function
|
||||
--- @param ... any
|
||||
--- @return async.Task
|
||||
function M.arun(func, ...)
|
||||
local task = Task._new(func)
|
||||
task:_resume(...)
|
||||
return task
|
||||
end
|
||||
|
||||
--- @class async.TaskFun
|
||||
--- @field package _fun fun(...: any): any
|
||||
--- @operator call(...): any
|
||||
local TaskFun = {}
|
||||
TaskFun.__index = TaskFun
|
||||
|
||||
function TaskFun:__call(...)
|
||||
return M.arun(self._fun, ...)
|
||||
end
|
||||
|
||||
--- Create an async function
|
||||
--- @param fun function
|
||||
--- @return async.TaskFun
|
||||
function M.async(fun)
|
||||
return setmetatable({ _fun = fun }, TaskFun)
|
||||
end
|
||||
|
||||
--- Returns the status of a task’s thread.
|
||||
---
|
||||
--- @param task? async.Task
|
||||
--- @return 'running'|'suspended'|'normal'|'dead'?
|
||||
function M.status(task)
|
||||
task = task or running()
|
||||
if task then
|
||||
assert(getmetatable(task) == Task, 'Expected Task')
|
||||
return task:status()
|
||||
end
|
||||
end
|
||||
|
||||
--- @async
|
||||
--- @generic R1, R2, R3, R4
|
||||
--- @param fun fun(callback: fun(r1: R1, r2: R2, r3: R3, r4: R4)): any?
|
||||
--- @return R1, R2, R3, R4
|
||||
local function yield(fun)
|
||||
assert(type(fun) == 'function', 'Expected function')
|
||||
return coroutine.yield(fun)
|
||||
end
|
||||
|
||||
--- @async
|
||||
--- @param task async.Task
|
||||
--- @return any ...
|
||||
local function await_task(task)
|
||||
--- @param callback fun(err?: string, ...: any)
|
||||
--- @return function
|
||||
local res = pack_len(yield(function(callback)
|
||||
task:await(callback)
|
||||
return task
|
||||
end))
|
||||
|
||||
local err = res[1]
|
||||
|
||||
if err then
|
||||
-- TODO(lewis6991): what is the correct level to pass?
|
||||
error(err, 0)
|
||||
end
|
||||
|
||||
return unpack_len(res, 2)
|
||||
end
|
||||
|
||||
--- Asynchronous blocking wait
|
||||
--- @param argc integer
|
||||
--- @param fun async.CallbackFn
|
||||
--- @param ... any func arguments
|
||||
--- @return any ...
|
||||
local function await_cbfun(argc, fun, ...)
|
||||
local args = pack_len(...)
|
||||
|
||||
--- @param callback fun(...:any)
|
||||
--- @return any?
|
||||
return yield(function(callback)
|
||||
args[argc] = callback
|
||||
args.n = math.max(args.n, argc)
|
||||
return fun(unpack_len(args))
|
||||
end)
|
||||
end
|
||||
|
||||
--- @param taskfun async.TaskFun
|
||||
--- @param ... any
|
||||
--- @return any ...
|
||||
local function await_taskfun(taskfun, ...)
|
||||
return taskfun._fun(...)
|
||||
end
|
||||
|
||||
--- Asynchronous blocking wait
|
||||
---
|
||||
--- Example:
|
||||
--- ```lua
|
||||
--- local task = async.arun(function()
|
||||
--- return 1, 'a'
|
||||
--- end)
|
||||
---
|
||||
--- local task_fun = async.async(function(arg)
|
||||
--- return 2, 'b', arg
|
||||
--- end)
|
||||
---
|
||||
--- async.arun(function()
|
||||
--- do -- await a callback function
|
||||
--- async.await(1, vim.schedule)
|
||||
--- end
|
||||
---
|
||||
--- do -- await a task (new async context)
|
||||
--- local n, s = async.await(task)
|
||||
--- assert(n == 1 and s == 'a')
|
||||
--- end
|
||||
---
|
||||
--- do -- await a started task function (new async context)
|
||||
--- local n, s, arg = async.await(task_fun('A'))
|
||||
--- assert(n == 2)
|
||||
--- assert(s == 'b')
|
||||
--- assert(args == 'A')
|
||||
--- end
|
||||
---
|
||||
--- do -- await a task function (re-using the current async context)
|
||||
--- local n, s, arg = async.await(task_fun, 'B')
|
||||
--- assert(n == 2)
|
||||
--- assert(s == 'b')
|
||||
--- assert(args == 'B')
|
||||
--- end
|
||||
--- end)
|
||||
--- ```
|
||||
--- @async
|
||||
--- @overload fun(argc: integer, func: async.CallbackFn, ...:any): any ...
|
||||
--- @overload fun(task: async.Task): any ...
|
||||
--- @overload fun(taskfun: async.TaskFun): any ...
|
||||
function M.await(...)
|
||||
assert(running(), 'Not in async context')
|
||||
|
||||
local arg1 = select(1, ...)
|
||||
|
||||
if type(arg1) == 'number' then
|
||||
return await_cbfun(...)
|
||||
elseif getmetatable(arg1) == Task then
|
||||
return await_task(...)
|
||||
elseif getmetatable(arg1) == TaskFun then
|
||||
return await_taskfun(...)
|
||||
end
|
||||
|
||||
error('Invalid arguments, expected Task or (argc, func) got: ' .. type(arg1), 2)
|
||||
end
|
||||
|
||||
--- Creates an async function with a callback style function.
|
||||
---@generic F: function
|
||||
---@param func F
|
||||
---@param argc integer
|
||||
---@return F
|
||||
function M.wrap(func, argc)
|
||||
vim.validate('func', func, 'function')
|
||||
vim.validate('argc', argc, 'number')
|
||||
---@param ... unknown
|
||||
---@return unknown
|
||||
---
|
||||
--- Example:
|
||||
---
|
||||
--- ```lua
|
||||
--- --- Note the callback argument is not present in the return function
|
||||
--- --- @type fun(timeout: integer)
|
||||
--- local sleep = async.awrap(2, function(timeout, callback)
|
||||
--- local timer = vim.uv.new_timer()
|
||||
--- timer:start(timeout * 1000, 0, callback)
|
||||
--- -- uv_timer_t provides a close method so timer will be
|
||||
--- -- cleaned up when this function finishes
|
||||
--- return timer
|
||||
--- end)
|
||||
---
|
||||
--- async.arun(function()
|
||||
--- print('hello')
|
||||
--- sleep(2)
|
||||
--- print('world')
|
||||
--- end)
|
||||
--- ```
|
||||
---
|
||||
--- local atimer = async.awrap(
|
||||
--- @param argc integer
|
||||
--- @param func async.CallbackFn
|
||||
--- @return async function
|
||||
function M.awrap(argc, func)
|
||||
assert(type(argc) == 'number')
|
||||
assert(type(func) == 'function')
|
||||
--- @async
|
||||
return function(...)
|
||||
return co.yield(argc, func, ...)
|
||||
return M.await(argc, func, ...)
|
||||
end
|
||||
end
|
||||
|
||||
---Use this to create a function which executes in an async context but
|
||||
---called from a non-async context. Inherently this cannot return anything
|
||||
---since it is non-blocking
|
||||
---@generic F: function
|
||||
---@param func async F
|
||||
---@param nargs? integer
|
||||
---@return F
|
||||
function M.sync(func, nargs)
|
||||
nargs = nargs or 0
|
||||
return function(...)
|
||||
local callback = select(nargs + 1, ...)
|
||||
execute(func, callback, unpack({ ... }, 1, nargs))
|
||||
if vim.schedule then
|
||||
--- An async function that when called will yield to the Neovim scheduler to be
|
||||
--- able to call the API.
|
||||
M.schedule = M.awrap(1, vim.schedule)
|
||||
end
|
||||
|
||||
--- Create a function that runs a function when it is garbage collected.
|
||||
--- @generic F
|
||||
--- @param f F
|
||||
--- @param gc fun()
|
||||
--- @return F
|
||||
local function gc_fun(f, gc)
|
||||
local proxy = newproxy(true)
|
||||
local proxy_mt = getmetatable(proxy)
|
||||
proxy_mt.__gc = gc
|
||||
proxy_mt.__call = function(_, ...)
|
||||
return f(...)
|
||||
end
|
||||
|
||||
return proxy
|
||||
end
|
||||
|
||||
--- @param task_cbs table<async.Task,function>
|
||||
local function gc_cbs(task_cbs)
|
||||
for task, tcb in pairs(task_cbs) do
|
||||
for j, cb in pairs(task._callbacks) do
|
||||
if cb == tcb then
|
||||
task._callbacks[j] = nil
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
---@param n integer max number of concurrent jobs
|
||||
---@param interrupt_check? function
|
||||
---@param thunks function[]
|
||||
---@return any
|
||||
function M.join(n, interrupt_check, thunks)
|
||||
return co.yield(1, function(finish)
|
||||
if #thunks == 0 then
|
||||
return finish()
|
||||
--- @async
|
||||
--- Example:
|
||||
--- ```lua
|
||||
--- local task1 = async.arun(function()
|
||||
--- return 1, 'a'
|
||||
--- end)
|
||||
---
|
||||
--- local task2 = async.arun(function()
|
||||
--- return 1, 'a'
|
||||
--- end)
|
||||
---
|
||||
--- local task3 = async.arun(function()
|
||||
--- error('task3 error')
|
||||
--- end)
|
||||
---
|
||||
--- async.arun(function()
|
||||
--- for i, err, r1, r2 in async.iter({task1, task2, task3})
|
||||
--- print(i, err, r1, r2)
|
||||
--- end
|
||||
--- end)
|
||||
--- ```
|
||||
---
|
||||
--- Prints:
|
||||
--- ```
|
||||
--- 1 nil 1 'a'
|
||||
--- 2 nil 2 'b'
|
||||
--- 3 'task3 error' nil nil
|
||||
--- ```
|
||||
---
|
||||
--- @param tasks async.Task[]
|
||||
--- @return fun(): (integer?, any?, ...)
|
||||
function M.iter(tasks)
|
||||
assert(running(), 'Not in async context')
|
||||
|
||||
local results = {} --- @type [integer, any, ...][]
|
||||
|
||||
-- Iter blocks in an async context so only one waiter is needed
|
||||
local waiter = nil
|
||||
local task_cbs = {} --- @type table<async.Task,function>
|
||||
local remaining = #tasks
|
||||
|
||||
--- If can_gc_cbs is true, then the iterator function has been garbage
|
||||
--- collected and means any awaiters can also be garbage collected. The
|
||||
--- only time we can't do this is if with the special case when iter() is
|
||||
--- called anonymously (`local i = async.iter(tasks)()`), so we should not
|
||||
--- garbage collect the callbacks until at least one awaiter is called.
|
||||
local can_gc_cbs = false
|
||||
|
||||
for i, task in ipairs(tasks) do
|
||||
local function cb(err, ...)
|
||||
if can_gc_cbs == true then
|
||||
gc_cbs(task_cbs)
|
||||
end
|
||||
|
||||
local remaining = { select(n + 1, unpack(thunks)) }
|
||||
local to_go = #thunks
|
||||
local callback = waiter
|
||||
|
||||
local ret = {} ---@type any[]
|
||||
-- Clear waiter before calling it
|
||||
waiter = nil
|
||||
|
||||
local function cb(...)
|
||||
ret[#ret + 1] = { ... }
|
||||
to_go = to_go - 1
|
||||
if to_go == 0 then
|
||||
finish(ret)
|
||||
elseif not interrupt_check or not interrupt_check() then
|
||||
if #remaining > 0 then
|
||||
local next_task = table.remove(remaining)
|
||||
next_task(cb)
|
||||
end
|
||||
remaining = remaining - 1
|
||||
if callback then
|
||||
-- Iterator is waiting, yield to it
|
||||
callback(i, err, ...)
|
||||
else
|
||||
-- Task finished before Iterator was called. Store results.
|
||||
table.insert(results, pack_len(i, err, ...))
|
||||
end
|
||||
end
|
||||
|
||||
for i = 1, math.min(n, #thunks) do
|
||||
thunks[i](cb)
|
||||
task_cbs[task] = cb
|
||||
task:await(cb)
|
||||
end
|
||||
end, 1)
|
||||
|
||||
return gc_fun(
|
||||
M.awrap(1, function(callback)
|
||||
if next(results) then
|
||||
local res = table.remove(results, 1)
|
||||
callback(unpack_len(res))
|
||||
elseif remaining == 0 then
|
||||
callback() -- finish
|
||||
else
|
||||
assert(not waiter, 'internal error: waiter already set')
|
||||
waiter = callback
|
||||
end
|
||||
end),
|
||||
function()
|
||||
-- Don't gc callbacks just yet. Wait until at least one of them is called.
|
||||
can_gc_cbs = true
|
||||
end
|
||||
)
|
||||
end
|
||||
|
||||
---An async function that when called will yield to the Neovim scheduler to be
|
||||
---able to call the API.
|
||||
---@type fun()
|
||||
M.main = M.wrap(vim.schedule, 1)
|
||||
do -- join()
|
||||
--- @param results table<integer,table>
|
||||
--- @param i integer
|
||||
--- @param ... any
|
||||
--- @return boolean
|
||||
local function collect(results, i, ...)
|
||||
if i then
|
||||
results[i] = pack_len(...)
|
||||
end
|
||||
return i ~= nil
|
||||
end
|
||||
|
||||
--- @param iter fun(): ...
|
||||
--- @return table<integer,table>
|
||||
local function drain_iter(iter)
|
||||
local results = {} --- @type table<integer,table>
|
||||
while collect(results, iter()) do
|
||||
end
|
||||
return results
|
||||
end
|
||||
|
||||
--- @async
|
||||
--- Wait for all tasks to finish and return their results.
|
||||
---
|
||||
--- Example:
|
||||
--- ```lua
|
||||
--- local task1 = async.arun(function()
|
||||
--- return 1, 'a'
|
||||
--- end)
|
||||
---
|
||||
--- local task2 = async.arun(function()
|
||||
--- return 1, 'a'
|
||||
--- end)
|
||||
---
|
||||
--- local task3 = async.arun(function()
|
||||
--- error('task3 error')
|
||||
--- end)
|
||||
---
|
||||
--- async.arun(function()
|
||||
--- local results = async.join({task1, task2, task3})
|
||||
--- print(vim.inspect(results))
|
||||
--- end)
|
||||
--- ```
|
||||
---
|
||||
--- Prints:
|
||||
--- ```
|
||||
--- {
|
||||
--- [1] = { nil, 1, 'a' },
|
||||
--- [2] = { nil, 2, 'b' },
|
||||
--- [3] = { 'task2 error' },
|
||||
--- }
|
||||
--- ```
|
||||
--- @param tasks async.Task[]
|
||||
--- @return table<integer,[any?,...?]>
|
||||
function M.join(tasks)
|
||||
assert(running(), 'Not in async context')
|
||||
return drain_iter(M.iter(tasks))
|
||||
end
|
||||
|
||||
--- @async
|
||||
--- @param tasks async.Task[]
|
||||
--- @return integer?, any?, ...?
|
||||
function M.joinany(tasks)
|
||||
return M.iter(tasks)()
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -9,23 +9,53 @@ local parsers = require('nvim-treesitter.parsers')
|
||||
local util = require('nvim-treesitter.util')
|
||||
|
||||
---@type fun(path: string, new_path: string, flags?: table): string?
|
||||
local uv_copyfile = a.wrap(uv.fs_copyfile, 4)
|
||||
local uv_copyfile = a.awrap(4, uv.fs_copyfile)
|
||||
|
||||
---@type fun(path: string, mode: integer): string?
|
||||
local uv_mkdir = a.wrap(uv.fs_mkdir, 3)
|
||||
local uv_mkdir = a.awrap(3, uv.fs_mkdir)
|
||||
|
||||
---@type fun(path: string, new_path: string): string?
|
||||
local uv_rename = a.wrap(uv.fs_rename, 3)
|
||||
local uv_rename = a.awrap(3, uv.fs_rename)
|
||||
|
||||
---@type fun(path: string, new_path: string, flags?: table): string?
|
||||
local uv_symlink = a.wrap(uv.fs_symlink, 4)
|
||||
local uv_symlink = a.awrap(4, uv.fs_symlink)
|
||||
|
||||
---@type fun(path: string): string?
|
||||
local uv_unlink = a.wrap(uv.fs_unlink, 2)
|
||||
local uv_unlink = a.awrap(2, uv.fs_unlink)
|
||||
|
||||
local MAX_JOBS = 100
|
||||
local INSTALL_TIMEOUT = 60000
|
||||
|
||||
--- @async
|
||||
--- @param max_jobs integer
|
||||
--- @param task_funs async.TaskFun[]
|
||||
local function join(max_jobs, task_funs)
|
||||
if #task_funs == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
max_jobs = math.min(max_jobs, #task_funs)
|
||||
|
||||
local remaining = { select(max_jobs + 1, unpack(task_funs)) }
|
||||
local to_go = #task_funs
|
||||
|
||||
a.await(1, function(finish)
|
||||
local function cb()
|
||||
to_go = to_go - 1
|
||||
if to_go == 0 then
|
||||
finish()
|
||||
elseif #remaining > 0 then
|
||||
local next_task = table.remove(remaining)
|
||||
next_task():await(cb)
|
||||
end
|
||||
end
|
||||
|
||||
for i = 1, max_jobs do
|
||||
task_funs[i]():await(cb)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
---@async
|
||||
---@param cmd string[]
|
||||
---@param opts? vim.SystemOpts
|
||||
@@ -33,8 +63,8 @@ local INSTALL_TIMEOUT = 60000
|
||||
local function system(cmd, opts)
|
||||
local cwd = opts and opts.cwd or uv.cwd()
|
||||
log.trace('running job: (cwd=%s) %s', cwd, table.concat(cmd, ' '))
|
||||
local r = a.wrap(vim.system, 3)(cmd, opts) --[[@as vim.SystemCompleted]]
|
||||
a.main()
|
||||
local r = a.await(3, vim.system, cmd, opts) --[[@as vim.SystemCompleted]]
|
||||
a.schedule()
|
||||
if r.stdout and r.stdout ~= '' then
|
||||
log.trace('stdout -> %s', r.stdout)
|
||||
end
|
||||
@@ -190,7 +220,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
|
||||
do -- Create tmp dir
|
||||
logger:debug('Creating temporary directory: %s', tmp)
|
||||
local err = mkpath(tmp)
|
||||
a.main()
|
||||
a.schedule()
|
||||
if err then
|
||||
return logger:error('Could not create %s-tmp: %s', project_name, err)
|
||||
end
|
||||
@@ -211,7 +241,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
|
||||
do -- Remove tarball
|
||||
logger:debug('Removing %s...', tarball_path)
|
||||
local err = uv_unlink(tarball_path)
|
||||
a.main()
|
||||
a.schedule()
|
||||
if err then
|
||||
return logger:error('Could not remove tarball: %s', err)
|
||||
end
|
||||
@@ -223,7 +253,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
|
||||
local extracted = fs.joinpath(tmp, repo_project_name .. '-' .. dir_rev)
|
||||
logger:debug('Moving %s to %s/...', extracted, output_dir)
|
||||
local err = uv_rename(extracted, output_dir)
|
||||
a.main()
|
||||
a.schedule()
|
||||
if err then
|
||||
return logger:error('Could not rename temp: %s', err)
|
||||
end
|
||||
@@ -265,7 +295,7 @@ local function do_install(logger, compile_location, target_location)
|
||||
end
|
||||
|
||||
local err = uv_copyfile(compile_location, target_location)
|
||||
a.main()
|
||||
a.schedule()
|
||||
if err then
|
||||
return logger:error('Error during parser installation: %s', err)
|
||||
end
|
||||
@@ -343,7 +373,7 @@ local function try_install_lang(lang, cache_dir, install_dir, generate)
|
||||
local queries_src = M.get_package_path('runtime', 'queries', lang)
|
||||
uv_unlink(queries)
|
||||
local err = uv_symlink(queries_src, queries, { dir = true, junction = true })
|
||||
a.main()
|
||||
a.schedule()
|
||||
if err then
|
||||
return logger:error(err)
|
||||
end
|
||||
@@ -403,20 +433,20 @@ end
|
||||
---@field max_jobs? integer
|
||||
|
||||
--- Install a parser
|
||||
---@async
|
||||
---@param languages string[]
|
||||
---@param options? InstallOptions
|
||||
---@param callback? fun(boolean)
|
||||
local function install(languages, options, callback)
|
||||
local function install(languages, options)
|
||||
options = options or {}
|
||||
|
||||
local cache_dir = fs.normalize(fn.stdpath('cache'))
|
||||
local install_dir = config.get_install_dir('parser')
|
||||
|
||||
local tasks = {} ---@type fun()[]
|
||||
local task_funs = {} ---@type async.TaskFun[]
|
||||
local done = 0
|
||||
for _, lang in ipairs(languages) do
|
||||
tasks[#tasks + 1] = a.sync(function()
|
||||
a.main()
|
||||
task_funs[#task_funs + 1] = a.async(function()
|
||||
a.schedule()
|
||||
local status = install_lang(lang, cache_dir, install_dir, options.force, options.generate)
|
||||
if status ~= 'failed' then
|
||||
done = done + 1
|
||||
@@ -424,29 +454,24 @@ local function install(languages, options, callback)
|
||||
end)
|
||||
end
|
||||
|
||||
a.join(options and options.max_jobs or MAX_JOBS, nil, tasks)
|
||||
if #tasks > 1 then
|
||||
a.main()
|
||||
log.info('Installed %d/%d languages', done, #tasks)
|
||||
end
|
||||
if callback then
|
||||
callback(done == #tasks)
|
||||
join(options and options.max_jobs or MAX_JOBS, task_funs)
|
||||
if #task_funs > 1 then
|
||||
a.schedule()
|
||||
log.info('Installed %d/%d languages', done, #task_funs)
|
||||
end
|
||||
return done == #task_funs
|
||||
end
|
||||
|
||||
---@param languages string[]|string
|
||||
---@param options? InstallOptions
|
||||
---@param callback? fun(boolean)
|
||||
M.install = a.sync(function(languages, options, callback)
|
||||
M.install = a.async(function(languages, options)
|
||||
reload_parsers()
|
||||
languages = config.norm_languages(languages, { unsupported = true })
|
||||
install(languages, options, callback)
|
||||
end, 3)
|
||||
return install(languages, options)
|
||||
end)
|
||||
|
||||
---@param languages? string[]|string
|
||||
---@param _options? table
|
||||
---@param callback? function
|
||||
M.update = a.sync(function(languages, _options, callback)
|
||||
M.update = a.async(function(languages)
|
||||
reload_parsers()
|
||||
if not languages or #languages == 0 then
|
||||
languages = 'all'
|
||||
@@ -455,14 +480,12 @@ M.update = a.sync(function(languages, _options, callback)
|
||||
languages = vim.tbl_filter(needs_update, languages) ---@type string[]
|
||||
|
||||
if #languages > 0 then
|
||||
install(languages, { force = true }, callback)
|
||||
return install(languages, { force = true })
|
||||
else
|
||||
log.info('All parsers are up-to-date')
|
||||
if callback then
|
||||
callback(true)
|
||||
return true
|
||||
end
|
||||
end
|
||||
end, 3)
|
||||
end)
|
||||
|
||||
---@async
|
||||
---@param logger Logger
|
||||
@@ -477,7 +500,7 @@ local function uninstall_lang(logger, lang, parser, queries)
|
||||
if fn.filereadable(parser) == 1 then
|
||||
logger:debug('Unlinking ' .. parser)
|
||||
local perr = uv_unlink(parser)
|
||||
a.main()
|
||||
a.schedule()
|
||||
|
||||
if perr then
|
||||
return logger:error(perr)
|
||||
@@ -487,7 +510,7 @@ local function uninstall_lang(logger, lang, parser, queries)
|
||||
if fn.isdirectory(queries) == 1 then
|
||||
logger:debug('Unlinking ' .. queries)
|
||||
local qerr = uv_unlink(queries)
|
||||
a.main()
|
||||
a.schedule()
|
||||
|
||||
if qerr then
|
||||
return logger:error(qerr)
|
||||
@@ -498,16 +521,14 @@ local function uninstall_lang(logger, lang, parser, queries)
|
||||
end
|
||||
|
||||
---@param languages string[]|string
|
||||
---@param _options? table
|
||||
---@param _callback? fun()
|
||||
M.uninstall = a.sync(function(languages, _options, _callback)
|
||||
M.uninstall = a.async(function(languages)
|
||||
languages = config.norm_languages(languages or 'all', { missing = true, dependencies = true })
|
||||
|
||||
local parser_dir = config.get_install_dir('parser')
|
||||
local query_dir = config.get_install_dir('queries')
|
||||
local installed = config.installed_parsers()
|
||||
|
||||
local tasks = {} ---@type fun()[]
|
||||
local task_funs = {} ---@type async.TaskFun[]
|
||||
local done = 0
|
||||
for _, lang in ipairs(languages) do
|
||||
local logger = log.new('uninstall/' .. lang)
|
||||
@@ -516,7 +537,7 @@ M.uninstall = a.sync(function(languages, _options, _callback)
|
||||
else
|
||||
local parser = fs.joinpath(parser_dir, lang) .. '.so'
|
||||
local queries = fs.joinpath(query_dir, lang)
|
||||
tasks[#tasks + 1] = a.sync(function()
|
||||
task_funs[#task_funs + 1] = a.async(function()
|
||||
local err = uninstall_lang(logger, lang, parser, queries)
|
||||
if not err then
|
||||
done = done + 1
|
||||
@@ -525,11 +546,11 @@ M.uninstall = a.sync(function(languages, _options, _callback)
|
||||
end
|
||||
end
|
||||
|
||||
a.join(MAX_JOBS, nil, tasks)
|
||||
if #tasks > 1 then
|
||||
a.main()
|
||||
log.info('Uninstalled %d/%d languages', done, #tasks)
|
||||
join(MAX_JOBS, task_funs)
|
||||
if #task_funs > 1 then
|
||||
a.schedule()
|
||||
log.info('Uninstalled %d/%d languages', done, #task_funs)
|
||||
end
|
||||
end, 2)
|
||||
end)
|
||||
|
||||
return M
|
||||
|
||||
@@ -21,24 +21,17 @@ vim.opt.runtimepath:append('.')
|
||||
-- needed on CI
|
||||
vim.fn.mkdir(vim.fn.stdpath('cache'), 'p')
|
||||
|
||||
local ok = nil
|
||||
if update then
|
||||
require('nvim-treesitter.install').update('all', {}, function(success)
|
||||
ok = success
|
||||
end)
|
||||
else
|
||||
require('nvim-treesitter.install').install(
|
||||
---@type async.Task
|
||||
local task = update and require('nvim-treesitter.install').update('all')
|
||||
or require('nvim-treesitter.install').install(
|
||||
#parsers > 0 and parsers or 'all',
|
||||
{ force = true, generate = generate, max_jobs = max_jobs },
|
||||
function(success)
|
||||
ok = success
|
||||
end
|
||||
{ force = true, generate = generate, max_jobs = max_jobs }
|
||||
)
|
||||
end
|
||||
|
||||
vim.wait(6000000, function()
|
||||
return ok ~= nil
|
||||
end)
|
||||
local ok, err_or_ok = task:pwait(1800000) -- wait max. 30 minutes
|
||||
if not ok then
|
||||
print('ERROR: ', err_or_ok)
|
||||
vim.cmd.cq()
|
||||
elseif not err_or_ok then
|
||||
vim.cmd.cq()
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user