{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d37e5428",
   "metadata": {},
   "source": [
    "In this Jupyter notebook, we will train a machine learned FV solver to solve the 1D advection equation at reduced resolution. Our objective is to study the difference in performance between a solver trained using an unrolled loss function and one trained using a standard time-derivative (stationary) loss function.\n",
    "\n",
    "We are also interested in investigating the difference in $\\frac{d\\boldsymbol{a}}{dt}$ between each model, to see what unrolling the loss function actually does to the time-derivative."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d59ad011",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup paths\n",
    "import sys\n",
    "basedir = '/Users/nickm/thesis/icml2023paper/1d_advection'\n",
    "readwritedir = '/Users/nickm/thesis/icml2023paper/1d_advection'\n",
    "\n",
    "sys.path.append('{}/core'.format(basedir))\n",
    "sys.path.append('{}/simulate'.format(basedir))\n",
    "sys.path.append('{}/ml'.format(basedir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "316a93ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import external packages\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as onp\n",
    "from jax import config, vmap\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "import xarray\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae413a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import internal packages\n",
    "from flux import Flux\n",
    "from initialconditions import get_a0, get_initial_condition_fn, get_a\n",
    "from simparams import CoreParams, CoreParamsDG, SimulationParams\n",
    "from legendre import generate_legendre\n",
    "from simulations import AdvectionFVSim, AdvectionDGSim\n",
    "from trajectory import get_trajectory_fn, get_inner_fn\n",
    "from trainingutils import save_training_data, save_training_data_unroll\n",
    "from mlparams import TrainingParams, StencilParams, TrainingParamsUnroll\n",
    "from model import LearnedStencil\n",
    "from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, \n",
    "                           compute_losses_no_model, init_params, save_training_params, load_training_params, \n",
    "                          save_training_data_unroll, get_batch_fn_unroll, get_loss_fn_unroll)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2932d0ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions\n",
    "\n",
    "def plot_fv(a, core_params, color=\"blue\"):\n",
    "    plot_dg(a[...,None], core_params, color=color)\n",
    "    \n",
    "def plot_fv_trajectory(trajectory, core_params, t_inner, color='blue'):\n",
    "    plot_dg_trajectory(trajectory[...,None], core_params, t_inner, color=color)\n",
    "    \n",
    "def plot_dg(a, core_params, color='blue'):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    def evalf(x, a, j, dx, leg_poly):\n",
    "        x_j = dx * (0.5 + j)\n",
    "        xi = (x - x_j) / (0.5 * dx)\n",
    "        vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "        return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = a.shape[0]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "\n",
    "    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p))\n",
    "    a_plot = a_plot.T.reshape(-1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {('x'): xs}\n",
    "    data = xarray.DataArray(a_plot, coords=coords)\n",
    "    data.plot(color=color)\n",
    "\n",
    "def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = trajectory.shape[1]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    \n",
    "    def get_plot_repr(a):\n",
    "        def evalf(x, a, j, dx, leg_poly):\n",
    "            x_j = dx * (0.5 + j)\n",
    "            xi = (x - x_j) / (0.5 * dx)\n",
    "            vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "            return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T\n",
    "\n",
    "    get_trajectory_plot_repr = vmap(get_plot_repr)\n",
    "    trajectory_plot = get_trajectory_plot_repr(trajectory)\n",
    "\n",
    "    outer_steps = trajectory.shape[0]\n",
    "    \n",
    "    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {\n",
    "        'x': xs,\n",
    "        'time': t_inner * jnp.arange(outer_steps)\n",
    "    }\n",
    "    xarray.DataArray(trajectory_plot, dims=[\"time\", \"x\"], coords=coords).plot(\n",
    "        col='time', col_wrap=5, color=color)\n",
    "    \n",
    "def plot_multiple_fv_trajectories(trajectories, core_params, t_inner):\n",
    "    plot_multiple_dg_trajectories([trajectory[..., None] for trajectory in trajectories], core_params, t_inner)\n",
    "\n",
    "def plot_multiple_dg_trajectories(trajectories, core_params, t_inner):\n",
    "    outer_steps = trajectories[0].shape[0]\n",
    "    nx = trajectories[0].shape[1]\n",
    "    \n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    \n",
    "    def get_plot_repr(a):\n",
    "        def evalf(x, a, j, dx, leg_poly):\n",
    "            x_j = dx * (0.5 + j)\n",
    "            xi = (x - x_j) / (0.5 * dx)\n",
    "            vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "            return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T\n",
    "\n",
    "    get_trajectory_plot_repr = vmap(get_plot_repr)\n",
    "    trajectory_plots = []\n",
    "    for trajectory in trajectories:  \n",
    "        trajectory_plots.append(get_trajectory_plot_repr(trajectory).reshape(outer_steps, -1))\n",
    "        \n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {\n",
    "        'x': xs,\n",
    "        'time': t_inner * jnp.arange(outer_steps)\n",
    "    }\n",
    "    xarray.DataArray(trajectory_plots, dims=[\"stack\", \"time\", \"x\"], coords=coords).plot.line(\n",
    "        col='time', hue=\"stack\", col_wrap=5)\n",
    "    \n",
    "\n",
    "def get_core_params(order, flux='upwind'):\n",
    "    Lx = 1.0\n",
    "    if order == 0:\n",
    "        return CoreParams(Lx, flux)\n",
    "    else:\n",
    "        return CoreParamsDG(Lx, flux, order)\n",
    "\n",
    "def get_sim_params(name = \"test\", cfl_safety=0.3, rk='ssp_rk3'):\n",
    "    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk)\n",
    "\n",
    "def get_training_params(n_data, train_id=\"test\", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd'):\n",
    "    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)\n",
    "\n",
    "def get_training_params_unroll(n_data, train_id=\"test\", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd', n_unroll=2):\n",
    "    return TrainingParamsUnroll(n_unroll, n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)\n",
    "\n",
    "def get_stencil_params(kernel_size = 3, kernel_out = 4, stencil_width=4, depth = 3, width = 16):\n",
    "    return StencilParams(kernel_size, kernel_out, stencil_width, depth, width)\n",
    "\n",
    "\n",
    "def l2_norm_trajectory(trajectory):\n",
    "    return (jnp.mean(trajectory**2, axis=1))\n",
    "    \n",
    "def get_model(core_params, stencil_params):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    features = [stencil_params.width for _ in range(stencil_params.depth - 1)]\n",
    "    return LearnedStencil(features, stencil_params.kernel_size, stencil_params.kernel_out, stencil_params.stencil_width, p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a764bc88",
   "metadata": {},
   "source": [
    "## Stationary Loss\n",
    "\n",
    "##### Training Loop\n",
    "\n",
    "First, we will generate the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dd99d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training hyperparameters\n",
    "init_description = 'sum_sin'\n",
    "kwargs_init = {'min_num_modes': 2, 'max_num_modes': 6, 'min_k': 0, 'max_k': 3, 'amplitude_max': 1.0}\n",
    "kwargs_sim = {'name' : \"advection_test2\", 'cfl_safety' : 0.1, 'rk' : 'ssp_rk3'}\n",
    "kwargs_train_FV = {'train_id': \"advection_test2\", 'batch_size' : 8, 'optimizer': 'adam', 'num_epochs' : 1}\n",
    "kwargs_stencil = {'kernel_size' : 3, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}\n",
    "n_runs = 10\n",
    "t_inner_train = 0.01\n",
    "outer_steps_train = int(1.0/t_inner_train)\n",
    "fv_flux_baseline = 'muscl' # learning a correction to the MUSCL scheme\n",
    "nx_exact = 256\n",
    "nxs = [16, 32, 64, 128, 256]\n",
    "learning_rate_list = [0.0e-3, 0.0e-3, 0.0e-3, 0.0e-3, 0.0e-3]\n",
    "assert len(nxs) == len(learning_rate_list)\n",
    "key = jax.random.PRNGKey(12)\n",
    "\n",
    "# setup\n",
    "core_params = get_core_params(0, flux=fv_flux_baseline)\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "n_data = n_runs * outer_steps_train\n",
    "training_params_list = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]\n",
    "stencil_params = get_stencil_params(**kwargs_stencil)\n",
    "sim = AdvectionFVSim(core_params, sim_params)\n",
    "init_fn = lambda key: get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "model = get_model(core_params, stencil_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89a824db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save training data\n",
    "save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24d0c4f5",
   "metadata": {},
   "source": [
    "Next, we initialize the model parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84cb8f40",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(42)\n",
    "i_params = init_params(key, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbdd25e8",
   "metadata": {},
   "source": [
    "Next, we run a training loop for each value of nx. The learning rate undergoes a prespecified decay."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc39c297",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, nx in enumerate(nxs):\n",
    "    print(nx)\n",
    "    training_params = training_params_list[i]\n",
    "    idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "    batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)\n",
    "    loss_fn = get_loss_fn(model, core_params)\n",
    "    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)\n",
    "    save_training_params(nx, sim_params, training_params, params, losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65b7f887",
   "metadata": {},
   "source": [
    "Next, we load and plot the losses for each nx to check that the simulation trained properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff423c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Equation is 1D Advection\")\n",
    "for i, nx in enumerate(nxs):\n",
    "    losses, _ = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    plt.plot(losses, label=nx)\n",
    "    print(\"nx is {}, average loss before training is {}\".format(nx, jnp.mean(losses)))\n",
    "plt.ylim([0,20])\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db884533",
   "metadata": {},
   "source": [
    "Next, we plot the accuracy of the trained model on a few simple test cases to qualitatively evaluate the success of the training. We will eventually quantify the accuracy of the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a76a307e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick a key that gives something nice\n",
    "key = jax.random.PRNGKey(18)\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(\"nx is {}\".format(nx))\n",
    "    \n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    \n",
    "    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "    a0 = get_a0(f_init, core_params, nx)\n",
    "    t_inner = 1.0\n",
    "    outer_steps = 10\n",
    "    # with params\n",
    "    sim_model = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    trajectory_model = trajectory_fn_model(a0)\n",
    "    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])\n",
    "    \n",
    "    \"\"\"\n",
    "    # with global stabilization\n",
    "    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization=True, epsilon_gs=0.0, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "    trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])\n",
    "    \"\"\"\n",
    "\n",
    "    # without params\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory = trajectory_fn(a0)\n",
    "    #plot_fv_trajectory(trajectory, core_params, t_inner, color = 'red')\n",
    "    \n",
    "\n",
    "    plot_multiple_fv_trajectories([trajectory, trajectory_model], core_params, t_inner)\n",
    "    \n",
    "    plt.show()\n",
    "    plt.plot(l2_norm_trajectory(trajectory))\n",
    "    plt.plot(l2_norm_trajectory(trajectory_model))\n",
    "    #plt.plot(l2_norm_trajectory(trajectory_model_gs))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c40a6233",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 50\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10) # new key, same initial key for each nx\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    t_inner = 10.0\n",
    "    outer_steps = 10\n",
    "    sim_model = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    \n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    \n",
    "    num_nan = 0\n",
    "    \n",
    "    for n in range(N):\n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        num_nan += jnp.isnan(trajectory_model[-1]).any()\n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    print(\"nx is {}, num_nan is {} out of {}\".format(nx, num_nan, N))\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8685092d",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 25\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10)\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    t_inner = 10.0\n",
    "    outer_steps = 10\n",
    "    G = lambda f, u: (jnp.roll(u, -1) - u)\n",
    "    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = True, epsilon_gs = 0.0, G=G, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    \n",
    "    num_nan = 0\n",
    "    \n",
    "    for n in range(N):\n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        num_nan += jnp.isnan(trajectory_model[-1]).any()\n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    print(\"nx is {}, num_nan is {} out of {}\".format(nx, num_nan, N))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91bc56dd",
   "metadata": {},
   "source": [
    "The NaNs are eliminated."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd0d61e3",
   "metadata": {},
   "source": [
    "### Demonstrate that Global Stabilization Doesn't Degrade Accuracy\n",
    "\n",
    "We want to compare three different numerical algorithms for solving the 1D advection equation. We compare: (a) MUSCL (b) Machine Learned (ML) (c) Machine learned with global stabilization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ddc64",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 50\n",
    "\n",
    "mse_muscl = []\n",
    "mse_ml = []\n",
    "mse_mlgs = []\n",
    "\n",
    "def normalized_mse(traj, traj_exact):\n",
    "    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])\n",
    "\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10)\n",
    "    \n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    t_inner = 0.1\n",
    "    outer_steps = 10\n",
    "    \n",
    "    mse_muscl_nx = 0.0\n",
    "    mse_ml_nx = 0.0\n",
    "    mse_mlgs_nx = 0.0\n",
    "    \n",
    "    \n",
    "    # MUSCL\n",
    "    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)\n",
    "\n",
    "    # Model without GS\n",
    "    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "\n",
    "    # Model with GS\n",
    "    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "    \n",
    "    for n in range(N):\n",
    "        \n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        \n",
    "        trajectory_muscl = trajectory_fn_muscl(a0)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "        \n",
    "        \n",
    "        # Exact trajectory\n",
    "        exact_trajectory = onp.zeros((trajectory_muscl.shape[0],nx))\n",
    "        for n in range(outer_steps):\n",
    "            t = n * t_inner\n",
    "            exact_trajectory[n] = get_a(f_init, t, core_params, nx)\n",
    "        \n",
    "        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory) / N\n",
    "        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory) / N\n",
    "        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory) / N\n",
    "        \n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    mse_muscl.append(mse_muscl_nx)\n",
    "    mse_ml.append(mse_ml_nx)\n",
    "    mse_mlgs.append(mse_mlgs_nx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f9b418",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import matplotlib.lines as mlines\n",
    "print(mse_muscl)\n",
    "print(mse_ml)\n",
    "print(mse_mlgs)\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))\n",
    "axs.spines['top'].set_visible(False)\n",
    "axs.spines['right'].set_visible(False)\n",
    "linewidth = 3\n",
    "\n",
    "mses = [mse_ml, mse_mlgs, mse_muscl]\n",
    "labels = [\"ML\", \"ML (Stabilized)\", \"MUSCL\"]\n",
    "colors = [\"blue\", \"red\", \"purple\", \"green\"]\n",
    "linestyles = [\"solid\", \"dashed\", \"solid\", \"solid\"]\n",
    "\n",
    "for k, mse in enumerate(mses):\n",
    "    plt.loglog(nxs, mse, label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])\n",
    "\n",
    "axs.set_xticks([64, 32, 16])\n",
    "axs.set_xticklabels([\"N=64\", \"N=32\", \"N=16\"], fontsize=18)\n",
    "axs.set_yticks([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0])\n",
    "axs.set_yticklabels([\"$10^{-5}$\", \"$10^{-4}$\", \"$10^{-3}$\", \"$10^{-2}$\", \"$10^{-1}$\", \"$10^0$\"], fontsize=18)\n",
    "axs.minorticks_off()\n",
    "axs.set_ylabel(\"Normalized MSE\", fontsize=18)\n",
    "axs.text(0.3, 0.95, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')\n",
    "\n",
    "\n",
    "handles = []\n",
    "for k, mse in enumerate(mses):\n",
    "    handles.append(\n",
    "        mlines.Line2D(\n",
    "            [],\n",
    "            [],\n",
    "            color=colors[k],\n",
    "            linewidth=linewidth,\n",
    "            label=labels[k],\n",
    "            linestyle=linestyles[k]\n",
    "        )\n",
    "    )\n",
    "axs.legend(handles=handles,loc=(0.655,0.45) , prop={'size': 15}, frameon=False)\n",
    "plt.ylim([2.5e-6, 1e0+1e-1])\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "668941e8",
   "metadata": {},
   "source": [
    "## Unroll Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce8aac98",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unroll = 8\n",
    "init_description = 'sum_sin'\n",
    "kwargs_init = {'min_num_modes': 2, 'max_num_modes': 6, 'min_k': 0, 'max_k': 3, 'amplitude_max': 0.5}\n",
    "kwargs_stencil = {'kernel_size' : 5, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}\n",
    "kwargs_sim_unroll = {'name' : \"unroll_test\", 'cfl_safety' : 0.1, 'rk' : 'ssp_rk3'}\n",
    "kwargs_train_FV_unroll = {'train_id': \"unroll_train\", 'batch_size' : 8, 'optimizer': 'adam', 'num_epochs' : 4, 'n_unroll' : n_unroll}\n",
    "fv_flux_baseline = 'muscl' # learning a correction to the MUSCL scheme\n",
    "\n",
    "n_runs = 200\n",
    "t_inner_train = 0.006\n",
    "outer_steps_train = int(1.0/t_inner_train)\n",
    "nx_exact = 256\n",
    "nxs = [16, 32]\n",
    "learning_rate_list_unroll = [3e-3, 3e-3]\n",
    "key = jax.random.PRNGKey(12)\n",
    "\n",
    "# setup\n",
    "core_params = get_core_params(0, flux=fv_flux_baseline)\n",
    "sim_params = get_sim_params(**kwargs_sim_unroll)\n",
    "dts = [sim_params.cfl_safety * (core_params.Lx / nx) for nx in nxs] # advection, c = 1\n",
    "n_data = n_runs * outer_steps_train\n",
    "training_params_list_unroll = [get_training_params_unroll(n_data, **kwargs_train_FV_unroll, learning_rate = lr) for lr in learning_rate_list_unroll]\n",
    "stencil_params = get_stencil_params(**kwargs_stencil)\n",
    "sim = AdvectionFVSim(core_params, sim_params)\n",
    "init_fn = lambda key: get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "model = get_model(core_params, stencil_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "756cc338",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save training data\n",
    "save_training_data_unroll(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs, n_unroll, dts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8812a92",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(42)\n",
    "i_params = init_params(key, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0344324b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, nx in enumerate(nxs):\n",
    "    print(nx)\n",
    "    training_params = training_params_list_unroll[i]\n",
    "    idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "    batch_fn = get_batch_fn_unroll(core_params, sim_params, training_params, nx)\n",
    "    loss_fn = get_loss_fn_unroll(model, core_params, sim_params, n_unroll)\n",
    "    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn, dt = dts[i])\n",
    "    save_training_params(nx, sim_params, training_params, params, losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a091ae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, nx in enumerate(nxs):\n",
    "    losses, _ = load_training_params(nx, sim_params, training_params_list_unroll[i], model)\n",
    "    print(losses)\n",
    "    plt.plot(losses, label=nx)\n",
    "plt.ylim([0,0.005])\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31522045",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick a key that gives something nice\n",
    "key = jax.random.PRNGKey(19)\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(\"nx is {}\".format(nx))\n",
    "    \n",
    "    _, params = load_training_params(nx, sim_params, training_params_list_unroll[i], model)\n",
    "    \n",
    "    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "    a0 = get_a0(f_init, core_params, nx)\n",
    "    t_inner = 1.0\n",
    "    outer_steps = 10\n",
    "    # with params\n",
    "    sim_model = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    trajectory_model = trajectory_fn_model(a0)\n",
    "    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])\n",
    "    \n",
    "    \"\"\"\n",
    "    # with global stabilization\n",
    "    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization=True, epsilon_gs=0.0, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "    trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])\n",
    "    \"\"\"\n",
    "\n",
    "    # without params\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory = trajectory_fn(a0)\n",
    "    #plot_fv_trajectory(trajectory, core_params, t_inner, color = 'red')\n",
    "    \n",
    "\n",
    "    plot_multiple_fv_trajectories([trajectory, trajectory_model], core_params, t_inner)\n",
    "    \n",
    "    plt.show()\n",
    "    plt.plot(l2_norm_trajectory(trajectory))\n",
    "    plt.plot(l2_norm_trajectory(trajectory_model))\n",
    "    #plt.plot(l2_norm_trajectory(trajectory_model_gs))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63137a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 50\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10) # new key, same initial key for each nx\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list_unroll[i], model)\n",
    "    t_inner = 10.0\n",
    "    outer_steps = 10\n",
    "    sim_model = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    \n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    \n",
    "    num_nan = 0\n",
    "    \n",
    "    for n in range(N):\n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        num_nan += jnp.isnan(trajectory_model[-1]).any()\n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    print(\"nx is {}, num_nan is {} out of {}\".format(nx, num_nan, N))\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f316872e",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 50\n",
    "\n",
    "mse_muscl = []\n",
    "mse_ml = []\n",
    "mse_mlgs = []\n",
    "\n",
    "def normalized_mse(traj, traj_exact):\n",
    "    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])\n",
    "\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10)\n",
    "    \n",
    "    _, params = load_training_params(nx, sim_params, training_params_list_unroll[i], model)\n",
    "    t_inner = 0.1\n",
    "    outer_steps = 10\n",
    "    \n",
    "    mse_muscl_nx = 0.0\n",
    "    mse_ml_nx = 0.0\n",
    "    mse_mlgs_nx = 0.0\n",
    "    \n",
    "    \n",
    "    # MUSCL\n",
    "    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)\n",
    "\n",
    "    # Model without GS\n",
    "    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "\n",
    "    # Model with GS\n",
    "    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "    \n",
    "    for n in range(N):\n",
    "        \n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        \n",
    "        trajectory_muscl = trajectory_fn_muscl(a0)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "        \n",
    "        \n",
    "        # Exact trajectory\n",
    "        exact_trajectory = onp.zeros((trajectory_muscl.shape[0],nx))\n",
    "        for n in range(outer_steps):\n",
    "            t = n * t_inner\n",
    "            exact_trajectory[n] = get_a(f_init, t, core_params, nx)\n",
    "        \n",
    "        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory) / N\n",
    "        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory) / N\n",
    "        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory) / N\n",
    "        \n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    mse_muscl.append(mse_muscl_nx)\n",
    "    mse_ml.append(mse_ml_nx)\n",
    "    mse_mlgs.append(mse_mlgs_nx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adcecc26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import matplotlib.lines as mlines\n",
    "print(mse_muscl)\n",
    "print(mse_ml)\n",
    "print(mse_mlgs)\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))\n",
    "axs.spines['top'].set_visible(False)\n",
    "axs.spines['right'].set_visible(False)\n",
    "linewidth = 3\n",
    "\n",
    "mses = [mse_ml, mse_mlgs, mse_muscl]\n",
    "labels = [\"ML\", \"ML (Stabilized)\", \"MUSCL\"]\n",
    "colors = [\"blue\", \"red\", \"purple\", \"green\"]\n",
    "linestyles = [\"solid\", \"dashed\", \"solid\", \"solid\"]\n",
    "\n",
    "for k, mse in enumerate(mses):\n",
    "    plt.loglog(nxs, mse, label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])\n",
    "\n",
    "axs.set_xticks([64, 32, 16])\n",
    "axs.set_xticklabels([\"N=64\", \"N=32\", \"N=16\"], fontsize=18)\n",
    "axs.set_yticks([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0])\n",
    "axs.set_yticklabels([\"$10^{-5}$\", \"$10^{-4}$\", \"$10^{-3}$\", \"$10^{-2}$\", \"$10^{-1}$\", \"$10^0$\"], fontsize=18)\n",
    "axs.minorticks_off()\n",
    "axs.set_ylabel(\"Normalized MSE\", fontsize=18)\n",
    "axs.text(0.3, 0.95, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')\n",
    "\n",
    "\n",
    "handles = []\n",
    "for k, mse in enumerate(mses):\n",
    "    handles.append(\n",
    "        mlines.Line2D(\n",
    "            [],\n",
    "            [],\n",
    "            color=colors[k],\n",
    "            linewidth=linewidth,\n",
    "            label=labels[k],\n",
    "            linestyle=linestyles[k]\n",
    "        )\n",
    "    )\n",
    "axs.legend(handles=handles,loc=(0.655,0.45) , prop={'size': 15}, frameon=False)\n",
    "plt.ylim([2.5e-6, 1e0+1e-1])\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66ac544d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
