{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "import jax\n",
    "from jax import custom_jvp, jacfwd\n",
    "import jax.numpy.linalg as LA\n",
    "from jax import jvp\n",
    "from functools import partial\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import Normalize\n",
    "import matplotlib.cm as cm\n",
    "import numpy as np\n",
    "\n",
    "from src.poisson_problem_definition import *\n",
    "from src.linear_solvers_scan import forward_solve_jacobi, forward_solve_SD, forward_solve, outer_optimisation, outer_optimisation_dynamic_inner\n",
    "from src.poisson_utilities import plot_outer_history_2 #plot_contours, \n",
    "# plt.rcParams['figure.dpi'] = 200\n",
    "# plt.rcParams['savefig.dpi'] = 200\n",
    "\n",
    "# import plotly.io as pio\n",
    "# pio.templates.default = \"simple_white\"\n",
    "\n",
    "plotly_config = {\n",
    "  'toImageButtonOptions': {\n",
    "    'format': 'png', # one of png, svg, jpeg, webp\n",
    "    'filename': 'custom_image',\n",
    "    'height': 500,\n",
    "    'width': 800,\n",
    "    'scale': 1.5 # Multiply title/legend/axis/canvas sizes by this factor\n",
    "  }\n",
    "}\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SETTINGS FOR PLOTTING\n",
    "plt.rc(\"text\")\n",
    "plt.rc(\"font\", family=\"sans-serif\")\n",
    "plt.rc(\"font\", size=9)\n",
    "plt.rc(\"axes\", labelsize=6)\n",
    "plt.rc(\"font\", size=9)\n",
    "plt.rc(\"legend\", fontsize=6)  # Make the legend/label fonts\n",
    "plt.rc(\"legend\", title_fontsize=7)\n",
    "plt.rc(\"xtick\", labelsize=6.5)  # a little smaller\n",
    "plt.rc(\"ytick\", labelsize=6.0)\n",
    "# Set title font size\n",
    "plt.rc(\"axes\", titlesize=8)\n",
    "plt.rc(\"lines\", linewidth=1.0)\n",
    "\n",
    "WIDTH = 5.5\n",
    "HEIGHT = 2.0\n",
    "plt.rc(\"figure\", figsize=(WIDTH, HEIGHT))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check whether jax is using gpu\n",
    "print(jax.devices())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "THETA_REF = 2.0\n",
    "rhs_fn = rhs_sine_fn\n",
    "# rhs_fn = rhs_step_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_DOF = 30      # N_DOF from slider\n",
    "x = jnp.linspace(0, 1, N_DOF+2); x = x[1:-1]; print(f\"x shape = {x.shape}\")\n",
    "A =  A_matrix(N_DOF)\n",
    "u_ref = u_direct(THETA_REF, x, A, rhs_fn) #LA.solve(A, rhs_fn(THETA_REF, x))\n",
    "assert u_ref.shape == (N_DOF,) # sanity check\n",
    "# plt.plot(u_ref); plt.title(f\"Exact solution, $\\\\theta$_ref = {THETA_REF}\"); plt.grid();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### custome jvp = linsolve(A, $\\dot{b}$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "SOLVER = \"jacobi\" # \"jacobi\" or \"SD\"\n",
    "\n",
    "if SOLVER == \"jacobi\":\n",
    "    LINSOLVE = forward_solve_jacobi\n",
    "elif SOLVER == \"SD\":\n",
    "    LINSOLVE = forward_solve_SD\n",
    "\n",
    "@partial(custom_jvp, nondiff_argnums=(1, 2))\n",
    "def u_solve(b, n_inner, u_init):\n",
    "    return LINSOLVE(A, b, n_inner, u_init)\n",
    "\n",
    "@u_solve.defjvp\n",
    "def u_solve_jvp(n_inner, u_init, primals, tangents):\n",
    "    \"\"\"\n",
    "    - primals = the input to the function whose JVP we are defining (b)\n",
    "    - tangents = the autodiff of the inputs to the function whose JVP we are defining (b), \n",
    "    w.r.t. to the parameter against which we have requested the jacfwd (theta)\n",
    "    \"\"\"\n",
    "    b, = primals # (n_dof,)\n",
    "    b_dot, = tangents # (n_dof, d); here (n_dof,)\n",
    "    primal_out = u_solve(b, n_inner, u_init) # history of <A,b> linsolve (N_INNER, n_dof)\n",
    "    tangent_out = LINSOLVE(A, b_dot, n_inner, u_init) # history of <A, b_dot> linsolve (N_INNER, n_dof)\n",
    "    return primal_out, tangent_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inner problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_INNER = 1000   # max n_inner from slider"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### u_init = zeros"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "reference_jacobian = jacfwd(u_direct)(THETA_REF, x, A, rhs_fn)\n",
    "U_INIT = jnp.zeros(N_DOF)\n",
    "\n",
    "# unrolled jacobian\n",
    "unrolled_jacobian = jacfwd(forward_solve)(THETA_REF, A, rhs_fn, x, N_INNER, SOLVER, U_INIT, False )\n",
    "# implicit jacobian\n",
    "func = lambda theta: u_solve(rhs_fn(theta, x), N_INNER, U_INIT) # (N_INNER+1, n_dof)\n",
    "implicit_jacobian = jacfwd(func)(THETA_REF)\n",
    "# jacobian suboptimalities\n",
    "implicit_jac_rel_subopt_zerosinit = LA.norm(implicit_jacobian - reference_jacobian, axis=1) / LA.norm(reference_jacobian)\n",
    "unrolled_jac_rel_subopt_zerosinit = LA.norm(unrolled_jacobian - reference_jacobian, axis=1) / LA.norm(reference_jacobian)\n",
    "\n",
    "# primal suboptimality\n",
    "primal = func(THETA_REF) # (N_INNER+1, n_dof)\n",
    "u_rel_subopt_zerosinit = LA.norm(primal - u_ref, axis=1) / LA.norm(u_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"theta_ref = {THETA_REF}, SOLVER = {SOLVER}\")\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(WIDTH, HEIGHT*1.5))\n",
    "lastiter = N_INNER\n",
    "\n",
    "# Plot on the left subplot (ax1)\n",
    "ax1.plot(unrolled_jac_rel_subopt_zerosinit[:lastiter], \"-\",  color=\"tab:blue\",   label=\"unrolled, $\\\\text{primal}_\\\\text{init} = \\\\text{zeros}$\")\n",
    "ax1.plot(implicit_jac_rel_subopt_zerosinit[:lastiter], \"--\", color=\"tab:orange\", label=\"implicit, $\\\\text{tangent}_\\\\text{init} = \\\\text{zeros}$\")\n",
    "# ax1.plot( unrolled_jac_rel_subopt_onesinit[:lastiter], \"-\",  marker='^',  markevery = int(lastiter/4), color=\"tab:blue\",   label=\"unrolled, $\\\\text{primal}_\\\\text{init} =\\\\text{ones}$\")\n",
    "# ax1.plot( implicit_jac_rel_subopt_onesinit[:lastiter], \"--\", marker='^',  markevery = int(lastiter/6), color=\"tab:orange\", label=\"implicit, $\\\\text{tangent}_\\\\text{init} =\\\\text{ones}$\")\n",
    "# ax1.plot( unrolled_jac_rel_subopt_randinit[:lastiter], \"-\",  marker='x',  markevery = int(lastiter/3), color=\"tab:blue\",   label=\"unrolled, $\\\\text{primal}_\\\\text{init} =\\\\text{rand}$\")\n",
    "# ax1.plot( implicit_jac_rel_subopt_randinit[:lastiter], \"--\", marker='x',  markevery = int(lastiter/10), color=\"tab:orange\",label=\"implicit, $\\\\text{tangent}_\\\\text{init} = \\\\text{rand}$\")\n",
    "\n",
    "ax1.set_ylabel(\"Jacobian relative suboptimality\")\n",
    "ax1.set_xlabel(\"# iterations\")\n",
    "ax1.legend(fontsize=8); ax1.set_yscale(\"log\")\n",
    "\n",
    "# Plot on the right subplot (ax2)\n",
    "# ax2.set_title(f\"Primal relative subopt history\\n $\\\\theta$_ref = {THETA_REF}\")\n",
    "ax2.plot(u_rel_subopt_zerosinit[:lastiter],  label=\"$\\\\text{primal}_\\\\text{init} = \\\\text{zeros}$\")\n",
    "# ax2.plot(u_rel_subopt_onesinit[:lastiter], \"-^\", markevery=int(lastiter/5),   label=\"$\\\\text{primal}_\\\\text{init} = \\\\text{ones}$\")\n",
    "# ax2.plot(u_rel_subopt_randinit[:lastiter], \"-x\", markevery=int(lastiter/5),   label=\"$\\\\text{primal}_\\\\text{init} = \\\\text{rand}$\")\n",
    "ax2.set_ylabel(\"Iterate relative suboptimality\")\n",
    "ax2.set_xlabel(\"# iterations\")\n",
    "ax2.legend(fontsize=8); ax2.set_yscale(\"log\")\n",
    "\n",
    "for ax in ax1,ax2:\n",
    "    ax.grid(which='major', axis='both')\n",
    "    ax.grid(which='minor', axis='x', linestyle=':', linewidth='0.5')\n",
    "    ax.minorticks_on()\n",
    "\n",
    "fig.suptitle(f\"Poisson 1 param, solver={SOLVER}\")\n",
    "fig.tight_layout()\n",
    "# fig.savefig(f\"figures/poisson_1_param__{SOLVER}_suboptimalities.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## SANITY CHECK: IS U_SOLVE WORKING CORRECTLY\n",
    "plt.rcParams['figure.dpi'] = 100\n",
    "plt.plot(LA.solve(A, rhs_fn(THETA_REF, x)), label=\"u_ref\")\n",
    "plt.plot(u_solve(rhs_fn(THETA_REF,x), 1000, jnp.zeros(N_DOF))[-1], 'o', label=\"$u_{init} = \\\\text{zeros}$\")\n",
    "plt.title(f\"Sanity check: is u_solve working correctly\")\n",
    "plt.grid()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Residuum plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def primal_rel_residuum(primal):\n",
    "    # primal_hist = u_solve(b, n_iter, solver, u_init)\n",
    "    Au = A @ primal.T\n",
    "    res = LA.norm(Au - rhs_sine_fn(THETA_REF, x))\n",
    "    # res_init = LA.norm(A @ jnp.zeros(N_DOF) - rhs_sine_fn(THETA_REF, x))\n",
    "    res_init = LA.norm(rhs_sine_fn(THETA_REF, x))\n",
    "    return res / res_init\n",
    "\n",
    "U_INIT = jnp.zeros(N_DOF)\n",
    "primal_hist = forward_solve(THETA_REF, A, rhs_sine_fn, x, 800, SOLVER, U_INIT)\n",
    "residuum_hist = jax.vmap(primal_rel_residuum)(primal_hist)\n",
    "\n",
    "\n",
    "plt.figure(figsize=(WIDTH*0.3, HEIGHT))\n",
    "plt.plot(residuum_hist)#, \"-o\", markevery=int(N_INNER/5))\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(which='major', axis='both')\n",
    "plt.grid(which='minor', axis='x', linestyle=':', linewidth='0.5')\n",
    "plt.minorticks_on()\n",
    "plt.xlabel(\"# iterations\")\n",
    "plt.ylabel(\"Primal relative residuum\")\n",
    "plt.title(f\"Poisson 1 params | solver = {SOLVER}\\nResiduum @ $\\\\theta$ = {THETA_REF}\")\n",
    "# plt.savefig(f\"figures/poisson_1_param__{SOLVER}_primal_residuum.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Outer Problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.linear_solvers_scan import squareloss_fn\n",
    "THETA_INIT = 5.0        # from slider\n",
    "N_OUTER = 180           # from slider\n",
    "LR = 275                # from slider\n",
    "N_INNER_MAX = 600\n",
    "U_INIT = jnp.zeros(N_DOF) # jnp.ones(n_dof) # jax.random.normal(jax.random.PRNGKey(0), shape=x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss function "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIFF_METHOD = \"implicit\"\n",
    "\n",
    "if DIFF_METHOD == \"unrolled\":\n",
    "    loss_fn = lambda theta: squareloss_fn(\n",
    "        forward_solve(theta, A, rhs_fn, x, N_INNER_MAX, SOLVER, U_INIT, False),\n",
    "        u_ref \n",
    "    )\n",
    "elif DIFF_METHOD == \"implicit\":\n",
    "    @jax.jit\n",
    "    def loss_fn(theta):\n",
    "        rhs_eval = rhs_fn(theta, x)\n",
    "        u_history = u_solve(rhs_eval, N_INNER_MAX, U_INIT)\n",
    "        return squareloss_fn(u_history, u_ref)\n",
    "\n",
    "# test\n",
    "print(f\"loss at theta_init at refinement level 100 = {loss_fn(THETA_INIT)[101]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Different constant n_inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_hist, loss_hist = outer_optimisation(\n",
    "    loss_fn,\n",
    "    n_outer_iterations=N_OUTER,\n",
    "    outer_lr=LR,\n",
    "    init=THETA_INIT,\n",
    "    n_inner_iterations=N_INNER_MAX\n",
    ")\n",
    "assert theta_hist.shape == loss_hist.shape == (N_OUTER+1, N_INNER_MAX+1)\n",
    "\n",
    "theta_rel_error = jnp.abs(theta_hist - THETA_REF) / jnp.abs(THETA_REF)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Line plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_outer_history_2(\n",
    "    loss_hist,\n",
    "    theta_rel_error,\n",
    "    \"Loss\",\n",
    "    \"Parameter relative error\",\n",
    "    log_flag = True,\n",
    "    iterations_step = 50,\n",
    "    suptitle=f\"Poisson 1 param, Solver: {SOLVER}, Diff method: {DIFF_METHOD}\",\n",
    "    is_prdp_included=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Manually Scheduled n_inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_MIN, N_STEP, N_MAX = 25, None, N_INNER_MAX\n",
    "\n",
    "hist_dynamic = outer_optimisation_dynamic_inner(\n",
    "    loss_fn,\n",
    "    n_outer_iterations=N_OUTER,\n",
    "    lr = LR,\n",
    "    theta_init = THETA_INIT,\n",
    "    n_inner_init = N_MIN,\n",
    "    n_inner_max = N_MAX,\n",
    "    n_step = N_STEP, # change inside the function itself\n",
    "    min_theta_change=0\n",
    ")\n",
    "theta_hist_dynamic, loss_hist_dynamic, n_inner_hist_dynamic = hist_dynamic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_inner_hist_dynamic_clipped = jnp.clip(n_inner_hist_dynamic, 0, N_MAX)\n",
    "theta_rel_error_dynamic = jnp.abs(theta_hist_dynamic - THETA_REF) / jnp.abs(THETA_REF)\n",
    "\n",
    "theta_rel_error_all = jnp.hstack([theta_rel_error, theta_rel_error_dynamic.reshape(-1,1)])\n",
    "loss_hist_all       = jnp.hstack([loss_hist, loss_hist_dynamic.reshape(-1,1)])\n",
    "\n",
    "fig_outer = plot_outer_history_2(\n",
    "    loss_hist_all,\n",
    "    theta_rel_error_all,\n",
    "    \"Loss\",\n",
    "    \"Parameter relative error\",\n",
    "    log_flag = True,\n",
    "    iterations_step = 50,\n",
    "    suptitle=f\"Poisson 1 param, Solver: {SOLVER}, Diff = {DIFF_METHOD}\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig_outer.savefig(f\"figures/poisson_1_param_{SOLVER}_unrolled__training_manual_PRDP.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot for PR savings section in paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create color map\n",
    "norm = Normalize(vmin=1, vmax=N_INNER_MAX)\n",
    "cmap = cm.get_cmap(\"plasma_r\")\n",
    "PRDP_COLOR = \"tab:green\"\n",
    "\n",
    "fig_pr_savings, axs = plt.subplots(1,2, figsize=(WIDTH*0.5, HEIGHT*0.9))\n",
    "\n",
    "# plot errors for all n_inner\n",
    "for i in np.arange(0, N_INNER_MAX+1, 50):\n",
    "    color = cmap(norm(i+1))\n",
    "    axs[0].plot(theta_rel_error_all[:,i], linewidth=1.0, color=color)\n",
    "# plot errors for scheduled n_inner\n",
    "axs[0].plot(theta_rel_error_all[:, -1], linewidth=1.0, color=PRDP_COLOR)\n",
    "axs[0].set_yscale(\"log\")\n",
    "axs[0].set_xlabel(\"Outer iteration\")\n",
    "axs[0].set_ylabel(\"Parameter relative error\")\n",
    "axs[0].grid()\n",
    "axs[0].set_xlim(0, None)\n",
    "\n",
    "# plot n_inner vs outer iteration\n",
    "K_epsilon = 600\n",
    "n_outer_for_k_epsilon = 125\n",
    "converged_n_inner = np.full((n_outer_for_k_epsilon,), N_INNER_MAX)\n",
    "\n",
    "n_outer_for_prdp = 145\n",
    "prdp_n_inner = n_inner_hist_dynamic_clipped[:n_outer_for_prdp]\n",
    "\n",
    "axs[1].plot(converged_n_inner, label=\"Fully converged\", color=cmap(norm(K_epsilon)))\n",
    "axs[1].plot(prdp_n_inner, label=\"Scheduled\", color=PRDP_COLOR)\n",
    "axs[1].fill_between(np.arange(n_outer_for_k_epsilon), prdp_n_inner[:n_outer_for_k_epsilon], converged_n_inner[:n_outer_for_k_epsilon], color=PRDP_COLOR, alpha=0.2, label=\"PR savings\")\n",
    "axs[1].set_xlabel(\"Outer iteration\")\n",
    "axs[1].set_ylabel(\"# physics solver iterations\")\n",
    "axs[1].grid()\n",
    "axs[1].set_xlim(0, None)\n",
    "\n",
    "# Add colorbar and legend\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = fig_pr_savings.colorbar(\n",
    "    sm,\n",
    "    ax=axs,\n",
    "    orientation=\"horizontal\",\n",
    "    location='top',\n",
    "    shrink=0.4,\n",
    "    anchor=(0, 2)\n",
    ")\n",
    "cbar.set_label(\"# physics solver iterations\")\n",
    "cbar.set_ticks(ticks=[0, 300, 600])\n",
    "fig_pr_savings.legend(loc='lower left', bbox_to_anchor=(0.65, 0.73), frameon=False)\n",
    "\n",
    "fig_pr_savings.set_layout_engine(\"compressed\")\n",
    "# fig_pr_savings.savefig(f\"result_plots/pr_savings_poisson_1_param_{SOLVER}_unrolled.pdf\", bbox_inches='tight', pad_inches=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot n_inner vs. n_outer\n",
    "N_CONVERGENCE = 600\n",
    "N_OUTER_TOL = 125\n",
    "n_convergence_array = jnp.full((N_OUTER_TOL,), N_CONVERGENCE)\n",
    "n_inner_hist_dynamic_tol = n_inner_hist_dynamic_clipped[:145]\n",
    "\n",
    "# calculate savings\n",
    "converged_cost = jnp.sum(n_convergence_array)\n",
    "dynamic_cost = jnp.sum(n_inner_hist_dynamic_tol)\n",
    "print(f\"converged_cost = {converged_cost}, dynamic_cost = {dynamic_cost}\")\n",
    "print(f\"savings = {1 - dynamic_cost / converged_cost}\")\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "plt.plot(n_inner_hist_dynamic_tol, label=\"Scheduled (progressively refined)\", color=\"green\")\n",
    "# plt.axhline(N_CONVERGENCE, color=\"red\", linestyle=\"--\", label=\"Fully converged\")\n",
    "plt.plot(n_convergence_array, color=\"red\", linestyle=\"--\", label=\"Fully converged\")\n",
    "plt.fill_between(range(len(n_inner_hist_dynamic_tol)), n_inner_hist_dynamic_tol, N_CONVERGENCE, where=(n_inner_hist_dynamic_tol < N_CONVERGENCE), color='green', alpha=0.2, label = \"Progressively refinement (PR) savings\")\n",
    "plt.xlabel(\"# outer iterations\")\n",
    "plt.ylabel(\"# inner iterations\")\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.xlim(0, 145)\n",
    "# plt.savefig(f\"figures/poisson_1_param_{SOLVER}_unrolled__cost_manual_PRDP.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PRDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def should_refine(error_hist, stepping_threshold = 0.98, nmax_threshold = 0.9, error_window = 3):\n",
    "\n",
    "    assert error_window > 1, \"error_window should be greater than 1\"\n",
    "    \n",
    "    if len(error_hist) <= error_window:\n",
    "        return False\n",
    "    \n",
    "    error_hist = np.array(error_hist)\n",
    "    error_ratio = error_hist[-1] / error_hist[-error_window]\n",
    "    \n",
    "    if error_ratio > stepping_threshold: # implies plateuing of error history\n",
    "        checkpoint_ratio = error_hist[-1] / should_refine.error_checkpoint\n",
    "        \n",
    "        if checkpoint_ratio < nmax_threshold or checkpoint_ratio > 1: \n",
    "            should_refine.error_checkpoint = error_hist[-1]\n",
    "            return True\n",
    "        else:\n",
    "            return False # no improvement against checkpoint => reached Nmax\n",
    "    else:\n",
    "        return False # error history hasn't plateued => continue using current N"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if DIFF_METHOD == \"unrolled\":\n",
    "    @partial(jax.jit, static_argnames=['n_inner'])\n",
    "    def loss_fn_prdp(theta, n_inner):\n",
    "        u = forward_solve(theta, A, rhs_fn, x, n_inner, SOLVER, U_INIT, False)[-1]\n",
    "        return squareloss_fn(u, u_ref)\n",
    "elif DIFF_METHOD == \"implicit\":\n",
    "    @partial(jax.jit, static_argnames=['n_inner'])\n",
    "    def loss_fn_prdp(theta, n_inner):\n",
    "        rhs_eval = rhs_fn(theta, x)\n",
    "        u = u_solve(rhs_eval, n_inner, U_INIT)[-1]\n",
    "        return squareloss_fn(u, u_ref)\n",
    "\n",
    "# test\n",
    "print(f\"loss at theta_init at refinement level 100 = {loss_fn_prdp(THETA_INIT, 101)}\")\n",
    "\n",
    "def update_fn(theta, n_inner, lr):\n",
    "    theta_new = theta - lr * jacfwd(loss_fn_prdp)(theta, n_inner)\n",
    "    loss = loss_fn_prdp(theta_new, n_inner)\n",
    "    return loss, theta_new\n",
    "\n",
    "def theta_error_fn(theta):\n",
    "    return jnp.abs(theta - THETA_REF) / jnp.abs(THETA_REF)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PRDP params\n",
    "N_INNER_INIT, N_INNER_STEP = 25, 10\n",
    "TAU_STEP, TAU_STOP = 0.92, 0.98\n",
    "ERROR_WINDOW = 2\n",
    "\n",
    "# initialize model to be trained\n",
    "theta = THETA_INIT\n",
    "\n",
    "# initialize physics refinement\n",
    "n_inner = N_INNER_INIT\n",
    "\n",
    "# initialize metrics\n",
    "loss_hist_prdp = [loss_fn_prdp(theta, n_inner)]\n",
    "theta_error_hist_prdp = [theta_error_fn(theta)]\n",
    "n_inner_hist_prdp = [N_INNER_INIT] \n",
    "\n",
    "# initialize PRDP's checkpoint error\n",
    "should_refine.error_checkpoint = 100\n",
    "\n",
    "# training loop\n",
    "for epoch in range(N_OUTER):\n",
    "    # update model\n",
    "    loss, theta = update_fn(theta, int(n_inner), LR)\n",
    "    loss_hist_prdp.append(loss)\n",
    "    theta_error_hist_prdp.append(theta_error_fn(theta))\n",
    "    n_inner_hist_prdp.append(int(n_inner))\n",
    "    print(f\"Epoch {epoch+1}/{N_OUTER}, n_inner  {int(n_inner)}, loss = {loss}, theta rel error = {theta_error_hist_prdp[-1]}\")\n",
    "\n",
    "    # PRDP\n",
    "    if should_refine(np.array(theta_error_hist_prdp), \n",
    "                     stepping_threshold=TAU_STEP,\n",
    "                     error_window=ERROR_WINDOW,\n",
    "    ):\n",
    "        n_inner = min(n_inner + N_INNER_STEP, N_INNER_MAX)\n",
    "\n",
    "theta_error_hist_prdp = np.array(theta_error_hist_prdp)\n",
    "loss_hist_prdp = np.array(loss_hist_prdp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_rel_error_all = np.hstack([theta_rel_error, theta_error_hist_prdp.reshape(-1,1)])\n",
    "loss_history_all = np.hstack([loss_hist, loss_hist_prdp.reshape(-1,1)])\n",
    "\n",
    "fig_outer_hist = plot_outer_history_2(\n",
    "    loss_history_all, \n",
    "    theta_rel_error_all, \n",
    "    \"loss\", \n",
    "    \"parameter relative error\", \n",
    "    log_flag=True, \n",
    "    iterations_step=50,\n",
    "    suptitle=f\"Poisson 3 parameters. Solver: {SOLVER}, Diff: {DIFF_METHOD}\", \n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### PRDP plot for paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# create color map\n",
    "norm = Normalize(vmin=1, vmax=N_INNER_MAX)\n",
    "cmap = cm.get_cmap(\"plasma_r\")\n",
    "PRDP_COLOR = \"tab:green\"\n",
    "\n",
    "fig_pr_savings, axs = plt.subplots(1,2, figsize=(WIDTH*0.5, HEIGHT*0.9))\n",
    "\n",
    "# plot errors for all constant n_inner\n",
    "for i in np.arange(0, N_INNER_MAX+1, 50):\n",
    "    color = cmap(norm(i+1))\n",
    "    axs[0].plot(theta_rel_error_all[:,i], linewidth=1.0, color=color)\n",
    "# plot errors for prdp n_inner\n",
    "axs[0].plot(theta_rel_error_all[:, -1], linewidth=1.0, color=PRDP_COLOR)\n",
    "axs[0].set_yscale(\"log\")\n",
    "axs[0].set_xlabel(\"Outer iteration\")\n",
    "axs[0].set_ylabel(\"Parameter relative error\")\n",
    "axs[0].grid()\n",
    "axs[0].set_xlim(0, None)\n",
    "\n",
    "# plot n_inner vs outer iteration\n",
    "K_epsilon = 600\n",
    "n_outer_for_k_epsilon = 125\n",
    "converged_n_inner = np.full((n_outer_for_k_epsilon,), N_INNER_MAX)\n",
    "cost_converged = K_epsilon * n_outer_for_k_epsilon\n",
    "print(f\"cost_converged = {cost_converged}\")\n",
    "\n",
    "n_outer_for_prdp = 170\n",
    "prdp_n_inner = n_inner_hist_prdp[:n_outer_for_prdp]\n",
    "cost_prdp = np.sum(prdp_n_inner)\n",
    "print(f\"cost_prdp = {cost_prdp}\")\n",
    "print(f\"savings = {1 - cost_prdp / cost_converged}\")\n",
    "\n",
    "axs[1].plot(converged_n_inner, label=\"Fully converged\", color=cmap(norm(K_epsilon)))\n",
    "axs[1].plot(prdp_n_inner, label=\"PRDP\", color=PRDP_COLOR)\n",
    "axs[1].fill_between(np.arange(n_outer_for_k_epsilon), prdp_n_inner[:n_outer_for_k_epsilon], converged_n_inner[:n_outer_for_k_epsilon], color=PRDP_COLOR, alpha=0.2, label=\"PR savings\")\n",
    "axs[1].set_xlabel(\"Outer iteration\")\n",
    "axs[1].set_ylabel(\"# physics solver iterations\")\n",
    "axs[1].grid()\n",
    "axs[1].set_xlim(0, None)\n",
    "axs[1].set_xticks([0, n_outer_for_k_epsilon, n_outer_for_prdp])\n",
    "axs[0].set_xticks([0, n_outer_for_k_epsilon, n_outer_for_prdp])\n",
    "\n",
    "# Add colorbar and legend\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = fig_pr_savings.colorbar(\n",
    "    sm,\n",
    "    ax=axs,\n",
    "    orientation=\"horizontal\",\n",
    "    location='top',\n",
    "    shrink=0.4,\n",
    "    anchor=(0, 2)\n",
    ")\n",
    "cbar.set_label(\"# physics solver iterations\")\n",
    "cbar.set_ticks(ticks=[0, 300, 600])\n",
    "fig_pr_savings.legend(loc='lower left', bbox_to_anchor=(0.65, 0.73), frameon=False)\n",
    "\n",
    "fig_pr_savings.set_layout_engine(\"compressed\")\n",
    "# fig_pr_savings.savefig(f\"result_plots/poisson_1_param_{SOLVER}_{DIFF_METHOD}.pdf\", bbox_inches='tight', pad_inches=0.01)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
