From 39b6a0596fe600eff9308f14678e5653045bf98c Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Thu, 17 Aug 2023 09:20:11 -0400 Subject: [PATCH] fix: fix parameter inspection on ensure_thread decorators (alternate) (#185) * fix: use different approach * test: apply fixes * back to signature * fix get_max_args * IMPORT THE FUTURE * try or return None * check for callable * Update test_utils.py Co-authored-by: Grzegorz Bokota * style: [pre-commit.ci] auto fixes [...] --------- Co-authored-by: Grzegorz Bokota Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/superqt/utils/_ensure_thread.py | 30 +++++++---- src/superqt/utils/_util.py | 23 +++++++++ tests/test_ensure_thread.py | 80 +++++++++++++++++++++++++++++ tests/test_utils.py | 64 +++++++++++++++++++++++ 4 files changed, 186 insertions(+), 11 deletions(-) create mode 100644 src/superqt/utils/_util.py diff --git a/src/superqt/utils/_ensure_thread.py b/src/superqt/utils/_ensure_thread.py index 926adae..ab699a0 100644 --- a/src/superqt/utils/_ensure_thread.py +++ b/src/superqt/utils/_ensure_thread.py @@ -3,7 +3,7 @@ from __future__ import annotations from concurrent.futures import Future from functools import wraps -from typing import TYPE_CHECKING, Callable, ClassVar, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, overload from qtpy.QtCore import ( QCoreApplication, @@ -15,6 +15,8 @@ from qtpy.QtCore import ( Slot, ) +from ._util import get_max_args + if TYPE_CHECKING: from typing import TypeVar @@ -28,7 +30,7 @@ class CallCallable(QObject): finished = Signal(object) instances: ClassVar[list[CallCallable]] = [] - def __init__(self, callable, *args, **kwargs): + def __init__(self, callable: Callable, args: tuple, kwargs: dict): super().__init__() self._callable = callable self._args = args @@ -88,15 +90,17 @@ def ensure_main_thread( """ def _out_func(func_): + max_args = get_max_args(func_) + @wraps(func_) - def _func(*args, **kwargs): + def _func(*args, _max_args_=max_args, **kwargs): return _run_in_thread( func_, QCoreApplication.instance().thread(), await_return, timeout, - *args, - **kwargs, + args[:_max_args_], + kwargs, ) return _func @@ -150,10 +154,13 @@ def ensure_object_thread( """ def _out_func(func_): + max_args = get_max_args(func_) + @wraps(func_) - def _func(self, *args, **kwargs): + def _func(*args, _max_args_=max_args, **kwargs): + thread = args[0].thread() # self return _run_in_thread( - func_, self.thread(), await_return, timeout, self, *args, **kwargs + func_, thread, await_return, timeout, args[:_max_args_], kwargs ) return _func @@ -166,9 +173,9 @@ def _run_in_thread( thread: QThread, await_return: bool, timeout: int, - *args, - **kwargs, -): + args: tuple, + kwargs: dict, +) -> Any: future = Future() # type: ignore if thread is QThread.currentThread(): result = func(*args, **kwargs) @@ -176,7 +183,8 @@ def _run_in_thread( future.set_result(result) return future return result - f = CallCallable(func, *args, **kwargs) + + f = CallCallable(func, args, kwargs) f.moveToThread(thread) f.finished.connect(future.set_result, Qt.ConnectionType.DirectConnection) QMetaObject.invokeMethod(f, "call", Qt.ConnectionType.QueuedConnection) # type: ignore # noqa diff --git a/src/superqt/utils/_util.py b/src/superqt/utils/_util.py new file mode 100644 index 0000000..bdb9d61 --- /dev/null +++ b/src/superqt/utils/_util.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from inspect import signature +from typing import Callable + + +def get_max_args(func: Callable) -> int | None: + """Return the maximum number of positional arguments that func can accept.""" + if not callable(func): + raise TypeError(f"{func!r} is not callable") + + try: + sig = signature(func) + except Exception: + return None + + max_args = 0 + for param in sig.parameters.values(): + if param.kind == param.VAR_POSITIONAL: + return None + if param.kind in {param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD}: + max_args += 1 + return max_args diff --git a/tests/test_ensure_thread.py b/tests/test_ensure_thread.py index f4bde0a..e35afee 100644 --- a/tests/test_ensure_thread.py +++ b/tests/test_ensure_thread.py @@ -1,7 +1,10 @@ import inspect import os +import threading import time from concurrent.futures import Future, TimeoutError +from functools import wraps +from unittest.mock import Mock import pytest from qtpy.QtCore import QCoreApplication, QObject, QThread, Signal @@ -217,3 +220,80 @@ def test_object_thread(qtbot): assert ob.thread() is thread with qtbot.waitSignal(thread.finished): thread.quit() + + +@pytest.mark.parametrize("mode", ["method", "func", "wrapped"]) +@pytest.mark.parametrize("deco", [ensure_main_thread, ensure_object_thread]) +def test_ensure_thread_sig_inspection(deco, mode): + class Emitter(QObject): + sig = Signal(int, int, int) + + obj = Emitter() + mock = Mock() + + if mode == "method": + + class Receiver(QObject): + @deco + def func(self, a: int, b: int): + mock(a, b) + + r = Receiver() + obj.sig.connect(r.func) + elif deco == ensure_object_thread: + return # not compatible with function types + + elif mode == "wrapped": + + def wr(fun): + @wraps(fun) + def wr2(*args): + mock(*args) + return fun(*args) * 2 + + return wr2 + + @deco + @wr + def wrapped_func(a, b): + return a + b + + obj.sig.connect(wrapped_func) + + elif mode == "func": + + @deco + def func(a: int, b: int) -> None: + mock(a, b) + + obj.sig.connect(func) + + # this is the crux of the test... + # we emit 3 args, but the function only takes 2 + # this should normally work fine in Qt. + # testing here that the decorator doesn't break it. + obj.sig.emit(1, 2, 3) + mock.assert_called_once_with(1, 2) + + +def test_main_thread_function(qtbot): + """Testing decorator on a function rather than QObject method.""" + + mock = Mock() + + class Emitter(QObject): + sig = Signal(int, int, int) + + @ensure_main_thread + def func(x: int) -> None: + mock(x, QThread.currentThread()) + + e = Emitter() + e.sig.connect(func) + + with qtbot.waitSignal(e.sig): + thread = threading.Thread(target=e.sig.emit, args=(1, 2, 3)) + thread.start() + thread.join() + + mock.assert_called_once_with(1, QCoreApplication.instance().thread()) diff --git a/tests/test_utils.py b/tests/test_utils.py index f80c606..c794271 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ from unittest.mock import Mock from qtpy.QtCore import QObject, Signal from superqt.utils import signals_blocked +from superqt.utils._util import get_max_args def test_signal_blocker(qtbot): @@ -27,3 +28,66 @@ def test_signal_blocker(qtbot): qtbot.wait(10) receiver.assert_not_called() + + +def test_get_max_args_simple(): + def fun1(): + pass + + assert get_max_args(fun1) == 0 + + def fun2(a): + pass + + assert get_max_args(fun2) == 1 + + def fun3(a, b=1): + pass + + assert get_max_args(fun3) == 2 + + def fun4(a, *, b=2): + pass + + assert get_max_args(fun4) == 1 + + def fun5(a, *b): + pass + + assert get_max_args(fun5) is None + + assert get_max_args(print) is None + + +def test_get_max_args_wrapped(): + from functools import partial, wraps + + def fun1(a, b): + pass + + assert get_max_args(partial(fun1, 1)) == 1 + + def dec(fun): + @wraps(fun) + def wrapper(*args, **kwargs): + return fun(*args, **kwargs) + + return wrapper + + assert get_max_args(dec(fun1)) == 2 + + +def test_get_max_args_methods(): + class A: + def fun1(self): + pass + + def fun2(self, a): + pass + + def __call__(self, a, b=1): + pass + + assert get_max_args(A().fun1) == 0 + assert get_max_args(A().fun2) == 1 + assert get_max_args(A()) == 2