From 952036533a8e769d2238a3ec00767b966d1d3752 Mon Sep 17 00:00:00 2001 From: Kevin Phan <98072684+ph-kev@users.noreply.github.com> Date: Fri, 3 Jan 2025 14:22:12 -0800 Subject: [PATCH] Add call_at_end and save_positions to callbacks --- src/Callbacks.jl | 38 ++++++++++++++++++++++++++++++++------ test/callbacks.jl | 17 ++++++++++++++++- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/src/Callbacks.jl b/src/Callbacks.jl index c4fe189d..7e9d43ea 100644 --- a/src/Callbacks.jl +++ b/src/Callbacks.jl @@ -88,8 +88,13 @@ Trigger `f!(integrator)` every `Δt` simulation time. If `atinit=true`, then `f!` will additionally be triggered at initialization. Otherwise the first trigger will be after `Δt` simulation time. + +If `call_at_end==true`, then `f!` will be triggered at the end of the time span. Otherwise +there is no guaranteed call to `f!` at the end of the time span. + +The boolean tuple `save_positions` determines whether to save before or after `f!`. """ -function EveryXSimulationTime(f!, Δt; atinit = false) +function EveryXSimulationTime(f!, Δt; atinit = false, call_at_end = false, save_positions = (true, true)) t_next = zero(Δt) function _initialize(c, u, t, integrator) @@ -111,14 +116,22 @@ function EveryXSimulationTime(f!, Δt; atinit = false) t_next += Δt end return true + elseif (call_at_end && t == integrator.sol.prob.tspan[2]) + return true else return false end end if isdefined(DiffEqBase, :finalize!) - SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) + SciMLBase.DiscreteCallback( + condition, + f!; + initialize = _initialize, + finalize = _finalize, + save_positions = save_positions, + ) else - SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, save_positions = save_positions) end end @@ -131,8 +144,13 @@ Trigger `f!(integrator)` every `Δsteps` simulation steps. If `atinit==true`, then `f!` will additionally be triggered at initialization. Otherwise the first trigger will be after `Δsteps`. + +If `call_at_end==true`, then `f!` will be triggered at the end of the time span. Otherwise +there is no guaranteed call to `f!` at the end of the time span. + +The boolean tuple `save_positions` determines whether to save before or after `f!`. """ -function EveryXSimulationSteps(f!, Δsteps; atinit = false) +function EveryXSimulationSteps(f!, Δsteps; atinit = false, call_at_end = false, save_positions = (true, true)) steps = 0 steps_next = 0 @@ -154,15 +172,23 @@ function EveryXSimulationSteps(f!, Δsteps; atinit = false) if steps >= steps_next steps_next += Δsteps return true + elseif (call_at_end && t == integrator.sol.prob.tspan[2]) + return true else return false end end if isdefined(DiffEqBase, :finalize!) - SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) + SciMLBase.DiscreteCallback( + condition, + f!; + initialize = _initialize, + finalize = _finalize, + save_positions = save_positions, + ) else - SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, save_positions = save_positions) end end diff --git a/test/callbacks.jl b/test/callbacks.jl index 22ce24d2..ced8688e 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -14,8 +14,9 @@ mutable struct MyCallback initialized::Bool calls::Int finalized::Bool + last_t::Real end -MyCallback() = MyCallback(false, 0, false) +MyCallback() = MyCallback(false, 0, false, -1.0) function Callbacks.initialize!(cb::MyCallback, integrator) cb.initialized = true @@ -25,6 +26,7 @@ function Callbacks.finalize!(cb::MyCallback, integrator) end function (cb::MyCallback)(integrator) cb.calls += 1 + cb.last_t = integrator.t end cb1 = MyCallback() @@ -32,6 +34,10 @@ cb2 = MyCallback() cb3 = MyCallback() cb4 = MyCallback() cb5 = MyCallback() +cb6 = MyCallback() +cb7 = MyCallback() +cb8 = MyCallback() +cb9 = MyCallback() cbs = CallbackSet( EveryXSimulationTime(cb1, 1 / 4), @@ -40,6 +46,10 @@ cbs = CallbackSet( EveryXSimulationSteps(cb4, 4, atinit = true), EveryXSimulationSteps(_ -> sleep(1 / 32), 1), EveryXWallTimeSeconds(cb5, 0.49, comm_ctx), + EveryXSimulationTime(cb6, 0.49, call_at_end = true), + EveryXSimulationSteps(cb7, 3, call_at_end = true), + EveryXSimulationTime(cb8, 0.3, call_at_end = false), + EveryXSimulationSteps(cb9, 3, call_at_end = false), ) const_prob_inc = ODEProblem( @@ -63,6 +73,11 @@ solve(const_prob_inc, LSRKEulerMethod(), dt = 1 / 32, callback = cbs) @test cb4.calls == 9 @test cb5.calls >= 2 +@test cb6.last_t == 1.0 +@test cb7.last_t == 1.0 +@test cb8.last_t == (1 / 32) * 29 +@test cb9.last_t == (1 / 32) * 30 + if isdefined(DiffEqBase, :finalize!) @test cb1.finalized