From 820104434b28438e90f7897d06260dccbd76f6fc Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 29 Mar 2022 08:57:18 +0100 Subject: [PATCH 1/5] Fixed edge case in which _clip_to_end was increasing the size of a rejected step, causing an infinite loop. --- diffrax/integrate.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 5f9337b8..ede17af4 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -88,8 +88,12 @@ def _save(state: _State, t: Scalar) -> _State: ) -def _clip_to_end(tnext, t1): - return jnp.where(tnext > t1 - 1e-6, t1, tnext) +def _clip_to_end(tnext, t1, keep_step): + if tnext.dtype is jnp.dtype("float64"): + tol = 1e-10 + else: + tol = 1e-6 + return jnp.where(keep_step & (tnext > t1 - tol), t1, tnext) def loop( @@ -161,7 +165,7 @@ def body_fun(state, inplace): # The 1e-6 tolerance means that we don't end up with too-small intervals for # dense output, which then gives numerically unstable answers due to floating # point errors. - tnext = _clip_to_end(tnext, t1) + tnext = _clip_to_end(tnext, t1, keep_step) tprev = jnp.minimum(tprev, t1) # The other parts of the mutable state are kept/not-kept (based on whether the From 6a4dd2da5971a1eff25937d112500461ee45c54c Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 29 Mar 2022 09:18:22 +0100 Subject: [PATCH 2/5] typo --- diffrax/integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index ede17af4..dbfcdc57 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -407,7 +407,7 @@ def _cond_fun(_state): def _body_fun(_state): _step, _t = _state - return _step + 1, _clip_to_end(_t + dt0, t1) + return _step + 1, _clip_to_end(_t + dt0, t1, True) compiled_num_steps, _ = lax.while_loop( _cond_fun, _body_fun, (0, t0) From b34f40c95f476573c35cec0e30af0b7691507dca Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 29 Mar 2022 09:52:17 +0100 Subject: [PATCH 3/5] In the reject-and-clip case we override the stepsize controller and take a half step. --- diffrax/integrate.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index dbfcdc57..05d42568 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -88,12 +88,14 @@ def _save(state: _State, t: Scalar) -> _State: ) -def _clip_to_end(tnext, t1, keep_step): +def _clip_to_end(tprev, tnext, t1, keep_step): if tnext.dtype is jnp.dtype("float64"): tol = 1e-10 else: tol = 1e-6 - return jnp.where(keep_step & (tnext > t1 - tol), t1, tnext) + clip = tnext > t1 - tol + tclip = jnp.where(keep_step, t1, tprev + 0.5 * (t1 - tprev)) + return jnp.where(clip, tclip, tnext) def loop( @@ -165,8 +167,8 @@ def body_fun(state, inplace): # The 1e-6 tolerance means that we don't end up with too-small intervals for # dense output, which then gives numerically unstable answers due to floating # point errors. - tnext = _clip_to_end(tnext, t1, keep_step) tprev = jnp.minimum(tprev, t1) + tnext = _clip_to_end(tprev, tnext, t1, keep_step) # The other parts of the mutable state are kept/not-kept (based on whether the # step was accepted) by the stepsize controller. But it doesn't get access to @@ -407,7 +409,7 @@ def _cond_fun(_state): def _body_fun(_state): _step, _t = _state - return _step + 1, _clip_to_end(_t + dt0, t1, True) + return _step + 1, _clip_to_end(_t, _t + dt0, t1, True) compiled_num_steps, _ = lax.while_loop( _cond_fun, _body_fun, (0, t0) From 49e457308b9bfec004fe721f678733beaa753b33 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 29 Mar 2022 09:53:08 +0100 Subject: [PATCH 4/5] bump version --- diffrax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index b207173a..6ba2c41e 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -81,4 +81,4 @@ ) -__version__ = "0.0.5" +__version__ = "0.0.6" From b116eacf064c65ad054ad31e253848680a233e88 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 29 Mar 2022 14:36:56 +0100 Subject: [PATCH 5/5] Test fix --- test/test_detest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_detest.py b/test/test_detest.py index 78e40bef..9ca511f1 100644 --- a/test/test_detest.py +++ b/test/test_detest.py @@ -404,7 +404,7 @@ def _test(solver_ctr, problems, higher): # build up by t=20. # Teeny-tiny steps fix this. dt0 = 0.000001 - max_steps = 20_000_000 + max_steps = 20_000_001 stepsize_controller = diffrax.ConstantStepSize() elif solver_ctr is diffrax.ReversibleHeun and problem is _a1: # ReversibleHeun is a bit like LeapfrogMidpoint, and therefore bad over