Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions airflow-ctl/src/airflowctl/ctl/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,11 @@ def __call__(self, parser, namespace, values, option_string=None):
help="The DAG ID of the DAG to pause or unpause",
)

# Variable Commands Args
ARG_VARIABLE_ACTION_ON_EXISTING_KEY = Arg(
ARG_ACTION_ON_EXISTING_KEY = Arg(
Comment thread
bugraoz93 marked this conversation as resolved.
flags=("-a", "--action-on-existing-key"),
type=str,
default="overwrite",
help="Action to take if we encounter a variable key that already exists.",
help="Action to take if the entity already exists.",
choices=("overwrite", "fail", "skip"),
)

Expand Down Expand Up @@ -877,7 +876,10 @@ def merge_commands(
name="import",
help="Import connections from a file exported with local CLI.",
func=lazy_load_command("airflowctl.ctl.commands.connection_command.import_"),
args=(Arg(flags=("file",), metavar="FILEPATH", help="Connections JSON file"),),
args=(
Arg(flags=("file",), metavar="FILEPATH", help="Connections JSON file"),
ARG_ACTION_ON_EXISTING_KEY,
),
),
)

Expand Down Expand Up @@ -907,7 +909,7 @@ def merge_commands(
name="import",
help="Import pools",
func=lazy_load_command("airflowctl.ctl.commands.pool_command.import_"),
args=(ARG_FILE,),
args=(ARG_FILE, ARG_ACTION_ON_EXISTING_KEY),
),
ActionCommand(
name="export",
Expand All @@ -925,7 +927,7 @@ def merge_commands(
name="import",
help="Import variables from a file exported with local CLI.",
func=lazy_load_command("airflowctl.ctl.commands.variable_command.import_"),
args=(ARG_FILE, ARG_VARIABLE_ACTION_ON_EXISTING_KEY),
args=(ARG_FILE, ARG_ACTION_ON_EXISTING_KEY),
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def import_(args, api_client=NEW_API_CLIENT) -> None:
connection_create_action = BulkCreateActionConnectionBody(
action="create",
entities=list(connections_data.values()),
action_on_existence=BulkActionOnExistence("fail"),
action_on_existence=BulkActionOnExistence(args.action_on_existing_key),
)
response = api_client.connections.bulk(BulkBodyConnectionBody(actions=[connection_create_action]))
if response.create.errors:
Expand Down
6 changes: 3 additions & 3 deletions airflow-ctl/src/airflowctl/ctl/commands/pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def import_(args, api_client: Client = NEW_API_CLIENT) -> None:
if not filepath.exists():
raise SystemExit(f"Missing pools file {args.file}")

success, errors = _import_helper(api_client, filepath)
success, errors = _import_helper(api_client, filepath, BulkActionOnExistence(args.action_on_existing_key))
if errors:
raise SystemExit(f"Failed to update pool(s): {errors}")
rich.print(success)
Expand Down Expand Up @@ -83,7 +83,7 @@ def export(args, api_client: Client = NEW_API_CLIENT) -> None:
raise SystemExit(f"Failed to export pools: {e}")


def _import_helper(api_client: Client, filepath: Path):
def _import_helper(api_client: Client, filepath: Path, action_on_existence: BulkActionOnExistence):
"""Help import pools from the json file."""
try:
with open(filepath) as f:
Expand Down Expand Up @@ -113,7 +113,7 @@ def _import_helper(api_client: Client, filepath: Path):
BulkCreateActionPoolBody(
action="create",
entities=pools_to_update,
action_on_existence=BulkActionOnExistence.FAIL,
action_on_existence=action_on_existence,
)
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from __future__ import annotations

import json
from unittest import mock
from unittest.mock import patch

import pytest

from airflowctl.api.client import ClientKind
from airflowctl.api.client import Client, ClientKind
from airflowctl.api.datamodels.generated import (
BulkActionOnExistence,
BulkActionResponse,
BulkResponse,
ConnectionBody,
Expand Down Expand Up @@ -176,3 +178,47 @@ def test_import_without_extra_field(self, api_client_maker, tmp_path, monkeypatc
extra=None,
description="",
)

@pytest.mark.parametrize(
("action_on_existing_key", "expected_enum"),
[
("overwrite", BulkActionOnExistence.OVERWRITE),
("skip", BulkActionOnExistence.SKIP),
("fail", BulkActionOnExistence.FAIL),
],
)
def test_import_action_on_existing_key(self, tmp_path, action_on_existing_key, expected_enum):
expected_json_path = tmp_path / self.export_file_name
connection_file = {
self.connection_id: {
"conn_type": "test_type",
"host": "test_host",
"extra": "{}",
"connection_id": self.connection_id,
}
}
expected_json_path.write_text(json.dumps(connection_file))

mock_client = mock.MagicMock(spec=Client)
mock_response = mock.MagicMock()
mock_response.create.success = [self.connection_id]
mock_response.create.errors = []
mock_client.connections.bulk.return_value = mock_response

connection_command.import_(
self.parser.parse_args(
[
"connections",
"import",
expected_json_path.as_posix(),
"--action-on-existing-key",
action_on_existing_key,
]
),
api_client=mock_client,
)

mock_client.connections.bulk.assert_called_once()
bulk_body = mock_client.connections.bulk.call_args[0][0]
action = bulk_body.actions[0]
assert action.action_on_existence == expected_enum
36 changes: 32 additions & 4 deletions airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ def test_import_missing_file(self, mock_client, tmp_path):
"""Test import with missing file."""
non_existent = tmp_path / "non_existent.json"
with pytest.raises(SystemExit, match=f"Missing pools file {non_existent}"):
pool_command.import_(mock.MagicMock(file=non_existent))
pool_command.import_(mock.MagicMock(file=non_existent, action_on_existing_key="fail"))

def test_import_invalid_json(self, mock_client, tmp_path):
"""Test import with invalid JSON file."""
invalid_json = tmp_path / "invalid.json"
invalid_json.write_text("invalid json")
with pytest.raises(SystemExit, match="Invalid json file"):
pool_command.import_(mock.MagicMock(file=invalid_json))
pool_command.import_(mock.MagicMock(file=invalid_json, action_on_existing_key="fail"))

def test_import_invalid_pool_config(self, mock_client, tmp_path):
"""Test import with invalid pool configuration."""
invalid_pool = tmp_path / "invalid_pool.json"
invalid_pool.write_text(json.dumps([{"invalid": "config"}]))
with pytest.raises(SystemExit, match="Invalid pool configuration: {'invalid': 'config'}"):
pool_command.import_(mock.MagicMock(file=invalid_pool))
pool_command.import_(mock.MagicMock(file=invalid_pool, action_on_existing_key="fail"))

def test_import_success(self, mock_client, tmp_path, capsys):
"""Test successful pool import."""
Expand All @@ -87,7 +87,7 @@ def test_import_success(self, mock_client, tmp_path, capsys):

mock_client.pools.bulk.return_value = mock_bulk_builder

pool_command.import_(mock.MagicMock(file=pools_file))
pool_command.import_(mock.MagicMock(file=pools_file, action_on_existing_key="fail"))

# Verify bulk operation was called with correct parameters
mock_client.pools.bulk.assert_called_once()
Expand All @@ -108,6 +108,34 @@ def test_import_success(self, mock_client, tmp_path, capsys):
captured = capsys.readouterr()
assert str(["test_pool"]) in captured.out

@pytest.mark.parametrize(
("action_on_existing_key", "expected_enum"),
[
("overwrite", BulkActionOnExistence.OVERWRITE),
("skip", BulkActionOnExistence.SKIP),
("fail", BulkActionOnExistence.FAIL),
],
)
def test_import_action_on_existing_key(
self, mock_client, tmp_path, action_on_existing_key, expected_enum
):
"""Test that --action-on-existing-key is passed through to the bulk API."""
pools_file = tmp_path / "pools.json"
pools_file.write_text(json.dumps([{"name": "test_pool", "slots": 1}]))

mock_response = mock.MagicMock()
mock_response.success = ["test_pool"]
mock_response.errors = []
mock_bulk_builder = mock.MagicMock()
mock_bulk_builder.create = mock_response
mock_client.pools.bulk.return_value = mock_bulk_builder

pool_command.import_(mock.MagicMock(file=pools_file, action_on_existing_key=action_on_existing_key))

call_args = mock_client.pools.bulk.call_args[1]
action = call_args["pools"].actions[0]
assert action.action_on_existence == expected_enum


class TestPoolExportCommand:
"""Test cases for pool export command."""
Expand Down