From 6bf77afb1f8039c8f4766a8d847ede2b8348cad9 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Mon, 17 Oct 2022 17:23:07 +0200
Subject: [PATCH] Refactor out different test cases to separate tests
The previous solution had two test conditions (strict and not strict)
and several scikit-learn versions, because of two distinct changes
within scikit-learn (the removal of min_impurity_split in 1.0, and the
restructuring of public/private models in 0.24).
I refactored out the separate test cases to greatly simplify the
individual tests, and I added a test case for scikit-learn>=1.0,
which was previously not covered.
---
tests/test_flows/test_flow_functions.py | 67 +++++++++++++++++--------
1 file changed, 45 insertions(+), 22 deletions(-)
diff --git a/tests/test_flows/test_flow_functions.py b/tests/test_flows/test_flow_functions.py
index eb80c2861..fe058df23 100644
--- a/tests/test_flows/test_flow_functions.py
+++ b/tests/test_flows/test_flow_functions.py
@@ -324,32 +324,55 @@ def test_get_flow_reinstantiate_model_no_extension(self):
)
@unittest.skipIf(
- LooseVersion(sklearn.__version__) == "0.19.1", reason="Target flow is from sklearn 0.19.1"
+ LooseVersion(sklearn.__version__) == "0.19.1",
+ reason="Requires scikit-learn!=0.19.1, because target flow is from that version.",
)
- def test_get_flow_reinstantiate_model_wrong_version(self):
- # Note that CI does not test against 0.19.1.
+ def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception(self):
openml.config.server = self.production_server
- _, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3]
- if sklearn_major > 23:
- flow = 18587 # 18687, 18725 --- flows building random forest on >= 0.23
- flow_sklearn_version = "0.23.1"
- else:
- flow = 8175
- flow_sklearn_version = "0.19.1"
- expected = (
- "Trying to deserialize a model with dependency "
- "sklearn=={} not satisfied.".format(flow_sklearn_version)
- )
+ flow = 8175
+ expected = "Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied."
self.assertRaisesRegex(
- ValueError, expected, openml.flows.get_flow, flow_id=flow, reinstantiate=True
+ ValueError,
+ expected,
+ openml.flows.get_flow,
+ flow_id=flow,
+ reinstantiate=True,
+ strict_version=True,
)
- if LooseVersion(sklearn.__version__) > "0.19.1":
- # 0.18 actually can't deserialize this because of incompatibility
- flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, strict_version=False)
- # ensure that a new flow was created
- assert flow.flow_id is None
- assert "sklearn==0.19.1" not in flow.dependencies
- assert "sklearn>=0.19.1" not in flow.dependencies
+
+ @unittest.skipIf(
+ LooseVersion(sklearn.__version__) < "1" and LooseVersion(sklearn.__version__) != "1.0.0",
+ reason="Requires scikit-learn < 1.0.1."
+ # Because scikit-learn dropped min_impurity_split hyperparameter in 1.0,
+ # and the requested flow is from 1.0.0 exactly.
+ )
+ def test_get_flow_reinstantiate_flow_not_strict_post_1(self):
+ openml.config.server = self.production_server
+ flow = openml.flows.get_flow(flow_id=19190, reinstantiate=True, strict_version=False)
+ assert flow.flow_id is None
+ assert "sklearn==1.0.0" not in flow.dependencies
+
+ @unittest.skipIf(
+ (LooseVersion(sklearn.__version__) < "0.23.2")
+ or ("1.0" < LooseVersion(sklearn.__version__)),
+ reason="Requires scikit-learn 0.23.2 or ~0.24."
+ # Because these still have min_impurity_split, but with new scikit-learn module structure."
+ )
+ def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self):
+ openml.config.server = self.production_server
+ flow = openml.flows.get_flow(flow_id=18587, reinstantiate=True, strict_version=False)
+ assert flow.flow_id is None
+ assert "sklearn==0.23.1" not in flow.dependencies
+
+ @unittest.skipIf(
+ "0.23" < LooseVersion(sklearn.__version__),
+ reason="Requires scikit-learn<=0.23, because the scikit-learn module structure changed.",
+ )
+ def test_get_flow_reinstantiate_flow_not_strict_pre_023(self):
+ openml.config.server = self.production_server
+ flow = openml.flows.get_flow(flow_id=8175, reinstantiate=True, strict_version=False)
+ assert flow.flow_id is None
+ assert "sklearn==0.19.1" not in flow.dependencies
def test_get_flow_id(self):
if self.long_version: