{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d480f15e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/Users/nickm/thesis/icml2023paper/1d_advection/core')\n",
    "sys.path.append('/Users/nickm/thesis/icml2023paper/1d_advection/simulate')\n",
    "sys.path.append('/Users/nickm/thesis/icml2023paper/1d_advection/ml')\n",
    "\n",
    "basedir = '/Users/nickm/thesis/icml2023paper/1d_advection'\n",
    "readwritedir = '/Users/nickm/thesis/icml2023paper/1d_advection'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5086ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import config, vmap\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "import xarray\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d9d17b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from initialconditions import get_a0, get_initial_condition_fn\n",
    "from simparams import CoreParams, CoreParamsDG, SimulationParams\n",
    "from legendre import generate_legendre\n",
    "from simulations import AdvectionFVSim, AdvectionDGSim\n",
    "from trajectory import get_trajectory_fn, get_inner_fn\n",
    "from trainingutils import save_training_data\n",
    "from mlparams import TrainingParams, StencilParams\n",
    "from model import LearnedStencil\n",
    "from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, \n",
    "                           compute_losses_no_model, init_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd2e2e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs_init = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 1, 'max_k': 4, 'amplitude_max': 1.0}\n",
    "kwargs_sim = {'name' : \"test\", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3', 'gs' : False}\n",
    "kwargs_train_FV = {'train_id': \"test\", 'batch_size' : 8, 'optimizer': 'adam', 'learning_rate' : 1e-3,  'num_epochs' : 10}\n",
    "kwargs_train_DG = {'train_id': \"test\", 'batch_size' : 8, 'optimizer': 'adam', 'learning_rate' : 1e-5, 'num_epochs' : 10}\n",
    "kwargs_stencil = {'kernel_size' : 3, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}\n",
    "n_runs = 50\n",
    "t_inner_train = 0.02\n",
    "outer_steps_train = int(1.0/t_inner_train)\n",
    "fv_flux_baseline = 'muscl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1237af61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_fv(a, core_params):\n",
    "    plot_dg(a[...,None], core_params)\n",
    "    \n",
    "def plot_fv_trajectory(trajectory, core_params, t_inner):\n",
    "    plot_dg_trajectory(trajectory[...,None], core_params, t_inner)\n",
    "    \n",
    "def plot_dg(a, core_params, color='blue'):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    def evalf(x, a, j, dx, leg_poly):\n",
    "        x_j = dx * (0.5 + j)\n",
    "        xi = (x - x_j) / (0.5 * dx)\n",
    "        vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "        return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = a.shape[0]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "\n",
    "    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p))\n",
    "    a_plot = a_plot.T.reshape(-1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {('x'): xs}\n",
    "    data = xarray.DataArray(a_plot, coords=coords)\n",
    "    data.plot(color=color)\n",
    "\n",
    "def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    NPLOT = [2,2,5,7][p-1]\n",
    "    nx = trajectory.shape[1]\n",
    "    dx = core_params.Lx / nx\n",
    "    xjs = jnp.arange(nx) * core_params.Lx / nx\n",
    "    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]\n",
    "    \n",
    "    def get_plot_repr(a):\n",
    "        def evalf(x, a, j, dx, leg_poly):\n",
    "            x_j = dx * (0.5 + j)\n",
    "            xi = (x - x_j) / (0.5 * dx)\n",
    "            vmap_polyval = vmap(jnp.polyval, (0, None), -1)\n",
    "            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array\n",
    "            return jnp.sum(poly_eval * a, axis=-1)\n",
    "\n",
    "        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)\n",
    "        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T\n",
    "\n",
    "    get_trajectory_plot_repr = vmap(get_plot_repr)\n",
    "    trajectory_plot = get_trajectory_plot_repr(trajectory)\n",
    "\n",
    "    outer_steps = trajectory.shape[0]\n",
    "    \n",
    "    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)\n",
    "    xs = xs.T.reshape(-1)\n",
    "    coords = {\n",
    "        'x': xs,\n",
    "        'time': t_inner * jnp.arange(outer_steps)\n",
    "    }\n",
    "    xarray.DataArray(trajectory_plot, dims=[\"time\", \"x\"], coords=coords).plot(\n",
    "        col='time', col_wrap=5)\n",
    "    \n",
    "\n",
    "def get_core_params(order, flux='upwind'):\n",
    "    Lx = 1.0\n",
    "    if order == 0:\n",
    "        return CoreParams(Lx, flux)\n",
    "    else:\n",
    "        return CoreParamsDG(Lx, flux, order)\n",
    "\n",
    "def get_sim_params(name = \"test\", cfl_safety=0.3, rk='ssp_rk3', gs=False):\n",
    "    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk, gs)\n",
    "\n",
    "def get_training_params(n_data, train_id=\"test\", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd'):\n",
    "    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)\n",
    "\n",
    "def get_stencil_params(kernel_size = 3, kernel_out = 4, stencil_width=4, depth = 3, width = 16):\n",
    "    return StencilParams(kernel_size, kernel_out, stencil_width, depth, width)\n",
    "\n",
    "def get_model(core_params, stencil_params):\n",
    "    if core_params.order is None:\n",
    "        p = 1\n",
    "    else:\n",
    "        p = core_params.order + 1\n",
    "    features = [stencil_params.width for _ in range(stencil_params.depth - 1)]\n",
    "    return LearnedStencil(features, stencil_params.kernel_size, stencil_params.kernel_out, stencil_params.stencil_width, p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd4f6f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### test IC for FV\n",
    "nx = 6\n",
    "key = jax.random.PRNGKey(15)\n",
    "core_params = get_core_params(0)\n",
    "f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)\n",
    "a0 = get_a0(f_init, core_params, nx)\n",
    "\n",
    "plot_fv(a0, core_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2da1a4d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### test simulate FV\n",
    "\n",
    "nx = 16\n",
    "key = jax.random.PRNGKey(18)\n",
    "f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)\n",
    "a0 = get_a0(f_init, core_params, nx)\n",
    "\n",
    "t_inner = 1.0\n",
    "outer_steps = 10\n",
    "core_params = get_core_params(0, flux=fv_flux_baseline)\n",
    "\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionFVSim(core_params, sim_params)\n",
    "inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "\n",
    "trajectory = trajectory_fn(a0)\n",
    "plot_fv_trajectory(trajectory, core_params, t_inner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53ee70d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "### generate data FV (1 timestep at a time)\n",
    "\n",
    "nx_exact = 128\n",
    "nxs = 8, 16, 32, 64\n",
    "\n",
    "core_params = get_core_params(0, flux=fv_flux_baseline)\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionFVSim(core_params, sim_params)\n",
    "init_fn = lambda key: get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14450246",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(12)\n",
    "save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f71841ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "### train FV model (1 timestep at a time)\n",
    "\n",
    "key = jax.random.PRNGKey(42)\n",
    "n_data = n_runs * outer_steps_train\n",
    "core_params = get_core_params(0, flux=fv_flux_baseline)\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionFVSim(core_params, sim_params)\n",
    "\n",
    "training_params = get_training_params(n_data, **kwargs_train_FV)\n",
    "stencil_params = get_stencil_params(**kwargs_stencil)\n",
    "model = get_model(core_params, stencil_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f68b22a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx = 16\n",
    "idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)\n",
    "loss_fn = get_loss_fn(model, core_params)\n",
    "params = init_params(key, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0cf42a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "losses, params = train_model(model, params, training_params, key, idx_fn, batch_fn, loss_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e88b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "#plt.plot(jnp.mean(losses.reshape(10,-1),axis=0))\n",
    "plt.plot(losses)\n",
    "plt.ylim([0.0,1.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a810eb23",
   "metadata": {},
   "outputs": [],
   "source": [
    "### test success of FV model\n",
    "\n",
    "# First, simulate sin wave to get a rough sense\n",
    "\n",
    "nx = 16\n",
    "key = jax.random.PRNGKey(18)\n",
    "f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)\n",
    "a0 = get_a0(f_init, core_params, nx)\n",
    "\n",
    "t_inner = 10.0\n",
    "outer_steps = 50\n",
    "core_params = get_core_params(0, flux=fv_flux_baseline)\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)\n",
    "inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "\n",
    "trajectory = trajectory_fn(a0)\n",
    "plot_fv_trajectory(trajectory, core_params, t_inner)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3ecc246",
   "metadata": {},
   "outputs": [],
   "source": [
    "############### NOW DISCONTINUOUS GALERKIN SIMS ################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d940350",
   "metadata": {},
   "outputs": [],
   "source": [
    "### test IC for DG\n",
    "nx = 8\n",
    "order = 2\n",
    "key = jax.random.PRNGKey(15)\n",
    "core_params = get_core_params(order)\n",
    "f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)\n",
    "a0 = get_a0(f_init, core_params, nx)\n",
    "\n",
    "plot_dg(a0, core_params, color='blue')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "543547a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### test simulate DG\n",
    "\n",
    "nx = 16\n",
    "order = 2\n",
    "key = jax.random.PRNGKey(15)\n",
    "core_params = get_core_params(order)\n",
    "\n",
    "f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)\n",
    "a0 = get_a0(f_init, core_params, nx)\n",
    "\n",
    "\n",
    "t_inner = 10.0\n",
    "outer_steps = 10\n",
    "core_params = get_core_params(order, flux='upwind')\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionDGSim(core_params, sim_params)\n",
    "inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "\n",
    "trajectory = trajectory_fn(a0)\n",
    "plot_dg_trajectory(trajectory, core_params, t_inner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77aa6fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "### generate data DG (1 timestep at a time)\n",
    "nx_exact = 32\n",
    "nxs = 4, 6, 8, 16\n",
    "order_exact = 2\n",
    "\n",
    "core_params = get_core_params(order_exact, flux='upwind')\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionDGSim(core_params, sim_params)\n",
    "init_fn = lambda key: get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "260b1a4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(12)\n",
    "save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63abf7f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "### train DG model (1 timestep at a time)\n",
    "\n",
    "t_inner_train = 0.02\n",
    "outer_steps_train = int(1.0/t_inner_train)\n",
    "key = jax.random.PRNGKey(42)\n",
    "order_exact = 2\n",
    "n_data = n_runs * outer_steps_train\n",
    "core_params = get_core_params(order_exact, flux='upwind')\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionFVSim(core_params, sim_params)\n",
    "\n",
    "training_params = get_training_params(n_data, **kwargs_train_DG)\n",
    "stencil_params = get_stencil_params(**kwargs_stencil)\n",
    "model = get_model(core_params, stencil_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f5a9f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx = 8\n",
    "idx_fn = lambda key: get_idx_gen(key, training_params)\n",
    "batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)\n",
    "loss_fn = get_loss_fn(model, core_params)\n",
    "params = init_params(key, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "520141c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "losses, params = train_model(model, params, training_params, key, idx_fn, batch_fn, loss_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c799859e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(losses)\n",
    "plt.ylim([0,1.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a1d3e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "### test success of DG model\n",
    "\n",
    "# First, simulate sin wave to get a rough sense\n",
    "\n",
    "nx = 8\n",
    "key = jax.random.PRNGKey(10)\n",
    "f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)\n",
    "a0 = get_a0(f_init, core_params, nx)\n",
    "\n",
    "t_inner = 1.0\n",
    "outer_steps = 30\n",
    "core_params = get_core_params(order_exact, flux='upwind')\n",
    "sim_params = get_sim_params(**kwargs_sim)\n",
    "sim = AdvectionDGSim(core_params, sim_params, model=model, params=params)\n",
    "inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)\n",
    "trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)\n",
    "\n",
    "trajectory = trajectory_fn(a0)\n",
    "plot_dg_trajectory(trajectory, core_params, t_inner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dc95044",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
