{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d37e5428",
   "metadata": {},
   "source": [
    "In this Jupyter notebook, we will train a machine learned FV solver to solve the 1D advection equation at reduced resolution. Our objective is to train some ML models and plot their accuracy relative to other solvers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d59ad011",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60d11308",
   "metadata": {},
   "outputs": [],
   "source": [
    "############\n",
    "# TO RUN, EDIT HERE\n",
    "############\n",
    "\n",
    "basedir = ... # should be '/path/to/outer_level_directory/'\n",
    "\n",
    "############\n",
    "# DONE EDITS \n",
    "############\n",
    "\n",
    "readwritedir = basedir # should be same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79852112",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('{}/core'.format(basedir))\n",
    "sys.path.append('{}/simulate'.format(basedir))\n",
    "sys.path.append('{}/ml'.format(basedir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "316a93ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import external packages\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as onp\n",
    "from jax import config, vmap\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "import xarray\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae413a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import internal packages\n",
    "from flux import Flux\n",
    "from initialconditions import get_a0, get_initial_condition_fn, get_a\n",
    "from simparams import CoreParams, SimulationParams\n",
    "from helper import generate_legendre\n",
    "from simulations import AdvectionFVSim\n",
    "from trajectory import get_trajectory_fn, get_inner_fn\n",
    "from trainingutils import save_training_data\n",
    "from mlparams import TrainingParams, ModelParams\n",
    "from model import LearnedFluxOutput\n",
    "from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, \n",
    "                           compute_losses_no_model, init_params, save_training_params, load_training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2932d0ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions\n",
    "\n",
    "def plot_fv(a, core_params, color=\"blue\"):\n",
    "    plot_dg(a[...,None], core_params, color=color)\n",
    "    \n",
    "def plot_fv_trajectory(trajectory, core_params, t_inner, color='blue'):\n",
    "    plot_dg_trajectory(trajectory[...,None], core_params, t_inner, color=color)\n",
    "    \n",
    "def plot_dg(a, core_params, color='blue'):\n",
    "    p = a.shape[-1]\n",
    "    def evalf(x, a, j, dx, leg_poly):\n",
    "        x_j = dx * (0.5 + j)\n",
    "        xi = (x - x_j) / (0.5 * dx)\n",
    "        vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "        return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = a.shape[0]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "\n",
    "    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p))\n",
    "    a_plot = a_plot.T.reshape(-1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {('x'): xs}\n",
    "    data = xarray.DataArray(a_plot, coords=coords)\n",
    "    data.plot(color=color)\n",
    "\n",
    "def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):\n",
    "    p = trajectory.shape[-1]\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = trajectory.shape[1]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    \n",
    "    def get_plot_repr(a):\n",
    "        def evalf(x, a, j, dx, leg_poly):\n",
    "            x_j = dx * (0.5 + j)\n",
    "            xi = (x - x_j) / (0.5 * dx)\n",
    "            vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "            return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T\n",
    "\n",
    "    get_trajectory_plot_repr = vmap(get_plot_repr)\n",
    "    trajectory_plot = get_trajectory_plot_repr(trajectory)\n",
    "\n",
    "    outer_steps = trajectory.shape[0]\n",
    "    \n",
    "    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {\n",
    "        'x': xs,\n",
    "        'time': t_inner * jnp.arange(outer_steps)\n",
    "    }\n",
    "    xarray.DataArray(trajectory_plot, dims=[\"time\", \"x\"], coords=coords).plot(\n",
    "        col='time', col_wrap=5, color=color)\n",
    "    \n",
    "def plot_multiple_fv_trajectories(trajectories, core_params, t_inner, ylim = [-1.5,1.5]):\n",
    "    plot_multiple_dg_trajectories([trajectory[..., None] for trajectory in trajectories], core_params, t_inner, ylim = ylim)\n",
    "\n",
    "def plot_multiple_dg_trajectories(trajectories, core_params, t_inner, ylim = [-1.5,1.5]):\n",
    "    outer_steps = trajectories[0].shape[0]\n",
    "    nx = trajectories[0].shape[1]\n",
    "    p = trajectories[0].shape[-1]\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    \n",
    "    def get_plot_repr(a):\n",
    "        def evalf(x, a, j, dx, leg_poly):\n",
    "            x_j = dx * (0.5 + j)\n",
    "            xi = (x - x_j) / (0.5 * dx)\n",
    "            vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "            return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T\n",
    "\n",
    "    get_trajectory_plot_repr = vmap(get_plot_repr)\n",
    "    trajectory_plots = []\n",
    "    for trajectory in trajectories:  \n",
    "        trajectory_plots.append(get_trajectory_plot_repr(trajectory).reshape(outer_steps, -1))\n",
    "        \n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {\n",
    "        'x': xs,\n",
    "        'time': t_inner * jnp.arange(outer_steps)\n",
    "    }\n",
    "    xarray.DataArray(trajectory_plots, dims=[\"stack\", \"time\", \"x\"], coords=coords).plot.line(\n",
    "        col='time', hue=\"stack\", col_wrap=5, ylim=ylim)\n",
    "    \n",
    "\n",
    "def get_core_params(flux='upwind'):\n",
    "    Lx = 1.0\n",
    "    return CoreParams(Lx, flux)\n",
    "\n",
    "def get_sim_params(name = \"test\", cfl_safety=0.3, rk='ssp_rk3'):\n",
    "    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk)\n",
    "\n",
    "def get_training_params(n_data, train_id=\"test\", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd'):\n",
    "    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)\n",
    "\n",
    "def get_model_params(kernel_size = 3, kernel_out = 4,  depth = 3, width = 16):\n",
    "    return ModelParams(kernel_size, kernel_out, depth, width)\n",
    "\n",
    "\n",
    "def l2_norm_trajectory(trajectory):\n",
    "    return (jnp.mean(trajectory**2, axis=1))\n",
    "    \n",
    "def get_model(core_params, model_params):\n",
    "    features = [model_params.width for _ in range(model_params.depth - 1)]\n",
    "    return LearnedFluxOutput(features, model_params.kernel_size, model_params.kernel_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a764bc88",
   "metadata": {},
   "source": [
    "### Finite Volume\n",
    "\n",
    "##### Training Loop\n",
    "\n",
    "First, we will generate the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dd99d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training hyperparameters\n",
    "init_description = 'sum_sin'\n",
    "kwargs_init = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 1, 'max_k': 4, 'amplitude_max': 1.0}\n",
    "kwargs_sim = {'name' : \"paper_test\", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3'}\n",
    "kwargs_model = {'kernel_size' : 3, 'kernel_out' : 4, 'depth' : 3, 'width' : 24}\n",
    "n_runs = 100\n",
    "t_inner_train = 0.02\n",
    "BS = 32\n",
    "NE = 200 # num epochs\n",
    "outer_steps_train = int(1.0/t_inner_train)\n",
    "nx_exact = 512\n",
    "nxs = [8, 16, 32, 64]\n",
    "learning_rate_list = [1e-3, 1e-3, 1e-3, 1e-3]\n",
    "assert len(nxs) == len(learning_rate_list)\n",
    "key = jax.random.PRNGKey(12)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a705b68d",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Setup for Generating Training Data\n",
    "core_params_muscl = get_core_params(flux='muscl')\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "n_data = n_runs * outer_steps_train\n",
    "sim = AdvectionFVSim(core_params_muscl, sim_params)\n",
    "init_fn = lambda key: get_initial_condition_fn(core_params_muscl, init_description, key=key, **kwargs_init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89a824db",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_training_data(key, init_fn, core_params_muscl, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a4b4d6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Setup for training models\n",
    "\n",
    "model_params = get_model_params(**kwargs_model)\n",
    "model = get_model(core_params_muscl, model_params)\n",
    "key = jax.random.PRNGKey(42)\n",
    "i_params = init_params(key, model)\n",
    "\n",
    "core_params_learned = get_core_params(flux='learned')\n",
    "kwargs_train_FV = {'train_id': \"flux_predicting\", 'batch_size' : BS, 'optimizer': 'adam', 'num_epochs' : NE}\n",
    "training_params_list_learned = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]\n",
    "\n",
    "core_params_limiter = get_core_params(flux='learnedlimiter')\n",
    "kwargs_train_FV = {'train_id': \"flux_limited\", 'batch_size' : BS, 'optimizer': 'adam', 'num_epochs' : NE}\n",
    "training_params_list_limited = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]\n",
    "\n",
    "core_params_combo = get_core_params(flux='combination_learned')\n",
    "kwargs_train_FV = {'train_id': \"combo_learned\", 'batch_size' : BS, 'optimizer': 'adam', 'num_epochs' : NE}\n",
    "training_params_list_combo = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbdd25e8",
   "metadata": {},
   "source": [
    "Next, we run a training loop for each value of nx. The learning rate undergoes a prespecified decay."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc39c297",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### First, train original ML Model\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(nx)\n",
    "    training_params = training_params_list_learned[i]\n",
    "    idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "    batch_fn = get_batch_fn(core_params_learned, sim_params, training_params, nx)\n",
    "    loss_fn = get_loss_fn(model, core_params_learned)\n",
    "    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)\n",
    "    save_training_params(nx, sim_params, training_params, params, losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2645aee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Second, train flux-limited model\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(nx)\n",
    "    training_params = training_params_list_limited[i]\n",
    "    idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "    batch_fn = get_batch_fn(core_params_limiter, sim_params, training_params, nx)\n",
    "    loss_fn = get_loss_fn(model, core_params_limiter)\n",
    "    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)\n",
    "    save_training_params(nx, sim_params, training_params, params, losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ed966f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Third, train combination model\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(nx)\n",
    "    training_params = training_params_list_combo[i]\n",
    "    idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "    batch_fn = get_batch_fn(core_params_combo, sim_params, training_params, nx)\n",
    "    loss_fn = get_loss_fn(model, core_params_combo)\n",
    "    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)\n",
    "    save_training_params(nx, sim_params, training_params, params, losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65b7f887",
   "metadata": {},
   "source": [
    "Next, we load and plot the losses for each nx to check that the simulation trained properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff423c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, nx in enumerate(nxs):\n",
    "    losses, _ = load_training_params(nx, sim_params, training_params_list_learned[i], model)\n",
    "    plt.plot(losses, label=\"learned {}\".format(nx))\n",
    "    print(losses)\n",
    "    \n",
    "    losses, _ = load_training_params(nx, sim_params, training_params_list_limited[i], model)\n",
    "    plt.plot(losses, label=\"limited {}\".format(nx))\n",
    "    \n",
    "    losses, _ = load_training_params(nx, sim_params, training_params_list_combo[i], model)\n",
    "    plt.plot(losses, label=\"combination {}\".format(nx))\n",
    "    \n",
    "plt.ylim([0,100])\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db884533",
   "metadata": {},
   "source": [
    "Next, we plot the accuracy of the trained model on a few simple test cases to qualitatively evaluate the success of the training. We will eventually quantify the accuracy of the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a76a307e",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(19)\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    print(\"nx is {}\".format(nx))\n",
    "    \n",
    "    f_init = get_initial_condition_fn(core_params_muscl, init_description, key=key, **kwargs_init)\n",
    "    a0 = get_a0(f_init, core_params_muscl, nx)\n",
    "    t_inner = 0.25\n",
    "    outer_steps = 5\n",
    "    \n",
    "    \n",
    "    ########\n",
    "    # Exact trajectory\n",
    "    ########\n",
    "    \n",
    "    trajectory_exact = onp.zeros((outer_steps, nx))\n",
    "    for k in range(outer_steps):\n",
    "        t = k * t_inner\n",
    "        trajectory_exact[k] = get_a(f_init, t, core_params_muscl, nx)\n",
    "    trajectory_exact = jnp.asarray(trajectory_exact)\n",
    "    \n",
    "    ########\n",
    "    # Flux 1: Centered\n",
    "    ########\n",
    "    core_params = get_core_params(flux='centered')\n",
    "    sim = AdvectionFVSim(core_params, sim_params)\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory_centered = trajectory_fn(a0)\n",
    "    \n",
    "    ########\n",
    "    # Flux 2: Upwind\n",
    "    ########\n",
    "    core_params = get_core_params(flux='upwind')\n",
    "    sim = AdvectionFVSim(core_params, sim_params)\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory_upwind = trajectory_fn(a0)\n",
    "    \n",
    "    ########\n",
    "    # Flux 3: MUSCL\n",
    "    ########\n",
    "    core_params = get_core_params(flux='muscl')\n",
    "    sim = AdvectionFVSim(core_params, sim_params)\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory_muscl = trajectory_fn(a0)\n",
    "    \n",
    "    ########\n",
    "    # Flux 4: Learned\n",
    "    ########\n",
    "    core_params = get_core_params(flux='learned')\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)\n",
    "    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory_learned = trajectory_fn(a0)\n",
    "    \n",
    "    ########\n",
    "    # Flux 5: Upwind + Centered\n",
    "    ########\n",
    "    core_params = get_core_params(flux='combination_learned')\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)\n",
    "    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory_combination = trajectory_fn(a0)\n",
    "    \n",
    "    ########\n",
    "    # Flux 6: Learned Limiter\n",
    "    ########\n",
    "    core_params = get_core_params(flux='learnedlimiter')\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list_limited[i], model)\n",
    "    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory_limiter = trajectory_fn(a0)\n",
    "    \n",
    "    ########\n",
    "    # Flux 7: Invariant-Preserving Learned\n",
    "    ########\n",
    "    core_params = get_core_params(flux='learned')\n",
    "    _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)\n",
    "    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params, global_stabilization=True)\n",
    "    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "    trajectory_invariant_learned = trajectory_fn(a0)\n",
    "    \n",
    "    \n",
    "    \n",
    "    trajectories = [trajectory_exact, trajectory_learned, trajectory_invariant_learned, trajectory_upwind, trajectory_centered, trajectory_muscl, trajectory_limiter, trajectory_combination]#, trajectory_invariant_learned]\n",
    "    plot_multiple_fv_trajectories(trajectories, core_params, t_inner)\n",
    "    \n",
    "\n",
    "    \n",
    "    plt.show()\n",
    "    \n",
    "    for trajectory in trajectories:\n",
    "        plt.plot(l2_norm_trajectory(trajectory))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd0d61e3",
   "metadata": {},
   "source": [
    "### Demonstrate that Global Stabilization Doesn't Degrade Accuracy\n",
    "\n",
    "We want to compare four different numerical algorithms for solving the 1D advection equation. We compare: (a) MUSCL (b) Machine Learned (ML) (c) Machine learned with global stabilization and (d) Machine learned with MC limiter. \n",
    "\n",
    "Make sure to use \"diff_lrs\" for the params."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ddc64",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 25\n",
    "\n",
    "mses = onp.zeros((len(nxs), 7))\n",
    "\n",
    "def normalized_mse(traj, traj_exact):\n",
    "    assert len(traj_exact.shape) == 2\n",
    "    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])\n",
    "\n",
    "\n",
    "for i, nx in enumerate(nxs):\n",
    "    \n",
    "    print(nx)\n",
    "    \n",
    "    key = jax.random.PRNGKey(10)\n",
    "    \n",
    "    t_inner = 0.1\n",
    "    outer_steps = 11\n",
    "    \n",
    "    for _ in range(N):\n",
    "        \n",
    "        ########\n",
    "        # Generate Exact Data\n",
    "        ########\n",
    "\n",
    "        core_params = get_core_params(flux='muscl')\n",
    "        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)\n",
    "        trajectory_exact = onp.zeros((outer_steps, nx))\n",
    "        for k in range(outer_steps):\n",
    "            t = k * t_inner\n",
    "            trajectory_exact[k] = get_a(f_init, t, core_params, nx)\n",
    "        trajectory_exact = jnp.asarray(trajectory_exact)\n",
    "\n",
    "        # Initial conditions\n",
    "        a0 = get_a(f_init, 0, core_params, nx)\n",
    "    \n",
    "        ########\n",
    "        # Flux 1: Centered\n",
    "        ########\n",
    "        core_params = get_core_params(flux='centered')\n",
    "        sim = AdvectionFVSim(core_params, sim_params)\n",
    "        inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "        trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "        trajectory_centered = trajectory_fn(a0)\n",
    "\n",
    "        ########\n",
    "        # Flux 2: Upwind\n",
    "        ########\n",
    "        core_params = get_core_params(flux='upwind')\n",
    "        sim = AdvectionFVSim(core_params, sim_params)\n",
    "        inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "        trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "        trajectory_upwind = trajectory_fn(a0)\n",
    "\n",
    "        ########\n",
    "        # Flux 3: MUSCL\n",
    "        ########\n",
    "        core_params = get_core_params(flux='muscl')\n",
    "        sim = AdvectionFVSim(core_params, sim_params)\n",
    "        inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "        trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "        trajectory_muscl = trajectory_fn(a0)\n",
    "\n",
    "        ########\n",
    "        # Flux 4: Learned\n",
    "        ########\n",
    "        core_params = get_core_params(flux='learned')\n",
    "        _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)\n",
    "        sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "        inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "        trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "        trajectory_learned = trajectory_fn(a0)\n",
    "\n",
    "        \n",
    "        ########\n",
    "        # Flux 5: Upwind + Centered\n",
    "        ########\n",
    "        core_params = get_core_params(flux='combination_learned')\n",
    "        _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)\n",
    "        sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "        inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "        trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "        trajectory_combination = trajectory_fn(a0)\n",
    "        \n",
    "\n",
    "        ########\n",
    "        # Flux 6: Learned Limiter\n",
    "        ########\n",
    "        core_params = get_core_params(flux='learnedlimiter')\n",
    "        _, params = load_training_params(nx, sim_params, training_params_list_limited[i], model)\n",
    "        sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "        inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "        trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "        trajectory_limiter = trajectory_fn(a0)\n",
    "        \n",
    "\n",
    "        ########\n",
    "        # Flux 7: Invariant-Preserving Learned\n",
    "        ########\n",
    "        core_params = get_core_params(flux='learned')\n",
    "        _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)\n",
    "        sim = AdvectionFVSim(core_params, sim_params, model=model, params=params, global_stabilization=True)\n",
    "        inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "        trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "        trajectory_invariant_learned = trajectory_fn(a0)\n",
    "\n",
    "        \n",
    "        mses[i, 0] += normalized_mse(trajectory_centered, trajectory_exact) / N\n",
    "        mses[i, 1] += normalized_mse(trajectory_upwind, trajectory_exact) / N\n",
    "        mses[i, 2] += normalized_mse(trajectory_muscl, trajectory_exact) / N\n",
    "        mses[i, 3] += normalized_mse(trajectory_learned, trajectory_exact) / N\n",
    "        mses[i, 4] += normalized_mse(trajectory_combination, trajectory_exact) / N\n",
    "        mses[i, 5] += normalized_mse(trajectory_limiter, trajectory_exact) / N\n",
    "        mses[i, 6] += normalized_mse(trajectory_invariant_learned, trajectory_exact) / N\n",
    "        \n",
    "        key, _ = jax.random.split(key)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f9b418",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import matplotlib.lines as mlines\n",
    "print(mses.shape)\n",
    "fig, axs = plt.subplots(1, 1, figsize=(9, 3))\n",
    "axs.spines['top'].set_visible(False)\n",
    "axs.spines['right'].set_visible(False)\n",
    "linewidth = 2\n",
    "\n",
    "labels = [\"Centered\", \"Upwind\", \"MUSCL MC Limiter\", \"ML\", \"ML Upwind-Biased\", \"ML MC Limiter\", \"ML Invariant-Preserving\"]\n",
    "colors = [\"blue\", \"red\", \"green\", \"black\", \"red\", \"green\", \"black\"]#[\"black\", \"#1232ED\", \"#E619D6\", \"red\", \"#EDCD12\", \"#19E629\", \"black\"]\n",
    "markers = ['P', 'o', '^', 's', 'o', '^', 's']\n",
    "linestyles = [\"solid\", \"solid\", \"solid\", \"solid\", \"--\", \"--\", \"--\"]\n",
    "\n",
    "\n",
    "for k in range(7):\n",
    "    plt.loglog(nxs, mses[:,k], label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k], marker=markers[k], markersize=8)\n",
    "\n",
    "    \n",
    "axs.set_xticks([64, 32, 16, 8])\n",
    "axs.set_xticklabels([\"N=64\", \"N=32\", \"N=16\", \"N=8\"], fontsize=16)\n",
    "axs.set_yticks([1e-4, 1e-3, 1e-2, 1e-1, 1e0])\n",
    "axs.set_yticklabels([\"$10^{-4}$\", \"$10^{-3}$\", \"$10^{-2}$\", \"$10^{-1}$\", \"$10^0$\"], fontsize=18)\n",
    "axs.minorticks_off()\n",
    "axs.set_ylabel(\"Normalized MSE\", fontsize=18)\n",
    "#axs.text(0.1, 0.25, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')\n",
    "\n",
    "\n",
    "handles = []\n",
    "for k in range(7):\n",
    "    handles.append(\n",
    "        mlines.Line2D(\n",
    "            [],\n",
    "            [],\n",
    "            color=colors[k],\n",
    "            linewidth=linewidth,\n",
    "            label=labels[k],\n",
    "            linestyle=linestyles[k],\n",
    "            marker=markers[k],\n",
    "            markersize=8\n",
    "        )\n",
    "    )\n",
    "axs.legend(handles=handles,loc=(0.97,0.05) , prop={'size': 16}, frameon=False)\n",
    "plt.ylim([2.5e-5, 2e0+3e-1])\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig('ICLR_invariant_preserving_mse_vs_nx.png')\n",
    "plt.savefig('ICLR_invariant_preserving_mse_vs_nx.eps')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69aeebc0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
