From b4f95c781a33e347d0e54e3d0547076aa17316c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Tue, 7 Jan 2025 10:17:24 +0100 Subject: [PATCH 1/3] improve handling of same nested namespaces previously, have a system :load with a subsystem :load could lead to problems --- ext/MTKExt.jl | 2 +- ext/MTKUtils.jl | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ext/MTKExt.jl b/ext/MTKExt.jl index 10dcdb87..0bb0cdf3 100644 --- a/ext/MTKExt.jl +++ b/ext/MTKExt.jl @@ -189,7 +189,7 @@ For a given system and name, extract all the relevant meta we want to keep for t function _get_metadata(sys, name) nt = (;) sym = try - getproperty_symbolic(sys, name) + getproperty_symbolic(sys, name; might_contain_toplevel_ns=false) catch e if !endswith(string(name), "ˍt") # known for "internal" derivatives @warn "Could not extract metadata for $name $(e.msg)" diff --git a/ext/MTKUtils.jl b/ext/MTKUtils.jl index 001e9fee..47f9acb7 100644 --- a/ext/MTKUtils.jl +++ b/ext/MTKUtils.jl @@ -93,18 +93,24 @@ function _collect_differentials!(found, ex) end """ - getproperty_symbolic(sys, var) + getproperty_symbolic(sys, var; might_contain_toplevel_ns=true) Like `getproperty` but works on a greater varaity of "var" - var can be Num or Symbolic (resolved using genname) -- strip namespace of sys if present +- strip namespace of sys if present (don't strip if `might_contain_top_level_ns=false`) - for nested variables (foo₊bar₊baz) resolve them one by one """ -function getproperty_symbolic(sys, var) +function getproperty_symbolic(sys, var; might_contain_toplevel_ns=true) ns = string(getname(sys)) varname = string(getname(var)) - varname_nons = replace(varname, r"^"*ns*"₊" => "") - parts = split(varname_nons, "₊") + # split of the toplevel namespace if necessary + if might_contain_toplevel_ns && startswith(varname, ns*"₊") + if getname(sys) ∈ getname.(ModelingToolkit.get_systems(sys)) + @warn "Namespace :$ns appears multiple times, this might lead to unexpected, since it is not clear whether the namespace should be stripped or not." + end + varname = replace(varname, r"^"*ns*"₊" => "") + end + parts = split(varname, "₊") r = getproperty(sys, Symbol(parts[1]); namespace=false) for part in parts[2:end] r = getproperty(r, Symbol(part); namespace=true) From ec96855d5f815fba5f0a108a1b8139edaf43925a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Tue, 7 Jan 2025 12:43:43 +0100 Subject: [PATCH 2/3] allow observing of inputs --- src/component_functions.jl | 17 ++++++++++------- src/symbolicindexing.jl | 26 +++++++++++++++++++++++++- src/utils.jl | 3 +++ test/ComponentLibrary.jl | 8 ++++---- test/symbolicindexing_test.jl | 20 +++++++++++++++++++- test/testutils.jl | 4 ++++ 6 files changed, 65 insertions(+), 13 deletions(-) diff --git a/src/component_functions.jl b/src/component_functions.jl index 06751882..6450e020 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -487,6 +487,9 @@ edges it returns a named tuple `(; src, dst)` with two symbol vectors. insym(c::VertexModel)::Vector{Symbol} = c.insym insym(c::EdgeModel)::@NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}} = c.insym +insym_all(c::VertexModel) = c.insym +insym_all(c::EdgeModel) = Iterators.flatten(values(c.insym)) + """ indim(c::VertexModel)::Int indim(c::EdgeModel)::@NamedTuple{src::Int,dst::Int} @@ -950,16 +953,16 @@ function _fill_defaults(T, @nospecialize(kwargs)) #### #### Cached outsymflat/outsymall #### - _outsym_flat = if T <: VertexModel - outsym - elseif T <: EdgeModel - vcat(outsym.src, outsym.dst) - else - error() - end + _outsym_flat = flatten_sym(outsym) dict[:_outsym_flat] = _outsym_flat + dict[:_obssym_all] = setdiff(_outsym_flat, sym) ∪ obssym + if !isnothing(insym) + insym_flat = flatten_sym(insym) + dict[:_obssym_all] = dict[:_obssym_all] ∪ insym_flat + end + #### #### External Inputs #### diff --git a/src/symbolicindexing.jl b/src/symbolicindexing.jl index 1e0b17aa..4389d7c1 100644 --- a/src/symbolicindexing.jl +++ b/src/symbolicindexing.jl @@ -412,6 +412,8 @@ function SII.observed(nw::Network, snis) stateidx = Dict{Int, Int}() # mapping i -> index in output outidx = Dict{Int, Int}() + # mapping i -> index in aggbuf + aggidx = Dict{Int, Int}() # mapping i -> f(fullstate, p, t) (component observables) obsfuns = Dict{Int, Function}() for (i, sni) in enumerate(_snis) @@ -427,12 +429,28 @@ function SII.observed(nw::Network, snis) elseif (idx=findfirst(isequal(sni.subidx), obssym(cf))) != nothing #found in observed _obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni)) obsfuns[i] = (u, outbuf, aggbuf, extbuf, p, t) -> _obsf(u, outbuf, aggbuf, extbuf, p, t)[idx] + elseif hasinsym(cf) && sni.subidx ∈ insym_all(cf) # found in input + if sni isa SymbolicVertexIndex + idx = findfirst(isequal(sni.subidx), insym_all(cf)) + aggidx[i] = nw.im.v_aggr[resolvecompidx(nw, sni)][idx] + elseif sni isa SymbolicEdgeIndex + edge = nw.im.edgevec[resolvecompidx(nw, sni)] + if (idx = findfirst(isequal(sni.subidx), insym(cf).src)) != nothing + outidx[i] = nw.im.v_out[edge.src][idx] + elseif (idx = findfirst(isequal(sni.subidx), insym(cf).dst)) != nothing + outidx[i] = nw.im.v_out[edge.dst][idx] + else + error() + end + else + error() + end else throw(ArgumentError("Cannot resolve observable $sni")) end end end - initbufs = !isempty(outidx) || !isempty(obsfuns) + initbufs = !isempty(outidx) || !isempty(aggidx) || !isempty(obsfuns) if isscalar @closure (u, p, t) -> begin @@ -443,6 +461,9 @@ function SII.observed(nw::Network, snis) elseif !isempty(outidx) idx = only(outidx).second outbuf[idx] + elseif !isempty(aggidx) + idx = only(aggidx).second + aggbuf[idx] else obsf = only(obsfuns).second obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u) @@ -459,6 +480,9 @@ function SII.observed(nw::Network, snis) for (i, outi) in outidx out[i] = outbuf[outi] end + for (i, aggi) in aggidx + out[i] = aggbuf[aggi] + end for (i, obsf) in obsfuns out[i] = obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u) end diff --git a/src/utils.jl b/src/utils.jl index 1e184477..e6e4ea04 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -134,3 +134,6 @@ rand_inputs_fg(cf) = rand_inputs_fg(Random.default_rng(), cf) abstract type SymbolicIndex{C,S} end abstract type SymbolicStateIndex{C,S} <: SymbolicIndex{C,S} end abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end + +flatten_sym(v::NamedTuple) = reduce(vcat, values(v)) +flatten_sym(v::AbstractVector{Symbol}) = v diff --git a/test/ComponentLibrary.jl b/test/ComponentLibrary.jl index 6131e4f6..dbddb1fd 100644 --- a/test/ComponentLibrary.jl +++ b/test/ComponentLibrary.jl @@ -47,9 +47,9 @@ diffusion_vertex() = VertexModel(f=diffusionvertex!, dim=1, g=1:1) Base.@propagate_inbounds function kuramoto_edge!(e, θ_s, θ_d, (K,), t) e .= K .* sin(θ_s[1] - θ_d[1]) end -function kuramoto_edge(; name=:kuramoto_edge) +function kuramoto_edge(; name=:kuramoto_edge, kwargs...) EdgeModel(;g=AntiSymmetric(kuramoto_edge!), - outsym=[:P], psym=[:K], name) + outsym=[:P], psym=[:K], name, kwargs...) end Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t) @@ -57,9 +57,9 @@ Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t) dv[1] = v[2] dv[2] = 1 / M * (Pm - D * v[2] + acc[1]) end -function kuramoto_second(; name=:kuramoto_second) +function kuramoto_second(; name=:kuramoto_second, kwargs...) VertexModel(; f=kuramoto_inertia!, sym=[:δ=>0, :ω=>0], - psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name) + psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name, kwargs...) end Base.@propagate_inbounds function kuramoto_vertex!(dθ, θ, esum, (ω,), t) diff --git a/test/symbolicindexing_test.jl b/test/symbolicindexing_test.jl index 300e261a..63eaa8c6 100644 --- a/test/symbolicindexing_test.jl +++ b/test/symbolicindexing_test.jl @@ -306,7 +306,7 @@ for idx in idxtypes println(idx, " => ", b.allocs, " allocations") end if VERSION ≥ v"1.11" - @test b.allocs <= 12 + @test b.allocs <= 13 end end @@ -470,3 +470,21 @@ nw = Network(g, [n1, n2, n3], [e1, e2]) @test s.p.e[:e2, 1] == s[EPIndex(2,1)] @test s.p.e[:e3, 1] == s[EPIndex(3,1)] end + +# test observed for inputs +@testset "test observing of model input" begin + v1 = Lib.kuramoto_second(name=:v1, vidx=1, insym=[:Pin]) + v2 = Lib.kuramoto_second(name=:v2, vidx=2, insym=[:Pin]) + v3 = Lib.kuramoto_second(name=:v3, vidx=3, insym=[:Pin]) + e1 = Lib.kuramoto_edge(name=:e1, src=1, dst=2, insym=[:δin]) + e2 = Lib.kuramoto_edge(name=:e2, src=2, dst=3, insym=[:δin]) + nw = Network([v1,v2,v3], [e1,e2]) + s = NWState(nw, rand(dim(nw)), rand(pdim(nw))) + @test s[VIndex(:v1, :Pin)] == s[EIndex(:e1, :₋P)] + @test s[VIndex(:v2, :Pin)] == s[EIndex(:e1, :P)] + s[EIndex(:e2, :₋P)] + @test s[VIndex(:v3, :Pin)] == s[EIndex(:e2, :P)] + @test s[EIndex(:e1, :src₊δin)] == s[VIndex(:v1, :δ)] + @test s[EIndex(:e1, :dst₊δin)] == s[VIndex(:v2, :δ)] + @test s[EIndex(:e2, :src₊δin)] == s[VIndex(:v2, :δ)] + @test s[EIndex(:e2, :dst₊δin)] == s[VIndex(:v3, :δ)] +end diff --git a/test/testutils.jl b/test/testutils.jl index 68d0df91..42da4742 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -1,3 +1,7 @@ +using CUDA +using Adapt +using NetworkDynamics: iscudacompatible, NaiveAggregator + """ Test utility, which rebuilds the Network with all different execution styles and compares the results of the coreloop. From fe84b3460f8c4fccdb8e770ce0adac49714ea31d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Tue, 7 Jan 2025 12:54:15 +0100 Subject: [PATCH 3/3] fix missing insym variable --- src/component_functions.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/component_functions.jl b/src/component_functions.jl index 6450e020..0fa389b1 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -958,6 +958,7 @@ function _fill_defaults(T, @nospecialize(kwargs)) dict[:_obssym_all] = setdiff(_outsym_flat, sym) ∪ obssym + insym = dict[:insym] if !isnothing(insym) insym_flat = flatten_sym(insym) dict[:_obssym_all] = dict[:_obssym_all] ∪ insym_flat @@ -994,10 +995,8 @@ function _fill_defaults(T, @nospecialize(kwargs)) _is = if isnothing(__is) Symbol[] - elseif __is isa NamedTuple - vcat(__is.src, __is.dst) else - __is + flatten_sym(insym) end if !allunique(vcat(_s, _ps, _obss, _is, _os)) throw(ArgumentError("Symbol names must be unique. There are clashes in sym, psym, outsym, obssym and insym."))