! pip install -q optax equinox jax matplotlib


import jax.numpy as jnp
import equinox as eqx
import optax
import time

from jax import random, vmap, grad
from jax.lax import scan

import matplotlib.pyplot as plt
%config InlineBackend.figure_format='retina'


class PiNN(eqx.Module):
    matrices: list
    biases: list

    def __init__(self, N_features, N_layers, key, s0=10):
        keys = random.split(key, N_layers+1)
        features = [N_features[0],] + [N_features[1],]*(N_layers-1) + [N_features[-1],]
        self.matrices = [random.uniform(key, (f_in, f_out), minval=-1, maxval=1)*jnp.sqrt(6/f_in) for f_in, f_out, key in zip(features[:-1], features[1:], keys)]
        keys = random.split(keys[-1], N_layers)
        self.biases = [jnp.zeros((f_out,)) for f_in, f_out, key in zip(features[:-1], features[1:], keys)]
        self.matrices[0] = self.matrices[0]*s0

    def __call__(self, x):
        f = x @ self.matrices[0] + self.biases[0]
        for i in range(1, len(self.matrices)):
            f = jnp.sin(f)
            f = f @ self.matrices[i] + self.biases[i]
        return f[0]

def residual_loss(model, x_in, x_bc, bc):
    nabla = 0
    for i in range(x_in.shape[1]):
        nabla += vmap(lambda y: grad(lambda x: grad(model)(x)[i])(y)[i])(x_in)
    return jnp.mean(nabla**2) + jnp.mean((vmap(model)(x_bc) - bc)**2)

compute_loss_and_grads = eqx.filter_value_and_grad(residual_loss)

def make_step_scan(carry, ind, optim):
    model, coords_in, coords_bc, bc, opt_state = carry
    loss, grads = compute_loss_and_grads(model, coords_in[ind[0]], coords_bc[ind[1]], bc[ind[1]])
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return [model, coords_in, coords_bc, bc, opt_state], loss

def solution(x):
    z = x[0] + 1j*x[1]
    z = (z - 1.2 - 0.5*1j)*(z + 0.2 - 0.5*1j)*(z - 0.5 + 0.2*1j)*(z - 0.5 - 1.2*1j)
    return jnp.real(jnp.sin(1/z)) # instead of jnp.real(1/z) as per reviewer's version


N_i = 64
x = jnp.linspace(0, 1, N_i+2)[1:-1]
coord = jnp.stack(jnp.meshgrid(x, x), 2).reshape(-1, 2)

N_b = coord.shape[0] // 4
x = jnp.linspace(0, 1, N_b)
bc_coords = jnp.concatenate([
    jnp.stack([jnp.zeros((N_b,)), x], axis=1),
    jnp.stack([x, jnp.zeros((N_b,))], axis=1),
    jnp.stack([x, jnp.ones((N_b,))], axis=1),
    jnp.stack([jnp.ones((N_b,)), x], axis=1),
], axis=0)

plt.plot(coord[:, 0], coord[:, 1], ".", color="red")
plt.plot(bc_coords[:, 0], bc_coords[:, 1], ".", color="navy");


N_i = 64
x = jnp.linspace(0, 1, N_i+2)[1:-1]
coord = jnp.stack(jnp.meshgrid(x, x), 2).reshape(-1, 2)

N_b = coord.shape[0] // 4
x = jnp.linspace(0, 1, N_b)
bc_coords = jnp.concatenate([
    jnp.stack([jnp.zeros((N_b,)), x], axis=1),
    jnp.stack([x, jnp.zeros((N_b,))], axis=1),
    jnp.stack([x, jnp.ones((N_b,))], axis=1),
    jnp.stack([jnp.ones((N_b,)), x], axis=1),
], axis=0)
bc_vals = vmap(solution)(bc_coords)

print(f"interior points {coord.shape[0]}")
print(f"boundary points {bc_coords.shape[0]}")

N_run = 200000
N_batch = 16*16
N_drop = 20000
learning_rate = 1e-3
gamma_= 0.5
key = random.PRNGKey(11)

N_features = 32
N_layers = 4
model = PiNN([2, N_features, 1], N_layers, key)

inds = random.randint(key, minval=0, maxval=coord.shape[0], shape=(N_run, 2, N_batch))
sc = optax.exponential_decay(learning_rate, N_drop, gamma_)
optim = optax.lion(learning_rate=sc)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

carry = [model, coord, bc_coords, bc_vals, opt_state]

make_step_scan_ = lambda a, b: make_step_scan(a, b, optim)

start = time.time()
carry, loss = scan(make_step_scan_, carry, inds)
stop = time.time()
training_time = stop - start
model = carry[0]

plt.yscale("log")
plt.plot(loss)

N = 128
x = jnp.linspace(0, 1, N)
coord = jnp.stack(jnp.meshgrid(x, x), 2).reshape(-1, 2)

sol_vals = vmap(solution)(coord)
prediction = vmap(model)(coord)
error = jnp.linalg.norm(sol_vals - prediction) / jnp.linalg.norm(prediction)
rmse = jnp.sqrt(jnp.mean((sol_vals - prediction)**2)) # instead of jnp.sqrt(jnp.mean(sol_vals - prediction)**2) as per reviewer's version
print(f"training time {training_time}")
print(f"relative error {error}")
print(f"rmse {rmse}")


N_i = 32
x = jnp.linspace(0, 1, N_i+2)[1:-1]
coord = jnp.stack(jnp.meshgrid(x, x), 2).reshape(-1, 2)

N_b = coord.shape[0] // 4
x = jnp.linspace(0, 1, N_b)
bc_coords = jnp.concatenate([
    jnp.stack([jnp.zeros((N_b,)), x], axis=1),
    jnp.stack([x, jnp.zeros((N_b,))], axis=1),
    jnp.stack([x, jnp.ones((N_b,))], axis=1),
    jnp.stack([jnp.ones((N_b,)), x], axis=1),
], axis=0)
bc_vals = vmap(solution)(bc_coords)

print(f"interior points {coord.shape[0]}")
print(f"boundary points {bc_coords.shape[0]}")

N_run = 200000
N_batch = 16*16
N_drop = 20000
learning_rate = 1e-3
gamma_= 0.5
key = random.PRNGKey(11)

N_features = 32
N_layers = 4
model = PiNN([2, N_features, 1], N_layers, key)

inds = random.randint(key, minval=0, maxval=coord.shape[0], shape=(N_run, 2, N_batch))
sc = optax.exponential_decay(learning_rate, N_drop, gamma_)
optim = optax.lion(learning_rate=sc)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

carry = [model, coord, bc_coords, bc_vals, opt_state]

make_step_scan_ = lambda a, b: make_step_scan(a, b, optim)

start = time.time()
carry, loss = scan(make_step_scan_, carry, inds)
stop = time.time()
training_time = stop - start
model = carry[0]

plt.yscale("log")
plt.plot(loss);

N = 128
x = jnp.linspace(0, 1, N)
coord = jnp.stack(jnp.meshgrid(x, x), 2).reshape(-1, 2)

sol_vals = vmap(solution)(coord)
prediction = vmap(model)(coord)
error = jnp.linalg.norm(sol_vals - prediction) / jnp.linalg.norm(prediction)
rmse = jnp.sqrt(jnp.mean((sol_vals - prediction)**2)) # instead of jnp.sqrt(jnp.mean(sol_vals - prediction)**2) as per reviewer's version
print(f"training time {training_time}")
print(f"relative error {error}")
print(f"rmse {rmse}")


N_i = 16
x = jnp.linspace(0, 1, N_i+2)[1:-1]
coord = jnp.stack(jnp.meshgrid(x, x), 2).reshape(-1, 2)

N_b = coord.shape[0] // 4
x = jnp.linspace(0, 1, N_b)
bc_coords = jnp.concatenate([
    jnp.stack([jnp.zeros((N_b,)), x], axis=1),
    jnp.stack([x, jnp.zeros((N_b,))], axis=1),
    jnp.stack([x, jnp.ones((N_b,))], axis=1),
    jnp.stack([jnp.ones((N_b,)), x], axis=1),
], axis=0)
bc_vals = vmap(solution)(bc_coords)

print(f"interior points {coord.shape[0]}")
print(f"boundary points {bc_coords.shape[0]}")

N_run = 200000
N_batch = 16*16
N_drop = 20000
learning_rate = 1e-3
gamma_= 0.5
key = random.PRNGKey(11)

N_features = 32
N_layers = 4
model = PiNN([2, N_features, 1], N_layers, key)

inds = random.randint(key, minval=0, maxval=coord.shape[0], shape=(N_run, 2, N_batch))
sc = optax.exponential_decay(learning_rate, N_drop, gamma_)
optim = optax.lion(learning_rate=sc)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

carry = [model, coord, bc_coords, bc_vals, opt_state]

make_step_scan_ = lambda a, b: make_step_scan(a, b, optim)

start = time.time()
carry, loss = scan(make_step_scan_, carry, inds)
stop = time.time()
training_time = stop - start
model = carry[0]

plt.yscale("log")
plt.plot(loss)

N = 128
x = jnp.linspace(0, 1, N)
coord = jnp.stack(jnp.meshgrid(x, x), 2).reshape(-1, 2)

sol_vals = vmap(solution)(coord)
prediction = vmap(model)(coord)
error = jnp.linalg.norm(sol_vals - prediction) / jnp.linalg.norm(prediction)
rmse = jnp.sqrt(jnp.mean((sol_vals - prediction)**2)) # instead of jnp.sqrt(jnp.mean(sol_vals - prediction)**2) as per reviewer's version
print(f"training time {training_time}")
print(f"relative error {error}")
print(f"rmse {rmse}")



