"""
The solver is based on and simplified from
    https://github.com/dfloryan/neural-manifold-dynamics/tree/main/reactDifCode
The solution setup is based on
    https://arxiv.org/pdf/2108.05928
"""

import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp
from tqdm import tqdm

def apply_bc(u):
    # Homogeneous Neumann boundary conditions
    u[0,  1:-1] = (4/3) * u[1,  1:-1] - (1/3) * u[2,  1:-1]
    u[-1, 1:-1] = (4/3) * u[-2, 1:-1] - (1/3) * u[-3, 1:-1]
    u[:, 0]     = (4/3) * u[:,  1] - (1/3) * u[:,  2]
    u[:, -1]    = (4/3) * u[:, -2] - (1/3) * u[:, -3]
    return u

def react_dif_neumann(t, z, beta, d1, d2, h, n):
    u = np.reshape(z[:n*n], (n, n))
    v = np.reshape(z[n*n:], (n, n))
    dudt = np.zeros((n, n))
    dvdt = np.zeros((n, n))
    u = apply_bc(u)
    v = apply_bc(v)

    uint = u[1:-1, 1:-1]
    vint = v[1:-1, 1:-1]

    uxx = (u[2:, 1:-1] - 2 * uint + u[:-2, 1:-1]) / h**2
    uyy = (u[1:-1, 2:] - 2 * uint + u[1:-1, :-2]) / h**2
    vxx = (v[2:, 1:-1] - 2 * vint + v[:-2, 1:-1]) / h**2
    vyy = (v[1:-1, 2:] - 2 * vint + v[1:-1, :-2]) / h**2

    tmp = (uint**2 + vint**2)
    dudt[1:-1, 1:-1] = (1 - tmp) * uint + beta * tmp * vint + d1 * (uxx + uyy)
    dvdt[1:-1, 1:-1] = -beta * tmp * uint + (1 - tmp) * vint + d2 * (vxx + vyy)

    rhs = np.concatenate([dudt.ravel(), dvdt.ravel()])
    return rhs

ifplt = 1
ifsav = 1

# Input parameters
beta = 1.0
d1 = 0.1
d2 = 0.1
n = 101
L = 20.0
T = 510
DT = 0.05
SKP = 10
t_span = (0, T+SKP)
t_eval = np.arange(SKP, T+SKP, DT)
Nt = len(t_eval)

# Construct spatial grid
x = np.linspace(-L/2, L/2, n)
y = np.linspace(-L/2, L/2, n)
h = x[1] - x[0]
x, y = np.meshgrid(x, y)

# Set initial conditions
r = np.sqrt(x**2 + y**2)
u = np.tanh(r * np.cos(np.angle(x + 1j * y) - r))
v = np.tanh(r * np.sin(np.angle(x + 1j * y) - r))
z0 = np.concatenate([u.ravel(), v.ravel()])

# Solve the ODE
sol = solve_ivp(
    lambda t, z: np.array(react_dif_neumann(t, z, beta, d1, d2, h, n)),
    t_span, z0, t_eval=t_eval, method='RK45')

# Extract results
u = sol.y[:n*n, :].T
v = sol.y[n*n:, :].T

# Apply homogeneous Neumann boundary conditions at each time step
for i in tqdm(range(Nt)):
    utemp = apply_bc(u[i].reshape(n, n))
    vtemp = apply_bc(v[i].reshape(n, n))
    u[i, :] = utemp.ravel()
    v[i, :] = vtemp.ravel()

if ifplt:
    dN = 144//4
    tdx = np.arange(4)*dN - Nt
    f, ax = plt.subplots(nrows=4, ncols=2, sharex=True, sharey=True, figsize=(4,10))
    for _i, _t in enumerate(tdx):
        ax[_i,0].contourf(x, y, u[_t].reshape(n, n))
        ax[_i,1].contourf(x, y, v[_t].reshape(n, n))

    f = plt.figure()
    plt.plot(t_eval, u[:,5000])
    plt.plot(t_eval, v[:,5000])

if ifsav:
    dat = {
        'x' : x,
        'y' : y,
        't' : t_eval-t_eval[0],
        'udata' : sol.y.T,
        'dt' : DT
    }
    pickle.dump(dat, open(f'rddata.pkl', 'wb'))

plt.show()