diff --git a/src/superqt/utils/_throttler.py b/src/superqt/utils/_throttler.py index d5f69b9..925433c 100644 --- a/src/superqt/utils/_throttler.py +++ b/src/superqt/utils/_throttler.py @@ -32,6 +32,7 @@ from concurrent.futures import Future from enum import IntFlag, auto from functools import wraps from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload +from weakref import WeakKeyDictionary from qtpy.QtCore import QObject, Qt, QTimer, Signal @@ -202,18 +203,26 @@ class ThrottledCallable(GenericSignalThrottler, Generic[P, R]): super().__init__(kind, emissionPolicy, parent) self._future: Future[R] = Future() + if isinstance(func, staticmethod): + self._func = func.__func__ + else: + self._func = func + self.__wrapped__ = func self._args: tuple = () self._kwargs: dict = {} self.triggered.connect(self._set_future_result) + self._name = None + + self._obj_dkt = WeakKeyDictionary() # even if we were to compile __call__ with a signature matching that of func, # PySide wouldn't correctly inspect the signature of the ThrottledCallable # instance: https://bugreports.qt.io/browse/PYSIDE-2423 # so we do it ourselfs and limit the number of positional arguments # that we pass to func - self._max_args: int | None = get_max_args(func) + self._max_args: int | None = get_max_args(self._func) def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Future[R]": # noqa if not self._future.done(): @@ -227,9 +236,45 @@ class ThrottledCallable(GenericSignalThrottler, Generic[P, R]): return self._future def _set_future_result(self): - result = self.__wrapped__(*self._args[: self._max_args], **self._kwargs) + result = self._func(*self._args[: self._max_args], **self._kwargs) self._future.set_result(result) + def __set_name__(self, owner, name): + if not isinstance(self.__wrapped__, staticmethod): + self._name = name + + def _get_throttler(self, instance, owner, parent, obj): + throttler = ThrottledCallable( + self.__wrapped__.__get__(instance, owner), + self._kind, + self._emissionPolicy, + parent=parent, + ) + throttler.setTimerType(self.timerType()) + throttler.setTimeout(self.timeout()) + try: + setattr( + obj, + self._name, + throttler, + ) + except AttributeError: + self._obj_dkt[obj] = throttler + return throttler + + def __get__(self, instance, owner): + if instance is None or not self._name: + return self + + if instance in self._obj_dkt: + return self._obj_dkt[instance] + + parent = self.parent() + if parent is None and isinstance(instance, QObject): + parent = instance + + return self._get_throttler(instance, owner, parent, instance) + @overload def qthrottled( @@ -237,6 +282,7 @@ def qthrottled( timeout: int = 100, leading: bool = True, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, + parent: QObject | None = None, ) -> ThrottledCallable[P, R]: ... @@ -247,6 +293,7 @@ def qthrottled( timeout: int = 100, leading: bool = True, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, + parent: QObject | None = None, ) -> Callable[[Callable[P, R]], ThrottledCallable[P, R]]: ... @@ -256,6 +303,7 @@ def qthrottled( timeout: int = 100, leading: bool = True, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, + parent: QObject | None = None, ) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]: """Creates a throttled function that invokes func at most once per timeout. @@ -284,8 +332,11 @@ def qthrottled( - `Qt.CoarseTimer`: Coarse timers try to keep accuracy within 5% of the desired interval - `Qt.VeryCoarseTimer`: Very coarse timers only keep full second accuracy + parent: QObject or None + Parent object for timer. If using qthrottled as function it may be usefull + for cleaning data """ - return _make_decorator(func, timeout, leading, timer_type, Kind.Throttler) + return _make_decorator(func, timeout, leading, timer_type, Kind.Throttler, parent) @overload @@ -294,6 +345,7 @@ def qdebounced( timeout: int = 100, leading: bool = False, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, + parent: QObject | None = None, ) -> ThrottledCallable[P, R]: ... @@ -304,6 +356,7 @@ def qdebounced( timeout: int = 100, leading: bool = False, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, + parent: QObject | None = None, ) -> Callable[[Callable[P, R]], ThrottledCallable[P, R]]: ... @@ -313,6 +366,7 @@ def qdebounced( timeout: int = 100, leading: bool = False, timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer, + parent: QObject | None = None, ) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]: """Creates a debounced function that delays invoking `func`. @@ -344,8 +398,11 @@ def qdebounced( - `Qt.CoarseTimer`: Coarse timers try to keep accuracy within 5% of the desired interval - `Qt.VeryCoarseTimer`: Very coarse timers only keep full second accuracy + parent: QObject or None + Parent object for timer. If using qthrottled as function it may be usefull + for cleaning data """ - return _make_decorator(func, timeout, leading, timer_type, Kind.Debouncer) + return _make_decorator(func, timeout, leading, timer_type, Kind.Debouncer, parent) def _make_decorator( @@ -354,10 +411,16 @@ def _make_decorator( leading: bool, timer_type: Qt.TimerType, kind: Kind, + parent: QObject | None = None, ) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]: def deco(func: Callable[P, R]) -> ThrottledCallable[P, R]: + nonlocal parent + + instance: object | None = getattr(func, "__self__", None) + if isinstance(instance, QObject) and parent is None: + parent = instance policy = EmissionPolicy.Leading if leading else EmissionPolicy.Trailing - obj = ThrottledCallable(func, kind, policy) + obj = ThrottledCallable(func, kind, policy, parent=parent) obj.setTimerType(timer_type) obj.setTimeout(timeout) return wraps(func)(obj) diff --git a/tests/test_throttler.py b/tests/test_throttler.py index 577a482..884d98e 100644 --- a/tests/test_throttler.py +++ b/tests/test_throttler.py @@ -4,6 +4,7 @@ import pytest from qtpy.QtCore import QObject, Signal from superqt.utils import qdebounced, qthrottled +from superqt.utils._throttler import ThrottledCallable def test_debounced(qtbot): @@ -26,6 +27,66 @@ def test_debounced(qtbot): assert mock2.call_count == 10 +def test_debouncer_method(qtbot): + class A(QObject): + def __init__(self): + super().__init__() + self.count = 0 + + def callback(self): + self.count += 1 + + a = A() + assert all(not isinstance(x, ThrottledCallable) for x in a.children()) + b = qdebounced(a.callback, timeout=4) + assert any(isinstance(x, ThrottledCallable) for x in a.children()) + for _ in range(10): + b() + + qtbot.wait(5) + + assert a.count == 1 + + +def test_debouncer_method_definition(qtbot): + mock1 = Mock() + mock2 = Mock() + + class A(QObject): + def __init__(self): + super().__init__() + self.count = 0 + + @qdebounced(timeout=4) + def callback(self): + self.count += 1 + + @qdebounced(timeout=4) + @staticmethod + def call1(): + mock1() + + @staticmethod + @qdebounced(timeout=4) + def call2(): + mock2() + + a = A() + assert all(not isinstance(x, ThrottledCallable) for x in a.children()) + for _ in range(10): + a.callback(1) + A.call1(34) + a.call1(22) + a.call2(22) + A.call2(32) + + qtbot.wait(5) + + assert a.count == 1 + mock1.assert_called_once() + mock2.assert_called_once() + + def test_throttled(qtbot): mock1 = Mock() mock2 = Mock()