From 6bf77afb1f8039c8f4766a8d847ede2b8348cad9 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Mon, 17 Oct 2022 17:23:07 +0200 Subject: [PATCH] Refactor out different test cases to separate tests The previous solution had two test conditions (strict and not strict) and several scikit-learn versions, because of two distinct changes within scikit-learn (the removal of min_impurity_split in 1.0, and the restructuring of public/private models in 0.24). I refactored out the separate test cases to greatly simplify the individual tests, and I added a test case for scikit-learn>=1.0, which was previously not covered. --- tests/test_flows/test_flow_functions.py | 67 +++++++++++++++++-------- 1 file changed, 45 insertions(+), 22 deletions(-) diff --git a/tests/test_flows/test_flow_functions.py b/tests/test_flows/test_flow_functions.py index eb80c2861..fe058df23 100644 --- a/tests/test_flows/test_flow_functions.py +++ b/tests/test_flows/test_flow_functions.py @@ -324,32 +324,55 @@ def test_get_flow_reinstantiate_model_no_extension(self): ) @unittest.skipIf( - LooseVersion(sklearn.__version__) == "0.19.1", reason="Target flow is from sklearn 0.19.1" + LooseVersion(sklearn.__version__) == "0.19.1", + reason="Requires scikit-learn!=0.19.1, because target flow is from that version.", ) - def test_get_flow_reinstantiate_model_wrong_version(self): - # Note that CI does not test against 0.19.1. + def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception(self): openml.config.server = self.production_server - _, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3] - if sklearn_major > 23: - flow = 18587 # 18687, 18725 --- flows building random forest on >= 0.23 - flow_sklearn_version = "0.23.1" - else: - flow = 8175 - flow_sklearn_version = "0.19.1" - expected = ( - "Trying to deserialize a model with dependency " - "sklearn=={} not satisfied.".format(flow_sklearn_version) - ) + flow = 8175 + expected = "Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied." self.assertRaisesRegex( - ValueError, expected, openml.flows.get_flow, flow_id=flow, reinstantiate=True + ValueError, + expected, + openml.flows.get_flow, + flow_id=flow, + reinstantiate=True, + strict_version=True, ) - if LooseVersion(sklearn.__version__) > "0.19.1": - # 0.18 actually can't deserialize this because of incompatibility - flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, strict_version=False) - # ensure that a new flow was created - assert flow.flow_id is None - assert "sklearn==0.19.1" not in flow.dependencies - assert "sklearn>=0.19.1" not in flow.dependencies + + @unittest.skipIf( + LooseVersion(sklearn.__version__) < "1" and LooseVersion(sklearn.__version__) != "1.0.0", + reason="Requires scikit-learn < 1.0.1." + # Because scikit-learn dropped min_impurity_split hyperparameter in 1.0, + # and the requested flow is from 1.0.0 exactly. + ) + def test_get_flow_reinstantiate_flow_not_strict_post_1(self): + openml.config.server = self.production_server + flow = openml.flows.get_flow(flow_id=19190, reinstantiate=True, strict_version=False) + assert flow.flow_id is None + assert "sklearn==1.0.0" not in flow.dependencies + + @unittest.skipIf( + (LooseVersion(sklearn.__version__) < "0.23.2") + or ("1.0" < LooseVersion(sklearn.__version__)), + reason="Requires scikit-learn 0.23.2 or ~0.24." + # Because these still have min_impurity_split, but with new scikit-learn module structure." + ) + def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self): + openml.config.server = self.production_server + flow = openml.flows.get_flow(flow_id=18587, reinstantiate=True, strict_version=False) + assert flow.flow_id is None + assert "sklearn==0.23.1" not in flow.dependencies + + @unittest.skipIf( + "0.23" < LooseVersion(sklearn.__version__), + reason="Requires scikit-learn<=0.23, because the scikit-learn module structure changed.", + ) + def test_get_flow_reinstantiate_flow_not_strict_pre_023(self): + openml.config.server = self.production_server + flow = openml.flows.get_flow(flow_id=8175, reinstantiate=True, strict_version=False) + assert flow.flow_id is None + assert "sklearn==0.19.1" not in flow.dependencies def test_get_flow_id(self): if self.long_version: