{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a725b829-475c-47ca-9517-12b8f402e8ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "\n",
    "from scipy.interpolate import RegularGridInterpolator\n",
    "from scipy.stats import binned_statistic_2d\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "from utils import *\n",
    "\n",
    "# === Visualization ===\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import FancyArrowPatch, ConnectionPatch\n",
    "from matplotlib.collections import LineCollection\n",
    "from matplotlib.colors import ListedColormap, Normalize\n",
    "from matplotlib.ticker import LogLocator, MaxNLocator, NullLocator\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import matplotlib.patches as patches\n",
    "import matplotlib.colors as mcolors\n",
    "from matplotlib.path import Path\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "sns.set(\n",
    "    context='paper',\n",
    "    font_scale=2.2,\n",
    "    font=\"Arial\",\n",
    "    palette='deep',\n",
    "    style='ticks',\n",
    "    color_codes=True,\n",
    "    rc={\n",
    "        \"mathtext.fontset\": \"cm\",\n",
    "        \"axes.linewidth\": 2,\n",
    "        \"font.size\": 16,\n",
    "        \"figure.dpi\": 100,\n",
    "        \"text.usetex\": False,\n",
    "        \"lines.linewidth\": 2,\n",
    "        \"axes.labelpad\": 0,\n",
    "        \"xtick.direction\": \"in\",\n",
    "        \"ytick.direction\": \"in\",\n",
    "        \"xtick.major.size\": 6,\n",
    "        \"ytick.major.size\": 6,\n",
    "        \"xtick.major.width\": 3,\n",
    "        \"ytick.major.width\": 3,\n",
    "        \"xtick.minor.size\": 3,\n",
    "        \"ytick.minor.size\": 3,\n",
    "    }\n",
    ")\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cd7b38c-0fa0-46bf-acac-9cf08acb4758",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ========== 1. Create grid and loss surface ==========\n",
    "x = np.linspace(-4, 4, 300)\n",
    "y = np.linspace(-3, 3, 300)\n",
    "X, Y = np.meshgrid(x, y)\n",
    "\n",
    "# Define loss surface components\n",
    "central_peak = 1.4 * np.exp(-(((X + 0.2)**2 + (Y + 0.4)**2) / 1))\n",
    "valley_main = 1 - np.exp(-((Y - 0.1 + 0.15 * X)**2) / 1)\n",
    "valley_second = -np.exp(-((X - 2)**2) / 0.6) * np.exp(-((Y + 1 + 0.4 * X)**2) / 1)\n",
    "\n",
    "# Combine and normalize\n",
    "Z = central_peak + (valley_main + valley_second)\n",
    "Z = (Z - Z.min()) / (Z.max() - Z.min())\n",
    "\n",
    "# ========== 2. Define closed- and open-loop paths ==========\n",
    "x_closed = np.array([-3, -3, -2.7, -2, -1.5, -1, -0.7, -0.4, 0, 0.3, 0.8, 1.1, 1.3, 1.6, 1.8, 2])\n",
    "y_closed = np.array([2.1, 1.8, 1.2, 1.15, 1.14, 1.13, 1.05, 1.0, 0.9, 0.8, 0.7, 0.55, 0.35, -0.1, -0.5, -1.6])\n",
    "\n",
    "x_open = np.array([-3, -3, -2.7, -2.5, -2.3, -2, -1.5, -1, -0.7, -0.4, 0, 0.3, 0.8, 1.1, 1.3, 1.6, 2])\n",
    "y_open = np.array([2.1, 1.8, 1.2, 1.1, 1.0, 0.85, 0.65, 0.2, -0.1, -0.4, -0.6, -0.7, -0.8, -0.85, -0.9, -1.1, -1.6])\n",
    "\n",
    "# Interpolate Z values for path points\n",
    "interp_func = RegularGridInterpolator((x, y), Z.T)\n",
    "z_closed = interp_func(np.column_stack((x_closed, y_closed)))\n",
    "z_open = interp_func(np.column_stack((x_open, y_open)))\n",
    "\n",
    "# ========== 3. Plotting ==========\n",
    "fig = plt.figure(figsize=(7, 4.5))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "\n",
    "# -- Background --\n",
    "ax.set_facecolor('white')\n",
    "fig.patch.set_facecolor('white')\n",
    "\n",
    "# -- Surface plot --\n",
    "surf = ax.plot_surface(X, Y, Z, cmap='mako_r', alpha=0.8, edgecolor='none', zorder=0)\n",
    "\n",
    "# -- Trajectories --\n",
    "ax.plot(x_closed, y_closed, z_closed, color=sns.color_palette('deep')[0], lw=3, label=r'$\\text{Closed-Loop}$', zorder=20)\n",
    "ax.plot(x_open, y_open, z_open, color=sns.color_palette('deep')[3], lw=3, label=r'$\\text{Open-Loop}$', zorder=20)\n",
    "\n",
    "# -- Start/end points for closed-loop path --\n",
    "ax.plot(x_closed[[0, -1]], y_closed[[0, -1]], z_closed[[0, -1]], \n",
    "        color='k', marker='o', lw=0, ms=5, zorder=10)\n",
    "\n",
    "# ========== 4. Axis styling ==========\n",
    "ax.set_xlabel(r'$\\theta_2$', fontsize=30, labelpad=0)\n",
    "ax.set_ylabel(r'$\\theta_1$', fontsize=30, labelpad=0)\n",
    "ax.set_zlabel('')\n",
    "ax.set_zticks([]); ax.set_zticklabels([])\n",
    "ax.set_xticks([]); ax.set_xticklabels([])\n",
    "ax.set_yticks([]); ax.set_yticklabels([])\n",
    "ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))  # Hide z-axis line\n",
    "\n",
    "# Remove 3D pane shading\n",
    "for pane in [ax.xaxis.pane, ax.yaxis.pane, ax.zaxis.pane]:\n",
    "    pane.set_edgecolor('w')\n",
    "    pane.set_alpha(0)\n",
    "\n",
    "# ========== 5. Annotations ==========\n",
    "# Start and end text\n",
    "ax.text(x_closed[0] - 0.3, y_closed[0] + 0.4, z_closed[0], r'$\\text{Start}$', fontsize=22, color='k', zorder=30)\n",
    "ax.text(x_closed[-1] + 0.1, y_closed[-1] - 0.2, z_closed[-1], r'$\\text{End}$', fontsize=22, color='k', zorder=30)\n",
    "\n",
    "# View settings\n",
    "ax.view_init(elev=80, azim=240)\n",
    "ax.grid(False)\n",
    "\n",
    "# Limits\n",
    "ax.set_xlim(x.min(), x.max())\n",
    "ax.set_ylim(y.min() + 0.2, y.max())\n",
    "\n",
    "# ========== 6. Export ==========\n",
    "plt.tight_layout()\n",
    "# plt.savefig('../figs/1_b.png', bbox_inches='tight', dpi=200)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef0ff057-f512-47a2-98ab-1cb34b9eafe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data \n",
    "with open('../data/non_linear_closed.pkl', 'rb') as f:\n",
    "    loss, test_loss = pickle.load(f)\n",
    "\n",
    "with open('../data/non_linear_open.pkl', 'rb') as f:\n",
    "    loss2, test_loss2 = pickle.load(f)\n",
    "\n",
    "# --- Setup figure and subplots ---\n",
    "fig, (ax, ax1) = plt.subplots(\n",
    "    2, 1,\n",
    "    figsize=(9, 8.5),\n",
    "    sharex=False,\n",
    "    sharey=True,\n",
    "    gridspec_kw={'hspace': .4}\n",
    ")\n",
    "\n",
    "# ========== Upper plot: Closed-loop test & train loss ==========\n",
    "# -- Plot test loss curve (multi-colored by stage) --\n",
    "x = np.arange(1000)\n",
    "y = test_loss\n",
    "\n",
    "colors = sns.color_palette('mako_r')\n",
    "cmap = ListedColormap([colors[0], colors[2], colors[4]])\n",
    "\n",
    "points = np.array([x, y]).T.reshape(-1, 1, 2)\n",
    "segments = np.concatenate([points[:-1], points[1:]], axis=1)\n",
    "stage_index = np.zeros_like(x[:-1])\n",
    "stage_index[:8] = 0\n",
    "stage_index[8:129] = 1\n",
    "stage_index[129:] = 2\n",
    "\n",
    "# lc = LineCollection(segments, cmap=cmap, norm=Normalize(0, 2), linewidth=3)\n",
    "# lc.set_array(stage_index)\n",
    "# ax.add_collection(lc)\n",
    "ax.plot(x, np.array(test_loss)[:1000], color=sns.color_palette('deep')[0], lw=3)\n",
    "\n",
    "# -- Plot training loss on twin axis --\n",
    "ax_tw = ax.twinx()\n",
    "ax_tw.plot(x, np.array(loss)[:1000], '--k', lw=3)\n",
    "\n",
    "# -- Background stage shading --\n",
    "ax.axvspan(0, 12, color=colors[0], alpha=0.2)\n",
    "ax.axvspan(12, 110, color=colors[2], alpha=0.2)\n",
    "ax.axvspan(110, 1000, color=colors[4], alpha=0.2)\n",
    "\n",
    "# -- Stage labels and lines --\n",
    "ax.axvline(x=12, color='grey', ls='--')\n",
    "ax.axvline(x=110, color='grey', ls='--')\n",
    "ax.text(1.5, 10**5.4, r'$\\text{Stage 1}$', size=28)\n",
    "ax.text(19.1, 10**5.4, r'$\\text{Stage 2}$', size=28)\n",
    "ax.text(180, 10**5.4, r'$\\text{Stage 3}$', size=28)\n",
    "\n",
    "# -- Axis scaling and limits --\n",
    "for axis in [ax, ax_tw]:\n",
    "    axis.set_xscale('log')\n",
    "    axis.set_yscale('log')\n",
    "    axis.set_xlim(0.7, 1005)\n",
    "    axis.set_ylim(1, 10**6.5)\n",
    "\n",
    "ax.yaxis.set_major_locator(LogLocator(base=10.0, numticks=1))\n",
    "ax_tw.yaxis.set_major_locator(LogLocator(base=10.0, numticks=1))\n",
    "ax_tw.yaxis.set_tick_params(labelleft=False)\n",
    "ax_tw.xaxis.set_minor_locator(NullLocator())\n",
    "ax.xaxis.set_minor_locator(NullLocator())\n",
    "\n",
    "# -- Legend --\n",
    "ax.legend(\n",
    "    handles=[\n",
    "        Line2D([0], [0], color='k', linestyle='--', lw=4),\n",
    "        Line2D([0], [0], color=sns.color_palette('deep')[0], linestyle='-', lw=4)\n",
    "    ],\n",
    "    labels=[r'$\\text{Train}$', r'$\\text{Test}$'],\n",
    "    loc='lower left',\n",
    "    frameon=False,\n",
    "    fontsize=27\n",
    ")\n",
    "\n",
    "# ========== Lower plot: Open-loop ==========\n",
    "ax1.plot(x, test_loss2, ls='-', color=sns.color_palette('deep')[3], lw=4, label=r'$\\text{Open-Loop}$')\n",
    "ax1.axvspan(0, 1000, color='grey', alpha=0.05)\n",
    "\n",
    "# -- Twin y-axis for training loss --\n",
    "ax1_tw = ax1.twinx()\n",
    "ax1_tw.plot(x, np.array(loss2[:1000]), ls='--', color='k', lw=3)\n",
    "ax1_tw.set_xscale('log')\n",
    "ax1_tw.set_yscale('log')\n",
    "ax1_tw.set_yticks([10**0, 10**2])\n",
    "ax1_tw.yaxis.set_minor_locator(NullLocator())\n",
    "\n",
    "# -- Labels --\n",
    "ax1.text(2000, 10**5.2, r'$\\text{Train}\\,\\,$' + r'$\\text{Loss}$', size=32, rotation=90)\n",
    "ax1.text(0.25, 10**5.2, r'$\\text{Test}\\,\\,$' + r'$\\text{Loss}$', size=32, rotation=90)\n",
    "\n",
    "# -- Axis config --\n",
    "ax1.set_xlim(-1, 1010)\n",
    "ax1.set_xlabel(r'$\\text{Epoch}$', size=28, labelpad=-8)\n",
    "ax.set_xlabel(r'$\\text{Epoch}$', size=28, labelpad=-8)\n",
    "ax.set_yticks([10, 10**5])\n",
    "ax_tw.set_yticks([10, 10**5])\n",
    "\n",
    "# -- Font sizes --\n",
    "for a in [ax, ax_tw, ax1, ax1_tw]:\n",
    "    a.tick_params(axis='both', labelsize=24)\n",
    "\n",
    "ax1.tick_params(axis='both', labelsize=26)\n",
    "\n",
    "# -- Legend --\n",
    "ax1.legend(\n",
    "    handles=[\n",
    "        Line2D([0], [0], color='k', linestyle='--', lw=4),\n",
    "        Line2D([0], [0], color=sns.color_palette('deep')[3], linestyle='-', lw=4)\n",
    "    ],\n",
    "    labels=[r'$\\text{Train}$', r'$\\text{Test} \\,\\text{(closed-loop)}$'],\n",
    "    loc='lower left',\n",
    "    frameon=False,\n",
    "    fontsize=27\n",
    ")\n",
    "\n",
    "# ========== Cleanup and display ==========\n",
    "sns.despine()\n",
    "ax.spines['right'].set_visible(True)\n",
    "ax1_tw.spines['right'].set_visible(True)\n",
    "\n",
    "for axx in [ax, ax1]:\n",
    "    for label in axx.get_xticklabels():\n",
    "        label.set_y(-0.01)\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig('../figs/1c.pdf', bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b906472f-7a7a-4cd7-8e1f-d154b83ba6a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ========== 1. Evaluate loss over LQR gain grid ==========\n",
    "\n",
    "dt = 1\n",
    "x_target = np.array([0.0, 0.0])\n",
    "A = np.array([[1, dt], [0, 1]])\n",
    "B = np.array([[0], [dt]])\n",
    "\n",
    "Q = np.eye(2)\n",
    "R = np.array([[1e-8]])  # Small R to avoid divide-by-zero in Riccati\n",
    "\n",
    "P = scipy.linalg.solve_discrete_are(A, B, Q, R)\n",
    "K = np.linalg.inv(B.T @ P @ B + R) @ (B.T @ P @ A)\n",
    "K_LQR = K.copy()\n",
    "\n",
    "all_k0 = np.linspace(-1.2, .6, 50)\n",
    "all_k1 = np.linspace(-1.2, .6, 50)\n",
    "\n",
    "kk0, kk1, zzz = [], [], []\n",
    "\n",
    "for k0 in all_k0:\n",
    "    for k1 in all_k1:\n",
    "        K[0][0] = k0\n",
    "        K[0][1] = k1\n",
    "\n",
    "        num_steps = int(50 / dt)\n",
    "        x = np.array([1.0, 1.0])\n",
    "        loss = 0.0\n",
    "\n",
    "        for _ in range(num_steps):\n",
    "            u = K @ (x - x_target)\n",
    "            x = A @ x + B.flatten() * u\n",
    "            loss += np.sum((x - x_target)**2)\n",
    "\n",
    "        kk0.append(k0)\n",
    "        kk1.append(k1)\n",
    "        zzz.append(np.log(loss / num_steps))\n",
    "\n",
    "# ========== 2. Load learned gains from saved RNN controllers ==========\n",
    "with open('../data/non_linear_all_k_eff.pkl', 'rb') as f:\n",
    "    all_k = pickle.load(f)\n",
    "\n",
    "\n",
    "# ========== 3. Stability classification over (k₁, k₂) ==========\n",
    "\n",
    "all_a = np.linspace(-1.2, 0.6, 200)\n",
    "all_b = np.linspace(-1.2, 0.6, 200)\n",
    "all_stability = []\n",
    "\n",
    "for a in all_a:\n",
    "    for b in all_b:\n",
    "        P_red = np.array([[1, 1], [a, 1 + b]])\n",
    "        eigvals = np.linalg.eigvals(P_red)\n",
    "\n",
    "        if np.all(np.abs(eigvals) < 1):\n",
    "            label = 2  # Stable\n",
    "        elif np.any((np.abs(eigvals) > 1) & (np.imag(eigvals) != 0)):\n",
    "            label = 1  # Unstable (oscillatory)\n",
    "        else:\n",
    "            label = 0  # Unstable (real)\n",
    "\n",
    "        all_stability.append(label)\n",
    "\n",
    "X, Y = np.meshgrid(all_a, all_b)\n",
    "Z = np.array(all_stability).reshape(len(all_b), len(all_a)).T\n",
    "\n",
    "# ========== 4. Plotting: Loss, Trajectories, Stability ==========\n",
    "fig, ax = plt.subplots(1, 3, figsize=(15, 4.8))\n",
    "\n",
    "# === Panel (a): Test Loss ===\n",
    "ax[0].plot(np.arange(1000), test_loss, color=sns.color_palette('deep')[0], lw=3, label=r'$\\text{Closed-Loop}$')\n",
    "ax[0].plot(np.arange(1000), test_loss2, color=sns.color_palette('deep')[3], lw=3, label=r'$\\text{Open-Loop}$')\n",
    "ax[0].set(xscale='log', yscale='log', xlabel=r'$\\text{Epoch}$', ylabel=r\"$\\text{Loss}$\")\n",
    "ax[0].set_xticks([1, 10, 100, 1000])\n",
    "ax[0].set_yticks([10, 1000, 100_000])\n",
    "ax[0].legend(fontsize=18, loc='lower left', frameon=False)\n",
    "ax[0].xaxis.set_minor_locator(NullLocator()); ax[0].yaxis.set_minor_locator(NullLocator())\n",
    "\n",
    "# === Panel (b): Loss + Trajectories ===\n",
    "im = ax[1].imshow(np.array(zzz).reshape(50, 50).T, extent=[-1.2, .6, -1.2, .6], origin=\"lower\",\n",
    "                  aspect=\"auto\", cmap=\"mako_r\", alpha=0.8)\n",
    "for i in range(len(all_k)):\n",
    "    ax[1].plot(all_k[i,:,0], all_k[i,:,1], lw=2, linestyle='-', color=sns.color_palette('deep')[i * 3])\n",
    "ax[1].contour(X, Y, Z, levels=[0.5, 1.5], colors='k', linewidths=2)\n",
    "\n",
    "# Region annotations\n",
    "ax[1].text(0.05, -0.02, r'$\\text{Start}$', fontsize=17, color='k', zorder=30)\n",
    "ax[1].text(-0.34, -1.1, r'$\\text{End}$', fontsize=17, color='k', zorder=30)\n",
    "ax[1].text(-0.95, 0.01, r\"$\\text{Unstable}$\"+'\\n'+r\"$\\text{(osc.)}$\", fontsize=19, weight='bold', color='k', zorder=30)\n",
    "ax[1].text(0.05, -0.9, r\"$\\text{Unstable}$\"+'\\n'+r\"$\\text{(real)}$\", fontsize=19, weight='bold', color='k', zorder=30)\n",
    "ax[1].text(-0.95, -1.1, r\"$\\text{Stable}$\", fontsize=19, weight='bold', color='k', zorder=30)\n",
    "ax[1].set(xlabel=r\"$k_1$\", ylabel=r\"$k_2$\", xticks=[-.9, 0, .5], yticks=[-.9, 0, .5])\n",
    "\n",
    "# === Panel (c): Zoom-In ===\n",
    "im = ax[2].imshow(np.array(zzz).reshape(50, 50).T, extent=[-1.2, .6, -1.2, .6], origin=\"lower\",\n",
    "                  aspect=\"auto\", cmap=\"mako_r\", alpha=0.8)\n",
    "cbar = fig.colorbar(im, ax=ax[2], shrink=0.9)\n",
    "cbar.set_label(r'$\\text{Loss (log)}$', fontsize=18); cbar.ax.tick_params(labelsize=16)\n",
    "for i in range(len(all_k)):\n",
    "    ax[2].plot(all_k[i,:,0], all_k[i,:,1], '.-', lw=1, ms=4, color=sns.color_palette('deep')[i * 3])\n",
    "ax[2].set(xlim=(-0.06, 0.02), ylim=(-0.06, 0.04), xlabel=r\"$k_1$\", xticks=[-.04, 0], yticks=[-.04, 0, .04])\n",
    "\n",
    "# === Zoom box and connectors ===\n",
    "x0, x1 = -0.06, 0.02; y0, y1 = -0.06, 0.04\n",
    "zoom_box = patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor='black', facecolor='none', linestyle='--')\n",
    "ax[1].add_patch(zoom_box)\n",
    "fig.add_artist(ConnectionPatch(xyA=(0, 1), coordsA=\"axes fraction\", xyB=(x0, y1), coordsB=\"data\",\n",
    "                               axesA=ax[2], axesB=ax[1], arrowstyle=\"-\", lw=1, linestyle='--', color=\"black\"))\n",
    "fig.add_artist(ConnectionPatch(xyA=(0, 0), coordsA=\"axes fraction\", xyB=(x0, y0), coordsB=\"data\",\n",
    "                               axesA=ax[2], axesB=ax[1], arrowstyle=\"-\", lw=1, linestyle='--', color=\"black\"))\n",
    "\n",
    "# === Arrows: trajectory movement ===\n",
    "ax[1].add_patch(FancyArrowPatch((-0.06, -0.3), (-.27, -.85), connectionstyle=\"arc3,rad=0.0\",\n",
    "                                arrowstyle=\"-|>\", lw=2, color=sns.color_palette('deep')[0], mutation_scale=15))\n",
    "ax[1].add_patch(FancyArrowPatch((-0.3, -.1), (-.5, -.85), connectionstyle=\"arc3,rad=0.6\",\n",
    "                                arrowstyle=\"-|>\", lw=2, color=sns.color_palette('deep')[3], mutation_scale=15))\n",
    "\n",
    "# === Markers and Annotations ===\n",
    "ax[0].scatter(24, test_loss[24], marker='o', color='silver', s=150, zorder=30, edgecolors='k', linewidths=1.5)\n",
    "ax[0].scatter(110, test_loss[110], marker='o', color='gold', s=150, zorder=30, edgecolors='k', linewidths=1.5)\n",
    "ax[2].scatter(all_k[0, 24, 0]-0.001, all_k[0, 24, 1], color='silver', s=150, zorder=30, edgecolors='k', linewidths=1.5)\n",
    "ax[2].scatter(all_k[0, 110, 0]-0.0006, all_k[0, 110, 1]+0.001, color='gold', s=150, zorder=30, edgecolors='k', linewidths=1.5)\n",
    "\n",
    "# === Stage text (loss panel) ===\n",
    "ax[0].annotate(r'$\\text{Stage 1}$', xy=(5, test_loss[5]+1e3), xytext=(2.1, test_loss[5]+1e5),\n",
    "               arrowprops=dict(arrowstyle='-|>', color='k', lw=1.5), fontsize=18, color='k', zorder=30)\n",
    "ax[0].annotate(r'$\\text{Stage 2}$', xy=(60, test_loss[50]+2), xytext=(19, test_loss[50]+10**3.8),\n",
    "               arrowprops=dict(arrowstyle='-|>', color='k', lw=2), fontsize=18, color='k', zorder=35)\n",
    "ax[0].annotate(r'$\\text{Stage 3}$', xy=(190, test_loss[190]), xytext=(185, test_loss[190]+10**2.2),\n",
    "               arrowprops=dict(arrowstyle='-|>', color='k', lw=2), fontsize=18, color='k', zorder=32)\n",
    "ax[0].annotate(r'$\\text{Peak}$', xy=(210, 10**5.6), xytext=(280, 10**5.35),\n",
    "               arrowprops=dict(arrowstyle='-', lw=0, color='k'), fontsize=18, color='k', zorder=30)\n",
    "\n",
    "# === Stage text (zoom panel) ===\n",
    "x1, y1 = all_k[0, 16]; ax[2].annotate(r'$\\text{Stage 1}$', xy=(x1+0.0022, y1), xytext=(x1-0.012, y1+0.015),\n",
    "                                     arrowprops=dict(arrowstyle='-|>', color='black', lw=1.5),\n",
    "                                     fontsize=19, color='k', zorder=30)\n",
    "ax[2].annotate(r'$\\text{Stage 2}$', xy=(-0.024, -0.014), xytext=(-0.058, -0.017),\n",
    "               arrowprops=dict(arrowstyle='-|>', color='black', lw=1.5), fontsize=19, color='k', zorder=30)\n",
    "ax[2].annotate(r'$\\text{(zig-zag)}$', xy=(-0.024, -0.014), xytext=(-0.058, -0.025),\n",
    "               arrowprops=dict(arrowstyle='-', lw=0, color='black'), fontsize=16, color='k', zorder=30)\n",
    "x2 = all_k[0, 120, 0]-0.0006; y2 = all_k[0, 120, 1]+0.001\n",
    "ax[2].annotate(r'$\\text{Stage 3}$', xy=(x2, y2-0.002), xytext=(x2+0.0008, y2-0.015),\n",
    "               arrowprops=dict(arrowstyle='-|>', color='black', lw=1.5), fontsize=19, color='k', zorder=30)\n",
    "\n",
    "# === Subplot labels and lines ===\n",
    "for i, label in enumerate([r'($\\mathbf{a}$)', r'($\\mathbf{b}$)', r'($\\mathbf{c}$)']):\n",
    "    ax[i].text(-0.13, 1.15, label, transform=ax[i].transAxes, fontsize=24, va='top', ha='right', font=\"Arial\")\n",
    "\n",
    "ax[0].set_title(r'$\\text{Test Loss}$', size=22, pad=10)\n",
    "ax[1].set_title(r'$\\text{Closed-loop - Loss & Stability}$', size=22, pad=10)\n",
    "ax[2].set_title(r'$\\text{Zoom-In}$', size=22, pad=10)\n",
    "ax[2].axvline(x=0, color='k', lw=3)\n",
    "ax[2].plot([0, -0.06], [0, -0.06], color='k', lw=3)\n",
    "\n",
    "fig.align_xlabels(ax)\n",
    "plt.tight_layout()\n",
    "sns.despine(ax=ax[0])\n",
    "# plt.savefig('../figs/fig_2.pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e73313-1dc6-4205-a8ea-a43f2930b3bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Setup ===\n",
    "N = 100\n",
    "num_epochs = 1000\n",
    "batch_size = 100\n",
    "system = k_integrator_torch(k=2, dt=1, c=1, m=1)\n",
    "\n",
    "# === Load teacher model training history ===\n",
    "with open('../data/data_closed.pkl', 'rb') as f:\n",
    "    loss, model_his_parameters, grad = pickle.load(f)\n",
    "\n",
    "# === Compute test loss over epochs ===\n",
    "test_loss = np.zeros((1, num_epochs))\n",
    "x0 = np.random.uniform(-2, 2, (1, batch_size))\n",
    "x1 = np.random.uniform(-2, 2, (1, batch_size))\n",
    "h0 = np.zeros((model_his_parameters[0].N, batch_size))\n",
    "\n",
    "for ep in range(num_epochs):\n",
    "    model = model_his_parameters[ep]\n",
    "    P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "\n",
    "    h = torch.tensor(np.vstack((x0, x1, h0)), dtype=torch.float32)\n",
    "    traj = []\n",
    "    for _ in range(50):\n",
    "        traj.append(h[:2]**2)\n",
    "        h = P @ h\n",
    "\n",
    "    test_loss[0, ep] = torch.sum(torch.stack(traj), axis=1).mean().item()\n",
    "\n",
    "# === Eigenvalue tracking of full P matrix ===\n",
    "all_P = [\n",
    "    create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random).detach().numpy()\n",
    "    for model in model_his_parameters\n",
    "]\n",
    "eigvals = np.linalg.eig(np.array(all_P))[0]\n",
    "stab = np.all(np.abs(eigvals) < 1, axis=1).astype(int)\n",
    "\n",
    "# === Find stabilization epoch ===\n",
    "ep_of_stab = next((i for i in range(num_epochs) if stab[i] == 1 and np.all(stab[i:] == 1)), num_epochs + 1)\n",
    "\n",
    "# === Collect trajectories from key epochs ===\n",
    "selected_epochs = [8, 140, 999]\n",
    "all_traj = []\n",
    "\n",
    "for ep in selected_epochs:\n",
    "    model = model_his_parameters[ep]\n",
    "    P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "\n",
    "    num_of_s = 300 if ep < 200 else 50  # Longer trace for earlier epochs\n",
    "    x0 = np.array([1])\n",
    "    x1 = np.array([-1])\n",
    "    h0 = np.zeros((100, 1))\n",
    "\n",
    "    h = torch.tensor(np.vstack((x0, x1, h0)), dtype=torch.float32)\n",
    "    traj = []\n",
    "\n",
    "    for _ in range(num_of_s):\n",
    "        traj.append(h[:2])\n",
    "        h = P @ h\n",
    "\n",
    "    all_traj.append(torch.stack(traj).squeeze().detach().numpy()[:, 0])\n",
    "\n",
    "# === Log-transformed loss trajectory segments ===\n",
    "loss = test_loss[0]\n",
    "\n",
    "x1 = np.arange(0, 12, 2)\n",
    "y1 = np.log(loss[:12:2])\n",
    "\n",
    "x2 = np.arange(12, ep_of_stab, 2)\n",
    "y2 = np.log(loss[12:ep_of_stab:2])\n",
    "\n",
    "x3 = np.arange(ep_of_stab, num_epochs, 2)\n",
    "y3 = np.log(loss[ep_of_stab::2])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b424838-a49d-4319-9e17-c198599c28d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 5, figsize=(17, 4.3), width_ratios=[1, 1, 1, 1.3, 1])\n",
    "\n",
    "# === (a) Test Loss with stage annotations ===\n",
    "ax[0].plot(np.arange(9), test_loss[0, :9], color=sns.color_palette('deep')[0], lw=4)\n",
    "ax[0].plot(np.arange(8, 129), test_loss[0, 8:129], color=sns.color_palette('deep')[7], lw=4, alpha=0.1)\n",
    "ax[0].plot(np.arange(129, 1000), test_loss[0, 129:], color=sns.color_palette('deep')[7], lw=4, alpha=0.1)\n",
    "\n",
    "ax[0].axvspan(0, 8, color=sns.color_palette('mako_r')[0], alpha=0.4)\n",
    "ax[0].axvspan(8, 129, color='grey', alpha=0.1)\n",
    "ax[0].axvspan(129, 1000, color='grey', alpha=0.1)\n",
    "\n",
    "ax[0].axvline(x=8, color='grey', ls='--')\n",
    "ax[0].axvline(x=129, color='grey', ls='--')\n",
    "\n",
    "ax[0].set(title=r\"$\\text{Stage 1}$\", xlabel=r'$\\text{Epoch}$', ylabel=r\"$\\text{Loss}$\",\n",
    "          xscale='log', yscale='log', xlim=(-10, 1010), ylim=(0.11, 10**10))\n",
    "ax[0].set_xticks([1, 10, 100, 1000])\n",
    "ax[0].tick_params(axis='both', labelsize=16)\n",
    "\n",
    "# === (b) Trajectory ===\n",
    "ax[1].plot(all_traj[0], lw=3, color=sns.color_palette('mako_r')[0])\n",
    "ax[1].axhline(0, color='grey', linestyle='--', linewidth=2)\n",
    "ax[1].axvline(x=50, ls='--', color='crimson')\n",
    "ax[1].set(title=r\"$\\text{Trajectory}$\", xlabel=r\"$\\text{Time-Step}$\", ylabel=r\"$\\text{Position}$\")\n",
    "ax[1].set_xticks([50, 200])\n",
    "ax[1].set_yticks([-50, 0, 50])\n",
    "ax[1].tick_params(axis='both', which='major', width=2, length=6)\n",
    "\n",
    "# === (c) eig(P) spectrum ===\n",
    "eig, _ = np.linalg.eig(all_P[4])\n",
    "ax[2].scatter(eig.real, eig.imag, s=70, color=sns.color_palette('mako_r')[0], alpha=0.5)\n",
    "\n",
    "eig, _ = np.linalg.eig(all_P[10])\n",
    "ax[2].scatter(eig.real, eig.imag, s=70, color=sns.color_palette('mako_r')[0],\n",
    "              edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "eig, _ = np.linalg.eig(all_P[5:10])\n",
    "ax[2].scatter(eig.real, eig.imag, s=70, color=sns.color_palette('mako_r')[0], alpha=0.5)\n",
    "\n",
    "circle = np.linspace(0, 2 * np.pi, 100)\n",
    "ax[2].plot(np.cos(circle), np.sin(circle), color='grey', ls='--', lw=2.5)\n",
    "\n",
    "ax[2].axhline(0, color='grey', lw=1)\n",
    "ax[2].axvline(0, color='grey', lw=1)\n",
    "ax[2].spines['top'].set_visible(False)\n",
    "ax[2].spines['right'].set_visible(False)\n",
    "ax[2].set(title=r'$\\text{eig}(\\mathbfit{P})$', xlabel=r'$\\text{Real}$', ylabel=r'$\\text{Imag}.$',\n",
    "          xlim=(0.89, 1.11), ylim=(-0.33, 0.33))\n",
    "\n",
    "# Eig arrows\n",
    "arrow3a = FancyArrowPatch((1.085, .01), (1.005, .15), connectionstyle=\"arc3,rad=-0.35\",\n",
    "                          arrowstyle=\"-|>\", lw=2, color='black', mutation_scale=20, zorder=30)\n",
    "arrow3b = FancyArrowPatch((0.915, -.01), (.995, -.15), connectionstyle=\"arc3,rad=-0.35\",\n",
    "                          arrowstyle=\"-|>\", lw=2, color='black', mutation_scale=20, zorder=30)\n",
    "ax[2].add_patch(arrow3a)\n",
    "ax[2].add_patch(arrow3b)\n",
    "\n",
    "# === Trajectory simulation and loss ===\n",
    "def run_traj(P, num_steps=50):\n",
    "    x0 = np.array([[1]])\n",
    "    x1 = np.array([[-1]])\n",
    "    h0 = np.zeros((model.N, 1))\n",
    "    h = torch.tensor(np.vstack((x0, x1, h0)), dtype=torch.float32)\n",
    "\n",
    "    traj = []\n",
    "    for _ in range(num_steps):\n",
    "        traj.append(h[:2].clone())\n",
    "        h = P @ h\n",
    "\n",
    "    return torch.cat(traj, dim=1).T.detach().numpy()  \n",
    "\n",
    "\n",
    "num_steps = 100\n",
    "N = 100\n",
    "\n",
    "zm_pos = np.linspace(0.0001, 0.3, 100)\n",
    "loss_sim_pos, loss_the_pos = [], []\n",
    "\n",
    "for cur_zm in zm_pos:\n",
    "    model = P_Model(N=N, g=0.0, system=system, over=cur_zm)\n",
    "    P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "    traj = run_traj(P, num_steps)\n",
    "    loss = np.sum(0.5 * (traj[:, 0]**2 + traj[:, 1]**2))\n",
    "    loss_sim_pos.append(loss)\n",
    "\n",
    "    T = num_steps + 1\n",
    "    r = (np.abs(1 + np.sqrt(cur_zm + 0j))) ** 2\n",
    "    loss_the_pos.append(0.5 * (1 - r**T) / (1 - r))\n",
    "\n",
    "zm_neg = np.linspace(-0.5, -0.0001, 100)\n",
    "loss_sim_neg, loss_the_neg = [], []\n",
    "\n",
    "for cur_zm in zm_neg:\n",
    "    model = P_Model(N=N, g=0.0, system=system, over=cur_zm)\n",
    "    P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "    traj = run_traj(P, num_steps)\n",
    "    loss = np.sum(0.5 * (traj[:, 0]**2 + traj[:, 1]**2))\n",
    "    loss_sim_neg.append(loss)\n",
    "\n",
    "    T = num_steps + 1\n",
    "    r = (np.abs(1 + np.sqrt(cur_zm + 0j))) ** 2\n",
    "    loss_the_neg.append(0.5 * (1 - r**T) / (1 - r))\n",
    "\n",
    "# === (d) Plot theoretical vs simulation loss ===\n",
    "ax[3].plot(zm_pos, np.log(loss_sim_pos), color=sns.color_palette('mako_r')[0], lw=4, label=r'$\\text{Simulation}$')\n",
    "ax[3].plot(zm_pos, np.log(loss_the_pos), '--k', lw=4, label=r'$\\text{Theory}$')\n",
    "ax[3].plot(zm_neg, np.log(loss_sim_neg), color=sns.color_palette('mako_r')[0], lw=4)\n",
    "ax[3].plot(zm_neg, np.log(loss_the_neg), '--k', lw=4)\n",
    "\n",
    "ax[3].legend(fontsize=18, loc='upper left', frameon=False)\n",
    "ax[3].set(title=r'$\\text{Theoretical}$', xlabel=r'$ \\sigma_{zm} $', ylabel=r\"$\\text{Loss}$\")\n",
    "ax[3].set_xticks([-0.3, 0, 0.3])\n",
    "\n",
    "# === (e) Overlap vs epoch for different initial σ_zm ===\n",
    "zm_over_time = []\n",
    "all_zm_vals = np.linspace(-0.2, 0.2, 10)\n",
    "colors = sns.color_palette('mako_r', n_colors=len(all_zm_vals))\n",
    "\n",
    "for i, cur_zm in enumerate(all_zm_vals):\n",
    "    system = k_integrator_torch(k=2, dt=1, c=1, m=1)\n",
    "    model = P_Model(N=100, g=0.0, system=system, over=cur_zm)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
    "\n",
    "    loss, _model_his_parameters, grad = train_model_p(\n",
    "        model, optimizer, teacher=None, white_noise=None,\n",
    "        w_grad_clip=True, num_epochs=50, batch_size=100,\n",
    "        num_steps=50, clamp=False)\n",
    "\n",
    "    all_over = [(m.Z @ m.M).item() for m in _model_his_parameters]\n",
    "    ax[4].plot(all_over[::2], label=cur_zm, color=colors[i], lw=2)\n",
    "    ax[4].axhline(y=0, color='grey', ls='--', lw=2)\n",
    "    zm_over_time.append(all_over)\n",
    "\n",
    "ax[4].set(title=r'$\\text{Empirical}$', xlabel=r'$\\text{Epoch}$', ylabel=r'$ \\sigma_{zm} $')\n",
    "\n",
    "# === Subplot labels ===\n",
    "subplot_labels = [r'($\\mathbf{a}$)', r'($\\mathbf{b}$)', r'($\\mathbf{c}$)', r'($\\mathbf{d}$)', r'($\\mathbf{e}$)']\n",
    "for i, label in enumerate(subplot_labels):\n",
    "    ax[i].text(-0.2, 1.25, label, transform=ax[i].transAxes,\n",
    "               fontsize=24, fontweight='bold', va='top', ha='right', font=\"Arial\")\n",
    "\n",
    "# === Final display ===\n",
    "plt.tight_layout()\n",
    "sns.despine()\n",
    "# plt.savefig('../figs/fig_3.pdf', bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a160eee6-2dcc-45a2-a916-bafe89f766fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Load teacher model training history ===\n",
    "with open('../data/data_closed.pkl', 'rb') as f:\n",
    "    loss, model_his_parameters, grad = pickle.load(f)\n",
    "\n",
    "with open('../data/eig_all_dict.pkl', 'rb') as f:\n",
    "    eig_all_dict = pickle.load(f)\n",
    "\n",
    "eig_vs = eig_all_dict['eig_vs']\n",
    "eig_s = eig_all_dict['eig_s']\n",
    "eig_b = eig_all_dict['eig_b']\n",
    "eig_vb = eig_all_dict['eig_vb']\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 5, figsize=(17, 4.3), width_ratios=[1, 1, 1, 1, 1])\n",
    "\n",
    "palette = sns.color_palette('mako_r')\n",
    "# --------- Plot: Loss Curve --------- #\n",
    "def plot_loss_curve(ax, test_loss):\n",
    "    ax.plot(np.arange(9), test_loss[0, :9], color=sns.color_palette('deep')[7], lw=4, alpha=0.1)\n",
    "    ax.plot(np.arange(8, 129), test_loss[0, 8:129], color=sns.color_palette('deep')[0], lw=4)\n",
    "    ax.plot(np.arange(129, 1000), test_loss[0, 129:], color=sns.color_palette('deep')[7], lw=4, alpha=0.1)\n",
    "    \n",
    "    ax.axvspan(0, 8, color='grey', alpha=0.1)\n",
    "    ax.axvspan(8, 129, color=palette[2], alpha=0.4)\n",
    "    ax.axvspan(129, 1000, color='grey', alpha=0.1)\n",
    "    ax.axvline(8, color='grey', ls='--')\n",
    "    ax.axvline(129, color='grey', ls='--')\n",
    "\n",
    "    ax.set(title=r\"$\\text{Stage 2}$\", xscale='log', yscale='log', \n",
    "           xlim=(-10, 1010), ylim=(0.11, 1e10),\n",
    "           xlabel=r'$\\text{Epoch}$', ylabel=r\"$\\text{Loss}$\")\n",
    "    ax.set_xticks([1, 10, 100, 1000])\n",
    "    ax.tick_params(axis='both', labelsize=16)\n",
    "    ax.set_title(r'$\\text{Stage 2} $',size=26)\n",
    "    ax.set_xlabel(r'$\\text{Epoch}$',size=22)\n",
    "    ax.set_ylabel(r'$\\text{Loss}$',size=22)\n",
    "\n",
    "# --------- Plot: Trajectory --------- #\n",
    "def plot_trajectory(ax, traj):\n",
    "    ax.plot(traj, lw=3, color=palette[2])\n",
    "    ax.axhline(0, color='grey', linestyle='--', lw=2)\n",
    "    ax.set(title=r\"$\\text{Trajectory}$\", xlabel=r\"$\\text{Time-Step}$\", ylabel=r\"$\\text{Position}$\")\n",
    "    ax.set_yticks([-5, 0, 5])\n",
    "    ax.tick_params(axis='both', which='major', width=2, length=6)\n",
    "    ax.set_title(r\"$\\text{Trajectory}$\",size=26)\n",
    "    ax.set_xlabel(r\"$\\text{Time-Step}$\",size=22)\n",
    "    ax.set_ylabel(r\"$\\text{Position}$\",size=22)\n",
    "\n",
    "# --------- Plot: Eigenvalues with Zigzag Arrow --------- #\n",
    "def plot_eig_with_arrow(ax, all_P, ep_of_stab):\n",
    "    for idx in [ep_of_stab]:\n",
    "        eigvals = np.linalg.eigvals(all_P[idx])\n",
    "        ax.scatter(eigvals.real, eigvals.imag, s=50, color=palette[2],\n",
    "                   edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "    eigvals = np.linalg.eigvals(all_P[10])\n",
    "    ax.scatter(eigvals.real, eigvals.imag, s=50, color=sns.color_palette('mako_r')[2],alpha=0.2)\n",
    "\n",
    "    intermediate_eigvals = np.linalg.eigvals(all_P[12:ep_of_stab - 2])# [::2]\n",
    "    ax.scatter(intermediate_eigvals.real, intermediate_eigvals.imag,\n",
    "               s=50, color=palette[2], alpha=0.2)\n",
    "\n",
    "    ax.axhline(0, color='grey', lw=1)\n",
    "    ax.axvline(0, color='grey', lw=1)\n",
    "    ax.plot(np.cos(np.linspace(0, 2*np.pi, 100)),\n",
    "            np.sin(np.linspace(0, 2*np.pi, 100)), color='grey', ls='--', lw=2.5)\n",
    "\n",
    "    zigzag = [(0.999, 0.158), (0.998, 0.148), (0.997, 0.158), (0.996, 0.148),\n",
    "              (0.995, 0.158), (0.994, 0.148), (0.993, 0.158), (0.992, 0.148),\n",
    "              (0.991, 0.158), (0.990, 0.14985), (0.989, 0.158), (0.9863, 0.15)]\n",
    "    path = Path(zigzag, [Path.MOVETO] + [Path.LINETO]*(len(zigzag)-1))\n",
    "    arrow = FancyArrowPatch(path=path, arrowstyle='-|>', color='black', mutation_scale=15, lw=2, zorder=30)\n",
    "    ax.add_patch(arrow)\n",
    "\n",
    "    ax.set(title=r'$\\text{eig}(\\mathbfit{P})$', xlim=(0.98, 1.005), ylim=(0.12, 0.2),\n",
    "           xlabel=r'$\\text{Real}$', ylabel=r'$\\text{Imag}.$')\n",
    "    ax.set_yticks([0.13, 0.2])\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    \n",
    "    ax.set_title(r'$\\text{eig}(\\mathbfit{P})$',size=26)\n",
    "    ax.set_ylabel(r'$\\text{Imag}.$',size=22)\n",
    "    ax.set_xlabel(r'$\\text{Real}$',size=22)\n",
    "\n",
    "# --------- Model Definition --------- #\n",
    "class ReduceModel(nn.Module):\n",
    "    def __init__(self, a=0, b=0.01, c=0, d=0, alpha=0.01):\n",
    "        super().__init__()\n",
    "        self.a = nn.Parameter(torch.tensor(a, dtype=torch.float32))\n",
    "        self.b = nn.Parameter(torch.tensor(b, dtype=torch.float32))\n",
    "        self.c = nn.Parameter(torch.tensor(c, dtype=torch.float32))\n",
    "        self.d = nn.Parameter(torch.tensor(d, dtype=torch.float32))\n",
    "        self.alpha = alpha\n",
    "\n",
    "    def forward(self):\n",
    "        a, b, c, d = self.a, self.b, self.c, self.d\n",
    "        lam = a + (c * d) / (a**2 - 2*a - b + 1)\n",
    "        r = (2*a - b + 1 + (-a-2)*lam + lam**2)**10\n",
    "        n = (b+1)**2 + (b+3)**2 + (c*d + 2*b)**2 + (c*d + 3*b + 1)**2\n",
    "        return self.alpha * r + (1 - self.alpha) * n\n",
    "\n",
    "# --------- Optimization Loop --------- #\n",
    "def optimize_and_track(a, b, c, d, alpha, lr, color, ax):\n",
    "    model = ReduceModel(a, b, c, d, alpha)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "    red_all_loss, red_model_parameters, all_gradients = [], [], []\n",
    "\n",
    "    allep = 240\n",
    "    if alpha == 0.48:\n",
    "       allep=300 \n",
    "    if alpha == 0.55:\n",
    "        allep=190 \n",
    "        \n",
    "    for _ in range(allep):\n",
    "        loss = model()\n",
    "        red_all_loss.append(loss.item())\n",
    "        red_model_parameters.append(deepcopy(model))\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "        optimizer.step()\n",
    "\n",
    "        all_gradients.append([p.grad.item() for p in model.parameters()])\n",
    "\n",
    "    # Eigen root tracking\n",
    "    z = [] \n",
    "    for m in red_model_parameters:\n",
    "        a, b, c, d = m.a.item(), m.b.item(), m.c.item(), m.d.item()\n",
    "        coeffs = [1, -a - 2, 2*a - b + 1, a*b - a - c*d]\n",
    "        roots = np.roots(coeffs)\n",
    "        # ax.scatter(np.real(roots), np.imag(roots), color=color)\n",
    "        z.append(roots[0])\n",
    "        \n",
    "    ax.plot(np.real(z),np.imag(z),color=color,lw=4)\n",
    "    ax.scatter(z[0].real, z[0].imag, marker='o',color='crimson',s=100,edgecolor='k', linewidths=2, zorder=30)\n",
    "    ax.scatter(z[-1].real, z[-1].imag, s=50, color=color, edgecolor='k', linewidths=2, zorder=30)\n",
    "        \n",
    "\n",
    "# -------------------- Run Everything ------------------- #\n",
    "plot_loss_curve(ax[0], test_loss)\n",
    "plot_trajectory(ax[1], all_traj[1])\n",
    "plot_eig_with_arrow(ax[2], all_P, ep_of_stab)\n",
    "\n",
    "# Params from model\n",
    "model = model_his_parameters[10]\n",
    "m, z, u, v = model.M.detach().numpy().flatten(), model.Z.detach().numpy().flatten(), model.U.detach().numpy().flatten(), model.V.detach().numpy().flatten()\n",
    "a, b, c, d = v @ u, z @ m, v @ m, z @ u\n",
    "\n",
    "\n",
    "\n",
    "optimize_and_track(a, b, c, d, alpha=0.0,   lr=.001, color=palette[0], ax=ax[3])\n",
    "optimize_and_track(a, b, c, d, alpha=0.48, lr=.001, color=palette[1], ax=ax[3])\n",
    "optimize_and_track(a, b, c, d, alpha=0.52, lr=.001, color=palette[2], ax=ax[3])\n",
    "optimize_and_track(a, b, c, d, alpha=0.55, lr=.001, color=palette[4], ax=ax[3])\n",
    "optimize_and_track(a, b, c, d, alpha=1.0,   lr=.001, color=palette[5], ax=ax[3])\n",
    "\n",
    "\n",
    "# --------- Labels --------- #\n",
    "subplot_labels = [r'($\\mathbf{a}$)', r'($\\mathbf{b}$)', r'($\\mathbf{c}$)', r'($\\mathbf{d}$)', r'($\\mathbf{e}$)']\n",
    "for i, label in enumerate(subplot_labels):\n",
    "    ax[i].text(-0.2, 1.25, label, transform=ax[i].transAxes,\n",
    "               fontsize=24, fontweight='bold', va='top', ha='right', font=\"Arial\")\n",
    "\n",
    "\n",
    "arrow3 = FancyArrowPatch((1.002, 0.165), (1.002, 0.23), connectionstyle=\"arc3,rad=0\", arrowstyle=\"-|>\", lw=2.5,\n",
    "                         color=palette[0], mutation_scale=20, zorder=30)\n",
    "ax[3].add_patch(arrow3)\n",
    "ax[3].text(x=1.001,y=0.23,s=r'$\\alpha=0$',color='k',zorder=30)\n",
    "\n",
    "arrow3 = FancyArrowPatch((1.002, 0.165), (1.002, 0.09), connectionstyle=\"arc3,rad=0\", arrowstyle=\"-|>\", lw=2.5,\n",
    "                         color=palette[5], mutation_scale=20, zorder=30)\n",
    "ax[3].add_patch(arrow3)\n",
    "ax[3].text(x=1.001,y=0.08,s=r'$\\alpha=1$',color='k',zorder=30)\n",
    "\n",
    "\n",
    "\n",
    "# --------- Final Touch --------- #\n",
    "ax[3].axhline(0, color='grey', lw=1)\n",
    "ax[3].axvline(0, color='grey', lw=1)\n",
    "ax[3].spines['top'].set_visible(False)\n",
    "ax[3].spines['right'].set_visible(False)\n",
    "ax[3].plot(np.cos(np.linspace(0, 2*np.pi, 100)),\n",
    "           np.sin(np.linspace(0, 2*np.pi, 100)), color='grey', ls='--', lw=2.5)\n",
    "ax[3].set_xlim(0.97, 1.01)\n",
    "ax[3].set_ylim(0.06, 0.25)\n",
    "\n",
    "# ax[3].set_title(r'$\\text{eig}(\\mathbfit{P})$')\n",
    "ax[3].set_ylabel(r'$\\text{Imag}.$',size=22)\n",
    "ax[3].set_xlabel(r'$\\text{Real}$',size=22)\n",
    "ax[3].set_yticks([0.13, 0.25])\n",
    "ax[3].set_title(r'$\\text{Theoretical}$',size=26)\n",
    "\n",
    "#############################################\n",
    "# eig, _ = np.linalg.eig(vs_all_P[:108]) \n",
    "ax[4].plot(eig_vs.real[:,0], eig_vs.imag[:,0],color=sns.color_palette('mako_r')[0],lw=3) #s=50, \n",
    "\n",
    "\n",
    "\n",
    "# eig, _ = np.linalg.eig(s_all_P[:108]) \n",
    "ax[4].plot(eig_s.real[:,0], eig_s.imag[:,0],color=sns.color_palette('mako_r')[1],lw=3) #s=50, \n",
    "for idx in [106]:\n",
    "        # eigvals = np.linalg.eigvals(s_all_P[idx])\n",
    "        eigvals = eig_s[106]\n",
    "        ax[4].scatter(eigvals.real, eigvals.imag, s=50, color=palette[1],\n",
    "                   edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "\n",
    "eig, _ = np.linalg.eig(all_P[10:ep_of_stab]) \n",
    "ax[4].plot(eig.real[:,0], eig.imag[:,0], color=sns.color_palette('mako_r')[2],lw=3) #s=50, \n",
    "\n",
    "for idx in [ep_of_stab]:\n",
    "        eigvals = np.linalg.eigvals(all_P[idx])\n",
    "        ax[4].scatter(eigvals.real, eigvals.imag, s=50, color=palette[2],\n",
    "                   edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "eigvals = np.linalg.eigvals(all_P[10])\n",
    "ax[4].scatter(eigvals.real, eigvals.imag, marker='o',color='crimson',s=100,edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "    \n",
    "# eig, _ = np.linalg.eig(b_all_P[:111]) \n",
    "ax[4].plot(eig_b.real[:,0], eig_b.imag[:,0], color=sns.color_palette('mako_r')[4],lw=3)\n",
    "for idx in [111]:\n",
    "        # eigvals = np.linalg.eigvals(b_all_P[idx])\n",
    "        eigvals = eig_b[idx]\n",
    "        ax[4].scatter(eigvals.real, eigvals.imag, s=50, color=palette[4],\n",
    "                   edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "# eig, _ = np.linalg.eig(vb_all_P[:13]) \n",
    "ax[4].plot(eig_vb.real[:,0], eig_vb.imag[:,0],color=sns.color_palette('mako_r')[5],lw=3) #s=50, \n",
    "\n",
    "ax[4].axhline(0, color='grey', lw=1)\n",
    "ax[4].axvline(0, color='grey', lw=1)\n",
    "ax[4].spines['top'].set_visible(False)\n",
    "ax[4].spines['right'].set_visible(False)\n",
    "ax[4].plot(np.cos(np.linspace(0, 2*np.pi, 100)),\n",
    "           np.sin(np.linspace(0, 2*np.pi, 100)), color='grey', ls='--', lw=2.5)\n",
    "ax[4].set_xlim(0.97, 1.01)\n",
    "ax[4].set_ylim(0.06, 0.25)\n",
    "\n",
    "ax[4].set_title(r'$\\text{Empirical}$',size=26) \n",
    "ax[4].set_ylabel(r'$\\text{Imag}.$',size=22)\n",
    "ax[4].set_xlabel(r'$\\text{Real}$',size=22)\n",
    "ax[4].set_yticks([0.13, 0.25])\n",
    "\n",
    "\n",
    "from matplotlib.colors import Normalize\n",
    "from matplotlib.cm import ScalarMappable\n",
    "# Colorbar setup\n",
    "\n",
    "# --- Colorbar on ax[4] ---\n",
    "norm = Normalize(vmin=-10, vmax=120)\n",
    "cmap = sns.color_palette(\"mako_r\", as_cmap=True)\n",
    "sm = ScalarMappable(norm=norm, cmap=cmap)\n",
    "\n",
    "# Inset axes inside ax[3] (lower left, vertical)\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "\n",
    "cax = inset_axes(ax[4],\n",
    "                 width=\"6%\",  # width: 4% of parent_axes\n",
    "                 height=\"40%\",  # height: 30% of parent_axes\n",
    "                 loc='lower left',\n",
    "                 bbox_to_anchor=(0.03, 0.03, 1, 1),\n",
    "                 bbox_transform=ax[4].transAxes)\n",
    "\n",
    "cbar = plt.colorbar(sm, cax=cax, orientation='vertical')\n",
    "cbar.ax.tick_params(labelsize=14)\n",
    "cbar.set_ticks([10,   100])\n",
    "cbar.set_label(r\"$T$\", rotation=0, labelpad=-8, fontsize=22,va='center')\n",
    "#############################################\n",
    "\n",
    "\n",
    "# --- Colorbar on ax[4] ---\n",
    "norm = Normalize(vmin=0, vmax=1)\n",
    "cmap = sns.color_palette(\"mako_r\", as_cmap=True)\n",
    "sm = ScalarMappable(norm=norm, cmap=cmap)\n",
    "\n",
    "# Inset axes inside ax[3] (lower left, vertical)\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "\n",
    "cax = inset_axes(ax[3],\n",
    "                 width=\"6%\",  # width: 4% of parent_axes\n",
    "                 height=\"40%\",  # height: 30% of parent_axes\n",
    "                 loc='lower left',\n",
    "                 bbox_to_anchor=(0.03, 0.03, 1, 1),\n",
    "                 bbox_transform=ax[3].transAxes)\n",
    "\n",
    "\n",
    "\n",
    "cbar = plt.colorbar(sm, cax=cax, orientation='vertical')\n",
    "cbar.ax.tick_params(labelsize=14)\n",
    "cbar.set_ticks([0.1,   0.8])\n",
    "cbar.set_label(r\"$\\alpha$\", rotation=0, labelpad=-8, fontsize=22,va='center')\n",
    "\n",
    "\n",
    "arrow3 = FancyArrowPatch((1.002, 0.165), (1.002, 0.23), connectionstyle=\"arc3,rad=0\", arrowstyle=\"-|>\", lw=2.5,\n",
    "                         color=palette[0], mutation_scale=20, zorder=30)\n",
    "ax[4].add_patch(arrow3)\n",
    "ax[4].text(x=1.001,y=0.23,s=r'$T \\to 0$',color='k',zorder=30,size=14)\n",
    "\n",
    "arrow3 = FancyArrowPatch((1.002, 0.165), (1.002, 0.09), connectionstyle=\"arc3,rad=0\", arrowstyle=\"-|>\", lw=2.5,\n",
    "                         color=palette[5], mutation_scale=20, zorder=30)\n",
    "ax[4].add_patch(arrow3)\n",
    "ax[4].text(x=1.001,y=0.075,s=r'$T \\to \\infty$',color='k',zorder=30,size=14)\n",
    "\n",
    "ax[1].axvline(x=50,ls='--',color='crimson')\n",
    "\n",
    "ax[1].set_xticks([50,200])\n",
    "\n",
    "plt.tight_layout()\n",
    "sns.despine()\n",
    "# plt.savefig('../figs/fig_4.pdf',bbox_inches=None)# \n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e1e714a-3162-4a48-8941-2e7af96fcdc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "arr_betas = [0, .5, 2, 5]\n",
    "N = 100\n",
    "dic_h_d = {}\n",
    "system = k_integrator_torch(k=2,dt=1,c=1,m=1,)\n",
    "\n",
    "with open('../data/data_closed.pkl', 'rb') as f:\n",
    "    org_loss, model_his_parameters, grad = pickle.load(f)\n",
    "\n",
    "init_model  = model_his_parameters[130]\n",
    "\n",
    "for beta in arr_betas:\n",
    "\n",
    "    model = P_Model(N=N, g=0, rank=1,  system=system) # Instantiate the model\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Instantiate the optimizer\n",
    "\n",
    "    with torch.no_grad():\n",
    "        model.U.copy_(nn.Parameter(init_model.U.detach())) \n",
    "        model.V.copy_(nn.Parameter(init_model.V.detach())) \n",
    "        model.M.copy_(nn.Parameter(init_model.M.detach())) \n",
    "        model.Z.copy_(nn.Parameter(init_model.Z.detach()))\n",
    "\n",
    "    num_epochs = 1000\n",
    "    loss,model_his_parameters, grad = train_model_p(model,\n",
    "                                              optimizer,\n",
    "                                              teacher=None,\n",
    "                                              white_noise=None,\n",
    "                                              beta=beta,\n",
    "                                              w_grad_clip=True,\n",
    "                                              num_epochs=num_epochs,\n",
    "                                              batch_size=100,\n",
    "                                              num_steps=50,\n",
    "                                              clamp=False)\n",
    "\n",
    "    _all_P = []\n",
    "    for ep in range(num_epochs):\n",
    "        model = model_his_parameters[ep]\n",
    "        P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "        _all_P.append(P.detach().numpy())\n",
    "    \n",
    "    eig,_ = np.linalg.eig(np.array(_all_P))\n",
    "    stab = np.all(np.abs(eig) < 1, axis=1).astype(int)\n",
    "    \n",
    "    try:\n",
    "        for i in range(num_epochs):\n",
    "            if stab[i] == 1 and all(stab[i:] == [1] * (len(stab) - i)):\n",
    "                ep_of_stab = i\n",
    "                break\n",
    "    except:\n",
    "        ep_of_stab = num_epochs+1\n",
    "\n",
    "    dic_h_d[beta] = [loss, model_his_parameters, grad, _all_P, ep_of_stab]\n",
    "\n",
    "        \n",
    "\n",
    "\n",
    "dic_l_d = {}\n",
    "with open('../data/data_closed.pkl', 'rb') as f:\n",
    "    loss, model_his_parameters, grad = pickle.load(f)\n",
    "\n",
    "init_model  = model_his_parameters[130]\n",
    "\n",
    "for beta in arr_betas:\n",
    "    \n",
    "    model = init_model\n",
    "    m, z, u, v = model.M.detach().numpy().flatten(), model.Z.detach().numpy().flatten(), model.U.detach().numpy().flatten(), model.V.detach().numpy().flatten()\n",
    "    model = P_Model_eff(init_sig_zm=z@m, init_sig_zu=z@u, init_sig_vm=v@m, init_sig_vu=v@u)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Instantiate the optimizer\n",
    "    num_epochs = 1000\n",
    "    batch_size = 100\n",
    "    loss, model_his_parameters, grad = train_model_p_eff(model, optimizer, teacher=None, white_noise=None, beta=beta,\n",
    "                                                   w_grad_clip=1, num_epochs=num_epochs, batch_size=batch_size, num_steps=50, clamp=False)\n",
    "    \n",
    "    _all_P = []\n",
    "    for ep in range(num_epochs):\n",
    "        P = create_p_effective(model_his_parameters[ep])\n",
    "        _all_P.append(P)\n",
    "    \n",
    "    eig,_ = np.linalg.eig(np.array(_all_P))\n",
    "    stab = np.all(np.abs(eig) < 1, axis=1).astype(int)\n",
    "    \n",
    "    try:\n",
    "        for i in range(num_epochs):\n",
    "            if stab[i] == 1 and all(stab[i:] == [1] * (len(stab) - i)):\n",
    "                ep_of_stab = i\n",
    "                break\n",
    "    except:\n",
    "        ep_of_stab = num_epochs+1\n",
    "        \n",
    "    dic_l_d[beta] = [loss, model_his_parameters, grad, _all_P, ep_of_stab]\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2371fe4d-20ce-422c-a12a-dd3e4bb362a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 4, figsize=(17, 4.3), width_ratios=[1, 1, 1, 2])\n",
    "\n",
    "# === Panel (a): Test Loss ===\n",
    "ax[0].plot(np.arange(9),     test_loss[0, :9],     color=sns.color_palette('deep')[7], lw=4, alpha=0.1)\n",
    "ax[0].plot(np.arange(8, 129), test_loss[0, 8:129], color=sns.color_palette('deep')[7], lw=4, alpha=0.1)\n",
    "ax[0].plot(np.arange(129, 1000), test_loss[0, 129:], color=sns.color_palette('deep')[0], lw=4)\n",
    "\n",
    "ax[0].axvspan(0, 8, color='grey', alpha=0.1)\n",
    "ax[0].axvspan(8, 129, color='grey', alpha=0.1)\n",
    "ax[0].axvspan(129, 1000, color=sns.color_palette('mako_r')[4], alpha=0.4)\n",
    "\n",
    "ax[0].axvline(x=8, color='grey', ls='--')\n",
    "ax[0].axvline(x=129, color='grey', ls='--')\n",
    "\n",
    "ax[0].set_xscale('log')\n",
    "ax[0].set_yscale('log')\n",
    "ax[0].set_xlim(-10, 1010)\n",
    "ax[0].set_ylim(0.11, 1e10)\n",
    "ax[0].set_title(r\"$\\text{Stage 3}$\", size=26)\n",
    "ax[0].set_xlabel(r'$\\text{Epoch}$', size=22)\n",
    "ax[0].set_ylabel(r\"$\\text{Loss}$\", size=22)\n",
    "ax[0].set_xticks([1, 10, 100, 1000])\n",
    "ax[0].tick_params(axis='both', labelsize=16)\n",
    "\n",
    "# === Panel (b): Trajectory ===\n",
    "ax[1].plot(all_traj[2], lw=3, color=sns.color_palette('mako_r')[4])\n",
    "ax[1].axhline(0, color='grey', ls='--', lw=2)\n",
    "# ax[1].set_xticks([50, 200])\n",
    "ax[1].set_yticks([-3, 0, 3])\n",
    "ax[1].tick_params(axis='both', width=2, length=6)\n",
    "\n",
    "ax[1].set_title(r\"$\\text{Trajectory}$\", size=26)\n",
    "ax[1].set_xlabel(r\"$\\text{Time-Step}$\", size=22)\n",
    "ax[1].set_ylabel(r\"$\\text{Position}$\", size=22)\n",
    "ax[1].axvline(x=50, ls='--', color='crimson')\n",
    "\n",
    "# === Panel (c): eig(P) ===\n",
    "# Plot earlier eigenvalues (light)\n",
    "eig, _ = np.linalg.eig(all_P[130])\n",
    "ax[2].scatter(eig.real, eig.imag, s=50, color=sns.color_palette('mako_r')[4], alpha=0.2)\n",
    "\n",
    "eig, _ = np.linalg.eig(all_P[130:][::40])\n",
    "ax[2].scatter(eig.real[:, :2], eig.imag[:, :2], s=50, color=sns.color_palette('mako_r')[4], alpha=0.1)\n",
    "\n",
    "eig, _ = np.linalg.eig(all_P[130:180][::4])\n",
    "ax[2].scatter(eig.real[:, 2], eig.imag[:, 2], s=50, color=sns.color_palette('mako_r')[4], alpha=0.2)\n",
    "\n",
    "# Final stable eigenvalues (highlighted)\n",
    "eig, _ = np.linalg.eig(all_P[-1])\n",
    "ax[2].scatter(eig.real, eig.imag, s=50, color=sns.color_palette('mako_r')[4], edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "# Circle for stability\n",
    "theta = np.linspace(0, 2 * np.pi, 100)\n",
    "ax[2].plot(np.cos(theta), np.sin(theta), color='grey', ls='--', lw=2.5, alpha=1)\n",
    "\n",
    "ax[2].axhline(0, color='grey', lw=1)\n",
    "ax[2].axvline(0, color='grey', lw=1)\n",
    "ax[2].spines['top'].set_visible(False)\n",
    "ax[2].spines['right'].set_visible(False)\n",
    "\n",
    "ax[2].set_xlim(-0.3, 1.2)\n",
    "ax[2].set_ylim(-1.05, 1.05)\n",
    "ax[2].set_title(r'$\\text{eig}(\\mathbfit{P})$', size=26)\n",
    "ax[2].set_xlabel(r'$\\text{Real}$', size=22)\n",
    "ax[2].set_ylabel(r'$\\text{Imag}.$', size=22)\n",
    "\n",
    "# Arrows showing movement in eig space\n",
    "ax[2].add_patch(FancyArrowPatch((0.1, 0.0), (0.65, 0.0), connectionstyle=\"arc3,rad=0\", arrowstyle=\"-|>\", lw=2, color='black', mutation_scale=20, zorder=40))\n",
    "ax[2].add_patch(FancyArrowPatch((0.95,  0.2), (0.5, 0.5), connectionstyle=\"arc3,rad=.25\", arrowstyle=\"-|>\", lw=2, color='black', mutation_scale=15, zorder=0))\n",
    "ax[2].add_patch(FancyArrowPatch((0.95, -0.2), (0.5, -0.5), connectionstyle=\"arc3,rad=-.25\", arrowstyle=\"-|>\", lw=2, color='black', mutation_scale=15, zorder=0))\n",
    "\n",
    "# === Panel (d): Loss comparison across β ===\n",
    "skipy = 1\n",
    "labels_plotted = False\n",
    "\n",
    "for idx, beta in enumerate(arr_betas):\n",
    "    color = colors[[0, 1, 2, 4][idx]]\n",
    "    \n",
    "    ax[3].plot(dic_h_d[beta][0][::skipy], color=color, label=r'$P$' if idx > 0 else None)\n",
    "    ax[3].plot(dic_l_d[beta][0][::skipy], ls='--', color='k', lw=1.5, label=r'$P_{\\text{eff}}$' if idx > 0 else None)\n",
    "\n",
    "# Set log scales\n",
    "ax[3].set_xscale('log')\n",
    "ax[3].set_yscale('log')\n",
    "\n",
    "# Disable minor ticks\n",
    "ax[3].xaxis.set_minor_locator(NullLocator())\n",
    "ax[3].yaxis.set_minor_locator(NullLocator())\n",
    "\n",
    "# Axis labels and title\n",
    "ax[3].set_xlabel(r'$\\text{Epoch}$', size=22)\n",
    "ax[3].set_ylabel(r\"$\\text{Loss}$\", size=22)\n",
    "ax[3].set_title(r\"$\\text{Theoretical & Empirical}$\", size=26)\n",
    "\n",
    "# === Legend ===\n",
    "ax[3].legend(\n",
    "    handles=[\n",
    "        Line2D([0], [0], color='k', linestyle='-', lw=2, zorder=30),\n",
    "        Line2D([0], [0], color='k', linestyle='--', lw=2, zorder=30)\n",
    "    ],\n",
    "    labels=[r'$\\text{Simulation}$', r'$\\text{Theory}$'],\n",
    "    loc='lower left',\n",
    "    frameon=False,\n",
    "    fontsize=18\n",
    ")\n",
    "\n",
    "# === Colorbar indicating β values ===\n",
    "norm = Normalize(vmin=0, vmax=5)\n",
    "cmap = sns.color_palette(\"mako_r\", as_cmap=True)\n",
    "sm = ScalarMappable(norm=norm, cmap=cmap)\n",
    "\n",
    "\n",
    "cax = inset_axes(\n",
    "    ax[3], width=\"30%\", height=\"7%\", loc='lower left',\n",
    "    bbox_to_anchor=(0.05, 0.37, 1, 1), bbox_transform=ax[3].transAxes\n",
    ")\n",
    "cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')\n",
    "cbar.ax.tick_params(labelsize=14)\n",
    "cbar.set_ticks([0, 5])\n",
    "cbar.set_label(r\"$\\beta$\", rotation=0, labelpad=-45, fontsize=20, va='center')\n",
    "\n",
    "# === Subplot labels (a), (b), ... ===\n",
    "subplot_labels = [r'($\\mathbf{a}$)', r'($\\mathbf{b}$)', r'($\\mathbf{c}$)', r'($\\mathbf{d}$)']\n",
    "for i, label in enumerate(subplot_labels):\n",
    "    ax[i].text(\n",
    "        -0.15, 1.1, label,\n",
    "        transform=ax[i].transAxes,\n",
    "        fontsize=24, fontweight='bold',\n",
    "        va='top', ha='right', font=\"Arial\"\n",
    "    )\n",
    "\n",
    "# === Layout and Display ===\n",
    "\n",
    "plt.tight_layout()\n",
    "sns.despine()\n",
    "# plt.savefig('../figs/fig_5.pdf',bbox_inches=None)# 'tight'\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da97b2c-cb68-4503-ab52-115104258c19",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_time = 30.0    # total time\n",
    "sampling_rate = 10 # 10   # Hz\n",
    "n_samples = int(total_time * sampling_rate)\n",
    "t = np.linspace(0, total_time, n_samples, endpoint=False)\n",
    "\n",
    "# Load the saved arrays\n",
    "with open('../data/closed_loop_tracking_data.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "ref = data['ref']\n",
    "net_ot = data['net_ot']\n",
    "freqs_to_title = data['freqs_to_title']\n",
    "all_eigs = data['all_eigs']\n",
    "loss = data['loss']\n",
    "\n",
    "\n",
    "with open('../data/saved_arrays.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "r_x = data['r_x']\n",
    "r_y = data['r_y']\n",
    "states_x = data['states_x']\n",
    "states_y = data['states_y']\n",
    "test_losses = data['test_losses']\n",
    "all_loss = data['2_frq_loss']\n",
    "\n",
    "\n",
    "file_path = \"../data/adam_all_loss.pkl\"\n",
    "with open(file_path, \"rb\") as f:  # Note the 'rb' for reading binary\n",
    "    all_loss = pickle.load(f)\n",
    "\n",
    "\n",
    "S,E = 900,1100\n",
    "\n",
    "# Extract segments\n",
    "data_om1 = all_loss[S:E, 0]\n",
    "data_om2 = all_loss[S:E, 1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04feb680-5bb9-4c76-8088-844de378e13c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Figure and GridSpec ===\n",
    "fig = plt.figure(figsize=(18, 3.6))\n",
    "gs = gridspec.GridSpec(\n",
    "    4, 5,\n",
    "    width_ratios=[0.7, 1.0, 0.3, 0.2, 0.7],\n",
    "    height_ratios=[1, 1, 1, 1],\n",
    "    wspace=0.28, hspace=0.62\n",
    ")\n",
    "\n",
    "# === (a) Tracking Panel ===\n",
    "ax_track = plt.subplot(gs[:, 0])\n",
    "start, end = 48, 125\n",
    "x, y = states_x[start:end], states_y[start:end]\n",
    "\n",
    "ax_track.plot(r_x[start:end], r_y[start:end], color='k', lw=5, label=r'$\\text{Ref.}$')\n",
    "ax_track.plot(x, y, linestyle='--', color=sns.color_palette('deep')[0], lw=3, label=r'$\\text{RNN}$')\n",
    "ax_track.scatter(r_x[start], r_y[start], c=sns.color_palette('deep')[7], s=100, edgecolors='k', zorder=30)\n",
    "ax_track.scatter(x[0], y[0], c=sns.color_palette('deep')[0], s=100, edgecolors='k', zorder=30)\n",
    "ax_track.scatter(x[-1], y[-1], c=sns.color_palette('deep')[3], s=100, edgecolors='k', zorder=30)\n",
    "ax_track.text(-0.3, 0.63, r'$\\text{Start}$')\n",
    "ax_track.text(2.2, 2.5, r'$\\text{End}$')\n",
    "\n",
    "ax_track.set(\n",
    "    xlabel=r\"$x$\", ylabel=r\"$y$\",\n",
    "    xticks=[-4, 0, 4], yticks=[-4, 0, 4],\n",
    "    xlim=[-5.2, 5.2], ylim=[-5.2, 5.2]\n",
    ")\n",
    "ax_track.set_title(r'$\\text{Tracking Task}$', fontsize=18, pad=20)\n",
    "ax_track.legend(fontsize=16, loc='lower left', frameon=False, ncols=2)\n",
    "ax_track.tick_params(axis='both', labelsize=18)\n",
    "sns.despine(ax=ax_track)\n",
    "\n",
    "# === (b) Loss Plot ===\n",
    "ax_loss = plt.subplot(gs[:, 1])\n",
    "ax_loss.plot(loss, color=sns.color_palette('deep')[0], lw=3, label=r'$\\text{Closed-Loop}$')\n",
    "ax_loss.plot(test_losses, color=sns.color_palette('deep')[3], lw=3, label=r'$\\text{Open-Loop}$')\n",
    "\n",
    "highlight_epochs = [230, 1200, 2600, 4500]\n",
    "colors_idx = [1, 2, 4, 9]\n",
    "omega_labels = [r'$\\omega_{1}$', r'$\\omega_{2}$', r'$\\omega_{3}$', r'$\\omega_{4}$']\n",
    "y_offsets = [0.27, 0.18, 0.1, 0.03]\n",
    "\n",
    "for ep, idx, lbl, offset in zip(highlight_epochs, colors_idx, omega_labels, y_offsets):\n",
    "    ax_loss.scatter(ep, loss[ep], color=sns.color_palette('deep')[idx], s=50, edgecolors='k', zorder=36)\n",
    "    ax_loss.text(ep - 100, loss[ep] + offset, lbl, fontsize=20, zorder=30)\n",
    "\n",
    "ax_loss.set(\n",
    "    xscale='log', yscale='log',\n",
    "    xlabel=r'$\\text{Epoch}$', ylabel=r\"$\\text{Loss}$\",\n",
    "    xticks=[1, 10, 100, 1000, 10000],\n",
    "    xlim=(0, 10000)\n",
    ")\n",
    "ax_loss.xaxis.set_minor_locator(NullLocator())\n",
    "ax_loss.yaxis.set_minor_locator(NullLocator())\n",
    "ax_loss.set_title(r'$\\text{Test Loss}$', fontsize=18, pad=20)\n",
    "ax_loss.legend(fontsize=16, loc='lower left', frameon=False)\n",
    "ax_loss.tick_params(axis='both', labelsize=16)\n",
    "sns.despine(ax=ax_loss)\n",
    "\n",
    "# === (c) Small Trajectories ===\n",
    "small_axes = [plt.subplot(gs[i, 2]) for i in range(4)]\n",
    "for i, ax in enumerate(small_axes):\n",
    "    ax.plot(t, ref[i], '--', color='k', lw=1.5)\n",
    "    ax.plot(t, net_ot[i], color=sns.color_palette('deep')[colors_idx[i]], lw=2)\n",
    "    ax.set_xlim(0, 30.1)\n",
    "    ax.set_xticks([] if i < 3 else [0, 30])\n",
    "    ax.set_yticks([-3, 3])\n",
    "    ax.set_ylabel(r'$x$' if i % 2 == 0 else r'$y$', fontsize=18, labelpad=-5)\n",
    "    if i == 3:\n",
    "        ax.set_xlabel(r'$\\text{Time-Step}$', fontsize=18, labelpad=9)\n",
    "    ax.text(0.45, 1.1, omega_labels[i], transform=ax.transAxes, fontsize=16)\n",
    "    ax.tick_params(axis='both', labelsize=14)\n",
    "    sns.despine(ax=ax)\n",
    "\n",
    "small_axes[0].set_title(r'$\\text{Trajectory}$', fontsize=18, pad=20)\n",
    "\n",
    "# === (d) Eigenvalue Panels ===\n",
    "tiny_axes = [plt.subplot(gs[i, 3]) for i in range(4)]\n",
    "for i, ax in enumerate(tiny_axes):\n",
    "    eig = all_eigs[i]\n",
    "    ax.scatter(eig.real, eig.imag, color=sns.color_palette('deep')[colors_idx[i]], s=14)\n",
    "    ax.axhline(0, color='k', lw=2)\n",
    "    ax.axvline(0, color='k', lw=2)\n",
    "    ax.axhline(freqs_to_title[i], color='k', alpha=0.8)\n",
    "    ax.set(\n",
    "        xlim=(-2, 2), ylim=(-1, 3),\n",
    "        yticks=[round(freqs_to_title[i], 2)],\n",
    "        yticklabels=[omega_labels[i]]\n",
    "    )\n",
    "    if i == 3:\n",
    "        ax.set_xticks([-1, 1])\n",
    "        ax.set_xlabel(r'$\\text{Re}$', fontsize=18, labelpad=12)\n",
    "    else:\n",
    "        ax.set_xticks([])\n",
    "    ax.set_ylabel(r'$\\text{Im}$', fontsize=18, labelpad=-2)\n",
    "    ax.tick_params(axis='both', labelsize=14)\n",
    "    sns.despine(ax=ax)\n",
    "\n",
    "tiny_axes[0].set_title(r'$\\text{eig}(\\mathbfit{W})$', fontsize=18, pad=20)\n",
    "\n",
    "# === (e) Zoom-in Plateau ===\n",
    "medium_axes = [plt.subplot(gs[0:2, 4]), plt.subplot(gs[2:4, 4])]\n",
    "\n",
    "# Left y-axis\n",
    "medium_axes[0].plot(np.arange(S, E), data_om2, color=sns.color_palette('deep')[2], label=r'$\\omega_2$')\n",
    "medium_axes[0].set_ylabel(r'$\\text{Loss}\\,\\,\\,\\omega_2$', fontsize=15, labelpad=11)\n",
    "\n",
    "# Right y-axis\n",
    "ax2 = medium_axes[0].twinx()\n",
    "ax2.plot(np.arange(S, E), data_om1, color=sns.color_palette('deep')[1], label=r'$\\omega_1$')\n",
    "ax2.set_ylabel(r'$\\text{Loss}\\,\\,\\,\\omega_1$', fontsize=15)\n",
    "\n",
    "# Styling\n",
    "medium_axes[0].set_xlabel(r'$\\text{Epoch}$', fontsize=15)\n",
    "medium_axes[0].set_title(r'$\\text{Zoom-In Plateau}$', fontsize=16)\n",
    "medium_axes[0].tick_params(axis='both', labelsize=14)\n",
    "ax2.tick_params(axis='both', labelsize=11)\n",
    "\n",
    "# Move both axes up\n",
    "pos1, pos2 = medium_axes[0].get_position(), ax2.get_position()\n",
    "new_bottom = pos1.y0 + 0.12\n",
    "new_height = pos1.height * 0.85\n",
    "medium_axes[0].set_position([pos1.x0, new_bottom, pos1.width, new_height])\n",
    "ax2.set_position([pos2.x0, new_bottom, pos2.width, new_height])\n",
    "sns.despine(ax=medium_axes[0])\n",
    "sns.despine(ax=ax2)\n",
    "medium_axes[0].spines['right'].set_visible(True)\n",
    "ax2.spines['right'].set_visible(True)\n",
    "\n",
    "# === (f) Frequency Dependent Learning ===\n",
    "medium_axes[1].plot(np.array(highlight_epochs) / 10000, '.-', ms=12, color=sns.color_palette('deep')[0], label=r'$\\text{RNN}$')\n",
    "medium_axes[1].plot(np.array([2.5, 4, 6, 9.5]) / 20, '.-', ms=12, color=sns.color_palette('deep')[2], label=r'$\\text{Data}$')\n",
    "medium_axes[1].set(\n",
    "    xticks=[0, 1, 2, 3],\n",
    "    xticklabels=omega_labels,\n",
    "    ylim=(-0.05, 0.5),\n",
    "    yticks=[0.1, 0.5],\n",
    "    xlabel=r'$\\text{Frequency}$',\n",
    "    ylabel=r'$\\text{Norm. Acq. Time}$'\n",
    ")\n",
    "medium_axes[1].tick_params(axis='both', labelsize=14)\n",
    "medium_axes[1].legend(fontsize=12, loc='lower right', frameon=False)\n",
    "medium_axes[1].set_title(r'$\\text{Freq. Dependent Learning}$', fontsize=15)\n",
    "sns.despine(ax=medium_axes[1])\n",
    "\n",
    "# === Subplot Labels ===\n",
    "labels = [r'($\\mathbf{a}$)', r'($\\mathbf{b}$)', r'($\\mathbf{c}$)', r'($\\mathbf{d}$)', r'($\\mathbf{e}$)', r'($\\mathbf{f}$)']\n",
    "positions = [\n",
    "    (ax_track, (-0.05, 1.2)),\n",
    "    (ax_loss, (-0.05, 1.2)),\n",
    "    (small_axes[0], (-0.2, 2.2)),\n",
    "    (tiny_axes[0], (-0.25, 2.2)),\n",
    "    (medium_axes[0], (-0.10, 1.295)),\n",
    "    (medium_axes[1], (-0.10, 1.4))\n",
    "]\n",
    "for ax, (x, y), label in zip(*zip(*positions), labels):\n",
    "    ax.text(x, y, label, transform=ax.transAxes, fontsize=20, fontweight='bold', va='top', ha='right', font=\"Arial\")\n",
    "\n",
    "# Arrow annotation pointing to (e)\n",
    "ax_loss.annotate(\n",
    "    r'($\\mathbf{e}$)', xy=(700, loss[700]), xytext=(650, loss[700] + 1.24),\n",
    "    arrowprops=dict(arrowstyle='-|>', color='black', lw=1.5),\n",
    "    fontsize=19, color='k', zorder=30\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig('../figs/fig_6.pdf', bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
