{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86e8377d-d4bf-4961-8221-5ef5a84c635c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import pickle\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from copy import deepcopy\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "from utils import *\n",
    "from matplotlib.cm import ScalarMappable\n",
    "from matplotlib.colors import Normalize\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1209bc9f-fa98-428b-92e3-f381dc52cb8f",
   "metadata": {},
   "source": [
    "### Train closed-loop low-rank RNN on double integrator task "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fff3df0-edfb-4386-9b4f-a59999438ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "system = k_integrator_torch(k=2, dt=1, c=1, m=1)\n",
    "model = P_Model(N=N, g=0.0, rank=1, system=system) # Instantiate the model\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Instantiate the optimizer\n",
    "num_epochs = 1000\n",
    "loss,model_his_parameters, grad = train_model_p(model,\n",
    "                                              optimizer,\n",
    "                                              teacher=None,\n",
    "                                              white_noise=None,\n",
    "                                              w_grad_clip=True,\n",
    "                                              num_epochs=num_epochs,\n",
    "                                              batch_size=100,\n",
    "                                              num_steps=50,\n",
    "                                              clamp=False)\n",
    "\n",
    "\n",
    "all_P = []\n",
    "\n",
    "for ep in range(num_epochs):\n",
    "    model = model_his_parameters[ep]\n",
    "    P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "    all_P.append(P.detach().numpy())\n",
    "    \n",
    "all_P = np.array(all_P)\n",
    "eig,_ = np.linalg.eig(np.array(all_P))\n",
    "stab = np.all(np.abs(eig) < 1, axis=1).astype(int)\n",
    "\n",
    "try:\n",
    "    for i in range(num_epochs):\n",
    "        if stab[i] == 1 and all(stab[i:] == [1] * (len(stab) - i)):\n",
    "            ep_of_stab = i\n",
    "            break\n",
    "except:\n",
    "    ep_of_stab = num_epochs+1\n",
    "\n",
    "all_eig_whh = []\n",
    "for ep in range(num_epochs):\n",
    "    eig,_ = np.linalg.eig(all_P[ep,2:,2:])\n",
    "    all_eig_whh.append(eig)\n",
    "    \n",
    "all_eig_whh = np.array(all_eig_whh)\n",
    "\n",
    "plt.plot(np.log(loss))\n",
    "plt.axvline(x=ep_of_stab,ls='--',color='k',label='stability')\n",
    "plt.legend()\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('log(loss)')\n",
    "\n",
    "plt.tight_layout()\n",
    "sns.despine()\n",
    "plt.show()\n",
    "\n",
    "tt,uu = simulate_p(P,50,N)\n",
    "fig,ax = plt.subplots(1,3,figsize=(12, 3))\n",
    "for i in range(2):\n",
    "    ax[i].plot(torch.stack(tt).squeeze().detach().numpy()[:,i],color=sns.color_palette('deep')[i])\n",
    "    ax[i].axhline(y=0, color='k', ls='--', label='Target')\n",
    "    ax[i].set_xlabel('Time Steps',size=14) \n",
    "    ax[i].set_ylabel(f'$x_{i+1}$',size=14) \n",
    "    ax[i].tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "\n",
    "ax[2].plot(torch.stack(uu).squeeze().detach().numpy(),color=sns.color_palette('deep')[2])\n",
    "ax[2].set_xlabel('Time Steps',size=14) \n",
    "ax[2].set_ylabel(f'$u$',size=14) \n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bbc595e-fa96-45a2-b610-f52764e726fc",
   "metadata": {},
   "source": [
    "### Stage-1: Effective loss using geometric sum $\\lambda_1 \\approx 1 + \\sqrt{\\sigma_{zm}}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "339d07b1-ddbf-4d9a-904d-6ce545083289",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 100\n",
    "\n",
    "# stage 1 \n",
    "def run_traj(P,num_steps=50):\n",
    "    x0 = np.array([[1]]) # np.random.uniform(-2, 2, (1,1))\n",
    "    x1 = np.array([[-1]]) # np.random.uniform(-2, 2, (1,1))\n",
    "    h0 = np.zeros((model.N,1))\n",
    "    h = torch.tensor(np.vstack((x0, x1, h0)), dtype=torch.float32)\n",
    "    tt = [] \n",
    "    for i in range(num_steps):\n",
    "        tt.append(h[:2])\n",
    "        h = P@h\n",
    "    return torch.stack(tt).squeeze().detach().numpy()\n",
    "\n",
    "zm_pos = np.linspace(0.0001, 0.2, 100)\n",
    "loss_sim_pos = []\n",
    "loss_the_pos = [] \n",
    "\n",
    "for i,cur_zm in enumerate(zm_pos):\n",
    "    model = P_Model(N=N, g=0.0, system=system, over=cur_zm) # Instantiate the model\n",
    "    P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "    traj = run_traj(P,num_steps=num_steps)\n",
    "    _loss = np.sum( 0.5*(traj[:, 0]**2 + traj[:, 1]**2))\n",
    "    loss_sim_pos.append(_loss)\n",
    "    T = num_steps  + 1 \n",
    "    r  = (np.abs(1 + np.sqrt(cur_zm + 0j))) **2 \n",
    "    loss_the_pos.append( 0.5 * (1 - r**T) / (1 - r))\n",
    "    \n",
    "\n",
    "zm_neg = np.linspace(-0.5, -0.0001, 100)\n",
    "loss_sim_neg = []\n",
    "loss_the_neg = [] \n",
    "\n",
    "for i,cur_zm in enumerate(zm_neg):\n",
    "    model = P_Model(N=N, g=0.0, system=system, over=cur_zm) # Instantiate the model\n",
    "    P = create_p_matrix_from_low_rank(model.system, model.N, model.M, model.Z, model.U, model.V, model.W_random)\n",
    "    traj = run_traj(P,num_steps=num_steps)\n",
    "    _loss = np.sum( 0.5* (traj[:, 0]**2 + traj[:, 1]**2))\n",
    "    loss_sim_neg.append(_loss)\n",
    "    T = num_steps + 1 \n",
    "    r  = (np.abs(1 + np.sqrt(cur_zm + 0j))) ** 2\n",
    "    loss_the_neg.append( 0.5 * (1 - r**T) / (1 - r))\n",
    "\n",
    "plt.plot(zm_pos, np.log(loss_sim_pos),color=sns.color_palette('mako_r')[0], lw=4, label=r'$\\text{Empirical}$')\n",
    "plt.plot(zm_pos, np.log(loss_the_pos), '--k', lw=4, label=r'$\\text{Theory}$')\n",
    "plt.plot(zm_neg, np.log(loss_sim_neg), color=sns.color_palette('mako_r')[0], lw=4)\n",
    "plt.plot(zm_neg, np.log(loss_the_neg), '--k', lw=4)\n",
    "\n",
    "plt.legend(fontsize=18, loc='upper left',frameon=False,)\n",
    "\n",
    "plt.xlabel(r'$ \\sigma_{zm} $',size=30)\n",
    "plt.ylabel(r\"$\\text{Loss}$\",size=22)\n",
    "\n",
    "plt.tight_layout()\n",
    "sns.despine()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef05d796-e041-4b0b-b90c-c3b6b97d2c82",
   "metadata": {},
   "source": [
    "### Stage-2: Effective loss with $\\alpha$ hyperparameter controlling the weighting of short- and long-term contributions\n",
    "\n",
    "We simulate optimization in a reduced model using closed-form approximation of $\\lambda_3$ as a function of order parameters.  \n",
    "The effective loss combines a short-term surrogate $\\mathcal{L}_0$ and a long-term $\\mathcal{L}_\\infty^T$, weighted by $\\alpha$, which acts analogously to episode length $T$ in the full model.  \n",
    "Varying $\\alpha$ reveals eigenvalue trajectories that reflect different trade-offs between early episode control and long-term stability."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b882af8-9492-40d8-8629-263055093fe0",
   "metadata": {},
   "source": [
    "#### Helper to computes $\\mathcal{L}_0$ using low-order matrix powers of $P^n$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7461010d-bc16-41d6-8d63-c673a6288f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# help compute loss_0 (P^n) where n is small eg 1,2,3,4\n",
    "import sympy\n",
    "\n",
    "# Define symbols\n",
    "vu, zm, vm, zu,  = sympy.symbols('vu zm vm zu', complex=True)\n",
    "\n",
    "# Corrected matrix P\n",
    "P_eff = sympy.Matrix([\n",
    "    [1, 1, 0,    0],\n",
    "    [0, 1, zm,  zu],\n",
    "    [1, 1, 0,  0 ],\n",
    "    [0, 0, vm,  vu ]\n",
    "])\n",
    "\n",
    "P_eff = P_eff**3\n",
    "expr = sympy.simplify( P_eff[0,0]**2  + P_eff[1,0]**2 +  P_eff[0,1]**2 + P_eff[1,1]**2 )\n",
    "expr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9cc8a85-3deb-4eee-bcaa-6891d09b6f91",
   "metadata": {},
   "source": [
    "#### Helper to verify the perturbative approximation around $\\lambda_3$ for computing $\\mathcal{L}_\\infty$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8c1740-7579-4e06-af07-8974333a9545",
   "metadata": {},
   "outputs": [],
   "source": [
    "def quadSolver(a, b, c, tol=1e-18):\n",
    "    if a == b == 0:\n",
    "        print('Invalid equation.' if c != 0 else 'Trivial identity: 0 = 0')\n",
    "        return\n",
    "    if a == 0:\n",
    "        return [-c / b]\n",
    "    D = b**2 - 4 * a * c\n",
    "    return [(-b - np.sqrt(D)) / (2 * a), (-b + np.sqrt(D)) / (2 * a)]\n",
    "\n",
    "def compute_eps_order(m, z, u, v, lam, order=2):\n",
    "    a = 3 * lam - (v @ u + 2)\n",
    "    b = 3 * lam**2 - 2 * lam * (v @ u + 2) + (2 * v @ u - z @ m + 1)\n",
    "    c = lam**3 - (v @ u + 2) * lam**2 + (2 * v @ u - z @ m + 1) * lam + \\\n",
    "        (v @ u) * (z @ m) - v @ u - (v @ m) * (z @ u)\n",
    "    if order == 2:\n",
    "        roots = quadSolver(a, b, c)\n",
    "        return roots[0] if len(roots) > 1 else roots[1]\n",
    "    return -c / b if b != 0 else 0\n",
    "\n",
    "scan_ep = np.arange(10, ep_of_stab, 1)\n",
    "\n",
    "# === Eigenvalue evolution ===\n",
    "eig_empirical = []\n",
    "for ep in scan_ep:\n",
    "    eigvals, _ = np.linalg.eig(all_P[ep])\n",
    "    sorted_eigs = np.sort(np.abs(eigvals))[::-1]\n",
    "    eig_empirical.append(sorted_eigs)\n",
    "eig_empirical = np.array(eig_empirical)\n",
    "\n",
    "# === Init figure ===\n",
    "fig, ax = plt.subplots(1, 5, figsize=(19, 4))\n",
    "for i in range(4):\n",
    "    ax[i].set_title(f'$| \\lambda_{{{i+1}}} |$',size=20)\n",
    "    ax[i].set_xlabel('Epoch')\n",
    "    ax[i].plot(scan_ep, np.abs(eig_empirical[:, i]), label='Empirical')\n",
    "    ax[i].axvline(x=ep_of_stab, ls='--', color='r')\n",
    "    \n",
    "\n",
    "# === Roots of characteristic poly ===\n",
    "poly_root1, poly_root2, poly_root3 = [], [], []\n",
    "root1_theory = [] \n",
    "root3_theory = [] \n",
    "for ep in scan_ep:\n",
    "    model = model_his_parameters[ep]\n",
    "    m, z, u, v = map(lambda x: x.detach().numpy().flatten(), [model.M, model.Z, model.U, model.V])\n",
    "    coeffs = [1, -v @ u - 2, 2 * v @ u - z @ m + 1, (v @ u)*(z @ m) - v @ u - (v @ m)*(z @ u)]\n",
    "    roots = np.roots(coeffs)\n",
    "    poly_root1.append(roots[0])\n",
    "    poly_root2.append(roots[1])\n",
    "    poly_root3.append(roots[2])\n",
    "\n",
    "    lam = 1 + np.sqrt(z@m + 0j)\n",
    "    root1_theory.append(lam + compute_eps_order(m, z, u, v, lam,))\n",
    "    root3_theory.append(v @ u + ((v @ m)*(z @ u)) / ((v @ u)**2 - 2*v @ u - z @ m + 1))\n",
    "\n",
    "# ax[0].plot(scan_ep, np.abs(poly_root1), '--', color='g', label='Empirical')\n",
    "# ax[1].plot(scan_ep, np.abs(poly_root2), '--', color='g', label='Empirical')\n",
    "# ax[2].plot(scan_ep, np.abs(poly_root3), '--', color='g', label='Empirical')\n",
    "\n",
    "ax[0].plot(scan_ep, np.abs(root1_theory), '--', color='k', label='Theory')\n",
    "ax[1].plot(scan_ep, np.abs(root1_theory), '--', color='k', label='Theory')\n",
    "ax[2].plot(scan_ep, np.abs(root3_theory), '--', color='k', label='Theory')\n",
    "\n",
    "ax[0].legend()\n",
    "ax[1].legend()\n",
    "ax[2].legend()\n",
    "\n",
    "# === Add uᵀv line ===\n",
    "inner_uv = []\n",
    "for ep in scan_ep:\n",
    "    model = model_his_parameters[ep]\n",
    "    u, v = model.U.detach().numpy().flatten(), model.V.detach().numpy().flatten()\n",
    "    inner_uv.append(u @ v)\n",
    "# ax[2].plot(scan_ep, np.abs(inner_uv), '--', color='gray', label=r'$u^{\\top}v$',alpha=0.4)\n",
    "ax[2].legend()\n",
    "\n",
    "# === Center axis line ===\n",
    "ax[3].axhline(y=0, ls='--', color='k')\n",
    "ax[3].set_ylim(-0.01, 0.01)\n",
    "\n",
    "# === Analytic vs Empirical λ₁² ===\n",
    "s_empirical, s_theory, s_theory_uv = [], [], []\n",
    "for ep in scan_ep:\n",
    "    model = model_his_parameters[ep]\n",
    "    m, z, u, v = map(lambda x: x.detach().numpy().flatten(), [model.M, model.Z, model.U, model.V])\n",
    "    eigvals, _ = np.linalg.eig(all_P[ep])\n",
    "    idx = np.argsort(np.abs(eigvals))[::-1]\n",
    "    λ3 = eigvals[idx[2]]\n",
    "    s_empirical.append(np.abs(eigvals[idx[1]]**2))\n",
    "\n",
    "    b = -v @ u - 2\n",
    "    c = 2 * v @ u - z @ m + 1\n",
    "    λ3_approx = v @ u + ((v @ m)*(z @ u)) / ((v @ u)**2 - 2*v @ u - z @ m + 1)\n",
    "    s_theory.append(c + b * λ3_approx + λ3_approx**2)\n",
    "    s_theory_uv.append(c + b * (u @ v) + (u @ v)**2)\n",
    "\n",
    "ax[4].set_title(r'$|\\lambda_1|^2 \\quad (\\mathcal{L}_\\infty)$',size=20)\n",
    "ax[4].plot(scan_ep, s_empirical, label='Empirical')\n",
    "ax[4].plot(scan_ep, s_theory, '--', color='k', label='Theory')\n",
    "# ax[4].plot(scan_ep, s_theory_uv, '--', color='gray', label=r'$c + b u^{\\top}v + (u^{\\top}v)^2$', alpha=0.4)\n",
    "ax[4].axvline(x=ep_of_stab, ls='--', color='red')\n",
    "ax[4].set_xlabel('Epoch')\n",
    "ax[4].legend(fontsize='small')\n",
    "\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17376414-fafc-41ed-80fa-4f79d01a0d86",
   "metadata": {},
   "source": [
    "#### Optimize effective loss $\\alpha \\mathcal{L}_\\infty^T + (1{-}\\alpha)\\mathcal{L}_0$ across varying $\\alpha$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44755ac3-3393-47bc-835d-d9d146a02d09",
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = sns.color_palette('mako_r')\n",
    "alphas = [0, 0.2, 0.3, 0.35, 0.4, 1.0]\n",
    "norm = Normalize(vmin=min(alphas), vmax=max(alphas))\n",
    "cmap = plt.get_cmap('mako_r')\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 5))  # capture fig and ax\n",
    "\n",
    "# stage 2 \n",
    "class ReduceModel(nn.Module):\n",
    "    def __init__(self, vu=0, zm=0.01, vm=0, zu=0, alpha=0.01, T=10):\n",
    "        super().__init__()\n",
    "        self.vu = nn.Parameter(torch.tensor(vu, dtype=torch.float32))\n",
    "        self.zm = nn.Parameter(torch.tensor(zm, dtype=torch.float32))\n",
    "        self.vm = nn.Parameter(torch.tensor(vm, dtype=torch.float32))\n",
    "        self.zu = nn.Parameter(torch.tensor(zu, dtype=torch.float32))\n",
    "        self.alpha = alpha\n",
    "        self.T = T\n",
    "\n",
    "    def forward(self):\n",
    "        vu, zm, vm, zu = self.vu, self.zm, self.vm, self.zu\n",
    "        lambda_3 = vu + (vm * zu) / (vu**2 - 2*vu - zm + 1)\n",
    "        loss_inf = (2*vu - zm + 1 + (-vu-2)*lambda_3 + lambda_3**2)**self.T\n",
    "        loss_0 = (zm+1)**2 + (zm+3)**2 + (vm*zu + 2*zm)**2 + (vm*zu + 3*zm + 1)**2\n",
    "        return self.alpha * loss_inf + (1 - self.alpha) * loss_0\n",
    "\n",
    "def optimize_and_track(vu, zm, vm, zu, alpha, T, lr, color):\n",
    "    model = ReduceModel(vu, zm, vm, zu, alpha, T)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "    red_model_parameters = []\n",
    "\n",
    "    for _ in range(200):\n",
    "        _loss = model()\n",
    "        red_model_parameters.append(deepcopy(model))\n",
    "        optimizer.zero_grad()\n",
    "        _loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "        optimizer.step()\n",
    "\n",
    "    z = [] \n",
    "    for m in red_model_parameters:\n",
    "        vu, zm, vm, zu = m.vu.item(), m.zm.item(), m.vm.item(), m.zu.item()\n",
    "        coeffs = [1, -vu - 2, 2*vu - zm + 1, vu*zm - vu - vm*zu]\n",
    "        roots = np.roots(coeffs)\n",
    "        z.append(roots[0])\n",
    "        \n",
    "    ax.plot(np.real(z), np.imag(z), color=color, lw=4)\n",
    "    ax.scatter(z[0].real, z[0].imag, marker='o', color='crimson', s=100, edgecolor='k', linewidths=2, zorder=30)\n",
    "    ax.scatter(z[-1].real, z[-1].imag, s=50, color=color, edgecolor='k', linewidths=2, zorder=30)\n",
    "\n",
    "\n",
    "model = model_his_parameters[10]\n",
    "m, z, u, v = model.M.detach().numpy().flatten(), model.Z.detach().numpy().flatten(), model.U.detach().numpy().flatten(), model.V.detach().numpy().flatten()\n",
    "vu, zm, vm, zu = v @ u, z @ m, v @ m, z @ u\n",
    "\n",
    "for i, alpha in enumerate(alphas):\n",
    "    optimize_and_track(vu, zm, vm, zu, alpha=alpha, T=20, lr=0.001, color=palette[i])\n",
    "\n",
    "# === Formatting ===\n",
    "ax.axhline(0, color='grey', lw=1)\n",
    "ax.axvline(0, color='grey', lw=1)\n",
    "theta = np.linspace(0, 2*np.pi, 100)\n",
    "ax.plot(np.cos(theta), np.sin(theta), color='grey', ls='--', lw=2.5)\n",
    "\n",
    "ax.set_xlim(0.95, 1.02)\n",
    "ax.set_ylim(0.04, 0.3)\n",
    "ax.set_title(r'$\\text{eig}(\\mathbfit{P})$', size=22)\n",
    "ax.set_ylabel(r'$\\text{Imag.}$', size=22)\n",
    "ax.set_xlabel(r'$\\text{Real}$', size=22)\n",
    "ax.set_yticks([0.13, 0.25])\n",
    "\n",
    "# === Colorbar ===\n",
    "sm = ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)\n",
    "cbar.set_label(r'$\\alpha$', size=18)\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f65f23b-86b1-4d9d-8610-7b06ee3a25c0",
   "metadata": {},
   "source": [
    "### Stage-3: Gradient descent on the order parameters fit to the full high-dimensional model\n",
    "\n",
    "At stabilization epoch we extract the effective parameters from the high-dimensional model:\n",
    "\n",
    "$$\n",
    "\\sigma_{vu} = v^\\top u \\quad\n",
    "\\sigma_{zm} = z^\\top m \\quad\n",
    "\\sigma_{vm} = v^\\top m \\quad\n",
    "\\sigma_{zu} = z^\\top u \n",
    "$$\n",
    "\n",
    "We then perform gradient descent on the order parameters,  minimizing the loss over batch simulations with a low-dimensional approximation $ P_{\\text{eff}} $. \n",
    "Loss and eigenvalues are compared to assess how well the reduced model captures full-model learning dynamics in Stage 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c51c5751-339c-46bd-a3e3-2bdea04c48ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "ST = ep_of_stab\n",
    "\n",
    "model = model_his_parameters[ST]\n",
    "m, z, u, v = model.M.detach().numpy().flatten(), model.Z.detach().numpy().flatten(), model.U.detach().numpy().flatten(), model.V.detach().numpy().flatten()\n",
    "model = P_Model_eff(init_sig_zm=z@m, init_sig_zu=z@u, init_sig_vm=v@m, init_sig_vu=v@u)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Instantiate the optimizer\n",
    "num_epochs = 1000\n",
    "batch_size = 100\n",
    "eff_loss, eff_model_his_parameters, eff_grad = train_model_p_eff(model, optimizer, teacher=None, white_noise=None, beta=0,\n",
    "                                               w_grad_clip=1, num_epochs=num_epochs, batch_size=batch_size, num_steps=50, clamp=False)\n",
    "\n",
    "eff_all_P = []\n",
    "for ep in range(num_epochs):\n",
    "    P = create_p_effective(eff_model_his_parameters[ep])\n",
    "    eff_all_P.append(P)\n",
    "\n",
    "eff_all_P = np.array(eff_all_P)\n",
    "eig,_ = np.linalg.eig(eff_all_P)\n",
    "stab = np.all(np.abs(eig) < 1, axis=1).astype(int)\n",
    "\n",
    "try:\n",
    "    for i in range(num_epochs):\n",
    "        if stab[i] == 1 and all(stab[i:] == [1] * (len(stab) - i)):\n",
    "            ep_of_stab_eff = i\n",
    "            break\n",
    "except:\n",
    "    ep_of_stab_eff = num_epochs+1\n",
    "\n",
    "\n",
    "fig, (ax0,ax1) = plt.subplots(1,2,figsize=(14,5))\n",
    "\n",
    "\n",
    "ax0.plot(loss[ST:], label='Empirical',lw=2)\n",
    "ax0.plot(eff_loss,ls='--',color='k', label='Theory',lw=2)\n",
    "\n",
    "ax0.set_xlabel('Epoch',size=22)\n",
    "ax0.set_ylabel('Loss',size=22)\n",
    "\n",
    "ax0.set_yscale('log')\n",
    "ax0.set_xscale('log')\n",
    "ax0.legend(fontsize=22)\n",
    "\n",
    "eig,_ = np.linalg.eig(all_P)\n",
    "ax1.scatter(eig.real[-1],eig.imag[-1],marker='o',s=100, label='Empirical')\n",
    "\n",
    "eff_all_P = np.array(eff_all_P)\n",
    "eig,_ = np.linalg.eig(eff_all_P)\n",
    "ax1.scatter(eig.real[-1],eig.imag[-1],color='k',marker='x',s=100, label='Theory')\n",
    "ax1.axhline(0, color='grey', lw=1)\n",
    "ax1.axvline(0, color='grey', lw=1)\n",
    "theta = np.linspace(0, 2*np.pi, 100)\n",
    "ax1.plot(np.cos(theta), np.sin(theta), color='grey', ls='--', lw=2.5)\n",
    "\n",
    "ax1.set_title(r'$\\text{eig}(\\mathbfit{P})$', size=22)\n",
    "ax1.set_ylabel(r'$\\text{Imag.}$', size=22)\n",
    "ax1.set_xlabel(r'$\\text{Real}$', size=22)\n",
    "ax1.legend(fontsize=22,loc='upper left')\n",
    "ax1.axis('equal')\n",
    "\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
