Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions taskiq/cli/scheduler/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SchedulerArgs:
skip_first_run: bool = False
update_interval: int | None = None
loop_interval: int | None = None
send_timeout: float | None = None

@classmethod
def from_cli(cls, args: Sequence[str] | None = None) -> "SchedulerArgs":
Expand Down Expand Up @@ -111,6 +112,19 @@ def from_cli(cls, args: Sequence[str] | None = None) -> "SchedulerArgs":
"If not specified, scheduler will run once a second."
),
)
parser.add_argument(
"--send-timeout",
type=float,
default=None,
help=(
"Optional per-send timeout (seconds). If set, each spawned send "
"task is wrapped in asyncio.wait_for with this timeout, "
"preventing a single hung broker.kick from permanently blocking "
"subsequent ticks for the same schedule_id. On timeout the send "
"is cancelled, a warning is logged, and the next scheduled tick "
"will retry. Default: no timeout (backwards-compatible)."
),
)

namespace = parser.parse_args(args)
# If there are any patterns specified, remove default.
Expand Down
53 changes: 52 additions & 1 deletion taskiq/cli/scheduler/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,38 @@ async def send(
await scheduler.on_ready(source, task)


async def send_with_timeout(
scheduler: TaskiqScheduler,
source: ScheduleSource,
task: ScheduledTask,
timeout: float,
) -> None:
"""
Send a task, cancelling it if it does not complete within ``timeout`` seconds.

A timed-out send is logged at WARNING and swallowed — re-raising would propagate
out of the scheduler's main loop's ``add_done_callback``, which is not what we
want. Suppressing it lets the done_callback clear ``running_schedules`` so the
next cron boundary can re-dispatch the task. The slot is freed; the message is
dropped (the broker did not acknowledge it within the budget).

:param scheduler: current scheduler.
:param source: source of the task.
:param task: task to send.
:param timeout: seconds to wait before cancelling the send. Must be > 0.
"""
try:
await asyncio.wait_for(send(scheduler, source, task), timeout=timeout)
except asyncio.TimeoutError:
logger.warning(
"Sending task %s with schedule_id %s timed out after %.1fs "
"and was cancelled. The next scheduled tick will retry.",
task.task_name,
task.schedule_id,
timeout,
)


async def _sleep_until_next_second() -> None:
now = datetime.now(tz=timezone.utc)
await asyncio.sleep(1 - now.microsecond / 1_000_000)
Expand Down Expand Up @@ -292,6 +324,7 @@ async def run(
update_interval: timedelta | None = None,
loop_interval: timedelta | None = None,
skip_first_run: bool = False,
send_timeout: float | None = None,
) -> None:
"""
Runs scheduler loop.
Expand All @@ -303,11 +336,19 @@ async def run(
:param loop_interval: interval to check tasks to send.
:param skip_first_run: Wait for the beginning of the next minute
to skip the first run.
:param send_timeout: optional per-send timeout (seconds). If set, each
spawned send task is wrapped in :func:`asyncio.wait_for` with this
timeout, preventing a single hung ``broker.kick`` from permanently
blocking subsequent ticks of the same ``schedule_id`` via the
``running_schedules`` skip check. Default ``None`` (no timeout —
backwards-compatible behavior).
"""
if update_interval is None:
update_interval = timedelta(minutes=1)
if loop_interval is None:
loop_interval = timedelta(seconds=1)
if send_timeout is not None and send_timeout <= 0:
raise ValueError("send_timeout must be > 0 when provided")

running_schedules: dict[ScheduleId, asyncio.Task[Any]] = {}

Expand Down Expand Up @@ -335,8 +376,17 @@ async def run(
)

if is_ready_to_send and task.schedule_id not in running_schedules:
if send_timeout is not None:
send_coro = send_with_timeout(
self.scheduler,
source,
task,
timeout=send_timeout,
)
else:
send_coro = send(self.scheduler, source, task)
send_task = self._event_loop.create_task(
send(self.scheduler, source, task),
send_coro,
# We need to set the name of the task
# to be able to discard its reference
# after it is done.
Expand Down Expand Up @@ -412,6 +462,7 @@ async def run_scheduler(args: SchedulerArgs) -> None:
update_interval=update_interval,
loop_interval=loop_interval,
skip_first_run=args.skip_first_run,
send_timeout=args.send_timeout,
)
except asyncio.CancelledError:
logger.info("Shutting down scheduler.")
Expand Down
198 changes: 198 additions & 0 deletions tests/cli/scheduler/test_send_with_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import asyncio
import logging
from typing import Any

import pytest

from taskiq.abc.schedule_source import ScheduleSource
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.cli.scheduler.run import send, send_with_timeout
from taskiq.scheduler.scheduled_task import ScheduledTask
from taskiq.scheduler.scheduler import TaskiqScheduler


class _StubSource(ScheduleSource):
async def get_schedules(self) -> list[ScheduledTask]:
return []


class _HangingScheduler(TaskiqScheduler):
"""Test double whose ``on_ready`` blocks longer than any reasonable timeout.

Models the production failure mode: ``broker.kick`` (deep inside ``on_ready``)
hangs on a stalled Redis socket and never returns. The scheduler's spawned
``send`` task would normally retain its ``running_schedules`` entry forever,
silently skipping every subsequent tick.
"""

def __init__(self, hang_seconds: float = 60.0) -> None:
super().__init__(broker=InMemoryBroker(), sources=[_StubSource()])
self.hang_seconds = hang_seconds
self.on_ready_calls = 0

async def on_ready( # type: ignore[override]
self, source: ScheduleSource, task: ScheduledTask
) -> None:
self.on_ready_calls += 1
await asyncio.sleep(self.hang_seconds)


def _task() -> ScheduledTask:
return ScheduledTask(
task_name="dummy",
labels={},
args=[],
kwargs={},
cron="* * * * *",
schedule_id="dummy-schedule-id",
)


@pytest.mark.anyio
async def test_send_with_timeout_returns_on_timeout_without_raising(
caplog: pytest.LogCaptureFixture,
) -> None:
"""A send that exceeds the timeout must be cancelled cleanly and NOT raise.

Raising out of the spawned task would propagate to the
``add_done_callback``, but the callback only cares that the task completed;
we still want the entry removed from ``running_schedules`` so the next
cron boundary can retry. Returning normally is the right shape.
"""
caplog.set_level(logging.WARNING, logger="taskiq.cli.scheduler.run")

scheduler = _HangingScheduler(hang_seconds=30.0)
task = _task()
source = _StubSource()

# Must complete promptly (within ~0.1s) and NOT raise.
await asyncio.wait_for(
send_with_timeout(scheduler, source, task, timeout=0.05),
timeout=2.0,
)

assert scheduler.on_ready_calls == 1

warnings = [
rec
for rec in caplog.records
if rec.name == "taskiq.cli.scheduler.run" and rec.levelno >= logging.WARNING
]
assert len(warnings) == 1
msg = warnings[0].getMessage()
assert "timed out" in msg
assert "dummy-schedule-id" in msg
assert "dummy" in msg


@pytest.mark.anyio
async def test_send_with_timeout_does_not_log_on_success(
caplog: pytest.LogCaptureFixture,
) -> None:
"""A send that completes well within the timeout must not produce a warning."""
caplog.set_level(logging.WARNING, logger="taskiq.cli.scheduler.run")

class _FastScheduler(TaskiqScheduler):
def __init__(self) -> None:
super().__init__(broker=InMemoryBroker(), sources=[_StubSource()])
self.calls = 0

async def on_ready( # type: ignore[override]
self, source: ScheduleSource, task: ScheduledTask
) -> None:
self.calls += 1

scheduler = _FastScheduler()
task = _task()
source = _StubSource()

await send_with_timeout(scheduler, source, task, timeout=5.0)

assert scheduler.calls == 1
warnings = [
rec
for rec in caplog.records
if rec.name == "taskiq.cli.scheduler.run" and rec.levelno >= logging.WARNING
]
assert warnings == []


@pytest.mark.anyio
async def test_send_with_timeout_propagates_non_timeout_exceptions() -> None:
"""Errors inside on_ready that AREN'T a timeout must still propagate.

The wrapper only swallows asyncio.TimeoutError. Other exceptions (e.g.,
SendTaskError raised by a broker) need to surface to the existing logging
and observability surfaces.
"""

class _BoomScheduler(TaskiqScheduler):
def __init__(self) -> None:
super().__init__(broker=InMemoryBroker(), sources=[_StubSource()])

async def on_ready( # type: ignore[override]
self, source: ScheduleSource, task: ScheduledTask
) -> None:
raise RuntimeError("boom")

with pytest.raises(RuntimeError, match="boom"):
await send_with_timeout(_BoomScheduler(), _StubSource(), _task(), timeout=5.0)


@pytest.mark.anyio
async def test_send_with_timeout_cancels_inner_send() -> None:
"""The inner ``send`` coroutine must actually be cancelled on timeout.

Without this guarantee, a hung send could continue executing in the
background and double-fire when the next tick spawns a fresh send.
"""
cancelled = asyncio.Event()

class _CancelObservingScheduler(TaskiqScheduler):
def __init__(self) -> None:
super().__init__(broker=InMemoryBroker(), sources=[_StubSource()])

async def on_ready( # type: ignore[override]
self, source: ScheduleSource, task: ScheduledTask
) -> None:
try:
await asyncio.sleep(30.0)
except asyncio.CancelledError:
cancelled.set()
raise

await send_with_timeout(
_CancelObservingScheduler(),
_StubSource(),
_task(),
timeout=0.05,
)
# Give the event loop a chance to deliver the cancellation completion.
await asyncio.wait_for(cancelled.wait(), timeout=2.0)


@pytest.mark.anyio
async def test_plain_send_still_works_unchanged() -> None:
"""The original ``send`` function must remain unchanged in behavior.

The timeout path is opt-in via ``send_with_timeout`` only.
"""

class _RecordingScheduler(TaskiqScheduler):
def __init__(self) -> None:
super().__init__(broker=InMemoryBroker(), sources=[_StubSource()])
self.calls: list[Any] = []

async def on_ready( # type: ignore[override]
self, source: ScheduleSource, task: ScheduledTask
) -> None:
self.calls.append((source, task))

scheduler = _RecordingScheduler()
task = _task()
source = _StubSource()

await send(scheduler, source, task)

assert len(scheduler.calls) == 1
assert scheduler.calls[0][1].schedule_id == "dummy-schedule-id"
Loading