diff --git a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py index b6e537af97d2b..08a39d5e3d857 100644 --- a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py +++ b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py @@ -93,6 +93,12 @@ def date_param(): "dags update --dag-id=example_bash_operator --no-is-paused", # DAG Run commands "dagrun list --dag-id example_bash_operator --state success --limit=1", + # XCom commands - need a DAG run with completed tasks + f'xcom add --dag-id=example_bash_operator --dag-run-id="manual__{ONE_DATE_PARAM}" --task-id=runme_0 --key=test_xcom_key --value=\'{{"test": "value"}}\'', + f'xcom get --dag-id=example_bash_operator --dag-run-id="manual__{ONE_DATE_PARAM}" --task-id=runme_0 --key=test_xcom_key', + f'xcom list --dag-id=example_bash_operator --dag-run-id="manual__{ONE_DATE_PARAM}" --task-id=runme_0', + f'xcom edit --dag-id=example_bash_operator --dag-run-id="manual__{ONE_DATE_PARAM}" --task-id=runme_0 --key=test_xcom_key --value=\'{{"updated": "value"}}\'', + f'xcom delete --dag-id=example_bash_operator --dag-run-id="manual__{ONE_DATE_PARAM}" --task-id=runme_0 --key=test_xcom_key', # Jobs commands "jobs list", # Pools commands diff --git a/airflow-ctl/docs/images/command_hashes.txt b/airflow-ctl/docs/images/command_hashes.txt index c0e3b5995fae9..8a450901218a4 100644 --- a/airflow-ctl/docs/images/command_hashes.txt +++ b/airflow-ctl/docs/images/command_hashes.txt @@ -1,4 +1,4 @@ -main:deacf21c6300eae16afbf8cbd538f1ef +main:65249416abad6ad24c276fb44326ae15 assets:b3ae2b933e54528bf486ff28e887804d auth:f396d4bce90215599dde6ad0a8f30f29 backfill:bbce9859a2d1ce054ad22db92dea8c05 diff --git a/airflow-ctl/docs/images/output_main.svg b/airflow-ctl/docs/images/output_main.svg index f6c7225a4ebc9..8e4ef71bdb016 100644 --- a/airflow-ctl/docs/images/output_main.svg +++ b/airflow-ctl/docs/images/output_main.svg @@ -1,4 +1,4 @@ - + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + + + + - + - + - - Usage:airflowctl [-hGROUP_OR_COMMAND... - -Positional Arguments: -GROUP_OR_COMMAND - -    Groups -assetsPerform Assets operations -authManage authentication for CLI. Either pass -token from environment variable/parameter -or pass username and password. -backfillPerform Backfill operations -configPerform Config operations -connectionsPerform Connections operations -dagrunPerform DagRun operations -dagsPerform Dags operations -jobsPerform Jobs operations -poolsPerform Pools operations -providersPerform Providers operations -variablesPerform Variables operations - -    Commands: -versionShow version information - -Options: --h--helpshow this help message and exit + + Usage:airflowctl [-hGROUP_OR_COMMAND... + +Positional Arguments: +GROUP_OR_COMMAND + +    Groups +assetsPerform Assets operations +authManage authentication for CLI. Either pass token from +environment variable/parameter or pass username and +password. +backfillPerform Backfill operations +configPerform Config operations +connectionsPerform Connections operations +dagrunPerform DagRun operations +dagsPerform Dags operations +jobsPerform Jobs operations +poolsPerform Pools operations +providersPerform Providers operations +variablesPerform Variables operations +xcomPerform XCom operations + +    Commands: +versionShow version information + +Options: +-h--helpshow this help message and exit diff --git a/airflow-ctl/src/airflowctl/api/client.py b/airflow-ctl/src/airflowctl/api/client.py index 9e138ffc87b55..5a719cac164cc 100644 --- a/airflow-ctl/src/airflowctl/api/client.py +++ b/airflow-ctl/src/airflowctl/api/client.py @@ -48,6 +48,7 @@ ServerResponseError, VariablesOperations, VersionOperations, + XComOperations, ) from airflowctl.exceptions import ( AirflowCtlCredentialNotFoundException, @@ -301,6 +302,12 @@ def version(self): """Get the version of the server.""" return VersionOperations(self) + @lru_cache() # type: ignore[prop-decorator] + @property + def xcom(self): + """Operations related to XComs.""" + return XComOperations(self) + # API Client Decorator for CLI Actions @contextlib.contextmanager diff --git a/airflow-ctl/src/airflowctl/api/operations.py b/airflow-ctl/src/airflowctl/api/operations.py index 31e0298a19d82..588d85569e3b5 100644 --- a/airflow-ctl/src/airflowctl/api/operations.py +++ b/airflow-ctl/src/airflowctl/api/operations.py @@ -18,6 +18,7 @@ from __future__ import annotations import datetime +import json from typing import TYPE_CHECKING, Any, TypeVar import httpx @@ -70,6 +71,10 @@ VariableCollectionResponse, VariableResponse, VersionInfo, + XComCollectionResponse, + XComCreateBody, + XComResponseNative, + XComUpdateBody, ) from airflowctl.exceptions import AirflowCtlConnectionException @@ -697,3 +702,125 @@ def get(self) -> VersionInfo | ServerResponseError: return VersionInfo.model_validate_json(self.response.content) except ServerResponseError as e: raise e + + +class XComOperations(BaseOperations): + """XCom operations.""" + + def get( + self, + dag_id: str, + dag_run_id: str, + task_id: str, + key: str, + map_index: int = None, # type: ignore + ) -> XComResponseNative | ServerResponseError: + """Get an XCom entry.""" + try: + params: dict[str, Any] = {} + if map_index is not None: + params["map_index"] = map_index + self.response = self.client.get( + f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{key}", + params=params, + ) + return XComResponseNative.model_validate_json(self.response.content) + except ServerResponseError as e: + raise e + + def list( + self, + dag_id: str, + dag_run_id: str, + task_id: str, + map_index: int = None, # type: ignore + key: str = None, # type: ignore + ) -> XComCollectionResponse | ServerResponseError: + """List XCom entries.""" + params: dict[str, Any] = {} + if map_index is not None: + params["map_index"] = map_index + if key is not None: + params["xcom_key"] = key + return super().execute_list( + path=f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries", + data_model=XComCollectionResponse, + params=params, + ) + + def add( + self, + dag_id: str, + dag_run_id: str, + task_id: str, + key: str, + value: str, + map_index: int = None, # type: ignore + ) -> XComResponseNative | ServerResponseError: + """Add an XCom entry.""" + try: + parsed_value = json.loads(value) + except (ValueError, TypeError): + parsed_value = value + + body_dict: dict[str, Any] = {"key": key, "value": parsed_value} + if map_index is not None: + body_dict["map_index"] = map_index + body = XComCreateBody(**body_dict) + try: + self.response = self.client.post( + f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries", + json=body.model_dump(mode="json", exclude_unset=True), + ) + return XComResponseNative.model_validate_json(self.response.content) + except ServerResponseError as e: + raise e + + def edit( + self, + dag_id: str, + dag_run_id: str, + task_id: str, + key: str, + value: str, + map_index: int = None, # type: ignore + ) -> XComResponseNative | ServerResponseError: + """Edit an XCom entry.""" + try: + parsed_value = json.loads(value) + except (ValueError, TypeError): + parsed_value = value + + body_dict: dict[str, Any] = {"value": parsed_value} + if map_index is not None: + body_dict["map_index"] = map_index + body = XComUpdateBody(**body_dict) + try: + self.response = self.client.patch( + f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{key}", + json=body.model_dump(mode="json", exclude_unset=True), + ) + return XComResponseNative.model_validate_json(self.response.content) + except ServerResponseError as e: + raise e + + def delete( + self, + dag_id: str, + dag_run_id: str, + task_id: str, + key: str, + map_index: int = None, # type: ignore + ) -> str | ServerResponseError: + """Delete an XCom entry.""" + try: + params: dict[str, Any] = {} + if map_index is not None: + params["map_index"] = map_index + self.client.delete( + f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{key}", + params=params, + ) + return key + except ServerResponseError as e: + raise e diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index a578396acd492..467306c19724f 100644 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -381,7 +381,7 @@ def __init__(self, file_path: str | Path | None = None): # Exclude parameters that are not needed for CLI from datamodels self.excluded_parameters = ["schema_"] # This list is used to determine if the command/operation needs to output data - self.output_command_list = ["list", "get", "create", "delete", "update", "trigger"] + self.output_command_list = ["list", "get", "create", "delete", "update", "trigger", "add", "edit"] self.exclude_operation_names = ["LoginOperations", "VersionOperations", "BaseOperations"] self.exclude_method_names = [ "error", diff --git a/airflow-ctl/tests/airflow_ctl/api/test_operations.py b/airflow-ctl/tests/airflow_ctl/api/test_operations.py index f0a638475c480..4e8c0ab75907b 100644 --- a/airflow-ctl/tests/airflow_ctl/api/test_operations.py +++ b/airflow-ctl/tests/airflow_ctl/api/test_operations.py @@ -92,6 +92,9 @@ VariableCollectionResponse, VariableResponse, VersionInfo, + XComCollectionResponse, + XComResponse, + XComResponseNative, ) from airflowctl.api.operations import BaseOperations from airflowctl.exceptions import AirflowCtlConnectionException @@ -1265,3 +1268,337 @@ def handle_request(request: httpx.Request) -> httpx.Response: ) ) assert response.access_token == "NO_TOKEN" + + +class TestXComOperations: + """Test suite for XCom operations.""" + + dag_id: str = "test_dag" + dag_run_id: str = "manual__2025-01-24T00:00:00+00:00" + task_id: str = "test_task" + key: str = "test_key" + map_index: int = 0 + + xcom_response_native = XComResponseNative( + key=key, + timestamp=datetime.datetime(2025, 1, 24, 0, 0, 0), + logical_date=datetime.datetime(2025, 1, 24, 0, 0, 0), + map_index=-1, + task_id=task_id, + dag_id=dag_id, + run_id=dag_run_id, + dag_display_name=dag_id, + task_display_name=task_id, + value={"result": "success"}, + ) + + xcom_response = XComResponse( + key=key, + timestamp=datetime.datetime(2025, 1, 24, 0, 0, 0), + logical_date=datetime.datetime(2025, 1, 24, 0, 0, 0), + map_index=-1, + task_id=task_id, + dag_id=dag_id, + run_id=dag_run_id, + dag_display_name=dag_id, + task_display_name=task_id, + ) + + xcom_collection_response = XComCollectionResponse( + xcom_entries=[xcom_response], + total_entries=1, + ) + + def test_get(self): + """Test fetching a single XCom entry without map_index.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries/{self.key}" + ) + # Verify map_index is not in query params when not provided + assert "map_index" not in str(request.url.query) + return httpx.Response(200, json=json.loads(self.xcom_response_native.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.get( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + ) + assert response == self.xcom_response_native + + def test_get_with_map_index(self): + """Test fetching XCom entry for a mapped task with map_index.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries/{self.key}" + ) + # Verify map_index is included in query params + assert f"map_index={self.map_index}" in str(request.url.query) + return httpx.Response(200, json=json.loads(self.xcom_response_native.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.get( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + map_index=self.map_index, + ) + assert response == self.xcom_response_native + + def test_list(self): + """Test listing all XCom entries for a task instance without filters.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries" + ) + # Verify no filters in query params + assert "map_index" not in str(request.url.query) + assert "xcom_key" not in str(request.url.query) + return httpx.Response(200, json=json.loads(self.xcom_collection_response.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.list( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + ) + assert response == self.xcom_collection_response + + def test_list_with_map_index_filter(self): + """Test listing XCom entries filtered by map_index.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries" + ) + # Verify map_index filter is included + assert f"map_index={self.map_index}" in str(request.url.query) + return httpx.Response(200, json=json.loads(self.xcom_collection_response.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.list( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + map_index=self.map_index, + ) + assert response == self.xcom_collection_response + + def test_list_with_key_filter(self): + """Test listing XCom entries filtered by key.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries" + ) + # Verify xcom_key filter is included + assert f"xcom_key={self.key}" in str(request.url.query) + return httpx.Response(200, json=json.loads(self.xcom_collection_response.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.list( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + ) + assert response == self.xcom_collection_response + + def test_list_with_both_filters(self): + """Test listing XCom entries with both map_index and key filters.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries" + ) + # Verify both filters are included + assert f"map_index={self.map_index}" in str(request.url.query) + assert f"xcom_key={self.key}" in str(request.url.query) + return httpx.Response(200, json=json.loads(self.xcom_collection_response.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.list( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + map_index=self.map_index, + key=self.key, + ) + assert response == self.xcom_collection_response + + def test_add_with_json_value(self): + """Test adding a new XCom entry with JSON value.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries" + ) + # Verify request body + request_body = json.loads(request.content) + assert request_body["key"] == self.key + assert request_body["value"] == {"result": "success"} + # Verify map_index is NOT in body when not provided + assert "map_index" not in request_body + return httpx.Response(200, json=json.loads(self.xcom_response_native.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.add( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + value='{"result": "success"}', + ) + assert response == self.xcom_response_native + + def test_add_with_string_value(self): + """Test adding XCom entry with non-JSON string value.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries" + ) + # Verify plain string is stored as-is + request_body = json.loads(request.content) + assert request_body["value"] == "plain string value" + return httpx.Response(200, json=json.loads(self.xcom_response_native.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.add( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + value="plain string value", + ) + assert response == self.xcom_response_native + + def test_add_with_map_index(self): + """Test adding XCom entry for a mapped task with map_index.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries" + ) + # Verify map_index is included in request body + request_body = json.loads(request.content) + assert request_body["map_index"] == self.map_index + return httpx.Response(200, json=json.loads(self.xcom_response_native.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.add( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + value='{"result": "success"}', + map_index=self.map_index, + ) + assert response == self.xcom_response_native + + def test_edit_with_json_value(self): + """Test editing an existing XCom entry with JSON value.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries/{self.key}" + ) + # Verify request body + request_body = json.loads(request.content) + assert request_body["value"] == {"updated": "value"} + # Verify map_index is NOT in body when not provided + assert "map_index" not in request_body + return httpx.Response(200, json=json.loads(self.xcom_response_native.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.edit( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + value='{"updated": "value"}', + ) + assert response == self.xcom_response_native + + def test_edit_with_map_index(self): + """Test editing XCom entry for a mapped task with map_index.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries/{self.key}" + ) + # Verify map_index is included in request body + request_body = json.loads(request.content) + assert request_body["map_index"] == self.map_index + return httpx.Response(200, json=json.loads(self.xcom_response_native.model_dump_json())) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.edit( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + value='{"updated": "value"}', + map_index=self.map_index, + ) + assert response == self.xcom_response_native + + def test_delete(self): + """Test deleting an XCom entry without map_index.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries/{self.key}" + ) + # Verify map_index is NOT in query params when not provided + assert "map_index" not in str(request.url.query) + return httpx.Response(204) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.delete( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + ) + assert response == self.key + + def test_delete_with_map_index(self): + """Test deleting XCom entry for a mapped task with map_index.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == ( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/" + f"taskInstances/{self.task_id}/xcomEntries/{self.key}" + ) + # Verify map_index is included in query params + assert f"map_index={self.map_index}" in str(request.url.query) + return httpx.Response(204) + + client = make_api_client(transport=httpx.MockTransport(handle_request)) + response = client.xcom.delete( + dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + key=self.key, + map_index=self.map_index, + ) + assert response == self.key