Source code for tabulous.threading

from __future__ import annotations

from typing import Any, Callable, Generator, overload, TypeVar
from typing_extensions import ParamSpec
import inspect
from functools import lru_cache
from superqt.utils import (
    thread_worker as _thread_worker,
    GeneratorWorker,
    FunctionWorker,
)

__all__ = ["thread_worker"]

_Y = TypeVar("_Y")
_S = TypeVar("_S")
_R = TypeVar("_R")
_P = ParamSpec("_P")


@overload
def thread_worker(
    function: Callable[_P, Generator[_Y, _S, _R]],
    *,
    desc: str | None = None,
    total: int = 0,
) -> Callable[_P, GeneratorWorker[_Y, _S, _R]]:
    ...


@overload
def thread_worker(
    function: Callable[_P, _R],
    *,
    desc: str | None = None,
    total: int = 0,
) -> Callable[_P, FunctionWorker[_R]]:
    ...


@overload
def thread_worker(
    function: None = None,
    *,
    desc: str | None = None,
    total: int = 0,
) -> Callable[
    [Callable[_P, _R]], Callable[_P, FunctionWorker[_R] | GeneratorWorker[Any, Any, _R]]
]:
    ...


[docs]def thread_worker(function=None, *, desc=None, total=0): """ Convert the returned value of a function into a worker. >>> from tabulous.threading import thread_worker >>> @thread_worker >>> def func(): ... time.sleep(1) Parameters ---------- function : callable Function to be called in another thread. desc : str, optional Label that will shown beside the progress indicator. The function name will be used if not provided. total : int, default is 0 Total number of steps for the progress indicator. """ def _inner(fn: Callable): return create_worker(fn, desc=desc, total=total) return _inner if function is None else _inner(function)
def create_worker( fn: Callable, *, desc: str | None = None, total: int = 0, ): worker_constructor = _thread_worker(fn) sig = inspect.signature(fn) def _create_worker(*args, **kwargs): from tabulous._qt._mainwindow import QMainWindow nonlocal desc, total bound = sig.bind_partial(*args, **kwargs) if desc is None: _desc = getattr(fn, "__name__", repr(fn)) elif callable(desc): _desc = _call_with_filtered(desc, bound.arguments) else: _desc = desc if not isinstance(_desc, str): raise TypeError("`desc` did not return a str.") if callable(total): _total = _call_with_filtered(total, bound.arguments) else: _total = total if not isinstance(_total, int): raise TypeError("`total` did not return an int.") viewer = QMainWindow.currentViewer() worker = worker_constructor(*args, **kwargs) viewer.native._tablestack._info_stack.addWorker( worker, desc=_desc, total=_total ) return worker return _create_worker @lru_cache(maxsize=32) def _make_filter(fn: Callable[..., _R]) -> Callable[[dict[str, Any]], dict[str, Any]]: sig = inspect.signature(fn) arg_names: list[str] = [] for name, param in sig.parameters.items(): if param.kind in ( param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY, ): arg_names.append(name) elif param.kind == param.VAR_POSITIONAL: raise NotImplementedError("Cannot use *args or **kwargs") elif param.kind == param.VAR_KEYWORD: raise NotImplementedError("Cannot use *args or **kwargs") def _filter(kwargs: dict): return {k: v for k, v in kwargs.items() if k in arg_names} return _filter def _call_with_filtered(fn: Callable[..., _R], kwargs: dict[str, Any]) -> _R: filt = _make_filter(fn) return fn(**filt(kwargs))