From 396324f22c163ee23fcb43b927c5991c2986cb19 Mon Sep 17 00:00:00 2001 From: tstrilka Date: Mon, 12 Jan 2026 16:06:30 +0100 Subject: [PATCH] Add on_task_instance_skipped support to OpenLineage listener --- .../providers/openlineage/plugins/listener.py | 125 +++++++++++++++++- .../unit/openlineage/plugins/test_listener.py | 88 ++++++++++++ 2 files changed, 212 insertions(+), 1 deletion(-) diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index bc94294d77020..a827a8a01e053 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -512,6 +512,129 @@ def on_failure(): self._execute(on_failure, "on_failure", use_fork=True) + if AIRFLOW_V_3_0_PLUS: + + @hookimpl + def on_task_instance_skipped( + self, + previous_state: TaskInstanceState, + task_instance: RuntimeTaskInstance | TaskInstance, + ) -> None: + self.log.debug("OpenLineage listener got notification about task instance skip") + + if isinstance(task_instance, TaskInstance): + self._on_task_instance_manual_state_change( + ti=task_instance, + dagrun=task_instance.dag_run, + ti_state=TaskInstanceState.SKIPPED, + ) + return + + context = task_instance.get_template_context() + task = context["task"] + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + self._on_task_instance_skipped(task_instance, dag, dagrun, task) + + def _on_task_instance_skipped( + self, + task_instance: RuntimeTaskInstance, + dag, + dagrun, + task, + ) -> None: + end_date = timezone.utcnow() + + if is_operator_disabled(task): + self.log.debug( + "Skipping OpenLineage event emission for operator `%s` " + "due to its presence in [openlineage] disabled_for_operators.", + task.task_type, + ) + return + + if not is_selective_lineage_enabled(task): + self.log.debug( + "Skipping OpenLineage event emission for task `%s` " + "due to lack of explicit lineage enablement for task or DAG while " + "[openlineage] selective_enable is on.", + task_instance.task_id, + ) + return + + @print_warning(self.log) + def on_skipped(): + date = dagrun.logical_date + if AIRFLOW_V_3_0_PLUS and date is None: + date = dagrun.run_after + + parent_run_id = self.adapter.build_dag_run_id( + dag_id=task_instance.dag_id, + logical_date=date, + clear_number=dagrun.clear_number, + ) + + task_uuid = self.adapter.build_task_instance_run_id( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + try_number=task_instance.try_number, + logical_date=date, + map_index=task_instance.map_index, + ) + event_type = RunState.COMPLETE.value.lower() + operator_name = task.task_type.lower() + + data_interval_start = dagrun.data_interval_start + if isinstance(data_interval_start, datetime): + data_interval_start = data_interval_start.isoformat() + data_interval_end = dagrun.data_interval_end + if isinstance(data_interval_end, datetime): + data_interval_end = data_interval_end.isoformat() + + doc, doc_type = get_task_documentation(task) + if not doc: + doc, doc_type = get_dag_documentation(dag) + + with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): + task_metadata = self.extractor_manager.extract_metadata( + dagrun=dagrun, + task=task, + task_instance_state=TaskInstanceState.SKIPPED, + task_instance=task_instance, + ) + + redacted_event = self.adapter.complete_task( + run_id=task_uuid, + job_name=get_job_name(task_instance), + end_time=end_date.isoformat(), + task=task_metadata, + # If task owner is default ("airflow"), use DAG owner instead that may have more details + owners=[x.strip() for x in (task if task.owner != "airflow" else dag).owner.split(",")], + tags=dag.tags, + job_description=doc, + job_description_type=doc_type, + nominal_start_time=data_interval_start, + nominal_end_time=data_interval_end, + run_facets={ + **get_user_provided_run_facets(task_instance, TaskInstanceState.SKIPPED), + **get_task_parent_run_facet( + parent_run_id=parent_run_id, + parent_job_name=dag.dag_id, + dr_conf=getattr(dagrun, "conf", {}), + ), + **get_airflow_run_facet(dagrun, dag, task_instance, task, task_uuid), + **get_airflow_debug_facet(), + }, + ) + Stats.gauge( + f"ol.event.size.{event_type}.{operator_name}", + len(Serde.to_json(redacted_event).encode("utf-8")), + ) + + self._execute(on_skipped, "on_skipped", use_fork=True) + def _on_task_instance_manual_state_change( self, ti: TaskInstance, @@ -563,7 +686,7 @@ def on_state_change(): if ti_state == TaskInstanceState.FAILED: event_type = RunState.FAIL.value.lower() redacted_event = self.adapter.fail_task(**adapter_kwargs, error=error) - elif ti_state == TaskInstanceState.SUCCESS: + elif ti_state in (TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED): event_type = RunState.COMPLETE.value.lower() redacted_event = self.adapter.complete_task(**adapter_kwargs) else: diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index d226abb8c74ef..705775f258f8e 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -1797,6 +1797,94 @@ def test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_ope listener.extractor_manager.extract_metadata.assert_not_called() listener.adapter.complete_task.assert_not_called() + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + def test_on_task_instance_skipped_correctly_calls_openlineage_adapter_run_id_method(self): + """Tests the OpenLineageListener's response when a task instance is skipped. + + This test ensures that when an Airflow task instance is skipped via AirflowSkipException, + the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct + parameters derived from the task instance. + """ + listener, task_instance = self._create_listener_and_task_instance() + listener.on_task_instance_skipped(previous_state=None, task_instance=task_instance) + listener.adapter.build_task_instance_run_id.assert_called_once_with( + dag_id="dag_id", + task_id="task_id", + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), + try_number=1, + map_index=-1, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + def test_listener_on_task_instance_skipped_do_not_call_adapter_when_disabled_operator( + self, mock_get_user_provided_run_facets, mock_disabled + ): + listener, task_instance = self._create_listener_and_task_instance() + mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} + mock_disabled.return_value = True + + listener.on_task_instance_skipped(previous_state=None, task_instance=task_instance) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.complete_task.assert_not_called() + + @mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) + def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_instance_model_on_skip( + self, mock_get_task_parent_run_facet, mock_debug_facet, mock_debug_mode, mock_emit, time_machine + ): + """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. + + This particular test is using TaskInstance model available on API Server and not on worker, + to simulate the listener being called after task's state has been manually set to SKIPPED via API. + """ + time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False) + + listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False) + mock_get_task_parent_run_facet.return_value = {"parent": 4} + mock_debug_facet.return_value = {"debug": "packages"} + + listener.on_task_instance_skipped(previous_state=None, task_instance=task_instance) + calls = listener.adapter.complete_task.call_args_list + assert len(calls) == 1 + mock_get_task_parent_run_facet.assert_called_once_with( + parent_run_id="2020-01-01T01:01:01+00:00.dag_id.0", + parent_job_name=task_instance.dag_id, + dr_conf={}, + ) + expected_args = dict( + end_time="2023-01-03T13:01:01+00:00", + job_name="dag_id.task_id", + run_id="2020-01-01T01:01:01+00:00.dag_id.task_id.1.-1", + task=OperatorLineage(), + nominal_start_time=None, + nominal_end_time=None, + tags=None, + owners=None, + job_description=None, + job_description_type=None, + run_facets={ + "parent": 4, + "debug": "packages", + }, + ) + assert calls[0][1] == expected_args + + expected_args["run_id"] = "9d3b14f7-de91-40b6-aeef-e887e2c7673e" + adapter = OpenLineageAdapter() + adapter.complete_task(**expected_args) + assert mock_emit.assert_called_once + @pytest.mark.parametrize( ("max_workers", "expected"), [