TP 6 Operator learning - DeepONet¶
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import toeplitz
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
Dans ce TP, on va chercher à obetnir une approximation de la solution de l'équation $$ -u''(x) = f(x) \qquad x\in(0,1) $$ avec des conditions de Dirichlet aux bords: $u(0)=u(1)=0$.
Pour cela, on va considérer les architectures DeepONet et FNO qui vont prendre en entrée la fonction $f$ (une discrétisation de cette fonction), et nous renvoyer une approximation de $u$. Cette équation modélise une corde attachée aux points $0$ et $1$ et sur laquelle on exerce une force $f$.
I) Résolution numérique de l'équation.¶
On peut obtenir $u$ à l'aide d'un schéma aux différences finis. L'idée est d'approcher $u$ aux points $$ x_j^N = \frac{j}{N+1} \in[0,1],\qquad j\in\{0,\dots,N+1\}, $$ et on approche notre équation par $$ \frac{1}{h^2}(-u(x_{j-1}^N)+2u(x_j^N)-u(x_{j+1}^N)) + o(1)= f(x_j^N)\qquad j\in\{1,\dots,N\} $$ qui nous amène à résoudre le système $$ (P_N):\quad A^N U^N = F^N $$ avec $$ A^N = \frac{1}{h^2}\begin{pmatrix} 2 & -1 & 0 & 0 &\dots& 0 \\ -1& 2 & -1 & 0&\dots &0 \\ 0& -1 &2 &-1&\dots &0 \\ \vdots & & \ddots &\ddots&\ddots&\vdots \\ 0& \dots & 0&-1&2& -1 \\ 0&\dots&\dots&0&-1&2 \end{pmatrix}\qquad F^N = \begin{pmatrix} f(x_1^N)\\ \vdots \\ f(x_N^N) \end{pmatrix} $$ où $U^N \in \mathbb{R}^N$ est l'inconnue approchant $u$ au point $x^N_j$: $$ U^N =\begin{pmatrix} U^N_1\\ \vdots \\ U^N_N) \end{pmatrix} \simeq \begin{pmatrix} u(x_1^N)\\ \vdots \\ u(x_N^N) \end{pmatrix}. $$ Ici, on a $U^N_0 = U^N_{N+1}=0$.
Une implémentation d'un solver pour notre équation par différences finies est la suivante. On testera notre implémentation sur un cas test pour vérifier que tout marche bien. On considérera la fonction $f \equiv 1$ pour laquelle $$ u(x) = x*(1-1)/2\qquad x\in [0,1].$$
# Implementation of the solver for -u"=f using finite differences
def solve_poisson(f, x):
"""
Solve -u'' = f
Parameters:
x: Discretized points in the domain [0, 1]
f: Discretization of the source term f in the interior domain.
Returns:
u: Numerical solution to -u'' = f.
"""
# number of interior points
N = len(x[1:-1])
# Step size
h = 1 / (N + 1)
# Construct the Toeplitz matrix for the second derivative
A = ( np.diag([2]*N) + np.diag([-1]*(N-1),-1) + np.diag([-1]*(N-1),1) )/h**2
# Solve the linear system
u_interior = np.linalg.solve(A, f) # Only interior points
# Add boundary conditions (u(0) = u(1) = 0) to the solution
u = np.zeros(N + 2)
u[1:-1] = u_interior
return u
Testons notre code.
N = 100 # number of discretization points
x = np.linspace(0, 1, N+2) # discretization of [0,1] with N interior points
f = np.ones(N) # discretisation of the function constant equal to 1
u_expl = 0.5 * x * (1-x) # explicit solution to -u"=f
u = solve_poisson(f, x) # solve the Poisson equation with our solver
Illustrons la comparaison de la solution calculée et la solution explicite.
plt.plot(x, u, label = 'simulation')
plt.plot(x, u_expl, label = 'exacte')
plt.xlabel(r'$x$', fontsize=18)
plt.ylabel(r'$u$', fontsize=18)
plt.title('solution to $-u\'\' = f$, with $f\equiv 1$', fontsize=18)
plt.legend(fontsize=12)
plt.show()
Comme attendu, on reproduit très bien la solution explicite.
II) Construction des données d'entrainement à partir du solver.¶
Commençons par écrire une fonction qui nous permet de générer une fonction $f$ de manière aléatoire. On va considérer des trajectoires Gaussienne qu'on va simuler à l'aide de numpy.random.multivariate_normal
. On va considérer des trajectoires à moyenne nulle et de fonction de covariance
$$
\mathbb{E}[f(X)f(X')] = e^{-|x-x'|^2/(2\sigma^2)},
$$
appellé RBF (radial basis function) kernel.
# define the covariance function for trajectories over the grid x
def covariance(x, sigma):
X1, X2 = np.meshgrid(x,x)
return 0.3 * np.exp(-0.5*(X1-X2)**2/sigma**2)
# défine the mean for the gaussian trajectories
f_mean = np.zeros(len(x))
# define the covariance matrix for the gaussian trajectories
f_cov = covariance(x, 0.1)
# generate n_sample trajectories for illustration
n_samples = 5
y_samples = np.random.multivariate_normal(f_mean, f_cov, n_samples)
# Affichage des trajectoires
plt.figure(figsize=(8, 5))
for i in range(n_samples):
plt.plot(x, y_samples[i], label="")
plt.title("trajectories of a gaussian process")
plt.xlabel("x", fontsize=18)
plt.ylabel("f(x)", fontsize=18)
plt.show()
Construisons maintenant le jeu de donnée
# dataset size
n_data = 70
# draw randomly the f's functions
f_data_tmp = np.random.multivariate_normal(f_mean, f_cov, n_data)
u_data = np.zeros((n_data, len(x))) # initialization for the u's data set
for j in range(n_data):
u_data[j,:] = solve_poisson(f_data_tmp[j,1:-1], x)
# to solve the poisson equation we remove the point at x=0 and x=1
Pour finir, on va mettre le jeu de donnée aux bons formats (sous forme de matrice à une colonne et en torch.tensor
)
# for the grid point
x_data = np.tile(x, (n_data,1))
x_data = torch.tensor(x_data, dtype=torch.float32).reshape(-1,1)
# for the f's function
f_data = torch.tensor(f_data_tmp, dtype=torch.float32)
# for the u's function
u_data = torch.tensor(u_data, dtype=torch.float32).reshape(-1,1)
III) Entrainement du DeepONet¶
Commençons par introduire une classe pour définir un DeepONet, dont les conditions de Dirichlet aux bords sont déjà implémentées.
Question: Définir une classe de DeepONet.¶
class MLP(nn.Module):
def __init__(self, layers, activations):
"""
Parameters:
- layers: List[int], the number of neurons in each layer (including input and output layers).
- activations: List[nn.Module], the activation functions for each layer.
The length of activations should be len(layers) - 2 (one for each hidden layer).
"""
super().__init__()
if len(activations) != len(layers) - 2:
raise ValueError("The number of activation functions must match the number of hidden layers.")
# Build the network
self.layers = nn.Sequential()
for i in range(len(layers) - 1):
self.layers.add_module(f"linear_{i}", nn.Linear(layers[i], layers[i + 1]))
if i < len(activations): # Apply activation to all but the last layer
self.layers.add_module(f"activation_{i}", activations[i])
def forward(self, x):
return self.layers(x)
class DeepONet(nn.Module):
def __init__(self, layers_branch, activations_branch, layers_trunk, activation_trunk):
super().__init__()
if len(activations_branch) != len(layers_branch) - 2:
raise ValueError("For the branch net, the number of activation functions must match the number of hidden layers.")
if len(activations_trunk) != len(layers_trunk) - 2:
raise ValueError("For the trunk net, the number of activation functions must match the number of hidden layers.")
if layers_branch[-1] != layers_trunk[-1]:
raise ValueError("The output size of the branch and trunk net must match.")
# Build the branch net
self.layers_branch = MLP(layers_branch, activations_branch)
# Build the trunk net
self.layers_trunk = MLP(layers_trunk, activations_trunk)
# Add a bias term for the final combination
self.bias = nn.Parameter(torch.tensor(1., requires_grad=True))
def forward(self, f, x):
branch_output = self.layers_branch(f)
trunk_output = self.layers_trunk(x)
# Combine features with a linear layer
output = torch.sum(branch_output * trunk_output, dim=-1, keepdim=True) + self.bias
output = x * (1-x) * output # hard code Dirichlet
return output
Question: A partir du jeu de donner précédemment généré, entrainer un DeepONet.¶
def train_DON(model, optimizer, loss_fct, f_data, x_data, u_data, epochs=1000):
for epoch in range(epochs):
optimizer.zero_grad()
# data part
u = model(f_data, x_data)
loss = loss_fct(u, u_data)
loss.backward()
optimizer.step()
if epoch%1000 == 0:
print(f"Epoch {epoch}, Loss: {loss.item()}")
print(f"Epoch {epoch}, Loss: {loss.item()}")
# loss finction
loss_fct = nn.MSELoss(reduction='sum')
# deeponet architechture
layers_branch = [N+2, 32, 32, 32]
activations_branch = [nn.Tanh(), nn.Tanh()]
layers_trunk = [1, 32, 32, 32]
activations_trunk = [nn.Tanh(), nn.Tanh()]
# define the deeponet
model_don = DeepONet(layers_branch, activations_branch, layers_trunk, activations_trunk)
# print the model
print(model_don)
DeepONet( (layers_branch): MLP( (layers): Sequential( (linear_0): Linear(in_features=102, out_features=32, bias=True) (activation_0): Tanh() (linear_1): Linear(in_features=32, out_features=32, bias=True) (activation_1): Tanh() (linear_2): Linear(in_features=32, out_features=32, bias=True) ) ) (layers_trunk): MLP( (layers): Sequential( (linear_0): Linear(in_features=1, out_features=32, bias=True) (activation_0): Tanh() (linear_1): Linear(in_features=32, out_features=32, bias=True) (activation_1): Tanh() (linear_2): Linear(in_features=32, out_features=32, bias=True) ) ) )
# learning parameters and batch size
learning_rate = 0.001 # 0.001
epochs = 20000
# optimization method
optimizer_don = torch.optim.Adam(model_don.parameters(), lr=learning_rate)
# train the model
train_DON(model_don, optimizer_don, loss_fct, f_data, x_data, u_data, epochs)
Epoch 0, Loss: 229.49928283691406 Epoch 1000, Loss: 0.05638007074594498 Epoch 2000, Loss: 0.017969917505979538 Epoch 3000, Loss: 0.016588397324085236 Epoch 4000, Loss: 0.014387266710400581 Epoch 5000, Loss: 0.005724779795855284 Epoch 6000, Loss: 0.004946759901940823 Epoch 7000, Loss: 0.0039475602097809315 Epoch 8000, Loss: 0.002361342776566744 Epoch 9000, Loss: 0.0015101548051461577 Epoch 10000, Loss: 0.0020293891429901123 Epoch 11000, Loss: 0.0008600761648267508 Epoch 12000, Loss: 0.0005873643094673753 Epoch 13000, Loss: 0.0006012881640344858 Epoch 14000, Loss: 0.005301603116095066 Epoch 15000, Loss: 0.005529624875634909 Epoch 16000, Loss: 0.001972277881577611 Epoch 17000, Loss: 0.008807550184428692 Epoch 18000, Loss: 0.000430085085099563 Epoch 19000, Loss: 0.000499415909871459 Epoch 19999, Loss: 0.002584925852715969
IV) Généralisation¶
Question: Illustrer graphiquement les capacités de généralisation du DeepONet entrainé.¶
Pour illustrer graphiquement les capacités de généralisation, on va générer de nouvelles fonctions $f$, et comparer le $u$ obtenu par le solver numérique et celui prédit par le DeepONet qu'on vient d'entrainer.
Sur un graphique, nous allons tracer n_test
courbes à la fois.
# number of curves on the graphic
n_test = 3
# grid as torch.tensor
x_test = torch.tensor(x, dtype=torch.float32).reshape(-1,1)
# new set of input function
f_test = np.random.multivariate_normal(f_mean, f_cov, n_test)
f_test_torch = torch.tensor(f_test, dtype=torch.float32)
# compute the u's using the solver
u_test = np.zeros((n_test, N+2))
for j in range(n_test):
u_test[j,:] = solve_poisson(f_test[j,1:-1], x)
# compute the predictions and make the plot
for j in range(n_test):
u_pred = model_don(f_test_torch[j,:], x_test).detach()
plt.plot(x, u_test[j,:], 'o', label='test')
plt.plot(x, u_pred, label='prediction', lw=3)
plt.legend()
plt.show()
Question: A partir d'un jeu de fonction de test $f_1,\dots, f_{n_{test}}$ tirées aléatoirement comme précédemment, écrire une fonction qui permet de calcuer¶
$$ err_{test} = \frac{1}{n_{test}} \sum_{j=1}^{n_{test}} \sup_{x\in[0,1]} |u_j(x)-u^\ast_j(x)|$$ c'est à dire l'erreur en norme uniforme moyenne entre les solutions calculées avec le solver numérique et celles prédites par le DeepONet.
def test_error(model, n_test, x):
# generate the new test inputs f_j
f_test = np.random.multivariate_normal(f_mean, f_cov, n_test)
# conversion to torch tensor
f_test_torch = torch.tensor(f_test, dtype=torch.float32)
x_test = torch.tensor(x, dtype=torch.float32).reshape(-1,1)
# initialisation of the mean err
err = 0.
for j in range(n_test):
# solution provided by the numerical solver
u_test = solve_poisson(f_test[j,1:-1], x)
# prediction provided by the deeponet
u_tmp = lambda x: model(f_test_torch[j,:], x)
u_pred = np.array([u_tmp(x).item() for x in x_test])
# update the error
err += np.max(np.abs(u_pred - u_test))
return err/n_test
# number of test sample
n_test = 300
# estimation of the mean error
err_don = test_error(model_don, n_test, x)
err_don
0.0023194578910540053