fix: fix parameter inspection on ensure_thread decorators (alternate) (#185)

* fix: use different approach

* test: apply fixes

* back to signature

* fix get_max_args

* IMPORT THE FUTURE

* try or return None

* check for callable

* Update test_utils.py

Co-authored-by: Grzegorz Bokota <bokota+github@gmail.com>

* style: [pre-commit.ci] auto fixes [...]

---------

Co-authored-by: Grzegorz Bokota <bokota+github@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Talley Lambert
2023-08-17 09:20:11 -04:00
committed by GitHub
parent 9ff01e757b
commit 39b6a0596f
4 changed files with 186 additions and 11 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from concurrent.futures import Future from concurrent.futures import Future
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Callable, ClassVar, overload from typing import TYPE_CHECKING, Any, Callable, ClassVar, overload
from qtpy.QtCore import ( from qtpy.QtCore import (
QCoreApplication, QCoreApplication,
@@ -15,6 +15,8 @@ from qtpy.QtCore import (
Slot, Slot,
) )
from ._util import get_max_args
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeVar from typing import TypeVar
@@ -28,7 +30,7 @@ class CallCallable(QObject):
finished = Signal(object) finished = Signal(object)
instances: ClassVar[list[CallCallable]] = [] instances: ClassVar[list[CallCallable]] = []
def __init__(self, callable, *args, **kwargs): def __init__(self, callable: Callable, args: tuple, kwargs: dict):
super().__init__() super().__init__()
self._callable = callable self._callable = callable
self._args = args self._args = args
@@ -88,15 +90,17 @@ def ensure_main_thread(
""" """
def _out_func(func_): def _out_func(func_):
max_args = get_max_args(func_)
@wraps(func_) @wraps(func_)
def _func(*args, **kwargs): def _func(*args, _max_args_=max_args, **kwargs):
return _run_in_thread( return _run_in_thread(
func_, func_,
QCoreApplication.instance().thread(), QCoreApplication.instance().thread(),
await_return, await_return,
timeout, timeout,
*args, args[:_max_args_],
**kwargs, kwargs,
) )
return _func return _func
@@ -150,10 +154,13 @@ def ensure_object_thread(
""" """
def _out_func(func_): def _out_func(func_):
max_args = get_max_args(func_)
@wraps(func_) @wraps(func_)
def _func(self, *args, **kwargs): def _func(*args, _max_args_=max_args, **kwargs):
thread = args[0].thread() # self
return _run_in_thread( return _run_in_thread(
func_, self.thread(), await_return, timeout, self, *args, **kwargs func_, thread, await_return, timeout, args[:_max_args_], kwargs
) )
return _func return _func
@@ -166,9 +173,9 @@ def _run_in_thread(
thread: QThread, thread: QThread,
await_return: bool, await_return: bool,
timeout: int, timeout: int,
*args, args: tuple,
**kwargs, kwargs: dict,
): ) -> Any:
future = Future() # type: ignore future = Future() # type: ignore
if thread is QThread.currentThread(): if thread is QThread.currentThread():
result = func(*args, **kwargs) result = func(*args, **kwargs)
@@ -176,7 +183,8 @@ def _run_in_thread(
future.set_result(result) future.set_result(result)
return future return future
return result return result
f = CallCallable(func, *args, **kwargs)
f = CallCallable(func, args, kwargs)
f.moveToThread(thread) f.moveToThread(thread)
f.finished.connect(future.set_result, Qt.ConnectionType.DirectConnection) f.finished.connect(future.set_result, Qt.ConnectionType.DirectConnection)
QMetaObject.invokeMethod(f, "call", Qt.ConnectionType.QueuedConnection) # type: ignore # noqa QMetaObject.invokeMethod(f, "call", Qt.ConnectionType.QueuedConnection) # type: ignore # noqa

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from inspect import signature
from typing import Callable
def get_max_args(func: Callable) -> int | None:
"""Return the maximum number of positional arguments that func can accept."""
if not callable(func):
raise TypeError(f"{func!r} is not callable")
try:
sig = signature(func)
except Exception:
return None
max_args = 0
for param in sig.parameters.values():
if param.kind == param.VAR_POSITIONAL:
return None
if param.kind in {param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD}:
max_args += 1
return max_args

View File

@@ -1,7 +1,10 @@
import inspect import inspect
import os import os
import threading
import time import time
from concurrent.futures import Future, TimeoutError from concurrent.futures import Future, TimeoutError
from functools import wraps
from unittest.mock import Mock
import pytest import pytest
from qtpy.QtCore import QCoreApplication, QObject, QThread, Signal from qtpy.QtCore import QCoreApplication, QObject, QThread, Signal
@@ -217,3 +220,80 @@ def test_object_thread(qtbot):
assert ob.thread() is thread assert ob.thread() is thread
with qtbot.waitSignal(thread.finished): with qtbot.waitSignal(thread.finished):
thread.quit() thread.quit()
@pytest.mark.parametrize("mode", ["method", "func", "wrapped"])
@pytest.mark.parametrize("deco", [ensure_main_thread, ensure_object_thread])
def test_ensure_thread_sig_inspection(deco, mode):
class Emitter(QObject):
sig = Signal(int, int, int)
obj = Emitter()
mock = Mock()
if mode == "method":
class Receiver(QObject):
@deco
def func(self, a: int, b: int):
mock(a, b)
r = Receiver()
obj.sig.connect(r.func)
elif deco == ensure_object_thread:
return # not compatible with function types
elif mode == "wrapped":
def wr(fun):
@wraps(fun)
def wr2(*args):
mock(*args)
return fun(*args) * 2
return wr2
@deco
@wr
def wrapped_func(a, b):
return a + b
obj.sig.connect(wrapped_func)
elif mode == "func":
@deco
def func(a: int, b: int) -> None:
mock(a, b)
obj.sig.connect(func)
# this is the crux of the test...
# we emit 3 args, but the function only takes 2
# this should normally work fine in Qt.
# testing here that the decorator doesn't break it.
obj.sig.emit(1, 2, 3)
mock.assert_called_once_with(1, 2)
def test_main_thread_function(qtbot):
"""Testing decorator on a function rather than QObject method."""
mock = Mock()
class Emitter(QObject):
sig = Signal(int, int, int)
@ensure_main_thread
def func(x: int) -> None:
mock(x, QThread.currentThread())
e = Emitter()
e.sig.connect(func)
with qtbot.waitSignal(e.sig):
thread = threading.Thread(target=e.sig.emit, args=(1, 2, 3))
thread.start()
thread.join()
mock.assert_called_once_with(1, QCoreApplication.instance().thread())

View File

@@ -3,6 +3,7 @@ from unittest.mock import Mock
from qtpy.QtCore import QObject, Signal from qtpy.QtCore import QObject, Signal
from superqt.utils import signals_blocked from superqt.utils import signals_blocked
from superqt.utils._util import get_max_args
def test_signal_blocker(qtbot): def test_signal_blocker(qtbot):
@@ -27,3 +28,66 @@ def test_signal_blocker(qtbot):
qtbot.wait(10) qtbot.wait(10)
receiver.assert_not_called() receiver.assert_not_called()
def test_get_max_args_simple():
def fun1():
pass
assert get_max_args(fun1) == 0
def fun2(a):
pass
assert get_max_args(fun2) == 1
def fun3(a, b=1):
pass
assert get_max_args(fun3) == 2
def fun4(a, *, b=2):
pass
assert get_max_args(fun4) == 1
def fun5(a, *b):
pass
assert get_max_args(fun5) is None
assert get_max_args(print) is None
def test_get_max_args_wrapped():
from functools import partial, wraps
def fun1(a, b):
pass
assert get_max_args(partial(fun1, 1)) == 1
def dec(fun):
@wraps(fun)
def wrapper(*args, **kwargs):
return fun(*args, **kwargs)
return wrapper
assert get_max_args(dec(fun1)) == 2
def test_get_max_args_methods():
class A:
def fun1(self):
pass
def fun2(self, a):
pass
def __call__(self, a, b=1):
pass
assert get_max_args(A().fun1) == 0
assert get_max_args(A().fun2) == 1
assert get_max_args(A()) == 2