diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py index 5f1c734148b2b..bbb99441f8169 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py @@ -38,7 +38,10 @@ DagVersionResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import requires_access_dag +from airflow.api_fastapi.core_api.security import ( + ReadableDagVersionsFilterDep, + requires_access_dag, +) from airflow.models.dag_version import DagVersion dag_versions_router = AirflowRouter(tags=["DagVersion"], prefix="/dags/{dag_id}/dagVersions") @@ -102,6 +105,7 @@ def get_dag_versions( ), ], dag_bag: DagBagDep, + readable_dag_versions_filter: ReadableDagVersionsFilterDep, ) -> DAGVersionCollectionResponse: """ Get all DAG Versions. @@ -116,7 +120,7 @@ def get_dag_versions( dag_versions_select, total_entries = paginated_select( statement=query, - filters=[version_number, bundle_name, bundle_version], + filters=[version_number, bundle_name, bundle_version, readable_dag_versions_filter], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py b/airflow-core/src/airflow/api_fastapi/core_api/security.py index 9d31ed4bd5653..4e2d7f6c7ce70 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py @@ -64,6 +64,7 @@ from airflow.models import Connection, Pool, Variable from airflow.models.backfill import Backfill from airflow.models.dag import DagModel, DagRun, DagTag +from airflow.models.dag_version import DagVersion from airflow.models.dagwarning import DagWarning from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance as TI @@ -211,6 +212,13 @@ def to_orm(self, select: Select) -> Select: return select.where(DagTag.dag_id.in_(self.value)) +class PermittedDagVersionFilter(PermittedDagFilter): + """A parameter that filters the permitted dag versions for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(DagVersion.dag_id.in_(self.value or set())) + + def permitted_dag_filter_factory( method: ResourceMethod, filter_class=PermittedDagFilter ) -> Callable[[Request, BaseUser], PermittedDagFilter]: @@ -253,6 +261,9 @@ def depends_permitted_dags_filter( ReadableTagsFilterDep = Annotated[ PermittedTagFilter, Depends(permitted_dag_filter_factory("GET", PermittedTagFilter)) ] +ReadableDagVersionsFilterDep = Annotated[ + PermittedDagVersionFilter, Depends(permitted_dag_filter_factory("GET", PermittedDagVersionFilter)) +] def requires_access_backfill( diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py index 506875b2e351b..6e7971f76d6bc 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py @@ -262,7 +262,7 @@ class TestGetDagVersions(TestDagVersionEndpoint): ], "total_entries": 4, }, - 2, + 3, ], [ "dag_with_multiple_versions", @@ -301,7 +301,7 @@ class TestGetDagVersions(TestDagVersionEndpoint): ], "total_entries": 3, }, - 4, + 5, ], ], ) @@ -312,6 +312,22 @@ def test_get_dag_versions(self, test_client, dag_id, expected_response, expected assert response.status_code == 200 assert response.json() == expected_response + @pytest.mark.usefixtures("make_dag_with_multiple_versions") + @mock.patch( + "airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_dag_ids", + return_value={"dag_with_multiple_versions"}, + ) + def test_get_dag_versions_permission_filtering(self, _, test_client): + """Test that listing all DAG versions with ~ only returns versions for permitted DAGs.""" + with assert_queries_count(4): + response = test_client.get("/dags/~/dagVersions") + + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == 3 + dag_ids = {v["dag_id"] for v in body["dag_versions"]} + assert dag_ids == {"dag_with_multiple_versions"} + @pytest.mark.parametrize( "dag_id, expected_response, expected_query_count", [ @@ -362,7 +378,7 @@ def test_get_dag_versions(self, test_client, dag_id, expected_response, expected ], "total_entries": 4, }, - 2, + 3, ], [ "dag_with_multiple_versions", @@ -401,7 +417,7 @@ def test_get_dag_versions(self, test_client, dag_id, expected_response, expected ], "total_entries": 3, }, - 4, + 5, ], ], ) @@ -488,7 +504,7 @@ def test_get_dag_versions_with_url_template( def test_get_dag_versions_parameters( self, test_client, params, expected_versions, expected_total_entries ): - with assert_queries_count(2): + with assert_queries_count(3): response = test_client.get("/dags/~/dagVersions", params=params) assert response.status_code == 200 response_payload = response.json()