Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Physics informed neural operator ode #806

Merged
merged 160 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
160 commits
Select commit Hold shift + click to select a range
9988925
pino_ode_draft
KirillZubov Feb 12, 2024
5fc72b9
pino ode prot
KirillZubov Feb 13, 2024
25e4764
add physics loss function
KirillZubov Feb 13, 2024
5888e12
add doc draft
KirillZubov Feb 13, 2024
b6ae2b4
add test mapping inital condition u0 -> solution u(t){u0}
KirillZubov Feb 15, 2024
ad0902a
support train family ode by initial conditions
KirillZubov Feb 16, 2024
c7b10f8
generate prob set
KirillZubov Feb 28, 2024
911ec4e
ode system
KirillZubov Feb 29, 2024
a4ff413
support Fourier Neural Operator
KirillZubov Mar 1, 2024
cbf01af
update
KirillZubov Mar 5, 2024
db50090
clear up code
KirillZubov Mar 6, 2024
a08af21
Merge branch 'master' into pino_ode
KirillZubov Mar 6, 2024
60cda53
update tests
KirillZubov Mar 7, 2024
5781cb5
gpu
KirillZubov Mar 8, 2024
b562ebe
fix spelling
KirillZubov Mar 8, 2024
eb2e9c1
add PINOsolution
KirillZubov Mar 15, 2024
ca2b802
add NeuralOperator deps
KirillZubov Mar 15, 2024
96057a3
Merge branch 'master' into pino_ode
KirillZubov Mar 15, 2024
5b8d226
update Project.toml
KirillZubov Mar 18, 2024
c711b0d
update Project.toml
KirillZubov Mar 18, 2024
2ad84ac
update Project.toml
KirillZubov Mar 18, 2024
a71c469
update
KirillZubov Mar 18, 2024
fb73003
update
KirillZubov Mar 18, 2024
6c524ab
update
KirillZubov Mar 18, 2024
2addf6c
update gpu test
KirillZubov Mar 18, 2024
c058f5e
updste version Statistics
KirillZubov Mar 18, 2024
48c762b
add docs
KirillZubov Mar 19, 2024
671cedc
fix
KirillZubov Mar 19, 2024
fe8a819
fix
KirillZubov Mar 19, 2024
2c87492
fix TRAINSET
KirillZubov Mar 19, 2024
3905b14
fix typo
KirillZubov Mar 19, 2024
652fb72
add dep NeuralOperators
KirillZubov Mar 19, 2024
f74a46f
fix test
KirillZubov Mar 19, 2024
c163e28
fix gpu test
KirillZubov Mar 19, 2024
b52c4c2
fix tests
KirillZubov Mar 19, 2024
b829265
remove dep NeuralOperator
KirillZubov Mar 20, 2024
9870858
update tests
KirillZubov Mar 20, 2024
c5e7348
fix
KirillZubov Mar 20, 2024
cefb901
fix
KirillZubov Mar 20, 2024
10169cb
fix
KirillZubov Mar 20, 2024
2f2be69
fine tuning
KirillZubov Mar 21, 2024
c38a78f
fine tunnning update
KirillZubov Mar 22, 2024
1e38676
fix
KirillZubov Mar 22, 2024
8a98880
Merge branch 'master' into pino_ode
KirillZubov Apr 8, 2024
2cc1d1f
implement DeepONet, refactor pinoode
KirillZubov Apr 25, 2024
f873541
pure PINO with DeepOnet
KirillZubov May 2, 2024
0b4de35
support case right side eq depends on 'u'
KirillZubov May 6, 2024
68e38d5
example with data loss
KirillZubov May 6, 2024
345040b
support GridTraining
KirillZubov May 8, 2024
60005c8
update doc
KirillZubov May 8, 2024
fbef5c7
PINOPhi immutable
KirillZubov Jun 12, 2024
0fb7d23
Update src/pino_ode_solve.jl
KirillZubov Jun 12, 2024
b6639ea
Update src/pino_ode_solve.jl
KirillZubov Jun 12, 2024
cc1fada
remove SomeStrategy
KirillZubov Jun 12, 2024
bb0863b
rename operator_loss to inital_condition_loss
KirillZubov Jun 12, 2024
ea7c638
mutable PINOPhi
KirillZubov Jun 14, 2024
09f4891
support u0 is param
KirillZubov Jun 14, 2024
188ceec
update
KirillZubov Jun 17, 2024
eb005c6
update multiple parameters task
KirillZubov Jun 18, 2024
d754e30
add ParametricFunction
KirillZubov Jun 19, 2024
214b178
vector outputs and multiple parameters
KirillZubov Jun 20, 2024
7d81063
clear code, rm ParametricFunction
KirillZubov Jun 21, 2024
f5e2b06
begin migrate LuxNeuralOperators and add QuasiRandomTraining
KirillZubov Jun 21, 2024
30a5134
migrate to LuxNeuralOperators.DeepOnet
KirillZubov Jun 25, 2024
2818f34
add StochasticTraining
KirillZubov Jun 26, 2024
8895693
add interpolation
KirillZubov Jun 27, 2024
06ce517
update doc
KirillZubov Jun 27, 2024
7953468
output vector
KirillZubov Jul 1, 2024
f6f0405
Update docs/src/manual/pino_ode.md
KirillZubov Jul 2, 2024
e16b168
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 2, 2024
596a80a
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 2, 2024
445dfd0
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 2, 2024
d7b7a36
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 2, 2024
fe40cfd
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 2, 2024
e45011e
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 2, 2024
7fe3244
Update test/runtests.jl
KirillZubov Jul 2, 2024
839e1d5
revert "vector output"
KirillZubov Jul 2, 2024
53a52f8
fix
KirillZubov Jul 2, 2024
a39bb69
Add Unregistered LuxNeuralOperators
KirillZubov Jul 2, 2024
2983c18
CI.yml fix
KirillZubov Jul 2, 2024
885aa29
revert Add Unregistered LuxNeuralOperators
KirillZubov Jul 2, 2024
38a872e
fix
KirillZubov Jul 2, 2024
28447b6
add unregistered LuxNeuralOperators attempt 2
KirillZubov Jul 2, 2024
1292900
undo "add unregistered LuxNeuralOperators attempt 2"
KirillZubov Jul 2, 2024
10e08c2
FNO first shot
KirillZubov Jul 12, 2024
61fa804
fix
KirillZubov Jul 12, 2024
99e8e2d
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 17, 2024
c01c0ad
Update .github/workflows/CI.yml
KirillZubov Jul 17, 2024
1a8fd17
Update src/NeuralPDE.jl
KirillZubov Jul 17, 2024
c5bf9ca
Update Project.toml
KirillZubov Jul 17, 2024
c5d2501
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 17, 2024
fc38b50
Update docs/src/tutorials/pino_ode.md
KirillZubov Jul 17, 2024
09e5555
Update src/pino_ode_solve.jl
KirillZubov Jul 17, 2024
8ae7b0f
Update src/pino_ode_solve.jl
KirillZubov Jul 17, 2024
15e6124
vector output
KirillZubov Jul 22, 2024
efd2a40
wip FourierNeuralOperator
KirillZubov Jul 23, 2024
89fc2d9
update
KirillZubov Aug 14, 2024
5ab8a6a
update
KirillZubov Aug 16, 2024
433e529
update
KirillZubov Aug 20, 2024
c54d62f
update
KirillZubov Aug 26, 2024
d99062d
update
KirillZubov Aug 27, 2024
0bc28da
remove all unnecessary
KirillZubov Aug 29, 2024
2b45d10
update
KirillZubov Aug 30, 2024
d5b9d6d
update
KirillZubov Aug 30, 2024
7059e37
Merge branch 'master' into pino_ode
KirillZubov Aug 30, 2024
c4092b5
imutable
KirillZubov Aug 30, 2024
3ae25ce
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
908ca8a
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
13f955b
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
3813b1b
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
5be5db1
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
d0665ec
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
5295ceb
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
cc67c31
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
cf5221c
Update test/runtests.jl
KirillZubov Sep 2, 2024
a3dbfc6
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 2, 2024
0cdd8dc
Update src/pino_ode_solve.jl
KirillZubov Sep 5, 2024
bc32c45
Update src/pino_ode_solve.jl
KirillZubov Sep 5, 2024
6c49dd8
Update src/pino_ode_solve.jl
KirillZubov Sep 5, 2024
ff04565
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 5, 2024
2366840
Update docs/src/tutorials/pino_ode.md
KirillZubov Sep 5, 2024
b1209ab
lux v1.0
KirillZubov Sep 19, 2024
9fa2162
Merge branch 'master' into pino_ode
KirillZubov Sep 24, 2024
86677d0
Update Project.toml
KirillZubov Sep 24, 2024
6813c5d
lux 1
KirillZubov Sep 24, 2024
786035f
support Chain
KirillZubov Sep 24, 2024
5428b01
Update test/PINO_ode_tests.jl
KirillZubov Sep 24, 2024
60c3995
vector input Chain
KirillZubov Sep 25, 2024
90dfdb0
Update test/PINO_ode_tests.jl
KirillZubov Sep 27, 2024
2d7be46
Update test/PINO_ode_tests.jl
KirillZubov Sep 27, 2024
e342a7d
Update test/PINO_ode_tests.jl
KirillZubov Sep 27, 2024
000d8b5
input vector chain
KirillZubov Sep 28, 2024
5e9a025
support chain with StochasticTraining
KirillZubov Oct 1, 2024
e0cb528
NeuralOperators v0.5.0 in project
KirillZubov Oct 1, 2024
f1b3a36
wip output vector
KirillZubov Oct 2, 2024
ca24323
output vector
KirillZubov Oct 3, 2024
f74cf63
update Project.toml
KirillZubov Oct 3, 2024
885a72e
update
KirillZubov Oct 4, 2024
3baa5a6
update Project.toml
KirillZubov Oct 9, 2024
de008b3
undo
KirillZubov Oct 9, 2024
356083b
update runtests.jl
KirillZubov Oct 10, 2024
f59f225
update NeuralPDE
KirillZubov Oct 10, 2024
526fd0c
update Doc
KirillZubov Oct 10, 2024
d3ae594
update
KirillZubov Oct 10, 2024
1925aaf
Merge branch 'master' into pino_ode
KirillZubov Oct 17, 2024
6dd9e38
update
KirillZubov Oct 17, 2024
6748feb
update
KirillZubov Oct 17, 2024
2967a4a
Merge branch 'master' into pino_ode
KirillZubov Oct 18, 2024
36226f9
update
KirillZubov Oct 18, 2024
22f5144
PINOODETestSetup
KirillZubov Oct 18, 2024
18e45d2
update
KirillZubov Oct 18, 2024
aef24e3
update
KirillZubov Oct 18, 2024
501496e
update Project.toml docs
KirillZubov Oct 28, 2024
e8ac7f5
add sol(t)
KirillZubov Oct 28, 2024
11c67da
sol(t)
KirillZubov Oct 28, 2024
35d8346
update
KirillZubov Oct 29, 2024
c5a456a
Interpolation
KirillZubov Oct 29, 2024
b6609f9
fix
KirillZubov Oct 29, 2024
0617f42
update PINOODEInterpolation
KirillZubov Oct 30, 2024
9c857a5
update
KirillZubov Oct 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- Logging
- Forward
- DGM
- ODEPINO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you name it PINOODE for consistency?

- NNODE
- NeuralAdapter
- IntegroDiff
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"]
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"]
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pages = ["index.md",
"Bayesian PINNs for Coupled ODEs" => "tutorials/Lotka_Volterra_BPINNs.md",
"PINNs DAEs" => "tutorials/dae.md",
"Parameter Estimation with PINNs for ODEs" => "tutorials/ode_parameter_estimation.md",
"Physics informed Neural Operator ODEs" => "tutorials/pino_ode.md",
"Deep Galerkin Method" => "tutorials/dgm.md" #"examples/nnrode_example.md", # currently incorrect
],
"PDE PINN Tutorials" => Any[
Expand Down Expand Up @@ -31,6 +32,7 @@ pages = ["index.md",
"manual/training_strategies.md",
"manual/adaptive_losses.md",
"manual/logging.md",
"manual/neural_adapters.md"],
"manual/neural_adapters.md",
"manual/pino_ode.md"],
"Developer Documentation" => Any["developer/debugging.md"]
]
11 changes: 11 additions & 0 deletions docs/src/manual/pino_ode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Physics-Informed Neural operator for solve ODEs
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved

```@docs
PINOODE
```

```@docs
DeepONet
```


66 changes: 66 additions & 0 deletions docs/src/tutorials/pino_ode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Physics Informed Neural Operator for ODEs Solvers
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved

This tutorial provides an example of how to use the Physics Informed Neural Operator (PINO) for solving a family of parametric ordinary differential equations (ODEs).

## Operator Learning for a family of parametric ODE.
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved

In this section, we will define a parametric ODE and solve it using a PINO. The PINO will be trained to learn the mapping from the parameters of the ODE to its solution.
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved

```@example pino
using Test
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved
using OrdinaryDiffEq, OptimizationOptimisers
using Lux
using Statistics, Random
using NeuralPDE

equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)

# Define the architecture of the neural network that will be used as the PINO.
branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10))
trunk = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))
deeponet = NeuralPDE.DeepONet(branch, trunk; linear = nothing)

bounds = (p = [0.1f0, pi],)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having more than one parameters would be more illustrative of it's use case. Right now branch and trunk both have size one input and that makes it potentially confusing to the user for how to modify this demo towards a case with more parameters

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will do. It will take some time to implement this feature for many parameters.

db = (bounds.p[2] - bounds.p[1]) / 50
dt = (tspan[2] - tspan[1]) / 40
strategy = NeuralPDE.GridTraining([db, dt])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use grid training

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will support also QuasiRandomTraining for PINO ODE and use it in Doc example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StochasticTraining*

opt = OptimizationOptimisers.Adam(0.03)
alg = NeuralPDE.PINOODE(deeponet, opt, bounds; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 2000)
predict = sol.u
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sol.original is needed to explain

```

Now let's compare the prediction from the learned operator with the ground truth solution which is obtained by analytic solution the parametric ODE. Where
Compare prediction with ground truth.
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved

```@example pino
using Plots
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved
# Compute the ground truth solution for each parameter
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
p_ = bounds.p[1]:strategy.dx[1]:bounds.p[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using grid training will make this more compelling since it should predict at new parameters, not ones trained on

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes,I agree

p = reshape(p_, 1, size(p_)[1], 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

# Plot the predicted solution and the ground truth solution as a filled contour plot
# sol.u[1, :, :], represents the predicted solution for each parameter value and time
plot(predict[1, :, :], linetype = :contourf)
plot!(ground_solution[1, :, :], linetype = :contourf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't show how to generate the solution at new parameters which is the key to the pini interface

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as new parameters you means some another mesh that don't use for training but in the same boundary of parameters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

```

```@example pino
# 'i' is the index of the parameter 'p' in the dataset
i = 20
# 'predict' is the predicted solution from the PINO model
plot(predict[1, i, :], label = "Predicted")
# 'ground' is the ground truth solution
plot!(ground_solution[1, i, :], label = "Ground truth")
```
6 changes: 5 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using UnPack: @unpack
import ChainRulesCore, Lux, ComponentArrays
using Lux: FromFluxAdaptor
using ChainRulesCore: @non_differentiable
#using NeuralOperators

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand All @@ -47,6 +48,8 @@ include("adaptive_losses.jl")
include("ode_solve.jl")
# include("rode_solve.jl")
include("dae_solve.jl")
include("neural_operators.jl")
include("pino_ode_solve.jl")
include("transform_inf_integral.jl")
include("discretize.jl")
include("neural_adapter.jl")
Expand All @@ -55,7 +58,8 @@ include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("dgm.jl")

export NNODE, NNDAE,

export NNODE, NNDAE, PINOODE, DeepONet, SomeStrategy #TODO remove SomeStrategy
KirillZubov marked this conversation as resolved.
Show resolved Hide resolved
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
Expand Down
101 changes: 101 additions & 0 deletions src/neural_operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
abstract type NeuralOperator <: Lux.AbstractExplicitLayer end

"""
DeepONet(branch,trunk)
"""

"""
DeepONet(branch,trunk,linear=nothing)

`DeepONet` is differential neural operator focused for solving physic-informed parametric ODEs.

DeepONet uses two neural networks, referred to as the "branch" and "trunk", to approximate
the solution of a differential equation. The branch network takes the spatial variables as
input and the trunk network takes the temporal variables as input. The final output is
the dot product of the outputs of the branch and trunk networks.

DeepONet is composed of two separate neural networks referred to as the "branch" and "trunk",
respectively. The branch net takes on input represents a function evaluated at a collection
of fixed locations in some boundsand returns a features embedding. The trunk net takes the
continuous coordinates as inputs, and outputs a features embedding. The final output of the
DeepONet, the outputs of the branch and trunk networks are merged together via a dot product.

## Positional Arguments
* `branch`: A branch neural network.
* `trunk`: A trunk neural network.

## Keyword Arguments
* `linear`: A linear layer to apply to the output of the branch and trunk networks.

## Example

```julia
branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10))
trunk = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))
linear = Lux.Chain(Lux.Dense(10, 1))

deeponet = DeepONet(branch, trunk; linear= linear)

a = rand(1, 50, 40)
b = rand(1, 1, 40)
x = (branch = a, trunk = b)
θ, st = Lux.setup(Random.default_rng(), deeponet)
y, st = deeponet(x, θ, st)
```

## References
* Lu Lu, Pengzhan Jin, George Em Karniadakis "DeepONet: Learning nonlinear operators for identifying differential equations based on the universal approximation theorem of operators"
* Sifan Wang "Learning the solution operator of parametric partial differential equations with physics-informed DeepOnets"
"""
struct DeepONet{L <: Union{Nothing, Lux.AbstractExplicitLayer }} <: NeuralOperator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be living in NeuralOperators.jl. cc @avik-pal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just implemented it for tests because NeuralOperators.jl was outdate and couldn't used it. Yes agree, it should relocate to NeuralOperators.jl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be NeuralOperators.jl or LuxNeuralOperators.jl? I see SciML/NeuralOperators.jl#5 implementing Deeponet

Copy link
Member Author

@KirillZubov KirillZubov Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, Lux. I haven't known there is already one like this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do this task: move DeepOnet from here to LuxNeuralOperators

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basic NO and DeepONet are now in LuxNeuralOperators.jl
@KirillZubov

branch::Lux.AbstractExplicitLayer
trunk::Lux.AbstractExplicitLayer
linear::L
end

function DeepONet(branch, trunk; linear=nothing)
DeepONet(branch, trunk, linear)
end

function Lux.setup(rng::AbstractRNG, l::DeepONet)
branch, trunk, linear = l.branch, l.trunk, l.linear
θ_branch, st_branch = Lux.setup(rng, branch)
θ_trunk, st_trunk = Lux.setup(rng, trunk)
θ = (branch = θ_branch, trunk = θ_trunk)
st = (branch = st_branch, trunk = st_trunk)
if linear !== nothing
θ_liner, st_liner = Lux.setup(rng, linear)
θ = (θ..., liner = θ_liner)
st = (st..., liner = st_liner)
end
θ, st
end

Lux.initialstates(::AbstractRNG, ::DeepONet) = NamedTuple()

@inline function (f::DeepONet)(x::NamedTuple, θ, st::NamedTuple)
x_branch, x_trunk = x.branch, x.trunk
branch, trunk = f.branch, f.trunk
st_branch, st_trunk = st.branch, st.trunk
θ_branch, θ_trunk = θ.branch, θ.trunk
out_b, st_b = branch(x_branch, θ_branch, st_branch)
out_t, st_t = trunk(x_trunk, θ_trunk, st_trunk)
if f.linear !== nothing
linear = f.linear
θ_liner, st_liner = θ.liner, st.liner
# out = sum(out_b .* out_t, dims = 1)
out_ = out_b .* out_t
out, st_liner = linear(out_, θ_liner, st_liner)
out = sum(out, dims = 1)
return out, (branch = st_b, trunk = st_t, liner = st_liner)
else
out = sum(out_b .* out_t, dims = 1)
return out, (branch = st_b, trunk = st_t)
end
end
Loading
Loading