{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "31bc80e8",
   "metadata": {},
   "source": [
    "# DO NOT CHANGE, RAN TO COMPLETION AS PART OF SUBMISSION\n",
    "\n",
    "\n",
    "In this Jupyter notebook, we will train a machine learned DG solver to solve the 1D advection equation at reduced resolution. Our objective is to study the frequency of instability, and to demonstrate that global stabilization eliminates this instability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bf4d4b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup paths\n",
    "import sys\n",
    "basedir = '/Users/nickm/thesis/icml2023paper/1d_advection'\n",
    "readwritedir = '/Users/nickm/thesis/icml2023paper/1d_advection'\n",
    "\n",
    "sys.path.append('{}/core'.format(basedir))\n",
    "sys.path.append('{}/simulate'.format(basedir))\n",
    "sys.path.append('{}/ml'.format(basedir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "240d087b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import external packages\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as onp\n",
    "from jax import config, vmap\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "import xarray\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b1499b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import internal packages\n",
    "from flux import Flux\n",
    "from initialconditions import get_a0, get_initial_condition_fn, get_a\n",
    "from simparams import CoreParams, CoreParamsDG, SimulationParams\n",
    "from legendre import generate_legendre\n",
    "from simulations import AdvectionDGSim\n",
    "from trajectory import get_trajectory_fn, get_inner_fn\n",
    "from trainingutils import save_training_data\n",
    "from mlparams import TrainingParams, StencilParams\n",
    "from model import LearnedStencil\n",
    "from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, \n",
    "                           compute_losses_no_model, init_params, save_training_params, load_training_params)\n",
    "from helper import convert_DG_representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "614be22a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions\n",
    "\n",
    "def plot_fv(a, core_params, color=\"blue\"):\n",
    "    plot_dg(a[...,None], core_params, color=color)\n",
    "    \n",
    "def plot_fv_trajectory(trajectory, core_params, t_inner, color='blue'):\n",
    "    plot_dg_trajectory(trajectory[...,None], core_params, t_inner, color=color)\n",
    "    \n",
    "def plot_dg(a, core_params, color='blue'):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    def evalf(x, a, j, dx, leg_poly):\n",
    "        x_j = dx * (0.5 + j)\n",
    "        xi = (x - x_j) / (0.5 * dx)\n",
    "        vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "        return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = a.shape[0]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "\n",
    "    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p))\n",
    "    a_plot = a_plot.T.reshape(-1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {('x'): xs}\n",
    "    data = xarray.DataArray(a_plot, coords=coords)\n",
    "    data.plot(color=color)\n",
    "\n",
    "def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = trajectory.shape[1]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    \n",
    "    def get_plot_repr(a):\n",
    "        def evalf(x, a, j, dx, leg_poly):\n",
    "            x_j = dx * (0.5 + j)\n",
    "            xi = (x - x_j) / (0.5 * dx)\n",
    "            vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "            return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T\n",
    "\n",
    "    get_trajectory_plot_repr = vmap(get_plot_repr)\n",
    "    trajectory_plot = get_trajectory_plot_repr(trajectory)\n",
    "\n",
    "    outer_steps = trajectory.shape[0]\n",
    "    \n",
    "    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {\n",
    "        'x': xs,\n",
    "        'time': t_inner * jnp.arange(outer_steps)\n",
    "    }\n",
    "    xarray.DataArray(trajectory_plot, dims=[\"time\", \"x\"], coords=coords).plot(\n",
    "        col='time', col_wrap=5, color=color)\n",
    "    \n",
    "def plot_multiple_fv_trajectories(trajectories, core_params, t_inner):\n",
    "    plot_multiple_dg_trajectories([trajectory[..., None] for trajectory in trajectories], core_params, t_inner)\n",
    "\n",
    "def plot_multiple_dg_trajectories(trajectories, core_params, t_inner):\n",
    "    outer_steps = trajectories[0].shape[0]\n",
    "    nx = trajectories[0].shape[1]\n",
    "    \n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    \n",
    "    def get_plot_repr(a):\n",
    "        def evalf(x, a, j, dx, leg_poly):\n",
    "            x_j = dx * (0.5 + j)\n",
    "            xi = (x - x_j) / (0.5 * dx)\n",
    "            vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "            return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T\n",
    "\n",
    "    get_trajectory_plot_repr = vmap(get_plot_repr)\n",
    "    trajectory_plots = []\n",
    "    for trajectory in trajectories:  \n",
    "        trajectory_plots.append(get_trajectory_plot_repr(trajectory).reshape(outer_steps, -1))\n",
    "        \n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {\n",
    "        'x': xs,\n",
    "        'time': t_inner * jnp.arange(outer_steps)\n",
    "    }\n",
    "    xarray.DataArray(trajectory_plots, dims=[\"stack\", \"time\", \"x\"], coords=coords).plot.line(\n",
    "        col='time', hue=\"stack\", col_wrap=5)\n",
    "    \n",
    "\n",
    "def get_core_params(order, flux='upwind'):\n",
    "    Lx = 1.0\n",
    "    if order == 0:\n",
    "        return CoreParams(Lx, flux)\n",
    "    else:\n",
    "        return CoreParamsDG(Lx, flux, order)\n",
    "\n",
    "def get_sim_params(name = \"test\", cfl_safety=0.3, rk='ssp_rk3'):\n",
    "    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk)\n",
    "\n",
    "def get_training_params(n_data, train_id=\"test\", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd'):\n",
    "    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)\n",
    "\n",
    "def get_stencil_params(kernel_size = 3, kernel_out = 4, stencil_width=4, depth = 3, width = 16):\n",
    "    return StencilParams(kernel_size, kernel_out, stencil_width, depth, width)\n",
    "\n",
    "\n",
    "def l2_norm_trajectory_fv(trajectory):\n",
    "    return (jnp.mean(trajectory**2, axis=1))\n",
    "\n",
    "def l2_norm_trajectory_dg(trajectory, p):\n",
    "    twokplusone = 2 * jnp.arange(0, p) + 1\n",
    "    return (jnp.mean(jnp.sum(trajectory**2 / twokplusone[None, :], axis=-1), axis=1))\n",
    "    \n",
    "def get_model(core_params, stencil_params):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    features = [stencil_params.width for _ in range(stencil_params.depth - 1)]\n",
    "    return LearnedStencil(features, stencil_params.kernel_size, stencil_params.kernel_out, stencil_params.stencil_width, p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbfb2989",
   "metadata": {},
   "source": [
    "### Discontinuous Galerkin\n",
    "\n",
    "##### Training Loop\n",
    "\n",
    "First, we will generate the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "907c073d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training hyperparameters\n",
    "init_description = 'sum_sin'\n",
    "kwargs_init = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 0, 'max_k': 3, 'amplitude_max': 1.0}\n",
    "kwargs_sim = {'name' : \"dg_paper_data\", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3'}\n",
    "kwargs_train_DG = {'train_id': \"dg_paper_train\", 'batch_size' : 8, 'optimizer': 'adam', 'num_epochs' : 50}\n",
    "kwargs_stencil = {'kernel_size' : 3, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}\n",
    "n_runs = 100\n",
    "t_inner_train = 0.02\n",
    "outer_steps_train = int(1.0/t_inner_train)\n",
    "dg_flux_baseline = 'upwind'\n",
    "nx_exact = 128\n",
    "nxs = [8, 16, 32]\n",
    "learning_rate_list = [1e-3, 1e-3, 1e-3]\n",
    "assert len(nxs) == len(learning_rate_list)\n",
    "key = jax.random.PRNGKey(12)\n",
    "\n",
    "p = 1\n",
    "\n",
    "# setup\n",
    "core_params = get_core_params(p, flux=dg_flux_baseline)\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "n_data = n_runs * outer_steps_train\n",
    "training_params_list = [get_training_params(n_data, **kwargs_train_DG, learning_rate = lr) for lr in learning_rate_list]\n",
    "stencil_params = get_stencil_params(**kwargs_stencil)\n",
    "sim = AdvectionDGSim(core_params, sim_params)\n",
    "init_fn = lambda key: get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "model = get_model(core_params, stencil_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1592f3ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save training data\n",
    "save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d261c8c2",
   "metadata": {},
   "source": [
    "Next, we initialize the model parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e70430c",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(42)\n",
    "i_params = init_params(key, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6244e5b8",
   "metadata": {},
   "source": [
    "Next, we run a training loop for each value of nx. The learning rate undergoes a prespecified decay."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76667430",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, nx in enumerate(nxs):\n",
    "    print(nx)\n",
    "    training_params = training_params_list[i]\n",
    "    idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "    batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)\n",
    "    loss_fn = get_loss_fn(model, core_params)\n",
    "    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)\n",
    "    save_training_params(nx, sim_params, training_params, params, losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "431b6dcd",
   "metadata": {},
   "source": [
    "Next, we load and plot the losses for each nx to check that the simulation trained properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "381a6f73",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, nx in enumerate(nxs):\n",
    "    losses, _ = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    plt.plot(losses, label=nx)\n",
    "    print(losses)\n",
    "plt.ylim([0,5])\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90cba0ca",
   "metadata": {},
   "source": [
    "Next, we plot the accuracy of the trained model on a few simple test cases to qualitatively evaluate the success of the training. We will eventually quantify the accuracy of the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7aa4c5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick a key that gives something nice\n",
    "key = jax.random.PRNGKey(19)\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(\"nx is {}\".format(nx))\n",
    "    \n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    \n",
    "    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "    a0 = get_a0(f_init, core_params, nx)\n",
    "    t_inner = 1.0\n",
    "    outer_steps = 10\n",
    "    # with params\n",
    "    sim_model = AdvectionDGSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    trajectory_model = trajectory_fn_model(a0)\n",
    "    #plot_dg_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])\n",
    "    \n",
    "    \n",
    "    # with global stabilization\n",
    "    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization=True, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "    trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "    \n",
    "\n",
    "    # without params\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory = trajectory_fn(a0)\n",
    "    \n",
    "\n",
    "    plot_multiple_dg_trajectories([trajectory, trajectory_model], core_params, t_inner)\n",
    "    \n",
    "    plt.show()\n",
    "    plt.plot(l2_norm_trajectory_dg(trajectory, p))\n",
    "    plt.plot(l2_norm_trajectory_dg(trajectory_model, p))\n",
    "    plt.plot(l2_norm_trajectory_dg(trajectory_model_gs, p))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee1eb0e8",
   "metadata": {},
   "source": [
    "We see from above that the baseline (red) has a large amount of numerical diffusion for small number of gridpoints, while is more accurate for more gridpoints. We also see that the machine learned model learns to accurately evolve the solution for nx > 8. So far, so good. Let's now look at a different initial condition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed3ae205",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick a key that gives something nice\n",
    "key = jax.random.PRNGKey(10)\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(\"nx is {}\".format(nx))\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    \n",
    "    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "    a0 = get_a0(f_init, core_params, nx)\n",
    "    t_inner = 1.0\n",
    "    outer_steps = 10\n",
    "    \n",
    "    # with params\n",
    "    sim_model = AdvectionDGSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    trajectory_model = trajectory_fn_model(a0)\n",
    "\n",
    "    # without params\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory = trajectory_fn(a0)\n",
    "    \n",
    "    # with gs\n",
    "    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization=True, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "    trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "    \n",
    "    plot_multiple_dg_trajectories([trajectory, trajectory_model, trajectory_model_gs], core_params, t_inner)\n",
    "    plt.show()\n",
    "    \n",
    "    plt.plot(l2_norm_trajectory_dg(trajectory, p))\n",
    "    plt.plot(l2_norm_trajectory_dg(trajectory_model, p))\n",
    "    plt.plot(l2_norm_trajectory_dg(trajectory_model_gs, p))\n",
    "    #plt.plot(l2_norm_trajectory(trajectory_model_gs))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53686756",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10) # new key, same initial key for each nx\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    t_inner = 10.0\n",
    "    outer_steps = 10\n",
    "    sim_model = AdvectionDGSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    \n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    \n",
    "    num_nan = 0\n",
    "    \n",
    "    for n in range(N):\n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        num_nan += jnp.isnan(trajectory_model[-1]).any()\n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    print(\"nx is {}, num_nan is {} out of {}\".format(nx, num_nan, N))\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "268fa07a",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10)\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    t_inner = 10.0\n",
    "    outer_steps = 10\n",
    "    sim_model = AdvectionDGSim(core_params, sim_params, global_stabilization = True, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "    \n",
    "    num_nan = 0\n",
    "    \n",
    "    for n in range(N):\n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        num_nan += jnp.isnan(trajectory_model[-1]).any()\n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    print(\"nx is {}, num_nan is {} out of {}\".format(nx, num_nan, N))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f88e558",
   "metadata": {},
   "source": [
    "Are the NaNs eliminated by global stabilization? (They should be.)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fe1ba4f",
   "metadata": {},
   "source": [
    "### Demonstrate that Global Stabilization Doesn't Degrade Accuracy\n",
    "\n",
    "We want to compare 3 different numerical algorithms for solving the 1D advection equation. We compare: (a) Upwind (b) Machine Learned (ML) (c) Machine learned with global stabilization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dceefd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "\n",
    "mse_upwind = []\n",
    "mse_ml = []\n",
    "mse_mlgs = []\n",
    "\n",
    "def normalized_mse_fv(traj, traj_exact):\n",
    "    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])\n",
    "\n",
    "\n",
    "def normalized_mse_dg(traj, traj_exact, p):\n",
    "    twokplusone = 2 * jnp.arange(0, p+1) + 1\n",
    "    l2_normalization = jnp.mean(jnp.sum(traj_exact**2 / twokplusone[None, None, :], axis=-1), axis=-1)\n",
    "    l2 = jnp.sum((traj - traj_exact)**2 / twokplusone[None, None, :], axis=-1)\n",
    "    return jnp.mean(l2 / l2_normalization[:, None])\n",
    "\n",
    "vmap_convert_DG = vmap(convert_DG_representation, (0, None, None, None), 0)\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    key = jax.random.PRNGKey(10)\n",
    "    \n",
    "    _, params = load_training_params(nx, sim_params, training_params_list[i], model)\n",
    "    t_inner = 0.1\n",
    "    outer_steps = 10\n",
    "    \n",
    "    mse_upwind_nx = 0.0\n",
    "    mse_ml_nx = 0.0\n",
    "    mse_mlgs_nx = 0.0\n",
    "    \n",
    "    # Upwind\n",
    "    inner_fn_upwind = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn_upwind = get_trajectory_fn(inner_fn_upwind, outer_steps)\n",
    "\n",
    "    # Model without GS\n",
    "    sim_model = AdvectionDGSim(core_params, sim_params, global_stabilization = False, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "\n",
    "    # Model with GS\n",
    "    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization = True, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "\n",
    "    \n",
    "    for n in range(N):\n",
    "        \n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        a0_exact = get_a0(f_init, core_params, nx_exact)\n",
    "        \n",
    "        trajectory_upwind = trajectory_fn_upwind(a0)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "        \n",
    "        trajectory_exact = trajectory_fn_upwind(a0_exact)\n",
    "        trajectory_exact_ds = vmap_convert_DG(trajectory_exact, p+1, nx, core_params.Lx)\n",
    "\n",
    "        \n",
    "        mse_upwind_nx += normalized_mse_dg(trajectory_upwind, trajectory_exact_ds, p) / N\n",
    "        mse_ml_nx += normalized_mse_dg(trajectory_model, trajectory_exact_ds, p) / N\n",
    "        gs = normalized_mse_dg(trajectory_model_gs, trajectory_exact_ds, p) / N\n",
    "        mse_mlgs_nx += gs\n",
    "        \n",
    "        key, _ = jax.random.split(key)\n",
    "        \n",
    "    mse_upwind.append(mse_upwind_nx)\n",
    "    mse_ml.append(mse_ml_nx)\n",
    "    mse_mlgs.append(mse_mlgs_nx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d49436",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import matplotlib.lines as mlines\n",
    "print(mse_upwind)\n",
    "print(mse_ml)\n",
    "print(mse_mlgs)\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))\n",
    "axs.spines['top'].set_visible(False)\n",
    "axs.spines['right'].set_visible(False)\n",
    "linewidth = 3\n",
    "\n",
    "mses = [mse_ml, mse_mlgs, mse_upwind]\n",
    "labels = [\"DG ML\", \"DG ML\\n(Stabilized)\", \"DG (Upwind)\"]\n",
    "colors = [\"blue\", \"red\", \"purple\", \"green\"]\n",
    "linestyles = [\"solid\", \"dashed\", \"solid\", \"solid\"]\n",
    "\n",
    "for k, mse in enumerate(mses):\n",
    "    plt.loglog(nxs, mse, label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])\n",
    "\n",
    "axs.set_xticks([32, 16, 8])\n",
    "axs.set_xticklabels([\"N=32\", \"N=16\", \"N=8\"], fontsize=18)\n",
    "axs.set_yticks([1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1])\n",
    "axs.set_yticklabels([\"$10^{-8}$\", \"$10^{-7}$\", \"$10^{-6}$\", \"$10^{-5}$\", \"$10^{-4}$\", \"$10^{-3}$\", \"$10^{-2}$\", \"$10^{-1}$\"], fontsize=18)\n",
    "axs.minorticks_off()\n",
    "axs.set_ylabel(\"Normalized MSE\", fontsize=18)\n",
    "axs.text(0.2, 0.85, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')\n",
    "\n",
    "\n",
    "handles = []\n",
    "for k, mse in enumerate(mses):\n",
    "    handles.append(\n",
    "        mlines.Line2D(\n",
    "            [],\n",
    "            [],\n",
    "            color=colors[k],\n",
    "            linewidth=linewidth,\n",
    "            label=labels[k],\n",
    "            linestyle=linestyles[k]\n",
    "        )\n",
    "    )\n",
    "axs.legend(handles=handles, loc=(0.63,0.21), prop={'size': 15}, frameon=False)\n",
    "plt.ylim([3e-9, 1e-1+6e-2])\n",
    "fig.tight_layout()\n",
    "\n",
    "\n",
    "plt.savefig('mse_vs_nx_dg.png')\n",
    "plt.savefig('mse_vs_nx_dg.eps')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4ff50b5",
   "metadata": {},
   "source": [
    "### Demonstrate that Global Stabilization Improves Accuracy over Time\n",
    "\n",
    "For nx = 16, plot the accuracy of global stabilization vs ML on the y-axis, with time on the x-axis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f105b30",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx = 8\n",
    "N = 100\n",
    "\n",
    "Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]\n",
    "\n",
    "def normalized_mse_fv(traj, traj_exact):\n",
    "    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])\n",
    "\n",
    "\n",
    "def normalized_mse_dg(traj, traj_exact, p):\n",
    "    twokplusone = 2 * jnp.arange(0, p+1) + 1\n",
    "    l2_normalization = jnp.mean(jnp.sum(traj_exact**2 / twokplusone[None, None, :], axis=-1), axis=-1)\n",
    "    l2 = jnp.sum((traj - traj_exact)**2 / twokplusone[None, None, :], axis=-1)\n",
    "    return jnp.mean(l2 / l2_normalization[:, None])\n",
    "\n",
    "mse_upwind_time = []\n",
    "mse_ml_time = []\n",
    "mse_mlgs_time = []\n",
    "\n",
    "_, params = load_training_params(nx, sim_params, training_params_list[0], model)\n",
    "\n",
    "\n",
    "for T in Ts:\n",
    "    \n",
    "    print(T)\n",
    "    \n",
    "    key = jax.random.PRNGKey(10)\n",
    "    \n",
    "    t_inner = 0.1\n",
    "    outer_steps = int(T / t_inner)\n",
    "    \n",
    "    # Upwind\n",
    "    inner_fn_upwind = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn_upwind = get_trajectory_fn(inner_fn_upwind, outer_steps)\n",
    "\n",
    "    # Model without GS\n",
    "    sim_model = AdvectionDGSim(core_params, sim_params, global_stabilization = False, model=model, params=params)\n",
    "    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)\n",
    "    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)\n",
    "\n",
    "    # Model with GS\n",
    "    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization = True, model=model, params=params)\n",
    "    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)\n",
    "    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)\n",
    "\n",
    "    \n",
    "    mse_upwind_nx = 0.0\n",
    "    mse_ml_nx = 0.0\n",
    "    mse_mlgs_nx = 0.0\n",
    "    \n",
    "    for n in range(N):\n",
    "    \n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        a0 = get_a0(f_init, core_params, nx)\n",
    "        a0_exact = get_a0(f_init, core_params, nx_exact)\n",
    "        \n",
    "        trajectory_upwind = trajectory_fn_upwind(a0)\n",
    "        trajectory_model = trajectory_fn_model(a0)\n",
    "        trajectory_model_gs = trajectory_fn_model_gs(a0)\n",
    "        \n",
    "        # Exact trajectory\n",
    "        trajectory_exact = trajectory_fn_upwind(a0_exact)\n",
    "        trajectory_exact_ds = vmap_convert_DG(trajectory_exact, p+1, nx, core_params.Lx)\n",
    "        \n",
    "        mse_upwind_nx += normalized_mse_dg(trajectory_upwind, trajectory_exact_ds, p) / N\n",
    "        mse_ml_nx += normalized_mse_dg(trajectory_model, trajectory_exact_ds, p) / N\n",
    "        mse_mlgs_nx += normalized_mse_dg(trajectory_model_gs, trajectory_exact_ds, p) / N\n",
    "    \n",
    "        key, _ = jax.random.split(key)\n",
    "    \n",
    "    \n",
    "    mse_upwind_time.append(mse_upwind_nx)\n",
    "    mse_ml_time.append(mse_ml_nx)\n",
    "    mse_mlgs_time.append(mse_mlgs_nx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28142d77",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))\n",
    "axs.spines['top'].set_visible(False)\n",
    "axs.spines['right'].set_visible(False)\n",
    "linewidth = 3\n",
    "\n",
    "Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]\n",
    "\n",
    "mses = [mse_ml_time, mse_mlgs_time, mse_upwind_time]\n",
    "labels = [\"DG ML\", \"DG ML (Stabilized)\", \"DG Upwind\"]\n",
    "colors = [\"blue\", \"red\", \"purple\", \"green\"]\n",
    "linestyles = [\"solid\", \"dashed\", \"solid\", \"solid\"]\n",
    "\n",
    "for k, mse in enumerate(mses):\n",
    "    plt.loglog(Ts, [jnp.nan_to_num(error, nan=1e7) for error in mse], label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])\n",
    "\n",
    "Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]\n",
    "    \n",
    "axs.set_xticks(Ts)\n",
    "axs.set_xticklabels([\"t=1\", \"2\", \"5\", \"10\", \"20\", \"50\", \"t=100\"], fontsize=18)\n",
    "axs.set_yticks([1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0])\n",
    "axs.set_yticklabels([\"$10^{-7}$\", \"$10^{-6}$\", \"$10^{-5}$\", \"$10^{-4}$\", \"$10^{-3}$\", \"$10^{-2}$\", \"$10^{-1}$\", \"$10^0$\"], fontsize=18)\n",
    "axs.minorticks_off()\n",
    "axs.set_ylabel(\"Normalized MSE\", fontsize=18)\n",
    "axs.text(0.15, 0.8, '$N=8$', transform=axs.transAxes, fontsize=18, verticalalignment='top')\n",
    "\n",
    "\n",
    "handles = []\n",
    "for k, mse in enumerate(mses):\n",
    "    handles.append(\n",
    "        mlines.Line2D(\n",
    "            [],\n",
    "            [],\n",
    "            color=colors[k],\n",
    "            linewidth=linewidth,\n",
    "            label=labels[k],\n",
    "            linestyle=linestyles[k]\n",
    "        )\n",
    "    )\n",
    "axs.legend(handles=handles,loc=(0.52,0.03) , prop={'size': 15}, frameon=False)\n",
    "plt.ylim([2.0e-6, 1e0-1e-1])\n",
    "fig.tight_layout()\n",
    "\n",
    "\n",
    "plt.savefig('mse_vs_time_dg.png')\n",
    "plt.savefig('mse_vs_time_dg.eps')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "121534e2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b063edf5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
