Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions tests/test_flows/test_flow_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,32 +324,55 @@ def test_get_flow_reinstantiate_model_no_extension(self):
)

@unittest.skipIf(
LooseVersion(sklearn.__version__) == "0.19.1", reason="Target flow is from sklearn 0.19.1"
LooseVersion(sklearn.__version__) == "0.19.1",
reason="Requires scikit-learn!=0.19.1, because target flow is from that version.",
)
def test_get_flow_reinstantiate_model_wrong_version(self):
# Note that CI does not test against 0.19.1.
def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception(self):
openml.config.server = self.production_server
_, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3]
if sklearn_major > 23:
flow = 18587 # 18687, 18725 --- flows building random forest on >= 0.23
flow_sklearn_version = "0.23.1"
else:
flow = 8175
flow_sklearn_version = "0.19.1"
expected = (
"Trying to deserialize a model with dependency "
"sklearn=={} not satisfied.".format(flow_sklearn_version)
)
flow = 8175
expected = "Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied."
self.assertRaisesRegex(
ValueError, expected, openml.flows.get_flow, flow_id=flow, reinstantiate=True
ValueError,
expected,
openml.flows.get_flow,
flow_id=flow,
reinstantiate=True,
strict_version=True,
)
if LooseVersion(sklearn.__version__) > "0.19.1":
# 0.18 actually can't deserialize this because of incompatibility
flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, strict_version=False)
# ensure that a new flow was created
assert flow.flow_id is None
assert "sklearn==0.19.1" not in flow.dependencies
assert "sklearn>=0.19.1" not in flow.dependencies

@unittest.skipIf(
LooseVersion(sklearn.__version__) < "1" and LooseVersion(sklearn.__version__) != "1.0.0",
reason="Requires scikit-learn < 1.0.1."
# Because scikit-learn dropped min_impurity_split hyperparameter in 1.0,
# and the requested flow is from 1.0.0 exactly.
)
def test_get_flow_reinstantiate_flow_not_strict_post_1(self):
openml.config.server = self.production_server
flow = openml.flows.get_flow(flow_id=19190, reinstantiate=True, strict_version=False)
assert flow.flow_id is None
assert "sklearn==1.0.0" not in flow.dependencies

@unittest.skipIf(
(LooseVersion(sklearn.__version__) < "0.23.2")
or ("1.0" < LooseVersion(sklearn.__version__)),
reason="Requires scikit-learn 0.23.2 or ~0.24."
# Because these still have min_impurity_split, but with new scikit-learn module structure."
)
def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self):
openml.config.server = self.production_server
flow = openml.flows.get_flow(flow_id=18587, reinstantiate=True, strict_version=False)
assert flow.flow_id is None
assert "sklearn==0.23.1" not in flow.dependencies

@unittest.skipIf(
"0.23" < LooseVersion(sklearn.__version__),
reason="Requires scikit-learn<=0.23, because the scikit-learn module structure changed.",
)
def test_get_flow_reinstantiate_flow_not_strict_pre_023(self):
openml.config.server = self.production_server
flow = openml.flows.get_flow(flow_id=8175, reinstantiate=True, strict_version=False)
assert flow.flow_id is None
assert "sklearn==0.19.1" not in flow.dependencies

def test_get_flow_id(self):
if self.long_version:
Expand Down