Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
DagVersionResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.api_fastapi.core_api.security import (
ReadableDagVersionsFilterDep,
requires_access_dag,
)
from airflow.models.dag_version import DagVersion

dag_versions_router = AirflowRouter(tags=["DagVersion"], prefix="/dags/{dag_id}/dagVersions")
Expand Down Expand Up @@ -102,6 +105,7 @@ def get_dag_versions(
),
],
dag_bag: DagBagDep,
readable_dag_versions_filter: ReadableDagVersionsFilterDep,
) -> DAGVersionCollectionResponse:
"""
Get all DAG Versions.
Expand All @@ -116,7 +120,7 @@ def get_dag_versions(

dag_versions_select, total_entries = paginated_select(
statement=query,
filters=[version_number, bundle_name, bundle_version],
filters=[version_number, bundle_name, bundle_version, readable_dag_versions_filter],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
11 changes: 11 additions & 0 deletions airflow-core/src/airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from airflow.models import Connection, Pool, Variable
from airflow.models.backfill import Backfill
from airflow.models.dag import DagModel, DagRun, DagTag
from airflow.models.dag_version import DagVersion
from airflow.models.dagwarning import DagWarning
from airflow.models.log import Log
from airflow.models.taskinstance import TaskInstance as TI
Expand Down Expand Up @@ -211,6 +212,13 @@ def to_orm(self, select: Select) -> Select:
return select.where(DagTag.dag_id.in_(self.value))


class PermittedDagVersionFilter(PermittedDagFilter):
"""A parameter that filters the permitted dag versions for the user."""

def to_orm(self, select: Select) -> Select:
return select.where(DagVersion.dag_id.in_(self.value or set()))


def permitted_dag_filter_factory(
method: ResourceMethod, filter_class=PermittedDagFilter
) -> Callable[[Request, BaseUser], PermittedDagFilter]:
Expand Down Expand Up @@ -253,6 +261,9 @@ def depends_permitted_dags_filter(
ReadableTagsFilterDep = Annotated[
PermittedTagFilter, Depends(permitted_dag_filter_factory("GET", PermittedTagFilter))
]
ReadableDagVersionsFilterDep = Annotated[
PermittedDagVersionFilter, Depends(permitted_dag_filter_factory("GET", PermittedDagVersionFilter))
]


def requires_access_backfill(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class TestGetDagVersions(TestDagVersionEndpoint):
],
"total_entries": 4,
},
2,
3,
],
[
"dag_with_multiple_versions",
Expand Down Expand Up @@ -301,7 +301,7 @@ class TestGetDagVersions(TestDagVersionEndpoint):
],
"total_entries": 3,
},
4,
5,
],
],
)
Expand All @@ -312,6 +312,22 @@ def test_get_dag_versions(self, test_client, dag_id, expected_response, expected
assert response.status_code == 200
assert response.json() == expected_response

@pytest.mark.usefixtures("make_dag_with_multiple_versions")
@mock.patch(
"airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_dag_ids",
return_value={"dag_with_multiple_versions"},
)
def test_get_dag_versions_permission_filtering(self, _, test_client):
"""Test that listing all DAG versions with ~ only returns versions for permitted DAGs."""
with assert_queries_count(4):
response = test_client.get("/dags/~/dagVersions")

assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 3
dag_ids = {v["dag_id"] for v in body["dag_versions"]}
assert dag_ids == {"dag_with_multiple_versions"}

@pytest.mark.parametrize(
"dag_id, expected_response, expected_query_count",
[
Expand Down Expand Up @@ -362,7 +378,7 @@ def test_get_dag_versions(self, test_client, dag_id, expected_response, expected
],
"total_entries": 4,
},
2,
3,
],
[
"dag_with_multiple_versions",
Expand Down Expand Up @@ -401,7 +417,7 @@ def test_get_dag_versions(self, test_client, dag_id, expected_response, expected
],
"total_entries": 3,
},
4,
5,
],
],
)
Expand Down Expand Up @@ -488,7 +504,7 @@ def test_get_dag_versions_with_url_template(
def test_get_dag_versions_parameters(
self, test_client, params, expected_versions, expected_total_entries
):
with assert_queries_count(2):
with assert_queries_count(3):
response = test_client.get("/dags/~/dagVersions", params=params)
assert response.status_code == 200
response_payload = response.json()
Expand Down
Loading