{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import optax\n",
    "import equinox as eqx\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "from matplotlib.colors import Normalize\n",
    "from scipy.stats import t\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "import re\n",
    "import numpy as np\n",
    "# import pandas as pd\n",
    "# import seaborn as sns\n",
    "import glob\n",
    "from time import process_time\n",
    "\n",
    "from src.BTCS_Stepper import BTCS_Stepper, RandomTruncatedFourierSeries, rollout, dataloader\n",
    "from src.prdp import should_refine\n",
    "\n",
    "# plt.rcParams['figure.dpi'] = 200\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "\n",
    "# add magic comments for autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_INITIAL_CONDITIONS = 200    # training set size\n",
    "N_TEST_SAMPLES = 5              # validation set size\n",
    "\n",
    "BATCH_SIZE = 25\n",
    "ic_generator = RandomTruncatedFourierSeries(domain_extent=1.0, num_modes=5)\n",
    "\n",
    "SOLVER_NAME = \"jacobi\"\n",
    "U_INIT_STRING = \"zeroinit\" # \"randinit\" or \"zeroinit\" or \"onesinit\"\n",
    "\n",
    "if SOLVER_NAME == \"jacobi\":\n",
    "    N_DOF = 30\n",
    "    btcs_stepper = BTCS_Stepper(num_points=N_DOF)\n",
    "    LINSOLVER = btcs_stepper.jacobi_dynamic\n",
    "    N_EPOCHS = 100\n",
    "elif SOLVER_NAME == \"SD\":\n",
    "    N_DOF = 50\n",
    "    btcs_stepper = BTCS_Stepper(num_points=N_DOF)\n",
    "    LINSOLVER = btcs_stepper.sd_dynamic\n",
    "    N_EPOCHS = 200\n",
    "\n",
    "if U_INIT_STRING == \"zeroinit\":\n",
    "    U_INIT = jnp.zeros(N_DOF)\n",
    "elif U_INIT_STRING == \"randinit\":\n",
    "    U_INIT = jax.random.normal(jax.random.PRNGKey(0), shape=(N_DOF,))\n",
    "elif U_INIT_STRING == \"onesinit\":\n",
    "    U_INIT = jnp.ones(N_DOF)\n",
    "\n",
    "grid = jnp.linspace(0, 1, N_DOF+2)[1:-1]\n",
    "\n",
    "@eqx.filter_jit\n",
    "def val_loss(m, val_data):\n",
    "    \"\"\"Compute the loss on the test set.\n",
    "\n",
    "    Args:\n",
    "        m: the model to evaluate\n",
    "        val_data: the test data, with shape (n_samples, n_steps, n_dof)\n",
    "    \"\"\"\n",
    "    print(\"compiling val_loss()\")\n",
    "    val_ic_set = val_data[:,0]\n",
    "    pred_trajectories = jax.vmap(rollout(m, 2, include_init=True))(val_ic_set)\n",
    "    pred_1_errors = jnp.linalg.norm(pred_trajectories[:, 1] - val_data[:, 1], axis=1) # normed over n_dof\n",
    "    pred_2_errors = jnp.linalg.norm(pred_trajectories[:, 2] - val_data[:, 2], axis=1) # normed over n_dof\n",
    "    \n",
    "    data_1_norms  = jnp.linalg.norm(val_data[:, 0], axis=1) # norm over n_dof\n",
    "    data_2_norms = jnp.linalg.norm(val_data[:, 1], axis=1) # norm over n_dof\n",
    "    \n",
    "    pred_1_mse_normalized = jnp.mean((pred_1_errors**2 / data_1_norms**2), axis=0) # mean squared for over all samples\n",
    "    pred_2_mse_normalized = jnp.mean((pred_2_errors**2 / data_2_norms**2), axis=0) # mean squared for over all samples\n",
    "    \n",
    "    return jnp.hstack((pred_1_mse_normalized, pred_2_mse_normalized))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Train and Validation data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training Data: X's\n",
    "key = jax.random.PRNGKey(1337)\n",
    "ic_keys = jax.random.split(key, NUM_INITIAL_CONDITIONS)\n",
    "ic_funs = jax.vmap(ic_generator)(ic_keys) # list of functions that generate initial conditions on given grid\n",
    "ic_set = jax.vmap(lambda f: f(grid))(ic_funs) # vmap the list of functions to generate many initial conditions\n",
    "\n",
    "# Training Data: Y's\n",
    "train_set = jax.vmap(rollout(btcs_stepper, 2, include_init=True))(ic_set)\n",
    "\n",
    "# Validation Data: X's\n",
    "key = jax.random.PRNGKey(1338)\n",
    "test_ic_keys = jax.random.split(key, N_TEST_SAMPLES)\n",
    "test_ic_funs = jax.vmap(ic_generator)(test_ic_keys)\n",
    "test_ic_set = jax.vmap(lambda f: f(grid))(test_ic_funs)\n",
    "\n",
    "# Validation Data: Y's\n",
    "val_data_trjs = jax.vmap(rollout(btcs_stepper, 100, include_init=True))(test_ic_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Linsolve residuum "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def relative_residuum_hist(state, n_inner):\n",
    "    res_2 = btcs_stepper.residuum_history(state, SOLVER_NAME, n_inner) # (n_iter+1, n_dof)\n",
    "    rel_residuum_hist = jnp.linalg.norm(res_2, axis=1) / jnp.linalg.norm(state)\n",
    "    return rel_residuum_hist # (n_iter+1,)\n",
    "\n",
    "# do for all initial conditions in the training set and plot the residual history mean and std\n",
    "# res_hist_all = jax.vmap(relative_residuum_hist, in_axes=(0, None))(ic_set, 50) # (n_samples, n_iter+1)\n",
    "pred_1_set = train_set[:, 1]\n",
    "res_hist_all = jax.vmap(relative_residuum_hist, in_axes=(0, None))(pred_1_set, 50) # (n_samples, n_iter+1)\n",
    "res_hist_mean = jnp.mean(res_hist_all, axis=0)\n",
    "res_hist_std = jnp.std(res_hist_all, axis=0)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(3,3))\n",
    "ax.plot(res_hist_mean, label=\"mean\")\n",
    "ax.fill_between(range(res_hist_mean.shape[0]), res_hist_mean - res_hist_std, res_hist_mean + res_hist_std, alpha=0.2, label=\"std\")\n",
    "ax.set_yscale(\"log\")\n",
    "ax.set_xlabel(\"# iterations\")\n",
    "ax.set_ylabel(\"Primal relative residual\")\n",
    "ax.set_title(f\"Heat 1D, Residuums, solver={SOLVER_NAME}\\n Average of {NUM_INITIAL_CONDITIONS} initial conditions\")\n",
    "\n",
    "ax.grid(which='major', axis='y')\n",
    "ax.minorticks_on()\n",
    "ax.grid(which='both', axis='x', linestyle='--', linewidth=0.5)\n",
    "ax.grid(which='major', axis='both', linestyle='-', linewidth=1.0)\n",
    "\n",
    "# fig.savefig(f\"figures/heat_1d__primal_residuum__{SOLVER_NAME}.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define custom vjp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@eqx.filter_custom_vjp\n",
    "def linsolver(state, n_iter, u_init):\n",
    "    return LINSOLVER(state, n_iter, u_init)\n",
    "\n",
    "@linsolver.def_fwd\n",
    "def linsolver_fwd(perturbed, state, n_iter, u_init):\n",
    "    next_state = linsolver(state, n_iter, u_init)\n",
    "    res = None\n",
    "    return next_state, res\n",
    "\n",
    "@linsolver.def_bwd\n",
    "def linsolver_bwd(res, g, perturbed, state, n_iter, u_init):\n",
    "    # A.T v = g => v = A.inv g\n",
    "    print(\"using custom vjp\")\n",
    "    v = LINSOLVER(g, n_iter, u_init) # g is gradient of loss wrt u2 (n_dof,)\n",
    "    return v #, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "SOLVER_JACOBIANS = \"implicit\" # \"unrolled\" or \"implicit\"\n",
    "if SOLVER_JACOBIANS == \"unrolled\":\n",
    "    inner_solve = LINSOLVER\n",
    "elif SOLVER_JACOBIANS == \"implicit\":\n",
    "    inner_solve = linsolver\n",
    "else:\n",
    "    raise ValueError(f\"Invalid value for SOLVER_JACOBIANS: {SOLVER_JACOBIANS}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Outer loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimizer\n",
    "optimizer = optax.adam(optax.exponential_decay(1e-3, 100, 0.94))\n",
    "\n",
    "# Loss Function\n",
    "@eqx.filter_jit\n",
    "def loss_fn(model, data, inner_iterations):\n",
    "    print(\"compiling update_fn\")\n",
    "    ic = data[:, 0]\n",
    "    target = data[:, 2]\n",
    "    prediction_1 = jax.vmap(model)(ic) # batched forward pass # (batch_size, n_dof)\n",
    "    prediction_2 = jax.vmap(inner_solve,  in_axes=(0, None, None))(prediction_1, inner_iterations, U_INIT)\n",
    "    return jnp.mean((prediction_2 - target)**2) # MSE over batches as well as space\n",
    "\n",
    "@eqx.filter_jit\n",
    "def update_fn(model, state, data, inner_iterations):\n",
    "    print(\"compiling update_fn\")\n",
    "    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, data, inner_iterations)\n",
    "    updates, new_state = optimizer.update(grad, state, model)\n",
    "    new_model= eqx.apply_updates(model, updates)\n",
    "    return new_model, new_state, loss   \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loop over seeds_list, n_inner_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE_RESULTS = False\n",
    "# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]\n",
    "SEED_LIST = [1]\n",
    "N_INNER_LIST = [1,2,3,4,5,6,10,15,20,25]\n",
    "\n",
    "N_EPOCHS = 100\n",
    "print(f\"SOLVER NAME: {SOLVER_NAME} , SOLVER JACOBIAN = {SOLVER_JACOBIANS}, U_INIT = {U_INIT[:2]}...{U_INIT[-2:]}, N_DOF = {N_DOF}\")\n",
    "\n",
    "for seed_count, seed in enumerate(SEED_LIST):\n",
    "    \n",
    "    print(f\"Training with seed {seed} ({seed_count+1} of {len(SEED_LIST)})\")\n",
    "    key = jax.random.PRNGKey(seed)\n",
    "    key, model_init_key = jax.random.split(key)\n",
    "    \n",
    "    # init metrics\n",
    "    losses_all_n = []\n",
    "    errors_all_n = []\n",
    "    time_all_n = []\n",
    "    \n",
    "    # Loop over n_inner\n",
    "    for n_inner in N_INNER_LIST:\n",
    "        print(f\"\\nTraining with {n_inner} inner iterations\\n\")\n",
    "        \n",
    "        # initialize model\n",
    "        model_MLP = eqx.nn.MLP(\n",
    "            in_size=N_DOF, out_size=N_DOF, \n",
    "            width_size=64, depth=3, \n",
    "            activation=jax.nn.relu, \n",
    "            key=model_init_key)\n",
    "        \n",
    "        # initialize optimizer\n",
    "        opt_state = optimizer.init(eqx.filter(model_MLP, eqx.is_array))\n",
    "\n",
    "        # init metrics\n",
    "        loss_history = [loss_fn(model_MLP, train_set, n_inner)]\n",
    "        error_history = [val_loss(model_MLP, val_data_trjs)]\n",
    "\n",
    "        # Training Loop\n",
    "        key, shuffle_key = jax.random.split(key)\n",
    "        \n",
    "        for epoch in range(N_EPOCHS):\n",
    "            shuffle_key, subkey = jax.random.split(shuffle_key)\n",
    "            loss_mini_batch = []\n",
    "            for batch in dataloader(train_set, key=subkey, batch_size=BATCH_SIZE):\n",
    "                model_MLP, opt_state, loss = update_fn(model_MLP, opt_state, batch, n_inner)\n",
    "                loss_mini_batch.append(loss)\n",
    "            \n",
    "            loss_history.append(np.mean(loss_mini_batch))\n",
    "            error_history.append(val_loss(model_MLP, val_data_trjs))\n",
    "            \n",
    "            print(f\"Epoch {epoch+1}/{N_EPOCHS}, loss: {loss_history[-1]}, rel error: {error_history[-1]}\")\n",
    "        \n",
    "        losses_all_n.append(loss_history)\n",
    "        errors_all_n.append(np.array(error_history))\n",
    "    \n",
    "    losses_all_n = np.array(losses_all_n)\n",
    "    errors_all_n = np.array(errors_all_n) # shape (len(N_INNER_LIST), N_EPOCHS, 2)\n",
    "\n",
    "    # save results\n",
    "    if SAVE_RESULTS:\n",
    "        df = pd.DataFrame({\n",
    "            \"max_iter\": N_INNER_LIST,\n",
    "            \"losses\": list(losses_all_n),\n",
    "            \"1-step errors\": list(errors_all_n[:,:,0]),\n",
    "            \"2-step errors\": list(errors_all_n[:,:,1]),\n",
    "            \"time\": time_all_n,\n",
    "            \"seed\": seed,\n",
    "        })\n",
    "        file_name = f\"results/heat_1d_sep29_{SOLVER_NAME}_{SOLVER_JACOBIANS}_time/maxiter_constant__seed_{seed}.pkl\"\n",
    "        df.to_pickle(file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loop over seeds_list, use PRDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_MIN, N_STEP = 1,1\n",
    "\n",
    "# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]\n",
    "SEED_LIST = [1]\n",
    "N_EPOCHS = 100\n",
    "SAVE_RESULTS = False\n",
    "print(f\"PRDP: SOLVER_NAME: {SOLVER_NAME} , SOLVER_JACOBIAN = {SOLVER_JACOBIANS}, U_INIT = {U_INIT[:2]}...{U_INIT[-2:]}\")\n",
    "\n",
    "for seed_count, seed in enumerate(SEED_LIST):\n",
    "    print(f\"Training with seed {seed} ({seed_count+1}/{len(SEED_LIST)})\")\n",
    "    key = jax.random.PRNGKey(seed)\n",
    "    \n",
    "    # init model to be trained\n",
    "    key, model_init_key = jax.random.split(key)\n",
    "    model_mlp_prdp = eqx.nn.MLP(\n",
    "        in_size=N_DOF, out_size=N_DOF, \n",
    "        width_size=64, depth=3, \n",
    "        activation=jax.nn.relu, \n",
    "        key=model_init_key\n",
    "    )\n",
    "    \n",
    "    # initialize optimizer\n",
    "    opt_state = optimizer.init(eqx.filter(model_mlp_prdp, eqx.is_array))\n",
    "\n",
    "    # initialize metrics\n",
    "    btcs_stepper = BTCS_Stepper(num_points=N_DOF)\n",
    "    n_inner_tracker = N_MIN\n",
    "    loss_hist_prdp = [loss_fn(model_mlp_prdp, train_set, n_inner_tracker)]\n",
    "    error_hist_prdp = [val_loss(model_mlp_prdp, val_data_trjs)]\n",
    "    n_inner_hist_prdp = [np.nan] # no value at zeroth epoch, but need same list length as loss_hist\n",
    "\n",
    "    # initialize PRDP's Nmax checkpoint error\n",
    "    should_refine.error_checkpoint = 100\n",
    "\n",
    "    # Training Loop\n",
    "    key, shuffle_key = jax.random.split(key)\n",
    "    for epoch in range(N_EPOCHS):\n",
    "        shuffle_key, subkey = jax.random.split(shuffle_key)\n",
    "        loss_mini_batch = []\n",
    "        for batch in dataloader(train_set, key=subkey, batch_size=BATCH_SIZE):\n",
    "            model_mlp_prdp, opt_state, loss = update_fn(model_mlp_prdp, opt_state, batch, \n",
    "                                                        n_inner_tracker)\n",
    "            loss_mini_batch.append(loss)\n",
    "        \n",
    "        loss_hist_prdp.append(np.mean(loss_mini_batch))\n",
    "        error_hist_prdp.append(val_loss(model_mlp_prdp, val_data_trjs))\n",
    "        n_inner_hist_prdp.append(n_inner_tracker)\n",
    "        \n",
    "        print(f\"Epoch {epoch+1}/{N_EPOCHS}, n_inner: {n_inner_tracker}, loss: {loss_hist_prdp[-1]}, error: {error_hist_prdp[-1]}\")\n",
    "\n",
    "        # PRDP\n",
    "        if should_refine(np.array(error_hist_prdp)[:, 1],  # [:,1] is the two-step error history\n",
    "                         0.98, 0.9, 8): \n",
    "            n_inner_tracker += N_STEP\n",
    "    \n",
    "    # SAVE\n",
    "    loss_hist_prdp = np.array(loss_hist_prdp)\n",
    "    error_hist_prdp = np.array(error_hist_prdp)\n",
    "\n",
    "    if SAVE_RESULTS:\n",
    "        df = pd.DataFrame({\n",
    "            \"losses\": [loss_hist_prdp],\n",
    "            \"1-step errors\": [error_hist_prdp[:,0]],\n",
    "            \"2-step errors\": [error_hist_prdp[:,1]],\n",
    "            \"n_inner\": [n_inner_hist_prdp],\n",
    "            \"max_iter\": \"PRDP\",\n",
    "            \"auto_using\": \"two-step-error\",\n",
    "            \"seed\": seed,\n",
    "        })\n",
    "        file_name = f\"results/heat_1d_sep29_{SOLVER_NAME}_{SOLVER_JACOBIANS}_time/maxiter_auto__seed_{seed}.pkl\"\n",
    "        df.to_pickle(file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot error of the last seed\n",
    "plt.plot(error_hist_prdp[:,1])\n",
    "plt.yscale('log')\n",
    "plt.grid()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
