mirror of
https://github.com/pyapp-kit/superqt.git
synced 2026-01-04 11:21:09 +01:00
fix: fix parameter inspection on ensure_thread decorators (alternate) (#185)
* fix: use different approach * test: apply fixes * back to signature * fix get_max_args * IMPORT THE FUTURE * try or return None * check for callable * Update test_utils.py Co-authored-by: Grzegorz Bokota <bokota+github@gmail.com> * style: [pre-commit.ci] auto fixes [...] --------- Co-authored-by: Grzegorz Bokota <bokota+github@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Callable, ClassVar, overload
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, overload
|
||||||
|
|
||||||
from qtpy.QtCore import (
|
from qtpy.QtCore import (
|
||||||
QCoreApplication,
|
QCoreApplication,
|
||||||
@@ -15,6 +15,8 @@ from qtpy.QtCore import (
|
|||||||
Slot,
|
Slot,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ._util import get_max_args
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
@@ -28,7 +30,7 @@ class CallCallable(QObject):
|
|||||||
finished = Signal(object)
|
finished = Signal(object)
|
||||||
instances: ClassVar[list[CallCallable]] = []
|
instances: ClassVar[list[CallCallable]] = []
|
||||||
|
|
||||||
def __init__(self, callable, *args, **kwargs):
|
def __init__(self, callable: Callable, args: tuple, kwargs: dict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._callable = callable
|
self._callable = callable
|
||||||
self._args = args
|
self._args = args
|
||||||
@@ -88,15 +90,17 @@ def ensure_main_thread(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _out_func(func_):
|
def _out_func(func_):
|
||||||
|
max_args = get_max_args(func_)
|
||||||
|
|
||||||
@wraps(func_)
|
@wraps(func_)
|
||||||
def _func(*args, **kwargs):
|
def _func(*args, _max_args_=max_args, **kwargs):
|
||||||
return _run_in_thread(
|
return _run_in_thread(
|
||||||
func_,
|
func_,
|
||||||
QCoreApplication.instance().thread(),
|
QCoreApplication.instance().thread(),
|
||||||
await_return,
|
await_return,
|
||||||
timeout,
|
timeout,
|
||||||
*args,
|
args[:_max_args_],
|
||||||
**kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _func
|
return _func
|
||||||
@@ -150,10 +154,13 @@ def ensure_object_thread(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _out_func(func_):
|
def _out_func(func_):
|
||||||
|
max_args = get_max_args(func_)
|
||||||
|
|
||||||
@wraps(func_)
|
@wraps(func_)
|
||||||
def _func(self, *args, **kwargs):
|
def _func(*args, _max_args_=max_args, **kwargs):
|
||||||
|
thread = args[0].thread() # self
|
||||||
return _run_in_thread(
|
return _run_in_thread(
|
||||||
func_, self.thread(), await_return, timeout, self, *args, **kwargs
|
func_, thread, await_return, timeout, args[:_max_args_], kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return _func
|
return _func
|
||||||
@@ -166,9 +173,9 @@ def _run_in_thread(
|
|||||||
thread: QThread,
|
thread: QThread,
|
||||||
await_return: bool,
|
await_return: bool,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
*args,
|
args: tuple,
|
||||||
**kwargs,
|
kwargs: dict,
|
||||||
):
|
) -> Any:
|
||||||
future = Future() # type: ignore
|
future = Future() # type: ignore
|
||||||
if thread is QThread.currentThread():
|
if thread is QThread.currentThread():
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
@@ -176,7 +183,8 @@ def _run_in_thread(
|
|||||||
future.set_result(result)
|
future.set_result(result)
|
||||||
return future
|
return future
|
||||||
return result
|
return result
|
||||||
f = CallCallable(func, *args, **kwargs)
|
|
||||||
|
f = CallCallable(func, args, kwargs)
|
||||||
f.moveToThread(thread)
|
f.moveToThread(thread)
|
||||||
f.finished.connect(future.set_result, Qt.ConnectionType.DirectConnection)
|
f.finished.connect(future.set_result, Qt.ConnectionType.DirectConnection)
|
||||||
QMetaObject.invokeMethod(f, "call", Qt.ConnectionType.QueuedConnection) # type: ignore # noqa
|
QMetaObject.invokeMethod(f, "call", Qt.ConnectionType.QueuedConnection) # type: ignore # noqa
|
||||||
|
|||||||
23
src/superqt/utils/_util.py
Normal file
23
src/superqt/utils/_util.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from inspect import signature
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_args(func: Callable) -> int | None:
|
||||||
|
"""Return the maximum number of positional arguments that func can accept."""
|
||||||
|
if not callable(func):
|
||||||
|
raise TypeError(f"{func!r} is not callable")
|
||||||
|
|
||||||
|
try:
|
||||||
|
sig = signature(func)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
max_args = 0
|
||||||
|
for param in sig.parameters.values():
|
||||||
|
if param.kind == param.VAR_POSITIONAL:
|
||||||
|
return None
|
||||||
|
if param.kind in {param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD}:
|
||||||
|
max_args += 1
|
||||||
|
return max_args
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import Future, TimeoutError
|
from concurrent.futures import Future, TimeoutError
|
||||||
|
from functools import wraps
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from qtpy.QtCore import QCoreApplication, QObject, QThread, Signal
|
from qtpy.QtCore import QCoreApplication, QObject, QThread, Signal
|
||||||
@@ -217,3 +220,80 @@ def test_object_thread(qtbot):
|
|||||||
assert ob.thread() is thread
|
assert ob.thread() is thread
|
||||||
with qtbot.waitSignal(thread.finished):
|
with qtbot.waitSignal(thread.finished):
|
||||||
thread.quit()
|
thread.quit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mode", ["method", "func", "wrapped"])
|
||||||
|
@pytest.mark.parametrize("deco", [ensure_main_thread, ensure_object_thread])
|
||||||
|
def test_ensure_thread_sig_inspection(deco, mode):
|
||||||
|
class Emitter(QObject):
|
||||||
|
sig = Signal(int, int, int)
|
||||||
|
|
||||||
|
obj = Emitter()
|
||||||
|
mock = Mock()
|
||||||
|
|
||||||
|
if mode == "method":
|
||||||
|
|
||||||
|
class Receiver(QObject):
|
||||||
|
@deco
|
||||||
|
def func(self, a: int, b: int):
|
||||||
|
mock(a, b)
|
||||||
|
|
||||||
|
r = Receiver()
|
||||||
|
obj.sig.connect(r.func)
|
||||||
|
elif deco == ensure_object_thread:
|
||||||
|
return # not compatible with function types
|
||||||
|
|
||||||
|
elif mode == "wrapped":
|
||||||
|
|
||||||
|
def wr(fun):
|
||||||
|
@wraps(fun)
|
||||||
|
def wr2(*args):
|
||||||
|
mock(*args)
|
||||||
|
return fun(*args) * 2
|
||||||
|
|
||||||
|
return wr2
|
||||||
|
|
||||||
|
@deco
|
||||||
|
@wr
|
||||||
|
def wrapped_func(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
obj.sig.connect(wrapped_func)
|
||||||
|
|
||||||
|
elif mode == "func":
|
||||||
|
|
||||||
|
@deco
|
||||||
|
def func(a: int, b: int) -> None:
|
||||||
|
mock(a, b)
|
||||||
|
|
||||||
|
obj.sig.connect(func)
|
||||||
|
|
||||||
|
# this is the crux of the test...
|
||||||
|
# we emit 3 args, but the function only takes 2
|
||||||
|
# this should normally work fine in Qt.
|
||||||
|
# testing here that the decorator doesn't break it.
|
||||||
|
obj.sig.emit(1, 2, 3)
|
||||||
|
mock.assert_called_once_with(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_thread_function(qtbot):
|
||||||
|
"""Testing decorator on a function rather than QObject method."""
|
||||||
|
|
||||||
|
mock = Mock()
|
||||||
|
|
||||||
|
class Emitter(QObject):
|
||||||
|
sig = Signal(int, int, int)
|
||||||
|
|
||||||
|
@ensure_main_thread
|
||||||
|
def func(x: int) -> None:
|
||||||
|
mock(x, QThread.currentThread())
|
||||||
|
|
||||||
|
e = Emitter()
|
||||||
|
e.sig.connect(func)
|
||||||
|
|
||||||
|
with qtbot.waitSignal(e.sig):
|
||||||
|
thread = threading.Thread(target=e.sig.emit, args=(1, 2, 3))
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
mock.assert_called_once_with(1, QCoreApplication.instance().thread())
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from unittest.mock import Mock
|
|||||||
from qtpy.QtCore import QObject, Signal
|
from qtpy.QtCore import QObject, Signal
|
||||||
|
|
||||||
from superqt.utils import signals_blocked
|
from superqt.utils import signals_blocked
|
||||||
|
from superqt.utils._util import get_max_args
|
||||||
|
|
||||||
|
|
||||||
def test_signal_blocker(qtbot):
|
def test_signal_blocker(qtbot):
|
||||||
@@ -27,3 +28,66 @@ def test_signal_blocker(qtbot):
|
|||||||
qtbot.wait(10)
|
qtbot.wait(10)
|
||||||
|
|
||||||
receiver.assert_not_called()
|
receiver.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_max_args_simple():
|
||||||
|
def fun1():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert get_max_args(fun1) == 0
|
||||||
|
|
||||||
|
def fun2(a):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert get_max_args(fun2) == 1
|
||||||
|
|
||||||
|
def fun3(a, b=1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert get_max_args(fun3) == 2
|
||||||
|
|
||||||
|
def fun4(a, *, b=2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert get_max_args(fun4) == 1
|
||||||
|
|
||||||
|
def fun5(a, *b):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert get_max_args(fun5) is None
|
||||||
|
|
||||||
|
assert get_max_args(print) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_max_args_wrapped():
|
||||||
|
from functools import partial, wraps
|
||||||
|
|
||||||
|
def fun1(a, b):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert get_max_args(partial(fun1, 1)) == 1
|
||||||
|
|
||||||
|
def dec(fun):
|
||||||
|
@wraps(fun)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return fun(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
assert get_max_args(dec(fun1)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_max_args_methods():
|
||||||
|
class A:
|
||||||
|
def fun1(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fun2(self, a):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, a, b=1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert get_max_args(A().fun1) == 0
|
||||||
|
assert get_max_args(A().fun2) == 1
|
||||||
|
assert get_max_args(A()) == 2
|
||||||
|
|||||||
Reference in New Issue
Block a user