Skip to content

Commit

Permalink
allow observing of inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
hexaeder committed Jan 7, 2025
1 parent b4f95c7 commit ec96855
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 13 deletions.
17 changes: 10 additions & 7 deletions src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ edges it returns a named tuple `(; src, dst)` with two symbol vectors.
insym(c::VertexModel)::Vector{Symbol} = c.insym
insym(c::EdgeModel)::@NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}} = c.insym

insym_all(c::VertexModel) = c.insym
insym_all(c::EdgeModel) = Iterators.flatten(values(c.insym))

"""
indim(c::VertexModel)::Int
indim(c::EdgeModel)::@NamedTuple{src::Int,dst::Int}
Expand Down Expand Up @@ -950,16 +953,16 @@ function _fill_defaults(T, @nospecialize(kwargs))
####
#### Cached outsymflat/outsymall
####
_outsym_flat = if T <: VertexModel
outsym
elseif T <: EdgeModel
vcat(outsym.src, outsym.dst)
else
error()
end
_outsym_flat = flatten_sym(outsym)
dict[:_outsym_flat] = _outsym_flat

dict[:_obssym_all] = setdiff(_outsym_flat, sym) obssym

if !isnothing(insym)
insym_flat = flatten_sym(insym)
dict[:_obssym_all] = dict[:_obssym_all] insym_flat
end

####
#### External Inputs
####
Expand Down
26 changes: 25 additions & 1 deletion src/symbolicindexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ function SII.observed(nw::Network, snis)
stateidx = Dict{Int, Int}()
# mapping i -> index in output
outidx = Dict{Int, Int}()
# mapping i -> index in aggbuf
aggidx = Dict{Int, Int}()
# mapping i -> f(fullstate, p, t) (component observables)
obsfuns = Dict{Int, Function}()
for (i, sni) in enumerate(_snis)
Expand All @@ -427,12 +429,28 @@ function SII.observed(nw::Network, snis)
elseif (idx=findfirst(isequal(sni.subidx), obssym(cf))) != nothing #found in observed
_obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni))
obsfuns[i] = (u, outbuf, aggbuf, extbuf, p, t) -> _obsf(u, outbuf, aggbuf, extbuf, p, t)[idx]
elseif hasinsym(cf) && sni.subidx insym_all(cf) # found in input
if sni isa SymbolicVertexIndex
idx = findfirst(isequal(sni.subidx), insym_all(cf))
aggidx[i] = nw.im.v_aggr[resolvecompidx(nw, sni)][idx]
elseif sni isa SymbolicEdgeIndex
edge = nw.im.edgevec[resolvecompidx(nw, sni)]
if (idx = findfirst(isequal(sni.subidx), insym(cf).src)) != nothing
outidx[i] = nw.im.v_out[edge.src][idx]
elseif (idx = findfirst(isequal(sni.subidx), insym(cf).dst)) != nothing
outidx[i] = nw.im.v_out[edge.dst][idx]
else
error()
end
else
error()
end
else
throw(ArgumentError("Cannot resolve observable $sni"))
end
end
end
initbufs = !isempty(outidx) || !isempty(obsfuns)
initbufs = !isempty(outidx) || !isempty(aggidx) || !isempty(obsfuns)

if isscalar
@closure (u, p, t) -> begin
Expand All @@ -443,6 +461,9 @@ function SII.observed(nw::Network, snis)
elseif !isempty(outidx)
idx = only(outidx).second
outbuf[idx]
elseif !isempty(aggidx)
idx = only(aggidx).second
aggbuf[idx]
else
obsf = only(obsfuns).second
obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
Expand All @@ -459,6 +480,9 @@ function SII.observed(nw::Network, snis)
for (i, outi) in outidx
out[i] = outbuf[outi]
end
for (i, aggi) in aggidx
out[i] = aggbuf[aggi]
end
for (i, obsf) in obsfuns
out[i] = obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
end
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,6 @@ rand_inputs_fg(cf) = rand_inputs_fg(Random.default_rng(), cf)
abstract type SymbolicIndex{C,S} end
abstract type SymbolicStateIndex{C,S} <: SymbolicIndex{C,S} end
abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end

flatten_sym(v::NamedTuple) = reduce(vcat, values(v))
flatten_sym(v::AbstractVector{Symbol}) = v
8 changes: 4 additions & 4 deletions test/ComponentLibrary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ diffusion_vertex() = VertexModel(f=diffusionvertex!, dim=1, g=1:1)
Base.@propagate_inbounds function kuramoto_edge!(e, θ_s, θ_d, (K,), t)
e .= K .* sin(θ_s[1] - θ_d[1])
end
function kuramoto_edge(; name=:kuramoto_edge)
function kuramoto_edge(; name=:kuramoto_edge, kwargs...)
EdgeModel(;g=AntiSymmetric(kuramoto_edge!),
outsym=[:P], psym=[:K], name)
outsym=[:P], psym=[:K], name, kwargs...)
end

Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t)
M, D, Pm = p
dv[1] = v[2]
dv[2] = 1 / M * (Pm - D * v[2] + acc[1])
end
function kuramoto_second(; name=:kuramoto_second)
function kuramoto_second(; name=:kuramoto_second, kwargs...)
VertexModel(; f=kuramoto_inertia!, sym=[=>0, =>0],
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name)
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name, kwargs...)
end

Base.@propagate_inbounds function kuramoto_vertex!(dθ, θ, esum, (ω,), t)
Expand Down
20 changes: 19 additions & 1 deletion test/symbolicindexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ for idx in idxtypes
println(idx, " => ", b.allocs, " allocations")
end
if VERSION v"1.11"
@test b.allocs <= 12
@test b.allocs <= 13
end
end

Expand Down Expand Up @@ -470,3 +470,21 @@ nw = Network(g, [n1, n2, n3], [e1, e2])
@test s.p.e[:e2, 1] == s[EPIndex(2,1)]
@test s.p.e[:e3, 1] == s[EPIndex(3,1)]
end

# test observed for inputs
@testset "test observing of model input" begin
v1 = Lib.kuramoto_second(name=:v1, vidx=1, insym=[:Pin])
v2 = Lib.kuramoto_second(name=:v2, vidx=2, insym=[:Pin])
v3 = Lib.kuramoto_second(name=:v3, vidx=3, insym=[:Pin])
e1 = Lib.kuramoto_edge(name=:e1, src=1, dst=2, insym=[:δin])
e2 = Lib.kuramoto_edge(name=:e2, src=2, dst=3, insym=[:δin])
nw = Network([v1,v2,v3], [e1,e2])
s = NWState(nw, rand(dim(nw)), rand(pdim(nw)))
@test s[VIndex(:v1, :Pin)] == s[EIndex(:e1, :₋P)]
@test s[VIndex(:v2, :Pin)] == s[EIndex(:e1, :P)] + s[EIndex(:e2, :₋P)]
@test s[VIndex(:v3, :Pin)] == s[EIndex(:e2, :P)]
@test s[EIndex(:e1, :src₊δin)] == s[VIndex(:v1, )]
@test s[EIndex(:e1, :dst₊δin)] == s[VIndex(:v2, )]
@test s[EIndex(:e2, :src₊δin)] == s[VIndex(:v2, )]
@test s[EIndex(:e2, :dst₊δin)] == s[VIndex(:v3, )]
end
4 changes: 4 additions & 0 deletions test/testutils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using CUDA
using Adapt
using NetworkDynamics: iscudacompatible, NaiveAggregator

"""
Test utility, which rebuilds the Network with all different execution styles and compares the
results of the coreloop.
Expand Down

0 comments on commit ec96855

Please sign in to comment.