From 64dfb43d9e4579a6d7408f3b5638324653febdc6 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Thu, 17 Aug 2023 11:05:02 -0400 Subject: [PATCH] fix: fix callback of throttled/debounced decorated functions with mismatched args (#184) * fix: fix throttled inspection * build: change typing-ext deps * fix: use inspect.signature * use get_max_args * fix: fix typing --- .github/workflows/test_and_deploy.yml | 2 +- pyproject.toml | 2 +- src/superqt/utils/_throttler.py | 126 ++++++++++++-------------- tests/test_throttler.py | 29 ++++++ 4 files changed, 88 insertions(+), 71 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5be982a..c343528 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -100,7 +100,7 @@ jobs: run: | python -m pip install -U pip python -m pip install -e .[test,pyqt5] - python -m pip install qtpy==1.1.0 typing-extensions==3.10.0.0 + python -m pip install qtpy==1.1.0 typing-extensions==3.7.4.3 - name: Test uses: aganders3/headless-gui@v1.2 diff --git a/pyproject.toml b/pyproject.toml index b45abf1..d1f3199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "packaging", "pygments>=2.4.0", "qtpy>=1.1.0", - "typing-extensions", + "typing-extensions >=3.7.4.3,!=3.10.0.0", ] # extras diff --git a/src/superqt/utils/_throttler.py b/src/superqt/utils/_throttler.py index 0065c8f..d5f69b9 100644 --- a/src/superqt/utils/_throttler.py +++ b/src/superqt/utils/_throttler.py @@ -26,17 +26,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import sys +from __future__ import annotations + from concurrent.futures import Future from enum import IntFlag, auto from functools import wraps -from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload from qtpy.QtCore import QObject, Qt, QTimer, Signal +from ._util import get_max_args + if TYPE_CHECKING: - from qtpy.QtCore import SignalInstance - from typing_extensions import Literal, ParamSpec + from typing_extensions import ParamSpec P = ParamSpec("P") # maintain runtime compatibility with older typing_extensions @@ -70,7 +72,7 @@ class GenericSignalThrottler(QObject): self, kind: Kind, emissionPolicy: EmissionPolicy, - parent: Optional[QObject] = None, + parent: QObject | None = None, ) -> None: super().__init__(parent) @@ -166,7 +168,7 @@ class QSignalThrottler(GenericSignalThrottler): def __init__( self, policy: EmissionPolicy = EmissionPolicy.Leading, - parent: Optional[QObject] = None, + parent: QObject | None = None, ) -> None: super().__init__(Kind.Throttler, policy, parent) @@ -181,7 +183,7 @@ class QSignalDebouncer(GenericSignalThrottler): def __init__( self, policy: EmissionPolicy = EmissionPolicy.Trailing, - parent: Optional[QObject] = None, + parent: QObject | None = None, ) -> None: super().__init__(Kind.Debouncer, policy, parent) @@ -189,30 +191,44 @@ class QSignalDebouncer(GenericSignalThrottler): # below here part is unique to superqt (not from KD) -if TYPE_CHECKING: - from typing_extensions import Protocol +class ThrottledCallable(GenericSignalThrottler, Generic[P, R]): + def __init__( + self, + func: Callable[P, R], + kind: Kind, + emissionPolicy: EmissionPolicy, + parent: QObject | None = None, + ) -> None: + super().__init__(kind, emissionPolicy, parent) - class ThrottledCallable(Generic[P, R], Protocol): - triggered: "SignalInstance" + self._future: Future[R] = Future() + self.__wrapped__ = func - def cancel(self) -> None: - ... + self._args: tuple = () + self._kwargs: dict = {} + self.triggered.connect(self._set_future_result) - def flush(self) -> None: - ... + # even if we were to compile __call__ with a signature matching that of func, + # PySide wouldn't correctly inspect the signature of the ThrottledCallable + # instance: https://bugreports.qt.io/browse/PYSIDE-2423 + # so we do it ourselfs and limit the number of positional arguments + # that we pass to func + self._max_args: int | None = get_max_args(func) - def set_timeout(self, timeout: int) -> None: - ... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Future[R]": # noqa + if not self._future.done(): + self._future.cancel() - if sys.version_info < (3, 9): + self._future = Future() + self._args = args + self._kwargs = kwargs - def __call__(self, *args: "P.args", **kwargs: "P.kwargs") -> Future: - ... + self.throttle() + return self._future - else: - - def __call__(self, *args: "P.args", **kwargs: "P.kwargs") -> Future[R]: - ... + def _set_future_result(self): + result = self.__wrapped__(*self._args[: self._max_args], **self._kwargs) + self._future.set_result(result) @overload @@ -221,28 +237,26 @@ def qthrottled( timeout: int = 100, leading: bool = True, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, -) -> "ThrottledCallable[P, R]": +) -> ThrottledCallable[P, R]: ... @overload def qthrottled( - func: Optional["Literal[None]"] = None, + func: None = ..., timeout: int = 100, leading: bool = True, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, -) -> Callable[[Callable[P, R]], "ThrottledCallable[P, R]"]: +) -> Callable[[Callable[P, R]], ThrottledCallable[P, R]]: ... def qthrottled( - func: Optional[Callable[P, R]] = None, + func: Callable[P, R] | None = None, timeout: int = 100, leading: bool = True, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, -) -> Union[ - "ThrottledCallable[P, R]", Callable[[Callable[P, R]], "ThrottledCallable[P, R]"] -]: +) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]: """Creates a throttled function that invokes func at most once per timeout. The throttled function comes with a `cancel` method to cancel delayed func @@ -280,28 +294,26 @@ def qdebounced( timeout: int = 100, leading: bool = False, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, -) -> "ThrottledCallable[P, R]": +) -> ThrottledCallable[P, R]: ... @overload def qdebounced( - func: Optional["Literal[None]"] = None, + func: None = ..., timeout: int = 100, leading: bool = False, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, -) -> Callable[[Callable[P, R]], "ThrottledCallable[P, R]"]: +) -> Callable[[Callable[P, R]], ThrottledCallable[P, R]]: ... def qdebounced( - func: Optional[Callable[P, R]] = None, + func: Callable[P, R] | None = None, timeout: int = 100, leading: bool = False, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, -) -> Union[ - "ThrottledCallable[P, R]", Callable[[Callable[P, R]], "ThrottledCallable[P, R]"] -]: +) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]: """Creates a debounced function that delays invoking `func`. `func` will not be invoked until `timeout` ms have elapsed since the last time @@ -337,41 +349,17 @@ def qdebounced( def _make_decorator( - func: Optional[Callable[P, R]], + func: Callable[P, R] | None, timeout: int, leading: bool, timer_type: Qt.TimerType, kind: Kind, -) -> Union[ - "ThrottledCallable[P, R]", Callable[[Callable[P, R]], "ThrottledCallable[P, R]"] -]: - def deco(func: Callable[P, R]) -> "ThrottledCallable[P, R]": +) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]: + def deco(func: Callable[P, R]) -> ThrottledCallable[P, R]: policy = EmissionPolicy.Leading if leading else EmissionPolicy.Trailing - throttle = GenericSignalThrottler(kind, policy) - throttle.setTimerType(timer_type) - throttle.setTimeout(timeout) - last_f = None - future: Optional[Future] = None - - @wraps(func) - def inner(*args: "P.args", **kwargs: "P.kwargs") -> Future: - nonlocal last_f - nonlocal future - if last_f is not None: - throttle.triggered.disconnect(last_f) - if future is not None and not future.done(): - future.cancel() - - future = Future() - last_f = lambda: future.set_result(func(*args, **kwargs)) # noqa - throttle.triggered.connect(last_f) - throttle.throttle() - return future - - inner.cancel = throttle.cancel - inner.flush = throttle.flush - inner.set_timeout = throttle.setTimeout - inner.triggered = throttle.triggered - return inner # type: ignore + obj = ThrottledCallable(func, kind, policy) + obj.setTimerType(timer_type) + obj.setTimeout(timeout) + return wraps(func)(obj) return deco(func) if func is not None else deco diff --git a/tests/test_throttler.py b/tests/test_throttler.py index f0c9daa..577a482 100644 --- a/tests/test_throttler.py +++ b/tests/test_throttler.py @@ -1,5 +1,8 @@ from unittest.mock import Mock +import pytest +from qtpy.QtCore import QObject, Signal + from superqt.utils import qdebounced, qthrottled @@ -41,3 +44,29 @@ def test_throttled(qtbot): qtbot.wait(5) assert mock1.call_count == 2 assert mock2.call_count == 10 + + +@pytest.mark.parametrize("deco", [qthrottled, qdebounced]) +def test_ensure_throttled_sig_inspection(deco, qtbot): + mock = Mock() + + class Emitter(QObject): + sig = Signal(int, int, int) + + @deco + def func(a: int, b: int): + """docstring""" + mock(a, b) + + obj = Emitter() + obj.sig.connect(func) + + # this is the crux of the test... + # we emit 3 args, but the function only takes 2 + # this should normally work fine in Qt. + # testing here that the decorator doesn't break it. + with qtbot.waitSignal(func.triggered, timeout=1000): + obj.sig.emit(1, 2, 3) + mock.assert_called_once_with(1, 2) + assert func.__doc__ == "docstring" + assert func.__name__ == "func"