{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from matplotlib.animation import FuncAnimation\n",
    "# import pandas as pd\n",
    "import equinox as eqx\n",
    "import optax\n",
    "\n",
    "import src.implicax.implicax as implicax\n",
    "from src.implicax.implicax.utilities import rollout\n",
    "import pdequinox as pdeqx\n",
    "from src.BTCS_Stepper import dataloader\n",
    "from src.prdp import should_refine, numpy_ewma\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_REF = 97"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate random initial conditions - smoothened, normalized, divergence free"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SAMPLES = 205\n",
    "ic_x_set = []\n",
    "ic_y_set = []\n",
    "\n",
    "for i in range(NUM_SAMPLES):\n",
    "    key_x, key_y = jax.random.split(jax.random.PRNGKey(i))\n",
    "    u_0_x = jax.random.normal(key_x, (1, N_REF, N_REF))\n",
    "    u_0_y = jax.random.normal(key_y, (1, N_REF, N_REF))\n",
    "    ic_x_set.append(u_0_x)\n",
    "    ic_y_set.append(u_0_y)\n",
    "\n",
    "ic_x_set = jnp.stack(ic_x_set)\n",
    "ic_y_set = jnp.stack(ic_y_set)\n",
    "\n",
    "# Smoothen the initial conditions by pushing them through a heat solver (= low-pass filter)\n",
    "heat_stepper = implicax.Heat2d(1.0, N_REF, 1.0, nu=3e-3)\n",
    "ic_x_set = jax.vmap(heat_stepper)(ic_x_set)\n",
    "ic_y_set = jax.vmap(heat_stepper)(ic_y_set)\n",
    "\n",
    "# Remove the mean in each sample\n",
    "ic_x_set -= jnp.mean(ic_x_set, axis=(-1,-2), keepdims=True)\n",
    "ic_y_set -= jnp.mean(ic_y_set, axis=(-1,-2), keepdims=True)\n",
    "# Make sure the magnitude is around 1\n",
    "ic_x_set /= jnp.std(ic_x_set, axis=(-1,-2), keepdims=True)\n",
    "ic_y_set /= jnp.std(ic_y_set, axis=(-1,-2), keepdims=True)\n",
    "\n",
    "# The state for the NS simulator has three channels (velocity-x, velocity-y,\n",
    "# pressure). We will initialize pressure to zero.\n",
    "ic_pressure_set = jnp.zeros_like(ic_x_set)\n",
    "ic_set = jnp.concatenate((ic_x_set, ic_y_set, ic_pressure_set), axis=-3)\n",
    "\n",
    "\n",
    "# Instantiate the NS simulator with Re=1000 and use it to make the ic divergence-free\n",
    "ns_simulator = implicax.NavierStokes(1.0, N_REF, 0.1, nu=1e-4, maxiter_picard=1)\n",
    "ic_set = jax.vmap(ns_simulator.make_incompressible)(ic_set)\n",
    "\n",
    "assert ic_set.shape == (NUM_SAMPLES, 3, N_REF, N_REF)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TIMESTEPS = 5\n",
    "trj_set = jax.vmap(rollout(ns_simulator, NUM_TIMESTEPS, include_init=True))(ic_set)\n",
    "assert trj_set.shape == (NUM_SAMPLES, NUM_TIMESTEPS+1, 3, N_REF, N_REF) # sanity check\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# downsample for the source \n",
    "trj_set_source = trj_set[:,:,:,1::2,1::2]\n",
    "assert trj_set_source.shape == (NUM_SAMPLES, NUM_TIMESTEPS+1, 3, (N_REF-1)/2, (N_REF-1)/2)\n",
    "NDOF_SOURCE = trj_set_source.shape[3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train:Validation split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = trj_set_source[:200]\n",
    "val_set = trj_set_source[200:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning NS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss and test_loss for correction learning\n",
    "\n",
    "$$\n",
    "u_0 -\\mathcal{P}_s-> -\\mathbb{C}_{\\theta}-> u^{[1]} -\\mathcal{P}_s->  -\\mathbb{C}_{\\theta}-> u^{[2]}\n",
    "$$\n",
    "\n",
    "$\\mathcal{P}_s$ is coarser than $\\mathcal{P}_r$ as it uses half the number of points => corrector $\\mathbb{C}_{\\theta}$ learns to resolve fine physics.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ns_sim_halfspace = implicax.NavierStokes(1.0, NDOF_SOURCE, 0.1, nu=1e-4, maxiter_picard=1, restart=8, maxiter_linsolve=120)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss fn for correction learning\n",
    "@eqx.filter_jit\n",
    "def loss_fn(model, data, coarse_sim_1=ns_sim_halfspace, coarse_sim_2=ns_sim_halfspace):\n",
    "    \"\"\"\n",
    "    Correction learning loss function.\n",
    "    \"\"\"\n",
    "    print(\"compiling loss_fn\")\n",
    "    ic       = data[:,0]\n",
    "    target_1 = data[:,1]\n",
    "    target_2 = data[:,2]\n",
    "\n",
    "    pred_1_coarse = jax.vmap(coarse_sim_1)(ic)\n",
    "    correction_1 = jax.vmap(model)(pred_1_coarse)       # here the model learns the correction which then has to be manually added to the solution\n",
    "    pred_1_corrected = pred_1_coarse + correction_1\n",
    "\n",
    "    pred_2_coarse = jax.vmap(coarse_sim_2)(pred_1_corrected)\n",
    "    correction_2 = jax.vmap(model)(pred_2_coarse)\n",
    "    pred_2_corrected = pred_2_coarse + correction_2\n",
    "\n",
    "    return jnp.mean((target_2 - pred_2_corrected)**2) + jnp.mean((target_1 - pred_1_corrected)**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "@eqx.filter_jit\n",
    "def val_loss(model, test_data, coarse_sim_1=ns_sim_halfspace, coarse_sim_2=ns_sim_halfspace):\n",
    "    \"\"\"Compute the loss on the test set.\n",
    "    \n",
    "    Args:\n",
    "        model: the model to evaluate\n",
    "        test_data: the test data, with shape (n_samples, n_steps, n_channels, N, N)\n",
    "    \"\"\"\n",
    "    print(\"compiling val_loss function\")\n",
    "    ic = test_data[:, 0]\n",
    "    target_1 = test_data[:, 1]\n",
    "    target_2 = test_data[:, 2]\n",
    "    target_5 = test_data[:, 5]\n",
    "\n",
    "    pred_1_coarse    = jax.vmap(coarse_sim_1)(ic)\n",
    "    pred_1_corrected = jax.vmap(model)(pred_1_coarse) + pred_1_coarse\n",
    "    pred_2_coarse    = jax.vmap(coarse_sim_2)(pred_1_corrected)\n",
    "    pred_2_corrected = jax.vmap(model)(pred_2_coarse) + pred_2_coarse\n",
    "    pred_3_coarse    = jax.vmap(coarse_sim_1)(pred_2_corrected)\n",
    "    pred_3_corrected = jax.vmap(model)(pred_3_coarse) + pred_3_coarse\n",
    "    pred_4_coarse    = jax.vmap(coarse_sim_2)(pred_3_corrected)\n",
    "    pred_4_corrected = jax.vmap(model)(pred_4_coarse) + pred_4_coarse\n",
    "    pred_5_coarse    = jax.vmap(coarse_sim_1)(pred_4_corrected)\n",
    "    pred_5_corrected = jax.vmap(model)(pred_5_coarse) + pred_5_coarse\n",
    "\n",
    "    pred_1_errors = jnp.linalg.norm(pred_1_corrected - target_1, axis=(-2, -1)) # normed over space\n",
    "    pred_2_errors = jnp.linalg.norm(pred_2_corrected - target_2, axis=(-2, -1)) # normed over space\n",
    "    pred_5_errors = jnp.linalg.norm(pred_5_corrected - target_5, axis=(-2, -1)) # normed over space\n",
    "\n",
    "    data_1_norms = jnp.linalg.norm(target_1, axis=(-2, -1))\n",
    "    data_2_norms = jnp.linalg.norm(target_2, axis=(-2, -1))\n",
    "    data_5_norms = jnp.linalg.norm(target_5, axis=(-2, -1))\n",
    "    pred_1_mse_normed = jnp.mean((pred_1_errors**2 / data_1_norms**2), axis=0) # shape (3,): velocity-x, velocity-y, pressure\n",
    "    pred_2_mse_normed = jnp.mean((pred_2_errors**2 / data_2_norms**2), axis=0)\n",
    "    pred_5_mse_normed = jnp.mean((pred_5_errors**2 / data_5_norms**2), axis=0)\n",
    "\n",
    "    return jnp.vstack((pred_1_mse_normed, pred_2_mse_normed, pred_5_mse_normed)) # shape (3, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EPOCHS = 100\n",
    "# optimizer = optax.adam(1e-3)\n",
    "# optimizer = optax.adam(optax.exponential_decay(1e-3, 100, 0.85))\n",
    "optimizer = optax.adam(optax.cosine_decay_schedule(1e-3, N_EPOCHS*8, 0.1))\n",
    "\n",
    "@eqx.filter_jit\n",
    "def update_fn(model, opt_state, data, coarse_sim_1=ns_sim_halfspace, coarse_sim_2=ns_sim_halfspace):\n",
    "    print(\"compiling update_fn\")\n",
    "    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, data, coarse_sim_1, coarse_sim_2)\n",
    "    updates, new_state = optimizer.update(grad, opt_state, model)\n",
    "    new_model = eqx.apply_updates(model, updates)\n",
    "    return new_model, new_state, loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning (test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize model\n",
    "model_resnet = pdeqx.arch.ClassicResNet(\n",
    "    num_spatial_dims=2,\n",
    "    in_channels=3,\n",
    "    out_channels=3,\n",
    "    key=jax.random.PRNGKey(92),\n",
    "    hidden_channels=64,\n",
    "    num_blocks=3\n",
    ")\n",
    "\n",
    "# initialize optimizer\n",
    "opt_state = optimizer.init(eqx.filter(model_resnet, eqx.is_array))\n",
    "\n",
    "# initialize metrics\n",
    "loss_hist = [loss_fn(model_resnet, train_set)]\n",
    "rel_error_hist = [val_loss(model_resnet, val_set)]\n",
    "\n",
    "# training loop\n",
    "BATCH_SIZE = 25\n",
    "shuffle_key = jax.random.PRNGKey(42)\n",
    "\n",
    "for epoch in range(N_EPOCHS):\n",
    "    shuffle_key, subkey = jax.random.split(shuffle_key)\n",
    "    loss_mini_batch = []\n",
    "    for batch in dataloader(train_set, batch_size=BATCH_SIZE, key=subkey):\n",
    "        model_resnet, opt_state, loss = update_fn(model_resnet, opt_state, batch)\n",
    "        loss_mini_batch.append(loss)\n",
    "    loss_hist.append(np.mean(np.array(loss_mini_batch)))\n",
    "    rel_error_hist.append(val_loss(model_resnet, val_set))\n",
    "    print(f\"Epoch {epoch+1}/{N_EPOCHS}: Loss = {loss_hist[-1]}, rel. error 1-step = {rel_error_hist[-1][0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(rel_error_hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model\n",
    "# eqx.tree_serialise_leaves(f\"results/saved_models/ns_spatial_correction_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.eqx\", model_resnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rel_error_hist = np.array(rel_error_hist)\n",
    "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
    "axs[0].set_title(\"Training loss (minibatch avg.)\")\n",
    "axs[0].plot(loss_hist)\n",
    "axs[0].set_yscale(\"log\"); axs[0].set_ylim(1e-3, 1)\n",
    "\n",
    "axs[1].set_title(\"1-step nMSE on val. set\")\n",
    "axs[1].plot(rel_error_hist[:,0,0], label=\"velocity-x\")\n",
    "axs[1].plot(rel_error_hist[:,0,1], label=\"velocity-y\")\n",
    "axs[1].legend(); axs[1].set_ylim(1e-3, 1)\n",
    "\n",
    "axs[2].set_title(\"2-step nMSE on val. set\")\n",
    "axs[2].plot(rel_error_hist[:,1,0], label=\"velocity-x\")\n",
    "axs[2].plot(rel_error_hist[:,1,1], label=\"velocity-y\")\n",
    "axs[2].legend(); axs[2].set_ylim(1e-3, 1)\n",
    "\n",
    "axs[3].set_title(\"5-step nMSE on val. set\")\n",
    "axs[3].plot(rel_error_hist[:,2,0], label=\"velocity-x\")\n",
    "axs[3].plot(rel_error_hist[:,2,1], label=\"velocity-y\")\n",
    "axs[3].legend(); axs[3].set_ylim(1e-3, 1)\n",
    "\n",
    "for ax in axs:\n",
    "    ax.set_yscale(\"log\")\n",
    "    ax.grid(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Different constant inner iterations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED_LIST = [1]\n",
    "# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_ITER_LIST    = [1,2,3,4,5,6,8,10,15,20,40,50,60, 80, 120]\n",
    "RESTART = 8\n",
    "SAVE_RESULTS = False\n",
    "N_EPOCHS = 100\n",
    "\n",
    "for seed_count, seed in enumerate(SEED_LIST):\n",
    "    key = jax.random.PRNGKey(seed)\n",
    "    key, model_init_key = jax.random.split(key)\n",
    "    \n",
    "    # initialize metrics\n",
    "    losses_all_n = []\n",
    "    errors_all_n = []\n",
    "    time_all_n = []\n",
    "    # res_all_n = []\n",
    "    \n",
    "    for maxiter_count, max_iter in enumerate(MAX_ITER_LIST):\n",
    "        \n",
    "        print(f\"\\nTRAINING WITH: seed={seed} ({seed_count+1} of {len(SEED_LIST)}), max_iter={max_iter} ({maxiter_count+1} of {len(MAX_ITER_LIST)}), restart={RESTART}\\n\")\n",
    "        \n",
    "        # initialize the incompletely converged solver\n",
    "        ns_sim_incomplete_halfspace = implicax.NavierStokes(1.0, NDOF_SOURCE, 0.1, nu=1e-4, maxiter_picard=1,\n",
    "                                                            maxiter_linsolve=max_iter, restart=RESTART)\n",
    "\n",
    "        # initialize model\n",
    "        model_resnet = pdeqx.arch.ClassicResNet(\n",
    "            2, 3, 3,\n",
    "            hidden_channels=64, \n",
    "            num_blocks=3,\n",
    "            key=model_init_key,\n",
    "        )\n",
    "\n",
    "        # initialize optimizer\n",
    "        opt_state = optimizer.init(eqx.filter(model_resnet, eqx.is_array))\n",
    "\n",
    "        # initialize metrics\n",
    "        loss_hist = [loss_fn(model_resnet, train_set, coarse_sim_2=ns_sim_incomplete_halfspace, coarse_sim_1=ns_sim_incomplete_halfspace)]\n",
    "        rel_error_hist = [val_loss(model_resnet, val_set)]\n",
    "        # res_hist = [residuum_fn(model_resnet, train_set, ns_sim_incomplete_halfspace, ns_sim_incomplete_halfspace)]\n",
    "\n",
    "        # training loop\n",
    "        BATCH_SIZE = 25\n",
    "        key, shuffle_key = jax.random.split(key)\n",
    "        for epoch in range(N_EPOCHS):\n",
    "            shuffle_key, subkey = jax.random.split(shuffle_key)\n",
    "            loss_mini_batch = []\n",
    "            for batch in dataloader(train_set, batch_size=BATCH_SIZE, key=subkey):\n",
    "                model_resnet, opt_state, loss = update_fn(model_resnet, opt_state, batch, \n",
    "                                                          coarse_sim_2=ns_sim_incomplete_halfspace, \n",
    "                                                          coarse_sim_1=ns_sim_incomplete_halfspace)\n",
    "                loss_mini_batch.append(loss)\n",
    "            \n",
    "            loss_hist.append(np.mean(np.array(loss_mini_batch)))\n",
    "            rel_error_hist.append(val_loss(model_resnet, val_set))\n",
    "            print(f\"Epoch {epoch+1}/{N_EPOCHS}: Loss = {loss_hist[-1]}, 1-step rel. error = {rel_error_hist[-1][0]}\")\n",
    "\n",
    "        losses_all_n.append(loss_hist)\n",
    "        errors_all_n.append(np.array(rel_error_hist))\n",
    "\n",
    "\n",
    "    # save results\n",
    "    if SAVE_RESULTS:\n",
    "        losses_all_n = np.array(losses_all_n)\n",
    "        errors_all_n = np.array(errors_all_n) # shape (len(MAX_RESTART_LIST), N_EPOCHS, 3, 3)\n",
    "        # # res_all_n = np.array(res_all_n)\n",
    "        df = pd.DataFrame({\n",
    "            # \"max_iter\": [max_iter] * len(MAX_RESTART_LIST),\n",
    "            # \"max_restart\": MAX_RESTART_LIST,\n",
    "            \"max_restart\": [RESTART] * len(MAX_ITER_LIST),\n",
    "            \"max_iter\": MAX_ITER_LIST,\n",
    "            \"losses\": list(losses_all_n),\n",
    "            \"1-step errors\": list(errors_all_n[:,:,0]),\n",
    "            \"2-step errors\": list(errors_all_n[:,:,1]),\n",
    "            \"5-step errors\": list(errors_all_n[:,:,2]),\n",
    "            \"seed\": seed\n",
    "        })\n",
    "        file_name = f\"results/navier_stokes/maxiter_constant__seed_{seed}.pkl\"\n",
    "        df.to_pickle(file_name)\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 3, figsize=(16, 8))\n",
    "\n",
    "axs[0,0].set_title(\"Training loss (minibatch avg.)\")\n",
    "for i, max_iter in enumerate(MAX_ITER_LIST):\n",
    "    axs[0,0].plot(losses_all_n[i], label=f\"{max_iter}\")\n",
    "    axs[0,1].plot(errors_all_n[i][:,0,0], label=f\"{max_iter}\")\n",
    "    axs[0,2].plot(errors_all_n[i][:,1,0], label=f\"{max_iter}\")\n",
    "    axs[1,1].plot(errors_all_n[i][:,0,1], label=f\"{max_iter}\")\n",
    "    axs[1,2].plot(errors_all_n[i][:,1,1], label=f\"{max_iter}\")\n",
    "axs[0,0].set_ylim(1e-5, 1)\n",
    "\n",
    "for ax in axs.flatten():\n",
    "    ax.set_yscale(\"log\")\n",
    "    ax.grid(True)\n",
    "\n",
    "axs[0,1].set_title(\"x-velocity 1-step nMSE\"); axs[0,1].set_ylim(1e-3, 1)\n",
    "axs[0,2].set_title(\"x-velocity 2-step nMSE\"); axs[0,2].set_ylim(1e-3, 1)\n",
    "axs[1,1].set_title(\"y-velocity 1-step nMSE\"); axs[1,1].set_ylim(1e-3, 1)\n",
    "axs[1,2].set_title(\"y-velocity 2-step nMSE\"); axs[1,2].set_ylim(1e-3, 1)\n",
    "\n",
    "axs[0,-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', title=\"inner iters\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PRDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_MIN, N_STEP = 1, 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RESTART = 8  # GMRES\n",
    "SAVE_RESULTS = False\n",
    "\n",
    "# SEED_LIST = [1, 2, 25, 50, 1000, 1337, 2668, 3999, 12345, 54321]\n",
    "SEED_LIST = [1]\n",
    "\n",
    "for seed_count, seed in enumerate(SEED_LIST):\n",
    "    print(f\"Training with seed={seed} ({seed_count+1} of {len(SEED_LIST)})\")\n",
    "    key = jax.random.PRNGKey(seed)\n",
    "    \n",
    "    # init model to be trained\n",
    "    key, model_init_key = jax.random.split(key)\n",
    "    model_resnet = pdeqx.arch.ClassicResNet(\n",
    "        2, 3, 3,\n",
    "        hidden_channels=64, \n",
    "        num_blocks=3,\n",
    "        key=model_init_key,\n",
    "    )\n",
    "\n",
    "    # init coarse-convergence solver\n",
    "    n_inner = N_MIN\n",
    "    ns_sim_incomplete_halfspace = implicax.NavierStokes(\n",
    "        1.0, NDOF_SOURCE, 0.1, nu=1e-4,\n",
    "        maxiter_picard=1,\n",
    "        maxiter_linsolve=n_inner,\n",
    "        restart=RESTART\n",
    "    )\n",
    "\n",
    "    # initialize optimizer\n",
    "    opt_state = optimizer.init(eqx.filter(model_resnet, eqx.is_array))\n",
    "\n",
    "    # initialize metrics\n",
    "    loss_hist_prdp = [loss_fn(model_resnet, train_set, coarse_sim_2=ns_sim_incomplete_halfspace, coarse_sim_1=ns_sim_incomplete_halfspace)]\n",
    "    error_hist_prdp = [val_loss(model_resnet, val_set)]\n",
    "    n_inner_hist_prdp = [np.nan] # no meaning of n_inner at zeroth epoc\n",
    "\n",
    "    # initialize PRDP's Nmax checkpoint error\n",
    "    should_refine.error_checkpoint = 100\n",
    "\n",
    "    # training loop\n",
    "    BATCH_SIZE = 25\n",
    "    key, shuffle_key = jax.random.split(key)\n",
    "    start_time = time.process_time()\n",
    "    \n",
    "    for epoch in range(N_EPOCHS):\n",
    "        shuffle_key, subkey = jax.random.split(shuffle_key)\n",
    "        loss_mini_batch = []\n",
    "        for batch in dataloader(train_set, batch_size=BATCH_SIZE, key=subkey):\n",
    "            model_resnet, opt_state, loss = update_fn(model_resnet, opt_state, batch, \n",
    "                                                      coarse_sim_2=ns_sim_incomplete_halfspace, \n",
    "                                                      coarse_sim_1=ns_sim_incomplete_halfspace)\n",
    "            loss_mini_batch.append(loss)\n",
    "        \n",
    "        loss_hist_prdp.append(jnp.mean(jnp.array(loss_mini_batch)))\n",
    "        error_hist_prdp.append(val_loss(model_resnet, val_set))\n",
    "        print(f\"Epoch {epoch+1}/{N_EPOCHS}, n_inner = {n_inner}, Loss = {loss_hist_prdp[-1]}, 1-step rel. error 'x' = {error_hist_prdp[-1][0][0]}\")\n",
    "        n_inner_hist_prdp.append(n_inner)\n",
    "\n",
    "        # PRDP\n",
    "        if should_refine(\n",
    "            np.array(error_hist_prdp)[:,2,0], # [all epochs][which-step][which-channel]\n",
    "            stepping_threshold=0.95,\n",
    "        ): \n",
    "            n_inner += N_STEP\n",
    "            ns_sim_incomplete_halfspace = implicax.NavierStokes(\n",
    "                1.0, NDOF_SOURCE, 0.1, nu=1e-4, \n",
    "                maxiter_picard=1,\n",
    "                maxiter_linsolve=n_inner,\n",
    "                restart=RESTART\n",
    "            )\n",
    "\n",
    "    # SAVE\n",
    "    loss_hist_prdp = np.array(loss_hist_prdp)\n",
    "    error_hist_prdp = np.array(error_hist_prdp)\n",
    "\n",
    "    if SAVE_RESULTS:\n",
    "        df = pd.DataFrame({\n",
    "            \"losses\": [loss_hist_prdp],\n",
    "            \"1-step errors\": [error_hist_prdp[:,0]],\n",
    "            \"2-step errors\": [error_hist_prdp[:,1]],\n",
    "            \"5-step errors\": [error_hist_prdp[:,2]],\n",
    "            \"n_inner\": [n_inner_hist_prdp],\n",
    "            \"max_restart\": RESTART,\n",
    "            \"seed\": seed,\n",
    "            \"max_iter\": \"PRDP\",\n",
    "            \"auto_using\": \"fivesteperror\"\n",
    "        })\n",
    "        file_name = f\"results/ns_spatial_sep27_time/maxiter_auto_using_fivesteperror__seed_{seed}.pkl\"\n",
    "        df.to_pickle(file_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_hist_prdp = np.array(error_hist_prdp)\n",
    "\n",
    "fig, axs = plt.subplots(1, 3, figsize=(20, 5))\n",
    "\n",
    "axs[0].set_title(\"Training loss (minibatch avg.)\")\n",
    "axs[0].plot(loss_hist_prdp)\n",
    "axs[0].set_yscale(\"log\"); axs[0].grid()\n",
    "axs[0].set_ylim(1e-3, 1)\n",
    "\n",
    "axs[1].set_title(\"X-velocity nMSE (val. set)\")\n",
    "axs[1].plot(error_hist_prdp[:,0,0], label=\"1-step nMSE | x-vel\")\n",
    "axs[1].plot(error_hist_prdp[:,1,0], label=\"2-step nMSE | x-vel\", color=\"orange\")\n",
    "axs[1].plot(error_hist_prdp[:,2,0], label=\"5-step nMSE | x-vel\", color=\"green\")\n",
    "axs[1].legend(); axs[1].set_yscale(\"log\"); axs[1].grid()\n",
    "# axs[1].set_ylim(5e-3, 1)\n",
    "\n",
    "EMA_WINDOW = 6\n",
    "axs[1].plot(numpy_ewma(error_hist_prdp[:,2,0], EMA_WINDOW), label=\"EMA\")\n",
    "axs[1].axhline(y = error_hist_prdp[0,2,0], color='grey', linestyle=':')\n",
    "\n",
    "axs[2].set_title(\"GMRES maxiter\")\n",
    "axs[2].plot(n_inner_hist_prdp); plt.grid()\n",
    "# axs[2].set_ylim(0, 17)\n",
    "# axs[1].set_xlim(96, 100)\n",
    "# axs[1].set_ylim(1e-3, 1e-1)\n",
    "\n",
    "# from matplotlib.ticker import MaxNLocator\n",
    "# axs[2].yaxis.set_major_locator(MaxNLocator(integer=True))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_fresh",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
