{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%load_ext autoreload\n",
                "%autoreload 2"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import jax\n",
                "import jax.numpy as jnp\n",
                "import matplotlib.pyplot as plt\n",
                "from jax import jit, vmap, grad, jacfwd\n",
                "import jax.flatten_util\n",
                "from time import time\n",
                "import pandas as pd\n",
                "import numpy as np"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# High Speed Settings\n",
                "n_x = 1000  # number of sample points in sapce\n",
                "sub_sample = 150  # number of paramters to randomly sample\n",
                "dt = 5e-3  # time step for rk4 integrator\n",
                "\n",
                "\n",
                "# High Accuracy Settings \n",
                "# n_x = 10_000 # number of sample points in sapce\n",
                "# sub_sample = 800 # number of paramters to randomly sample\n",
                "# dt = 5e-3 # time step for rk4 integrator"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# set up time and space domain\n",
                "Tend = 4.0\n",
                "t_eval = jnp.linspace(0.0, Tend, int(Tend/dt)+1)\n",
                "\n",
                "dim = 1\n",
                "A, B = 0, 2*jnp.pi\n",
                "x_eval = jnp.expand_dims(jnp.linspace(A, B, n_x), axis=-1)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from rsng.dnn import build_nn, init_net\n",
                "\n",
                "key = jax.random.PRNGKey(1)\n",
                "\n",
                "width = 25\n",
                "depth = 7\n",
                "period = 2*jnp.pi\n",
                "\n",
                "net = build_nn(width, depth, period)\n",
                "u_scalar, theta_init, unravel = init_net(net, key, dim)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# used to take gradient and then squeeze\n",
                "def gradsqz(f, *args, **kwargs):\n",
                "    return lambda *fargs, **fkwargs: jnp.squeeze(grad(f, *args, **kwargs)(*fargs, **fkwargs))\n",
                "\n",
                "\n",
                "# batch the function over X points\n",
                "U = vmap(u_scalar, (None, 0))\n",
                "\n",
                "# derivative with repsect to theta\n",
                "U_dtheta = vmap(grad(u_scalar), (None, 0))\n",
                "\n",
                "# spatial derivatives\n",
                "U_ddx = vmap(gradsqz(gradsqz(u_scalar, 1), 1), (None, 0))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# load the parameters which fit inital condition\n",
                "theta_0 = pd.read_pickle('./rsng/data/theta_init_ac.pkl')\n",
                "theta_0 = jax.flatten_util.ravel_pytree(theta_0)[0]\n",
                "# plot inital condition\n",
                "plt.plot(x_eval, U(theta_0, x_eval))\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def rhs(t, theta):\n",
                "    \"f(theta), dynamics for the allen-cahn equation\"\n",
                "    u = U(theta, x_eval)\n",
                "    u_xx = U_ddx(theta, x_eval)\n",
                "    return 5e-3*u_xx+u-u**3\n",
                "\n",
                "\n",
                "def rhs_reparamaterized(t, theta, key):\n",
                "    \"\"\"\n",
                "    we project f(theta) onto J(theta)\n",
                "    we then sparsely subsample J via S_t\n",
                "    finally we solve a least sqaures problem to get theta_dot\n",
                "    \"\"\"\n",
                "    J = U_dtheta(theta, x_eval)  # take the gradient with respect to the parameters\n",
                "    S_t = jax.random.choice(key, len(theta), shape=(sub_sample,), replace=False) # create random indices over the columns\n",
                "    J = jnp.take(J, S_t, axis=1)  # subsample columns\n",
                "    f = rhs(t, theta)  # compute f from rhs\n",
                "\n",
                "    # solve least sqaures problem on subsampled gradient\n",
                "    theta_dot = jnp.linalg.lstsq(J, f, rcond=1e-4)[0]\n",
                "    \n",
                "    # go back into full parameter space, zero for non sampled columns\n",
                "    theta_dot = jnp.zeros(len(theta)).at[S_t].set(theta_dot)\n",
                "    return theta_dot"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def odeint_rk4(fn, y0, t, key):\n",
                "    \"Adapted from: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/jax_cosmo/scipy/ode.py\"\n",
                "    def rk4(carry, t):\n",
                "        y, t_prev, key = carry\n",
                "        h = t - t_prev\n",
                "        key, subkey = jax.random.split(key)\n",
                "\n",
                "        k1 = fn(t_prev, y, subkey)\n",
                "        k2 = fn(t_prev + h / 2, y + h * k1 / 2, subkey)\n",
                "        k3 = fn(t_prev + h / 2, y + h * k2 / 2, subkey)\n",
                "        k4 = fn(t, y + h * k3, subkey)\n",
                "\n",
                "        y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4)\n",
                "        return (y, t, key), y\n",
                "\n",
                "    (yf, _, _), y = jax.lax.scan(rk4, (y0, jnp.array(t[0]), key), t)\n",
                "    return y"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def integrate(y0, t):\n",
                "    return odeint_rk4(rhs_reparamaterized, y0, t, key)\n",
                "\n",
                "# here we seperate compile time from integration time\n",
                "integrate_complied = jit(integrate).lower(theta_0, t_eval).compile()\n",
                "print('jit complied!')\n",
                "time_start = time()\n",
                "y = integrate_complied(theta_0, t_eval)\n",
                "time_end = time()\n",
                "print('done!')"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "steps = len(t_eval)\n",
                "theta_dot = np.zeros((steps, len(x_eval)))\n",
                "for i in range(steps):\n",
                "    theta = y[i, :]\n",
                "    theta_dot[i] = jnp.squeeze(U(theta, x_eval))\n",
                "\n",
                "plt.imshow(theta_dot, aspect='auto')\n",
                "plt.title('sol')\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import scipy\n",
                "from scipy.interpolate import RegularGridInterpolator\n",
                "\n",
                "# evalulate the error against the true solution\n",
                "\n",
                "data = scipy.io.loadmat('./rsng/data/gt_ac_small.mat')\n",
                "t_true = np.float32(data['t'][0])\n",
                "x_true = np.float32(data['x'][0])\n",
                "usol = np.float32(data['Uvals'])\n",
                "\n",
                "gt_f = RegularGridInterpolator(\n",
                "    (t_true, x_true), usol, method='linear', bounds_error=True)\n",
                "\n",
                "m_grids = np.meshgrid(t_eval, x_eval, indexing='ij')\n",
                "m_grids = [m.flatten() for m in m_grids]\n",
                "t_grid = np.array(m_grids, dtype=np.float32).T\n",
                "true = gt_f(t_grid).reshape(len(t_eval), len(x_eval))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "rl = np.linalg.norm(true-theta_dot) / np.linalg.norm(true)\n",
                "print(f'relative l2 error: {rl:.2e}')\n",
                "print(f'Time: {time_end-time_start}')"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3 (ipykernel)",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.10.11"
        },
        "orig_nbformat": 4
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
