diff --git a/tests/test_flows/test_flow_functions.py b/tests/test_flows/test_flow_functions.py index eb80c2861..fe058df23 100644 --- a/tests/test_flows/test_flow_functions.py +++ b/tests/test_flows/test_flow_functions.py @@ -324,32 +324,55 @@ def test_get_flow_reinstantiate_model_no_extension(self): ) @unittest.skipIf( - LooseVersion(sklearn.__version__) == "0.19.1", reason="Target flow is from sklearn 0.19.1" + LooseVersion(sklearn.__version__) == "0.19.1", + reason="Requires scikit-learn!=0.19.1, because target flow is from that version.", ) - def test_get_flow_reinstantiate_model_wrong_version(self): - # Note that CI does not test against 0.19.1. + def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception(self): openml.config.server = self.production_server - _, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3] - if sklearn_major > 23: - flow = 18587 # 18687, 18725 --- flows building random forest on >= 0.23 - flow_sklearn_version = "0.23.1" - else: - flow = 8175 - flow_sklearn_version = "0.19.1" - expected = ( - "Trying to deserialize a model with dependency " - "sklearn=={} not satisfied.".format(flow_sklearn_version) - ) + flow = 8175 + expected = "Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied." self.assertRaisesRegex( - ValueError, expected, openml.flows.get_flow, flow_id=flow, reinstantiate=True + ValueError, + expected, + openml.flows.get_flow, + flow_id=flow, + reinstantiate=True, + strict_version=True, ) - if LooseVersion(sklearn.__version__) > "0.19.1": - # 0.18 actually can't deserialize this because of incompatibility - flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, strict_version=False) - # ensure that a new flow was created - assert flow.flow_id is None - assert "sklearn==0.19.1" not in flow.dependencies - assert "sklearn>=0.19.1" not in flow.dependencies + + @unittest.skipIf( + LooseVersion(sklearn.__version__) < "1" and LooseVersion(sklearn.__version__) != "1.0.0", + reason="Requires scikit-learn < 1.0.1." + # Because scikit-learn dropped min_impurity_split hyperparameter in 1.0, + # and the requested flow is from 1.0.0 exactly. + ) + def test_get_flow_reinstantiate_flow_not_strict_post_1(self): + openml.config.server = self.production_server + flow = openml.flows.get_flow(flow_id=19190, reinstantiate=True, strict_version=False) + assert flow.flow_id is None + assert "sklearn==1.0.0" not in flow.dependencies + + @unittest.skipIf( + (LooseVersion(sklearn.__version__) < "0.23.2") + or ("1.0" < LooseVersion(sklearn.__version__)), + reason="Requires scikit-learn 0.23.2 or ~0.24." + # Because these still have min_impurity_split, but with new scikit-learn module structure." + ) + def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self): + openml.config.server = self.production_server + flow = openml.flows.get_flow(flow_id=18587, reinstantiate=True, strict_version=False) + assert flow.flow_id is None + assert "sklearn==0.23.1" not in flow.dependencies + + @unittest.skipIf( + "0.23" < LooseVersion(sklearn.__version__), + reason="Requires scikit-learn<=0.23, because the scikit-learn module structure changed.", + ) + def test_get_flow_reinstantiate_flow_not_strict_pre_023(self): + openml.config.server = self.production_server + flow = openml.flows.get_flow(flow_id=8175, reinstantiate=True, strict_version=False) + assert flow.flow_id is None + assert "sklearn==0.19.1" not in flow.dependencies def test_get_flow_id(self): if self.long_version: