diff --git a/taskiq/cli/scheduler/args.py b/taskiq/cli/scheduler/args.py index 8f82b89f..bbf1fb82 100644 --- a/taskiq/cli/scheduler/args.py +++ b/taskiq/cli/scheduler/args.py @@ -20,6 +20,7 @@ class SchedulerArgs: skip_first_run: bool = False update_interval: int | None = None loop_interval: int | None = None + send_timeout: float | None = None @classmethod def from_cli(cls, args: Sequence[str] | None = None) -> "SchedulerArgs": @@ -111,6 +112,19 @@ def from_cli(cls, args: Sequence[str] | None = None) -> "SchedulerArgs": "If not specified, scheduler will run once a second." ), ) + parser.add_argument( + "--send-timeout", + type=float, + default=None, + help=( + "Optional per-send timeout (seconds). If set, each spawned send " + "task is wrapped in asyncio.wait_for with this timeout, " + "preventing a single hung broker.kick from permanently blocking " + "subsequent ticks for the same schedule_id. On timeout the send " + "is cancelled, a warning is logged, and the next scheduled tick " + "will retry. Default: no timeout (backwards-compatible)." + ), + ) namespace = parser.parse_args(args) # If there are any patterns specified, remove default. diff --git a/taskiq/cli/scheduler/run.py b/taskiq/cli/scheduler/run.py index 60390538..a2ab7720 100644 --- a/taskiq/cli/scheduler/run.py +++ b/taskiq/cli/scheduler/run.py @@ -174,6 +174,38 @@ async def send( await scheduler.on_ready(source, task) +async def send_with_timeout( + scheduler: TaskiqScheduler, + source: ScheduleSource, + task: ScheduledTask, + timeout: float, +) -> None: + """ + Send a task, cancelling it if it does not complete within ``timeout`` seconds. + + A timed-out send is logged at WARNING and swallowed — re-raising would propagate + out of the scheduler's main loop's ``add_done_callback``, which is not what we + want. Suppressing it lets the done_callback clear ``running_schedules`` so the + next cron boundary can re-dispatch the task. The slot is freed; the message is + dropped (the broker did not acknowledge it within the budget). + + :param scheduler: current scheduler. + :param source: source of the task. + :param task: task to send. + :param timeout: seconds to wait before cancelling the send. Must be > 0. + """ + try: + await asyncio.wait_for(send(scheduler, source, task), timeout=timeout) + except asyncio.TimeoutError: + logger.warning( + "Sending task %s with schedule_id %s timed out after %.1fs " + "and was cancelled. The next scheduled tick will retry.", + task.task_name, + task.schedule_id, + timeout, + ) + + async def _sleep_until_next_second() -> None: now = datetime.now(tz=timezone.utc) await asyncio.sleep(1 - now.microsecond / 1_000_000) @@ -292,6 +324,7 @@ async def run( update_interval: timedelta | None = None, loop_interval: timedelta | None = None, skip_first_run: bool = False, + send_timeout: float | None = None, ) -> None: """ Runs scheduler loop. @@ -303,11 +336,19 @@ async def run( :param loop_interval: interval to check tasks to send. :param skip_first_run: Wait for the beginning of the next minute to skip the first run. + :param send_timeout: optional per-send timeout (seconds). If set, each + spawned send task is wrapped in :func:`asyncio.wait_for` with this + timeout, preventing a single hung ``broker.kick`` from permanently + blocking subsequent ticks of the same ``schedule_id`` via the + ``running_schedules`` skip check. Default ``None`` (no timeout — + backwards-compatible behavior). """ if update_interval is None: update_interval = timedelta(minutes=1) if loop_interval is None: loop_interval = timedelta(seconds=1) + if send_timeout is not None and send_timeout <= 0: + raise ValueError("send_timeout must be > 0 when provided") running_schedules: dict[ScheduleId, asyncio.Task[Any]] = {} @@ -335,8 +376,17 @@ async def run( ) if is_ready_to_send and task.schedule_id not in running_schedules: + if send_timeout is not None: + send_coro = send_with_timeout( + self.scheduler, + source, + task, + timeout=send_timeout, + ) + else: + send_coro = send(self.scheduler, source, task) send_task = self._event_loop.create_task( - send(self.scheduler, source, task), + send_coro, # We need to set the name of the task # to be able to discard its reference # after it is done. @@ -412,6 +462,7 @@ async def run_scheduler(args: SchedulerArgs) -> None: update_interval=update_interval, loop_interval=loop_interval, skip_first_run=args.skip_first_run, + send_timeout=args.send_timeout, ) except asyncio.CancelledError: logger.info("Shutting down scheduler.") diff --git a/tests/cli/scheduler/test_send_with_timeout.py b/tests/cli/scheduler/test_send_with_timeout.py new file mode 100644 index 00000000..f97c94bc --- /dev/null +++ b/tests/cli/scheduler/test_send_with_timeout.py @@ -0,0 +1,198 @@ +import asyncio +import logging +from typing import Any + +import pytest + +from taskiq.abc.schedule_source import ScheduleSource +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.cli.scheduler.run import send, send_with_timeout +from taskiq.scheduler.scheduled_task import ScheduledTask +from taskiq.scheduler.scheduler import TaskiqScheduler + + +class _StubSource(ScheduleSource): + async def get_schedules(self) -> list[ScheduledTask]: + return [] + + +class _HangingScheduler(TaskiqScheduler): + """Test double whose ``on_ready`` blocks longer than any reasonable timeout. + + Models the production failure mode: ``broker.kick`` (deep inside ``on_ready``) + hangs on a stalled Redis socket and never returns. The scheduler's spawned + ``send`` task would normally retain its ``running_schedules`` entry forever, + silently skipping every subsequent tick. + """ + + def __init__(self, hang_seconds: float = 60.0) -> None: + super().__init__(broker=InMemoryBroker(), sources=[_StubSource()]) + self.hang_seconds = hang_seconds + self.on_ready_calls = 0 + + async def on_ready( # type: ignore[override] + self, source: ScheduleSource, task: ScheduledTask + ) -> None: + self.on_ready_calls += 1 + await asyncio.sleep(self.hang_seconds) + + +def _task() -> ScheduledTask: + return ScheduledTask( + task_name="dummy", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + schedule_id="dummy-schedule-id", + ) + + +@pytest.mark.anyio +async def test_send_with_timeout_returns_on_timeout_without_raising( + caplog: pytest.LogCaptureFixture, +) -> None: + """A send that exceeds the timeout must be cancelled cleanly and NOT raise. + + Raising out of the spawned task would propagate to the + ``add_done_callback``, but the callback only cares that the task completed; + we still want the entry removed from ``running_schedules`` so the next + cron boundary can retry. Returning normally is the right shape. + """ + caplog.set_level(logging.WARNING, logger="taskiq.cli.scheduler.run") + + scheduler = _HangingScheduler(hang_seconds=30.0) + task = _task() + source = _StubSource() + + # Must complete promptly (within ~0.1s) and NOT raise. + await asyncio.wait_for( + send_with_timeout(scheduler, source, task, timeout=0.05), + timeout=2.0, + ) + + assert scheduler.on_ready_calls == 1 + + warnings = [ + rec + for rec in caplog.records + if rec.name == "taskiq.cli.scheduler.run" and rec.levelno >= logging.WARNING + ] + assert len(warnings) == 1 + msg = warnings[0].getMessage() + assert "timed out" in msg + assert "dummy-schedule-id" in msg + assert "dummy" in msg + + +@pytest.mark.anyio +async def test_send_with_timeout_does_not_log_on_success( + caplog: pytest.LogCaptureFixture, +) -> None: + """A send that completes well within the timeout must not produce a warning.""" + caplog.set_level(logging.WARNING, logger="taskiq.cli.scheduler.run") + + class _FastScheduler(TaskiqScheduler): + def __init__(self) -> None: + super().__init__(broker=InMemoryBroker(), sources=[_StubSource()]) + self.calls = 0 + + async def on_ready( # type: ignore[override] + self, source: ScheduleSource, task: ScheduledTask + ) -> None: + self.calls += 1 + + scheduler = _FastScheduler() + task = _task() + source = _StubSource() + + await send_with_timeout(scheduler, source, task, timeout=5.0) + + assert scheduler.calls == 1 + warnings = [ + rec + for rec in caplog.records + if rec.name == "taskiq.cli.scheduler.run" and rec.levelno >= logging.WARNING + ] + assert warnings == [] + + +@pytest.mark.anyio +async def test_send_with_timeout_propagates_non_timeout_exceptions() -> None: + """Errors inside on_ready that AREN'T a timeout must still propagate. + + The wrapper only swallows asyncio.TimeoutError. Other exceptions (e.g., + SendTaskError raised by a broker) need to surface to the existing logging + and observability surfaces. + """ + + class _BoomScheduler(TaskiqScheduler): + def __init__(self) -> None: + super().__init__(broker=InMemoryBroker(), sources=[_StubSource()]) + + async def on_ready( # type: ignore[override] + self, source: ScheduleSource, task: ScheduledTask + ) -> None: + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + await send_with_timeout(_BoomScheduler(), _StubSource(), _task(), timeout=5.0) + + +@pytest.mark.anyio +async def test_send_with_timeout_cancels_inner_send() -> None: + """The inner ``send`` coroutine must actually be cancelled on timeout. + + Without this guarantee, a hung send could continue executing in the + background and double-fire when the next tick spawns a fresh send. + """ + cancelled = asyncio.Event() + + class _CancelObservingScheduler(TaskiqScheduler): + def __init__(self) -> None: + super().__init__(broker=InMemoryBroker(), sources=[_StubSource()]) + + async def on_ready( # type: ignore[override] + self, source: ScheduleSource, task: ScheduledTask + ) -> None: + try: + await asyncio.sleep(30.0) + except asyncio.CancelledError: + cancelled.set() + raise + + await send_with_timeout( + _CancelObservingScheduler(), + _StubSource(), + _task(), + timeout=0.05, + ) + # Give the event loop a chance to deliver the cancellation completion. + await asyncio.wait_for(cancelled.wait(), timeout=2.0) + + +@pytest.mark.anyio +async def test_plain_send_still_works_unchanged() -> None: + """The original ``send`` function must remain unchanged in behavior. + + The timeout path is opt-in via ``send_with_timeout`` only. + """ + + class _RecordingScheduler(TaskiqScheduler): + def __init__(self) -> None: + super().__init__(broker=InMemoryBroker(), sources=[_StubSource()]) + self.calls: list[Any] = [] + + async def on_ready( # type: ignore[override] + self, source: ScheduleSource, task: ScheduledTask + ) -> None: + self.calls.append((source, task)) + + scheduler = _RecordingScheduler() + task = _task() + source = _StubSource() + + await send(scheduler, source, task) + + assert len(scheduler.calls) == 1 + assert scheduler.calls[0][1].schedule_id == "dummy-schedule-id"