Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

observe inputs #189

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/MTKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ For a given system and name, extract all the relevant meta we want to keep for t
function _get_metadata(sys, name)
nt = (;)
sym = try
getproperty_symbolic(sys, name)
getproperty_symbolic(sys, name; might_contain_toplevel_ns=false)
catch e
if !endswith(string(name), "ˍt") # known for "internal" derivatives
@warn "Could not extract metadata for $name $(e.msg)"
Expand Down
16 changes: 11 additions & 5 deletions ext/MTKUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,24 @@ function _collect_differentials!(found, ex)
end

"""
getproperty_symbolic(sys, var)
getproperty_symbolic(sys, var; might_contain_toplevel_ns=true)

Like `getproperty` but works on a greater varaity of "var"
- var can be Num or Symbolic (resolved using genname)
- strip namespace of sys if present
- strip namespace of sys if present (don't strip if `might_contain_top_level_ns=false`)
- for nested variables (foo₊bar₊baz) resolve them one by one
"""
function getproperty_symbolic(sys, var)
function getproperty_symbolic(sys, var; might_contain_toplevel_ns=true)
ns = string(getname(sys))
varname = string(getname(var))
varname_nons = replace(varname, r"^"*ns*"₊" => "")
parts = split(varname_nons, "₊")
# split of the toplevel namespace if necessary
if might_contain_toplevel_ns && startswith(varname, ns*"₊")
if getname(sys) ∈ getname.(ModelingToolkit.get_systems(sys))
@warn "Namespace :$ns appears multiple times, this might lead to unexpected, since it is not clear whether the namespace should be stripped or not."
end
varname = replace(varname, r"^"*ns*"₊" => "")
end
parts = split(varname, "₊")
r = getproperty(sys, Symbol(parts[1]); namespace=false)
for part in parts[2:end]
r = getproperty(r, Symbol(part); namespace=true)
Expand Down
22 changes: 12 additions & 10 deletions src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ edges it returns a named tuple `(; src, dst)` with two symbol vectors.
insym(c::VertexModel)::Vector{Symbol} = c.insym
insym(c::EdgeModel)::@NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}} = c.insym

insym_all(c::VertexModel) = c.insym
insym_all(c::EdgeModel) = Iterators.flatten(values(c.insym))

"""
indim(c::VertexModel)::Int
indim(c::EdgeModel)::@NamedTuple{src::Int,dst::Int}
Expand Down Expand Up @@ -950,16 +953,17 @@ function _fill_defaults(T, @nospecialize(kwargs))
####
#### Cached outsymflat/outsymall
####
_outsym_flat = if T <: VertexModel
outsym
elseif T <: EdgeModel
vcat(outsym.src, outsym.dst)
else
error()
end
_outsym_flat = flatten_sym(outsym)
dict[:_outsym_flat] = _outsym_flat

dict[:_obssym_all] = setdiff(_outsym_flat, sym) ∪ obssym

insym = dict[:insym]
if !isnothing(insym)
insym_flat = flatten_sym(insym)
dict[:_obssym_all] = dict[:_obssym_all] ∪ insym_flat
end

####
#### External Inputs
####
Expand Down Expand Up @@ -991,10 +995,8 @@ function _fill_defaults(T, @nospecialize(kwargs))

_is = if isnothing(__is)
Symbol[]
elseif __is isa NamedTuple
vcat(__is.src, __is.dst)
else
__is
flatten_sym(insym)
end
if !allunique(vcat(_s, _ps, _obss, _is, _os))
throw(ArgumentError("Symbol names must be unique. There are clashes in sym, psym, outsym, obssym and insym."))
Expand Down
26 changes: 25 additions & 1 deletion src/symbolicindexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ function SII.observed(nw::Network, snis)
stateidx = Dict{Int, Int}()
# mapping i -> index in output
outidx = Dict{Int, Int}()
# mapping i -> index in aggbuf
aggidx = Dict{Int, Int}()
# mapping i -> f(fullstate, p, t) (component observables)
obsfuns = Dict{Int, Function}()
for (i, sni) in enumerate(_snis)
Expand All @@ -427,12 +429,28 @@ function SII.observed(nw::Network, snis)
elseif (idx=findfirst(isequal(sni.subidx), obssym(cf))) != nothing #found in observed
_obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni))
obsfuns[i] = (u, outbuf, aggbuf, extbuf, p, t) -> _obsf(u, outbuf, aggbuf, extbuf, p, t)[idx]
elseif hasinsym(cf) && sni.subidx ∈ insym_all(cf) # found in input
if sni isa SymbolicVertexIndex
idx = findfirst(isequal(sni.subidx), insym_all(cf))
aggidx[i] = nw.im.v_aggr[resolvecompidx(nw, sni)][idx]
elseif sni isa SymbolicEdgeIndex
edge = nw.im.edgevec[resolvecompidx(nw, sni)]
if (idx = findfirst(isequal(sni.subidx), insym(cf).src)) != nothing
outidx[i] = nw.im.v_out[edge.src][idx]
elseif (idx = findfirst(isequal(sni.subidx), insym(cf).dst)) != nothing
outidx[i] = nw.im.v_out[edge.dst][idx]
else
error()
end
else
error()
end
else
throw(ArgumentError("Cannot resolve observable $sni"))
end
end
end
initbufs = !isempty(outidx) || !isempty(obsfuns)
initbufs = !isempty(outidx) || !isempty(aggidx) || !isempty(obsfuns)

if isscalar
@closure (u, p, t) -> begin
Expand All @@ -443,6 +461,9 @@ function SII.observed(nw::Network, snis)
elseif !isempty(outidx)
idx = only(outidx).second
outbuf[idx]
elseif !isempty(aggidx)
idx = only(aggidx).second
aggbuf[idx]
else
obsf = only(obsfuns).second
obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
Expand All @@ -459,6 +480,9 @@ function SII.observed(nw::Network, snis)
for (i, outi) in outidx
out[i] = outbuf[outi]
end
for (i, aggi) in aggidx
out[i] = aggbuf[aggi]
end
for (i, obsf) in obsfuns
out[i] = obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
end
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,6 @@ rand_inputs_fg(cf) = rand_inputs_fg(Random.default_rng(), cf)
abstract type SymbolicIndex{C,S} end
abstract type SymbolicStateIndex{C,S} <: SymbolicIndex{C,S} end
abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end

flatten_sym(v::NamedTuple) = reduce(vcat, values(v))
flatten_sym(v::AbstractVector{Symbol}) = v
8 changes: 4 additions & 4 deletions test/ComponentLibrary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ diffusion_vertex() = VertexModel(f=diffusionvertex!, dim=1, g=1:1)
Base.@propagate_inbounds function kuramoto_edge!(e, θ_s, θ_d, (K,), t)
e .= K .* sin(θ_s[1] - θ_d[1])
end
function kuramoto_edge(; name=:kuramoto_edge)
function kuramoto_edge(; name=:kuramoto_edge, kwargs...)
EdgeModel(;g=AntiSymmetric(kuramoto_edge!),
outsym=[:P], psym=[:K], name)
outsym=[:P], psym=[:K], name, kwargs...)
end

Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t)
M, D, Pm = p
dv[1] = v[2]
dv[2] = 1 / M * (Pm - D * v[2] + acc[1])
end
function kuramoto_second(; name=:kuramoto_second)
function kuramoto_second(; name=:kuramoto_second, kwargs...)
VertexModel(; f=kuramoto_inertia!, sym=[:δ=>0, :ω=>0],
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name)
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name, kwargs...)
end

Base.@propagate_inbounds function kuramoto_vertex!(dθ, θ, esum, (ω,), t)
Expand Down
20 changes: 19 additions & 1 deletion test/symbolicindexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ for idx in idxtypes
println(idx, " => ", b.allocs, " allocations")
end
if VERSION ≥ v"1.11"
@test b.allocs <= 12
@test b.allocs <= 13
end
end

Expand Down Expand Up @@ -470,3 +470,21 @@ nw = Network(g, [n1, n2, n3], [e1, e2])
@test s.p.e[:e2, 1] == s[EPIndex(2,1)]
@test s.p.e[:e3, 1] == s[EPIndex(3,1)]
end

# test observed for inputs
@testset "test observing of model input" begin
v1 = Lib.kuramoto_second(name=:v1, vidx=1, insym=[:Pin])
v2 = Lib.kuramoto_second(name=:v2, vidx=2, insym=[:Pin])
v3 = Lib.kuramoto_second(name=:v3, vidx=3, insym=[:Pin])
e1 = Lib.kuramoto_edge(name=:e1, src=1, dst=2, insym=[:δin])
e2 = Lib.kuramoto_edge(name=:e2, src=2, dst=3, insym=[:δin])
nw = Network([v1,v2,v3], [e1,e2])
s = NWState(nw, rand(dim(nw)), rand(pdim(nw)))
@test s[VIndex(:v1, :Pin)] == s[EIndex(:e1, :₋P)]
@test s[VIndex(:v2, :Pin)] == s[EIndex(:e1, :P)] + s[EIndex(:e2, :₋P)]
@test s[VIndex(:v3, :Pin)] == s[EIndex(:e2, :P)]
@test s[EIndex(:e1, :src₊δin)] == s[VIndex(:v1, :δ)]
@test s[EIndex(:e1, :dst₊δin)] == s[VIndex(:v2, :δ)]
@test s[EIndex(:e2, :src₊δin)] == s[VIndex(:v2, :δ)]
@test s[EIndex(:e2, :dst₊δin)] == s[VIndex(:v3, :δ)]
end
4 changes: 4 additions & 0 deletions test/testutils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using CUDA
using Adapt
using NetworkDynamics: iscudacompatible, NaiveAggregator

"""
Test utility, which rebuilds the Network with all different execution styles and compares the
results of the coreloop.
Expand Down
Loading