feat(install)!: migrate to latest async.nvim impl (#7856)

Provides significantly simpler blocking installation and update.
This commit is contained in:
Lewis Russell
2025-05-16 15:44:26 +01:00
committed by Christian Clason
parent 7a4a35de3e
commit 69371f0148
7 changed files with 788 additions and 175 deletions

7
.gitattributes vendored
View File

@@ -1,3 +1,4 @@
runtime/queries/**/*.scm linguist-language=Tree-sitter-Query
doc/*.txt linguist-documentation
SUPPORTED_LANGUAGES.md linguist-generated
runtime/queries/**/*.scm linguist-language=tsq
doc/*.txt linguist-documentation
SUPPORTED_LANGUAGES.md linguist-generated
lua/nvim-treesitter/async.lua linguist-vendored

View File

@@ -68,13 +68,9 @@ Parsers and queries can then be installed with
require'nvim-treesitter'.install { 'rust', 'javascript', 'zig' }
```
(This is a no-op if the parsers are already installed.) Note that this function runs asynchronously; for synchronous installation in a script context ("bootstrapping"), use something like
(This is a no-op if the parsers are already installed.) Note that this function runs asynchronously; for synchronous installation in a script context ("bootstrapping"), you need to `wait()` for it to finish:
```lua
local done = nil
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' }, function(success)
done = success
end)
vim.wait(3000000, function() return done ~= nil end)
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' }):wait(300000) -- wait max. 5 minutes
```
Check [`:h nvim-treesitter-commands`](doc/nvim-treesitter.txt) for a list of all available commands.

View File

@@ -4,7 +4,6 @@ This document lists the planned and finished changes in this rewrite towards [Nv
## TODO
- [ ] **`install.lua`:** migrate to async v2
- [ ] **tests:** remove custom crate, plenary dependency
- [ ] **indents:** rewrite (Helix or Zed compatible)
- [ ] **textobjects:** include simple(!) `node`, `scope` (using `locals`) objects
@@ -26,3 +25,4 @@ This document lists the planned and finished changes in this rewrite towards [Nv
- [X] drop ensure_install (replace with install)
- [X] **CI:** switch to ts_query_ls, add update readme as check (remove update job)
- [X] **CI:** track versioned releases for tier 1
- [X] **`install.lua`:** migrate to async v2

View File

@@ -102,7 +102,7 @@ setup({opts}) *nvim-treesitter.setup()*
directory to install parsers and queries to. Note: will be
appended to |runtimepath|.
install({languages}, {opts}, {callback}) *nvim-treesitter.install()*
install({languages} [, {opts}]) *nvim-treesitter.install()*
Download, compile, and install the specified treesitter parsers and copy
the corresponding queries to a directory on |runtimepath|, enabling their
@@ -110,13 +110,9 @@ install({languages}, {opts}, {callback}) *nvim-treesitter.install()*
Note: This operation is performed asynchronously by default. For
synchronous operation (e.g., in a bootstrapping script), you need to
provide a suitable {callback}: >lua
local done = nil
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' },
function(success)
done = success
end)
vim.wait(3000000, function() return done ~= nil end)
`wait()` for it: >lua
require('nvim-treesitter').install({ 'rust', 'javascript', 'zig' })
:wait(300000) -- max. 5 minutes
<
Parameters: ~
• {languages} `(string[]|string)` (List of) languages or tiers (`stable`,
@@ -129,7 +125,6 @@ install({languages}, {opts}, {callback}) *nvim-treesitter.install()*
compiling.
• {max_jobs} (`integer?`) limit parallel tasks (useful in
combination with {generate} on memory-limited systems).
• {callback} `(function?`) Callback for synchronous execution.
uninstall({languages}) *nvim-treesitter.uninstall()*
@@ -139,25 +134,19 @@ uninstall({languages}) *nvim-treesitter.uninstall()
• {languages} `(string[]|string)` (List of) languages or tiers (`stable`,
`unstable`) to update.
update({languages}, {callback}) *nvim-treesitter.update()*
update([{languages}]) *nvim-treesitter.update()*
Update the parsers and queries if older than the revision specified in the
manifest.
Note: This operation is performed asynchronously by default. For
synchronous operation (e.g., in a bootstrapping script), you need to
provide a suitable {callback}: >lua
local done = nil
require('nvim-treesitter').update(),
function(success)
done = success
end)
vim.wait(3000000, function() return done ~= nil end)
`wait()` for it: >lua
require('nvim-treesitter').update():wait(300000) -- max. 5 minutes
<
Parameters: ~
• {languages} `(string[]|string)` (List of) languages or tiers to
uninstall.
• {callback} `(function?`) Callback for synchronous execution.
• {languages} `(string[]|string)?` (List of) languages or tiers to update
(default: all installed).
indentexpr() *nvim-treesitter.indentexpr()*

View File

@@ -1,112 +1,725 @@
local co = coroutine
local pcall = copcall or pcall
--- @param ... any
--- @return {[integer]: any, n: integer}
local function pack_len(...)
return { n = select('#', ...), ... }
end
--- like unpack() but use the length set by F.pack_len if present
--- @param t? { [integer]: any, n?: integer }
--- @param first? integer
--- @return ...any
local function unpack_len(t, first)
if t then
return unpack(t, first or 1, t.n or table.maxn(t))
end
end
--- @class async
local M = {}
---Executes a future with a callback when it is done
---@param func function
---@param callback function
---@param ... unknown
local function execute(func, callback, ...)
local thread = co.create(func)
--- Weak table to keep track of running tasks
--- @type table<thread,async.Task?>
local threads = setmetatable({}, { __mode = 'k' })
local function step(...)
local ret = { co.resume(thread, ...) }
---@type boolean, any
local stat, nargs_or_err = unpack(ret)
--- @return async.Task?
local function running()
local task = threads[coroutine.running()]
if task and not (task:_completed() or task._closing) then
return task
end
end
if not stat then
error(
string.format(
'The coroutine failed with this message: %s\n%s',
nargs_or_err,
debug.traceback(thread)
)
)
end
--- Base class for async tasks. Async functions should return a subclass of
--- this. This is designed specifically to be a base class of uv_handle_t
--- @class async.Handle
--- @field close fun(self: async.Handle, callback?: fun())
--- @field is_closing? fun(self: async.Handle): boolean
if co.status(thread) == 'dead' then
if callback then
callback(unpack(ret, 3, table.maxn(ret)))
end
return
end
--- @alias async.CallbackFn fun(...: any): async.Handle?
---@type function, any[]
local fn, args = ret[3], { unpack(ret, 4, table.maxn(ret)) }
args[nargs_or_err] = step
fn(unpack(args, 1, nargs_or_err))
--- @class async.Task : async.Handle
--- @field package _callbacks table<integer,fun(err?: any, ...: any)>
--- @field package _callback_pos integer
--- @field private _thread thread
---
--- Tasks can call other async functions (task of callback functions)
--- when we are waiting on a child, we store the handle to it here so we can
--- cancel it.
--- @field private _current_child? async.Handle
---
--- Error result of the task is an error occurs.
--- Must use `await` to get the result.
--- @field private _err? any
---
--- Result of the task.
--- Must use `await` to get the result.
--- @field private _result? any[]
local Task = {}
Task.__index = Task
--- @private
--- @param func function
--- @return async.Task
function Task._new(func)
local thread = coroutine.create(func)
local self = setmetatable({
_closing = false,
_thread = thread,
_callbacks = {},
_callback_pos = 1,
}, Task)
threads[thread] = self
return self
end
--- @param callback fun(err?: any, ...: any)
function Task:await(callback)
if self._closing then
callback('closing')
elseif self:_completed() then -- TODO(lewis6991): test
-- Already finished or closed
callback(self._err, unpack_len(self._result))
else
self._callbacks[self._callback_pos] = callback
self._callback_pos = self._callback_pos + 1
end
end
--- @package
function Task:_completed()
return (self._err or self._result) ~= nil
end
-- Use max 32-bit signed int value to avoid overflow on 32-bit systems.
-- Do not use `math.huge` as it is not interpreted as a positive integer on all
-- platforms.
local MAX_TIMEOUT = 2 ^ 31 - 1
--- Synchronously wait (protected) for a task to finish (blocking)
---
--- If an error is returned, `Task:traceback()` can be used to get the
--- stack trace of the error.
---
--- Example:
--- ```lua
---
--- local ok, err_or_result = task:pwait(10)
---
--- if not ok then
--- error(task:traceback(err_or_result))
--- end
---
--- local _, result = assert(task:pwait(10))
--- ```
---
--- Can be called if a task is closing.
--- @param timeout? integer
--- @return boolean status
--- @return any ... result or error
function Task:pwait(timeout)
local done = vim.wait(timeout or MAX_TIMEOUT, function()
-- Note we use self:_completed() instead of self:await() to avoid creating a
-- callback. This avoids having to cleanup/unregister any callback in the
-- case of a timeout.
return self:_completed()
end)
if not done then
return false, 'timeout'
elseif self._err then
return false, self._err
else
return true, unpack_len(self._result)
end
end
--- Synchronously wait for a task to finish (blocking)
---
--- Example:
--- ```lua
--- local result = task:wait(10) -- wait for 10ms or else error
---
--- local result = task:wait() -- wait indefinitely
--- ```
--- @param timeout? integer Timeout in milliseconds
--- @return any ... result
function Task:wait(timeout)
local res = pack_len(self:pwait(timeout))
local stat = res[1]
if not stat then
error(self:traceback(res[2]))
end
step(...)
return unpack_len(res, 2)
end
--- @private
--- @param msg? string
--- @param _lvl? integer
--- @return string
function Task:_traceback(msg, _lvl)
_lvl = _lvl or 0
local thread = ('[%s] '):format(self._thread)
local child = self._current_child
if getmetatable(child) == Task then
--- @cast child async.Task
msg = child:_traceback(msg, _lvl + 1)
end
local tblvl = getmetatable(child) == Task and 2 or nil
msg = (msg or '') .. debug.traceback(self._thread, '', tblvl):gsub('\n\t', '\n\t' .. thread)
if _lvl == 0 then
--- @type string
msg = msg
:gsub('\nstack traceback:\n', '\nSTACK TRACEBACK:\n', 1)
:gsub('\nstack traceback:\n', '\n')
:gsub('\nSTACK TRACEBACK:\n', '\nstack traceback:\n', 1)
end
return msg
end
--- Get the traceback of a task when it is not active.
--- Will also get the traceback of nested tasks.
---
--- @param msg? string
--- @return string
function Task:traceback(msg)
return self:_traceback(msg)
end
--- If a task completes with an error, raise the error
function Task:raise_on_error()
self:await(function(err)
if err then
error(self:_traceback(err), 0)
end
end)
return self
end
--- @private
--- @param err? any
--- @param result? {[integer]: any, n: integer}
function Task:_finish(err, result)
self._current_child = nil
self._err = err
self._result = result
threads[self._thread] = nil
local errs = {} --- @type string[]
for _, cb in pairs(self._callbacks) do
--- @type boolean, string
local ok, cb_err = pcall(cb, err, unpack_len(result))
if not ok then
errs[#errs + 1] = cb_err
end
end
if #errs > 0 then
error(table.concat(errs, '\n'), 0)
end
end
--- @return boolean
function Task:is_closing()
return self._closing
end
--- Close the task and all its children.
--- If callback is provided it will run asynchronously,
--- else it will run synchronously.
---
--- @param callback? fun()
function Task:close(callback)
if self:_completed() then
if callback then
callback()
end
return
end
if self._closing then
return
end
self._closing = true
if callback then -- async
if self._current_child then
self._current_child:close(function()
self:_finish('closed')
callback()
end)
else
self:_finish('closed')
callback()
end
else -- sync
if self._current_child then
self._current_child:close(function()
self:_finish('closed')
end)
else
self:_finish('closed')
end
vim.wait(0, function()
return self:_completed()
end)
end
end
--- @param obj any
--- @return boolean
local function is_async_handle(obj)
local ty = type(obj)
return (ty == 'table' or ty == 'userdata') and vim.is_callable(obj.close)
end
--- @param ... any
function Task:_resume(...)
--- @type [boolean, string|async.CallbackFn]
local ret = pack_len(coroutine.resume(self._thread, ...))
local stat = ret[1]
if not stat then
-- Coroutine had error
self:_finish(ret[2])
elseif coroutine.status(self._thread) == 'dead' then
-- Coroutine finished
local result = pack_len(unpack_len(ret, 2))
self:_finish(nil, result)
else
local fn = ret[2]
--- @cast fn -string
-- TODO(lewis6991): refine error handler to be more specific
local ok, r
ok, r = pcall(fn, function(...)
if is_async_handle(r) then
--- @cast r async.Handle
-- We must close children before we resume to ensure
-- all resources are collected.
local args = pack_len(...)
r:close(function()
self:_resume(unpack_len(args))
end)
else
self:_resume(...)
end
end)
if not ok then
self:_finish(r)
elseif is_async_handle(r) then
self._current_child = r
end
end
end
--- @return 'running'|'suspended'|'normal'|'dead'?
function Task:status()
return coroutine.status(self._thread)
end
--- Run a function in an async context, asynchronously.
---
--- Examples:
--- ```lua
--- -- The two below blocks are equivalent:
---
--- -- Run a uv function and wait for it
--- local stat = async.arun(function()
--- return async.await(2, vim.uv.fs_stat, 'foo.txt')
--- end):wait()
---
--- -- Since uv functions have sync versions. You can just do:
--- local stat = vim.fs_stat('foo.txt')
--- ```
--- @param func function
--- @param ... any
--- @return async.Task
function M.arun(func, ...)
local task = Task._new(func)
task:_resume(...)
return task
end
--- @class async.TaskFun
--- @field package _fun fun(...: any): any
--- @operator call(...): any
local TaskFun = {}
TaskFun.__index = TaskFun
function TaskFun:__call(...)
return M.arun(self._fun, ...)
end
--- Create an async function
--- @param fun function
--- @return async.TaskFun
function M.async(fun)
return setmetatable({ _fun = fun }, TaskFun)
end
--- Returns the status of a tasks thread.
---
--- @param task? async.Task
--- @return 'running'|'suspended'|'normal'|'dead'?
function M.status(task)
task = task or running()
if task then
assert(getmetatable(task) == Task, 'Expected Task')
return task:status()
end
end
--- @async
--- @generic R1, R2, R3, R4
--- @param fun fun(callback: fun(r1: R1, r2: R2, r3: R3, r4: R4)): any?
--- @return R1, R2, R3, R4
local function yield(fun)
assert(type(fun) == 'function', 'Expected function')
return coroutine.yield(fun)
end
--- @async
--- @param task async.Task
--- @return any ...
local function await_task(task)
--- @param callback fun(err?: string, ...: any)
--- @return function
local res = pack_len(yield(function(callback)
task:await(callback)
return task
end))
local err = res[1]
if err then
-- TODO(lewis6991): what is the correct level to pass?
error(err, 0)
end
return unpack_len(res, 2)
end
--- Asynchronous blocking wait
--- @param argc integer
--- @param fun async.CallbackFn
--- @param ... any func arguments
--- @return any ...
local function await_cbfun(argc, fun, ...)
local args = pack_len(...)
--- @param callback fun(...:any)
--- @return any?
return yield(function(callback)
args[argc] = callback
args.n = math.max(args.n, argc)
return fun(unpack_len(args))
end)
end
--- @param taskfun async.TaskFun
--- @param ... any
--- @return any ...
local function await_taskfun(taskfun, ...)
return taskfun._fun(...)
end
--- Asynchronous blocking wait
---
--- Example:
--- ```lua
--- local task = async.arun(function()
--- return 1, 'a'
--- end)
---
--- local task_fun = async.async(function(arg)
--- return 2, 'b', arg
--- end)
---
--- async.arun(function()
--- do -- await a callback function
--- async.await(1, vim.schedule)
--- end
---
--- do -- await a task (new async context)
--- local n, s = async.await(task)
--- assert(n == 1 and s == 'a')
--- end
---
--- do -- await a started task function (new async context)
--- local n, s, arg = async.await(task_fun('A'))
--- assert(n == 2)
--- assert(s == 'b')
--- assert(args == 'A')
--- end
---
--- do -- await a task function (re-using the current async context)
--- local n, s, arg = async.await(task_fun, 'B')
--- assert(n == 2)
--- assert(s == 'b')
--- assert(args == 'B')
--- end
--- end)
--- ```
--- @async
--- @overload fun(argc: integer, func: async.CallbackFn, ...:any): any ...
--- @overload fun(task: async.Task): any ...
--- @overload fun(taskfun: async.TaskFun): any ...
function M.await(...)
assert(running(), 'Not in async context')
local arg1 = select(1, ...)
if type(arg1) == 'number' then
return await_cbfun(...)
elseif getmetatable(arg1) == Task then
return await_task(...)
elseif getmetatable(arg1) == TaskFun then
return await_taskfun(...)
end
error('Invalid arguments, expected Task or (argc, func) got: ' .. type(arg1), 2)
end
--- Creates an async function with a callback style function.
---@generic F: function
---@param func F
---@param argc integer
---@return F
function M.wrap(func, argc)
vim.validate('func', func, 'function')
vim.validate('argc', argc, 'number')
---@param ... unknown
---@return unknown
---
--- Example:
---
--- ```lua
--- --- Note the callback argument is not present in the return function
--- --- @type fun(timeout: integer)
--- local sleep = async.awrap(2, function(timeout, callback)
--- local timer = vim.uv.new_timer()
--- timer:start(timeout * 1000, 0, callback)
--- -- uv_timer_t provides a close method so timer will be
--- -- cleaned up when this function finishes
--- return timer
--- end)
---
--- async.arun(function()
--- print('hello')
--- sleep(2)
--- print('world')
--- end)
--- ```
---
--- local atimer = async.awrap(
--- @param argc integer
--- @param func async.CallbackFn
--- @return async function
function M.awrap(argc, func)
assert(type(argc) == 'number')
assert(type(func) == 'function')
--- @async
return function(...)
return co.yield(argc, func, ...)
return M.await(argc, func, ...)
end
end
---Use this to create a function which executes in an async context but
---called from a non-async context. Inherently this cannot return anything
---since it is non-blocking
---@generic F: function
---@param func async F
---@param nargs? integer
---@return F
function M.sync(func, nargs)
nargs = nargs or 0
return function(...)
local callback = select(nargs + 1, ...)
execute(func, callback, unpack({ ... }, 1, nargs))
end
if vim.schedule then
--- An async function that when called will yield to the Neovim scheduler to be
--- able to call the API.
M.schedule = M.awrap(1, vim.schedule)
end
---@param n integer max number of concurrent jobs
---@param interrupt_check? function
---@param thunks function[]
---@return any
function M.join(n, interrupt_check, thunks)
return co.yield(1, function(finish)
if #thunks == 0 then
return finish()
--- Create a function that runs a function when it is garbage collected.
--- @generic F
--- @param f F
--- @param gc fun()
--- @return F
local function gc_fun(f, gc)
local proxy = newproxy(true)
local proxy_mt = getmetatable(proxy)
proxy_mt.__gc = gc
proxy_mt.__call = function(_, ...)
return f(...)
end
return proxy
end
--- @param task_cbs table<async.Task,function>
local function gc_cbs(task_cbs)
for task, tcb in pairs(task_cbs) do
for j, cb in pairs(task._callbacks) do
if cb == tcb then
task._callbacks[j] = nil
break
end
end
end
end
local remaining = { select(n + 1, unpack(thunks)) }
local to_go = #thunks
--- @async
--- Example:
--- ```lua
--- local task1 = async.arun(function()
--- return 1, 'a'
--- end)
---
--- local task2 = async.arun(function()
--- return 1, 'a'
--- end)
---
--- local task3 = async.arun(function()
--- error('task3 error')
--- end)
---
--- async.arun(function()
--- for i, err, r1, r2 in async.iter({task1, task2, task3})
--- print(i, err, r1, r2)
--- end
--- end)
--- ```
---
--- Prints:
--- ```
--- 1 nil 1 'a'
--- 2 nil 2 'b'
--- 3 'task3 error' nil nil
--- ```
---
--- @param tasks async.Task[]
--- @return fun(): (integer?, any?, ...)
function M.iter(tasks)
assert(running(), 'Not in async context')
local ret = {} ---@type any[]
local results = {} --- @type [integer, any, ...][]
local function cb(...)
ret[#ret + 1] = { ... }
to_go = to_go - 1
if to_go == 0 then
finish(ret)
elseif not interrupt_check or not interrupt_check() then
if #remaining > 0 then
local next_task = table.remove(remaining)
next_task(cb)
end
-- Iter blocks in an async context so only one waiter is needed
local waiter = nil
local task_cbs = {} --- @type table<async.Task,function>
local remaining = #tasks
--- If can_gc_cbs is true, then the iterator function has been garbage
--- collected and means any awaiters can also be garbage collected. The
--- only time we can't do this is if with the special case when iter() is
--- called anonymously (`local i = async.iter(tasks)()`), so we should not
--- garbage collect the callbacks until at least one awaiter is called.
local can_gc_cbs = false
for i, task in ipairs(tasks) do
local function cb(err, ...)
if can_gc_cbs == true then
gc_cbs(task_cbs)
end
local callback = waiter
-- Clear waiter before calling it
waiter = nil
remaining = remaining - 1
if callback then
-- Iterator is waiting, yield to it
callback(i, err, ...)
else
-- Task finished before Iterator was called. Store results.
table.insert(results, pack_len(i, err, ...))
end
end
for i = 1, math.min(n, #thunks) do
thunks[i](cb)
task_cbs[task] = cb
task:await(cb)
end
return gc_fun(
M.awrap(1, function(callback)
if next(results) then
local res = table.remove(results, 1)
callback(unpack_len(res))
elseif remaining == 0 then
callback() -- finish
else
assert(not waiter, 'internal error: waiter already set')
waiter = callback
end
end),
function()
-- Don't gc callbacks just yet. Wait until at least one of them is called.
can_gc_cbs = true
end
end, 1)
)
end
---An async function that when called will yield to the Neovim scheduler to be
---able to call the API.
---@type fun()
M.main = M.wrap(vim.schedule, 1)
do -- join()
--- @param results table<integer,table>
--- @param i integer
--- @param ... any
--- @return boolean
local function collect(results, i, ...)
if i then
results[i] = pack_len(...)
end
return i ~= nil
end
--- @param iter fun(): ...
--- @return table<integer,table>
local function drain_iter(iter)
local results = {} --- @type table<integer,table>
while collect(results, iter()) do
end
return results
end
--- @async
--- Wait for all tasks to finish and return their results.
---
--- Example:
--- ```lua
--- local task1 = async.arun(function()
--- return 1, 'a'
--- end)
---
--- local task2 = async.arun(function()
--- return 1, 'a'
--- end)
---
--- local task3 = async.arun(function()
--- error('task3 error')
--- end)
---
--- async.arun(function()
--- local results = async.join({task1, task2, task3})
--- print(vim.inspect(results))
--- end)
--- ```
---
--- Prints:
--- ```
--- {
--- [1] = { nil, 1, 'a' },
--- [2] = { nil, 2, 'b' },
--- [3] = { 'task2 error' },
--- }
--- ```
--- @param tasks async.Task[]
--- @return table<integer,[any?,...?]>
function M.join(tasks)
assert(running(), 'Not in async context')
return drain_iter(M.iter(tasks))
end
--- @async
--- @param tasks async.Task[]
--- @return integer?, any?, ...?
function M.joinany(tasks)
return M.iter(tasks)()
end
end
return M

View File

@@ -9,23 +9,53 @@ local parsers = require('nvim-treesitter.parsers')
local util = require('nvim-treesitter.util')
---@type fun(path: string, new_path: string, flags?: table): string?
local uv_copyfile = a.wrap(uv.fs_copyfile, 4)
local uv_copyfile = a.awrap(4, uv.fs_copyfile)
---@type fun(path: string, mode: integer): string?
local uv_mkdir = a.wrap(uv.fs_mkdir, 3)
local uv_mkdir = a.awrap(3, uv.fs_mkdir)
---@type fun(path: string, new_path: string): string?
local uv_rename = a.wrap(uv.fs_rename, 3)
local uv_rename = a.awrap(3, uv.fs_rename)
---@type fun(path: string, new_path: string, flags?: table): string?
local uv_symlink = a.wrap(uv.fs_symlink, 4)
local uv_symlink = a.awrap(4, uv.fs_symlink)
---@type fun(path: string): string?
local uv_unlink = a.wrap(uv.fs_unlink, 2)
local uv_unlink = a.awrap(2, uv.fs_unlink)
local MAX_JOBS = 100
local INSTALL_TIMEOUT = 60000
--- @async
--- @param max_jobs integer
--- @param task_funs async.TaskFun[]
local function join(max_jobs, task_funs)
if #task_funs == 0 then
return
end
max_jobs = math.min(max_jobs, #task_funs)
local remaining = { select(max_jobs + 1, unpack(task_funs)) }
local to_go = #task_funs
a.await(1, function(finish)
local function cb()
to_go = to_go - 1
if to_go == 0 then
finish()
elseif #remaining > 0 then
local next_task = table.remove(remaining)
next_task():await(cb)
end
end
for i = 1, max_jobs do
task_funs[i]():await(cb)
end
end)
end
---@async
---@param cmd string[]
---@param opts? vim.SystemOpts
@@ -33,8 +63,8 @@ local INSTALL_TIMEOUT = 60000
local function system(cmd, opts)
local cwd = opts and opts.cwd or uv.cwd()
log.trace('running job: (cwd=%s) %s', cwd, table.concat(cmd, ' '))
local r = a.wrap(vim.system, 3)(cmd, opts) --[[@as vim.SystemCompleted]]
a.main()
local r = a.await(3, vim.system, cmd, opts) --[[@as vim.SystemCompleted]]
a.schedule()
if r.stdout and r.stdout ~= '' then
log.trace('stdout -> %s', r.stdout)
end
@@ -190,7 +220,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
do -- Create tmp dir
logger:debug('Creating temporary directory: %s', tmp)
local err = mkpath(tmp)
a.main()
a.schedule()
if err then
return logger:error('Could not create %s-tmp: %s', project_name, err)
end
@@ -211,7 +241,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
do -- Remove tarball
logger:debug('Removing %s...', tarball_path)
local err = uv_unlink(tarball_path)
a.main()
a.schedule()
if err then
return logger:error('Could not remove tarball: %s', err)
end
@@ -223,7 +253,7 @@ local function do_download(logger, url, project_name, cache_dir, revision, outpu
local extracted = fs.joinpath(tmp, repo_project_name .. '-' .. dir_rev)
logger:debug('Moving %s to %s/...', extracted, output_dir)
local err = uv_rename(extracted, output_dir)
a.main()
a.schedule()
if err then
return logger:error('Could not rename temp: %s', err)
end
@@ -265,7 +295,7 @@ local function do_install(logger, compile_location, target_location)
end
local err = uv_copyfile(compile_location, target_location)
a.main()
a.schedule()
if err then
return logger:error('Error during parser installation: %s', err)
end
@@ -343,7 +373,7 @@ local function try_install_lang(lang, cache_dir, install_dir, generate)
local queries_src = M.get_package_path('runtime', 'queries', lang)
uv_unlink(queries)
local err = uv_symlink(queries_src, queries, { dir = true, junction = true })
a.main()
a.schedule()
if err then
return logger:error(err)
end
@@ -403,20 +433,20 @@ end
---@field max_jobs? integer
--- Install a parser
---@async
---@param languages string[]
---@param options? InstallOptions
---@param callback? fun(boolean)
local function install(languages, options, callback)
local function install(languages, options)
options = options or {}
local cache_dir = fs.normalize(fn.stdpath('cache'))
local install_dir = config.get_install_dir('parser')
local tasks = {} ---@type fun()[]
local task_funs = {} ---@type async.TaskFun[]
local done = 0
for _, lang in ipairs(languages) do
tasks[#tasks + 1] = a.sync(function()
a.main()
task_funs[#task_funs + 1] = a.async(function()
a.schedule()
local status = install_lang(lang, cache_dir, install_dir, options.force, options.generate)
if status ~= 'failed' then
done = done + 1
@@ -424,29 +454,24 @@ local function install(languages, options, callback)
end)
end
a.join(options and options.max_jobs or MAX_JOBS, nil, tasks)
if #tasks > 1 then
a.main()
log.info('Installed %d/%d languages', done, #tasks)
end
if callback then
callback(done == #tasks)
join(options and options.max_jobs or MAX_JOBS, task_funs)
if #task_funs > 1 then
a.schedule()
log.info('Installed %d/%d languages', done, #task_funs)
end
return done == #task_funs
end
---@param languages string[]|string
---@param options? InstallOptions
---@param callback? fun(boolean)
M.install = a.sync(function(languages, options, callback)
M.install = a.async(function(languages, options)
reload_parsers()
languages = config.norm_languages(languages, { unsupported = true })
install(languages, options, callback)
end, 3)
return install(languages, options)
end)
---@param languages? string[]|string
---@param _options? table
---@param callback? function
M.update = a.sync(function(languages, _options, callback)
M.update = a.async(function(languages)
reload_parsers()
if not languages or #languages == 0 then
languages = 'all'
@@ -455,14 +480,12 @@ M.update = a.sync(function(languages, _options, callback)
languages = vim.tbl_filter(needs_update, languages) ---@type string[]
if #languages > 0 then
install(languages, { force = true }, callback)
return install(languages, { force = true })
else
log.info('All parsers are up-to-date')
if callback then
callback(true)
end
return true
end
end, 3)
end)
---@async
---@param logger Logger
@@ -477,7 +500,7 @@ local function uninstall_lang(logger, lang, parser, queries)
if fn.filereadable(parser) == 1 then
logger:debug('Unlinking ' .. parser)
local perr = uv_unlink(parser)
a.main()
a.schedule()
if perr then
return logger:error(perr)
@@ -487,7 +510,7 @@ local function uninstall_lang(logger, lang, parser, queries)
if fn.isdirectory(queries) == 1 then
logger:debug('Unlinking ' .. queries)
local qerr = uv_unlink(queries)
a.main()
a.schedule()
if qerr then
return logger:error(qerr)
@@ -498,16 +521,14 @@ local function uninstall_lang(logger, lang, parser, queries)
end
---@param languages string[]|string
---@param _options? table
---@param _callback? fun()
M.uninstall = a.sync(function(languages, _options, _callback)
M.uninstall = a.async(function(languages)
languages = config.norm_languages(languages or 'all', { missing = true, dependencies = true })
local parser_dir = config.get_install_dir('parser')
local query_dir = config.get_install_dir('queries')
local installed = config.installed_parsers()
local tasks = {} ---@type fun()[]
local task_funs = {} ---@type async.TaskFun[]
local done = 0
for _, lang in ipairs(languages) do
local logger = log.new('uninstall/' .. lang)
@@ -516,7 +537,7 @@ M.uninstall = a.sync(function(languages, _options, _callback)
else
local parser = fs.joinpath(parser_dir, lang) .. '.so'
local queries = fs.joinpath(query_dir, lang)
tasks[#tasks + 1] = a.sync(function()
task_funs[#task_funs + 1] = a.async(function()
local err = uninstall_lang(logger, lang, parser, queries)
if not err then
done = done + 1
@@ -525,11 +546,11 @@ M.uninstall = a.sync(function(languages, _options, _callback)
end
end
a.join(MAX_JOBS, nil, tasks)
if #tasks > 1 then
a.main()
log.info('Uninstalled %d/%d languages', done, #tasks)
join(MAX_JOBS, task_funs)
if #task_funs > 1 then
a.schedule()
log.info('Uninstalled %d/%d languages', done, #task_funs)
end
end, 2)
end)
return M

View File

@@ -21,24 +21,17 @@ vim.opt.runtimepath:append('.')
-- needed on CI
vim.fn.mkdir(vim.fn.stdpath('cache'), 'p')
local ok = nil
if update then
require('nvim-treesitter.install').update('all', {}, function(success)
ok = success
end)
else
require('nvim-treesitter.install').install(
---@type async.Task
local task = update and require('nvim-treesitter.install').update('all')
or require('nvim-treesitter.install').install(
#parsers > 0 and parsers or 'all',
{ force = true, generate = generate, max_jobs = max_jobs },
function(success)
ok = success
end
{ force = true, generate = generate, max_jobs = max_jobs }
)
end
vim.wait(6000000, function()
return ok ~= nil
end)
local ok, err_or_ok = task:pwait(1800000) -- wait max. 30 minutes
if not ok then
print('ERROR: ', err_or_ok)
vim.cmd.cq()
elseif not err_or_ok then
vim.cmd.cq()
end