diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index ff5fe3ab772b3..8b792c49dfb64 100755 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -266,12 +266,11 @@ def __call__(self, parser, namespace, values, option_string=None): help="The DAG ID of the DAG to pause or unpause", ) -# Variable Commands Args -ARG_VARIABLE_ACTION_ON_EXISTING_KEY = Arg( +ARG_ACTION_ON_EXISTING_KEY = Arg( flags=("-a", "--action-on-existing-key"), type=str, default="overwrite", - help="Action to take if we encounter a variable key that already exists.", + help="Action to take if the entity already exists.", choices=("overwrite", "fail", "skip"), ) @@ -877,7 +876,10 @@ def merge_commands( name="import", help="Import connections from a file exported with local CLI.", func=lazy_load_command("airflowctl.ctl.commands.connection_command.import_"), - args=(Arg(flags=("file",), metavar="FILEPATH", help="Connections JSON file"),), + args=( + Arg(flags=("file",), metavar="FILEPATH", help="Connections JSON file"), + ARG_ACTION_ON_EXISTING_KEY, + ), ), ) @@ -907,7 +909,7 @@ def merge_commands( name="import", help="Import pools", func=lazy_load_command("airflowctl.ctl.commands.pool_command.import_"), - args=(ARG_FILE,), + args=(ARG_FILE, ARG_ACTION_ON_EXISTING_KEY), ), ActionCommand( name="export", @@ -925,7 +927,7 @@ def merge_commands( name="import", help="Import variables from a file exported with local CLI.", func=lazy_load_command("airflowctl.ctl.commands.variable_command.import_"), - args=(ARG_FILE, ARG_VARIABLE_ACTION_ON_EXISTING_KEY), + args=(ARG_FILE, ARG_ACTION_ON_EXISTING_KEY), ), ) diff --git a/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py b/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py index b689083faa28e..b1a8a820998ac 100644 --- a/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py +++ b/airflow-ctl/src/airflowctl/ctl/commands/connection_command.py @@ -62,7 +62,7 @@ def import_(args, api_client=NEW_API_CLIENT) -> None: connection_create_action = BulkCreateActionConnectionBody( action="create", entities=list(connections_data.values()), - action_on_existence=BulkActionOnExistence("fail"), + action_on_existence=BulkActionOnExistence(args.action_on_existing_key), ) response = api_client.connections.bulk(BulkBodyConnectionBody(actions=[connection_create_action])) if response.create.errors: diff --git a/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py b/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py index bc0009e339147..b6a57de4bc4b4 100644 --- a/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py +++ b/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py @@ -40,7 +40,7 @@ def import_(args, api_client: Client = NEW_API_CLIENT) -> None: if not filepath.exists(): raise SystemExit(f"Missing pools file {args.file}") - success, errors = _import_helper(api_client, filepath) + success, errors = _import_helper(api_client, filepath, BulkActionOnExistence(args.action_on_existing_key)) if errors: raise SystemExit(f"Failed to update pool(s): {errors}") rich.print(success) @@ -83,7 +83,7 @@ def export(args, api_client: Client = NEW_API_CLIENT) -> None: raise SystemExit(f"Failed to export pools: {e}") -def _import_helper(api_client: Client, filepath: Path): +def _import_helper(api_client: Client, filepath: Path, action_on_existence: BulkActionOnExistence): """Help import pools from the json file.""" try: with open(filepath) as f: @@ -113,7 +113,7 @@ def _import_helper(api_client: Client, filepath: Path): BulkCreateActionPoolBody( action="create", entities=pools_to_update, - action_on_existence=BulkActionOnExistence.FAIL, + action_on_existence=action_on_existence, ) ] ) diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py index 02b56eda99b0e..f944e66ab1f24 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_connections_command.py @@ -17,12 +17,14 @@ from __future__ import annotations import json +from unittest import mock from unittest.mock import patch import pytest -from airflowctl.api.client import ClientKind +from airflowctl.api.client import Client, ClientKind from airflowctl.api.datamodels.generated import ( + BulkActionOnExistence, BulkActionResponse, BulkResponse, ConnectionBody, @@ -176,3 +178,47 @@ def test_import_without_extra_field(self, api_client_maker, tmp_path, monkeypatc extra=None, description="", ) + + @pytest.mark.parametrize( + ("action_on_existing_key", "expected_enum"), + [ + ("overwrite", BulkActionOnExistence.OVERWRITE), + ("skip", BulkActionOnExistence.SKIP), + ("fail", BulkActionOnExistence.FAIL), + ], + ) + def test_import_action_on_existing_key(self, tmp_path, action_on_existing_key, expected_enum): + expected_json_path = tmp_path / self.export_file_name + connection_file = { + self.connection_id: { + "conn_type": "test_type", + "host": "test_host", + "extra": "{}", + "connection_id": self.connection_id, + } + } + expected_json_path.write_text(json.dumps(connection_file)) + + mock_client = mock.MagicMock(spec=Client) + mock_response = mock.MagicMock() + mock_response.create.success = [self.connection_id] + mock_response.create.errors = [] + mock_client.connections.bulk.return_value = mock_response + + connection_command.import_( + self.parser.parse_args( + [ + "connections", + "import", + expected_json_path.as_posix(), + "--action-on-existing-key", + action_on_existing_key, + ] + ), + api_client=mock_client, + ) + + mock_client.connections.bulk.assert_called_once() + bulk_body = mock_client.connections.bulk.call_args[0][0] + action = bulk_body.actions[0] + assert action.action_on_existence == expected_enum diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py index 5ebfdc869c874..6946c3987954e 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py @@ -48,21 +48,21 @@ def test_import_missing_file(self, mock_client, tmp_path): """Test import with missing file.""" non_existent = tmp_path / "non_existent.json" with pytest.raises(SystemExit, match=f"Missing pools file {non_existent}"): - pool_command.import_(mock.MagicMock(file=non_existent)) + pool_command.import_(mock.MagicMock(file=non_existent, action_on_existing_key="fail")) def test_import_invalid_json(self, mock_client, tmp_path): """Test import with invalid JSON file.""" invalid_json = tmp_path / "invalid.json" invalid_json.write_text("invalid json") with pytest.raises(SystemExit, match="Invalid json file"): - pool_command.import_(mock.MagicMock(file=invalid_json)) + pool_command.import_(mock.MagicMock(file=invalid_json, action_on_existing_key="fail")) def test_import_invalid_pool_config(self, mock_client, tmp_path): """Test import with invalid pool configuration.""" invalid_pool = tmp_path / "invalid_pool.json" invalid_pool.write_text(json.dumps([{"invalid": "config"}])) with pytest.raises(SystemExit, match="Invalid pool configuration: {'invalid': 'config'}"): - pool_command.import_(mock.MagicMock(file=invalid_pool)) + pool_command.import_(mock.MagicMock(file=invalid_pool, action_on_existing_key="fail")) def test_import_success(self, mock_client, tmp_path, capsys): """Test successful pool import.""" @@ -87,7 +87,7 @@ def test_import_success(self, mock_client, tmp_path, capsys): mock_client.pools.bulk.return_value = mock_bulk_builder - pool_command.import_(mock.MagicMock(file=pools_file)) + pool_command.import_(mock.MagicMock(file=pools_file, action_on_existing_key="fail")) # Verify bulk operation was called with correct parameters mock_client.pools.bulk.assert_called_once() @@ -108,6 +108,34 @@ def test_import_success(self, mock_client, tmp_path, capsys): captured = capsys.readouterr() assert str(["test_pool"]) in captured.out + @pytest.mark.parametrize( + ("action_on_existing_key", "expected_enum"), + [ + ("overwrite", BulkActionOnExistence.OVERWRITE), + ("skip", BulkActionOnExistence.SKIP), + ("fail", BulkActionOnExistence.FAIL), + ], + ) + def test_import_action_on_existing_key( + self, mock_client, tmp_path, action_on_existing_key, expected_enum + ): + """Test that --action-on-existing-key is passed through to the bulk API.""" + pools_file = tmp_path / "pools.json" + pools_file.write_text(json.dumps([{"name": "test_pool", "slots": 1}])) + + mock_response = mock.MagicMock() + mock_response.success = ["test_pool"] + mock_response.errors = [] + mock_bulk_builder = mock.MagicMock() + mock_bulk_builder.create = mock_response + mock_client.pools.bulk.return_value = mock_bulk_builder + + pool_command.import_(mock.MagicMock(file=pools_file, action_on_existing_key=action_on_existing_key)) + + call_args = mock_client.pools.bulk.call_args[1] + action = call_args["pools"].actions[0] + assert action.action_on_existence == expected_enum + class TestPoolExportCommand: """Test cases for pool export command."""