Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,22 @@
from itertools import groupby
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy import (
CTE,
Text,
and_,
case,
cast as sql_cast,
delete,
exists,
func,
inspect,
or_,
select,
text,
tuple_,
update,
)
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
Expand Down Expand Up @@ -941,9 +956,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
opt_in_names.add(exc.name.module_path)
whens = []
if opt_in_names:
whens.append((TI.executor.in_(opt_in_names), random_db_uuid()))
whens.append((TI.executor.in_(opt_in_names), sql_cast(random_db_uuid(), Text)))
if default_opts_in:
whens.append((TI.executor.is_(None), random_db_uuid()))
whens.append((TI.executor.is_(None), sql_cast(random_db_uuid(), Text)))
if whens:
queued_values["external_executor_id"] = case(*whens, else_=TI.external_executor_id)

Expand Down
38 changes: 29 additions & 9 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,31 +2788,51 @@ def test_executable_task_instances_to_queued_sets_external_executor_id(self, dag
dag_id = "SchedulerJobTest.test_executable_sets_external_executor_id"
session = settings.Session()
with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE, session=session):
EmptyOperator(task_id="dummy")
EmptyOperator(task_id="a_task_pre_assign")
EmptyOperator(task_id="b_task_regular")

class PreAssigningExecutor(MockExecutor):
pre_assigns_external_executor_id = True
mock_module_path = "mock.pre_assigning.executor"
mock_alias = "pre_assigning_executor"

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job, executors=[PreAssigningExecutor()])
regular_exec = MockExecutor()
assert regular_exec.pre_assigns_external_executor_id is False, "Pre-condition"

pre_assigning_exec = PreAssigningExecutor()

self.job_runner = SchedulerJobRunner(job=Job(), executors=(regular_exec, pre_assigning_exec))

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("dummy", session)
ti.state = State.SCHEDULED
session.merge(ti)
ti_pre_assign = dr.get_task_instance("a_task_pre_assign", session)
ti_regular = dr.get_task_instance("b_task_regular", session)

ti_regular.state = State.SCHEDULED
ti_regular.executor = regular_exec.name.module_path
ti_pre_assign.state = State.SCHEDULED
ti_pre_assign.executor = pre_assigning_exec.name.module_path
session.flush()

returned_tis = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
returned_tis.sort(key=lambda ti: ti.task_id)

assert len(returned_tis) == 2

assert len(returned_tis) == 1
# In-memory object (post make_transient) should carry the UUID
assert returned_tis[0].id == ti_pre_assign.id
assert returned_tis[0].external_executor_id is not None
UUID(returned_tis[0].external_executor_id)
assert UUID(returned_tis[0].external_executor_id), "is valid uuid"

# DB row should also have it (the whole point — survives a crash)
db_value = session.scalar(select(TaskInstance.external_executor_id).where(TaskInstance.id == ti.id))
db_value = session.scalar(
select(TaskInstance.external_executor_id).where(TaskInstance.id == ti_pre_assign.id)
)
assert db_value == returned_tis[0].external_executor_id

# In mixed-executor mode, only TIs routed to a pre-assigning executor get an external_executor_id.
assert returned_tis[1].id == ti_regular.id
assert returned_tis[1].external_executor_id is None

session.rollback()

@pytest.mark.parametrize("state", [State.FAILED, State.SUCCESS])
Expand Down
Loading