{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc974d42-e1f8-4978-883d-837d66dfc4b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from scipy.interpolate import RegularGridInterpolator\n",
    "from scipy.stats import binned_statistic_2d\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "from utils import *\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aafbe6d4-c782-480a-b198-293dbdd84415",
   "metadata": {},
   "outputs": [],
   "source": [
    "# System and training parameters\n",
    "K = 2           # System dimension\n",
    "M = 1           # Control output size\n",
    "C = 1           # Input dimension\n",
    "dt = 1          # Time step\n",
    "low_d = False   # Use low-dimensional RNN\n",
    "opt = 'SGD'     # Optimizer type\n",
    "lr = 1e-2       # Learning rate\n",
    "\n",
    "# Initialize system (k-th order integrator)\n",
    "system = k_integrator_torch(k=K, dt=dt, c=C, m=M, noise=0.0, clamp=False)\n",
    "x_target = np.zeros(system.k)\n",
    "\n",
    "# Controller setup\n",
    "hidden_size = system.k if low_d else 100\n",
    "phi = 'tanh'\n",
    "save_model = 1\n",
    "\n",
    "controller = ct_rnn_controller(\n",
    "    input_size=system.c,\n",
    "    hidden_size=hidden_size,\n",
    "    output_size=system.m,\n",
    "    phi=phi,\n",
    "    manually_initialize=True,\n",
    "    with_bias=False,\n",
    "    train_Wih=True,\n",
    "    train_Whh=True,\n",
    "    train_Who=True,\n",
    "    small_scale=True,\n",
    "    Wih_start_big=False,\n",
    "    g=0.1,\n",
    "    dt=dt,\n",
    "    tau=1,\n",
    "    tracking_task=False,\n",
    "    RL=False\n",
    ")\n",
    "\n",
    "# Training loop\n",
    "net_name = 'closed_loop'\n",
    "loss = train_rnn(\n",
    "    net_num=net_name,\n",
    "    path='full_rnn_non_linear',\n",
    "    controller=controller,\n",
    "    system=system,\n",
    "    x_target=x_target,\n",
    "    teacher=None,\n",
    "    LR=lr,\n",
    "    num_epochs=1001,\n",
    "    opt=opt,\n",
    "    w_grad_clip=True,\n",
    "    batch_size=100,\n",
    "    num_steps=50,\n",
    "    dt=dt,\n",
    "    reg_U=0.005,\n",
    "    save_model=save_model,\n",
    "    log=False\n",
    ")\n",
    "\n",
    "# Plot training loss (log-scale)\n",
    "plt.plot(np.log(loss))\n",
    "plt.title(net_name)\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Log Loss')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Simulate trained controller\n",
    "system_state, control_u, hidden_state, state_loss, control_loss = simulate_rnn(\n",
    "    controller=controller,\n",
    "    system=system,\n",
    "    k=system.k,\n",
    "    num_steps=50,\n",
    "    x_target=np.zeros(system.k),\n",
    "    batch_size=1,\n",
    "    init_con=torch.rand(system.k, 1) * 2 - 1\n",
    ")\n",
    "\n",
    "# Plot system state and control output\n",
    "fig, ax = plt.subplots(1, system.k + 1, figsize=(14, 3))\n",
    "\n",
    "# Plot each system dimension\n",
    "for i in range(system.k):\n",
    "    ax[i].plot(system_state[:, i].squeeze(), color=sns.color_palette('tab10')[i])\n",
    "    ax[i].axhline(y=x_target[i], color='k', ls='--', label='Target')\n",
    "    ax[i].set_xlabel('Time Steps', size=14)\n",
    "    ax[i].set_ylabel(f'$x_{i}$', size=14)\n",
    "    ax[i].tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "# Plot control signal\n",
    "for j in range(M):\n",
    "    ax[system.k].plot(control_u[:, 0, j], color=sns.color_palette('tab10')[system.k + j])\n",
    "    ax[system.k].set_xlabel('Time Steps', size=14)\n",
    "    ax[system.k].set_ylabel('Control', size=14)\n",
    "    ax[system.k].tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Final state cost\n",
    "print(f'State cost: {np.sum(state_loss):.4f}')\n",
    "\n",
    "# Plot eigenvalues of recurrent matrix\n",
    "e_vals, _ = np.linalg.eig(controller.Whh.detach().numpy())\n",
    "plt.scatter(e_vals.real, e_vals.imag)\n",
    "plt.title('Eigenvalue Spectrum of $\\mathbf{W}_{hh}$')\n",
    "plt.xlabel('Re(λ)')\n",
    "plt.ylabel('Im(λ)')\n",
    "plt.axis('equal')\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3339ccb-e85a-4092-98d8-2da3762b5942",
   "metadata": {},
   "outputs": [],
   "source": [
    "# System and training parameters\n",
    "K = 2           # System dimension\n",
    "M = 1           # Control output size\n",
    "C = 1           # Input dimension\n",
    "dt = 1          # Time step\n",
    "low_d = False   # Use low-dimensional controller\n",
    "opt = 'SGD'\n",
    "lr = 1e-2\n",
    "\n",
    "# Initialize system\n",
    "system = k_integrator_torch(k=K, dt=dt, c=C, m=M, noise=0.0, clamp=False)\n",
    "x_target = np.zeros(system.k)\n",
    "\n",
    "# Set hidden size based on low_d flag\n",
    "hidden_size = system.k if low_d else 100\n",
    "phi = 'tanh'\n",
    "save_model = 1\n",
    "\n",
    "# Define student controller (to be trained)\n",
    "controller2 = ct_rnn_controller(\n",
    "    input_size=system.c,\n",
    "    hidden_size=hidden_size,\n",
    "    output_size=system.m,\n",
    "    phi=phi,\n",
    "    manually_initialize=True,\n",
    "    with_bias=False,\n",
    "    train_Wih=True,\n",
    "    train_Whh=True,\n",
    "    train_Who=True,\n",
    "    small_scale=True,\n",
    "    Wih_start_big=False,\n",
    "    g=0.1,\n",
    "    dt=dt,\n",
    "    tau=1,\n",
    "    tracking_task=False,\n",
    "    RL=False\n",
    ")\n",
    "\n",
    "# Optional initialization for low-dimensional controller\n",
    "if low_d:\n",
    "    vector1 = np.zeros(hidden_size)\n",
    "    vector1[0] = 1.0\n",
    "    vector2 = np.full(hidden_size, 0.01)\n",
    "    vector2[0] = 0.0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        controller2.Wih.copy_(torch.Tensor(vector1.reshape(1, hidden_size)))\n",
    "        controller2.Who.copy_(torch.Tensor(vector2.reshape(hidden_size, 1)))\n",
    "\n",
    "# Define teacher controller\n",
    "controller_teacher = ct_rnn_controller(\n",
    "    input_size=system.c,\n",
    "    hidden_size=hidden_size,\n",
    "    output_size=system.m,\n",
    "    phi=phi,\n",
    "    manually_initialize=True,\n",
    "    with_bias=False,\n",
    "    train_Wih=True,\n",
    "    train_Whh=True,\n",
    "    train_Who=True,\n",
    "    small_scale=True,\n",
    "    Wih_start_big=False,\n",
    "    g=0.1,\n",
    "    dt=dt,\n",
    "    tau=1,\n",
    "    RL=False\n",
    ")\n",
    "\n",
    "# Load pre-trained parameters for student and teacher\n",
    "controller2.load_state_dict(torch.load('models/full_rnn_non_linear/closed_loop/epoch_0.pth')['model_state'])\n",
    "controller_teacher.load_state_dict(torch.load('models/full_rnn_non_linear/closed_loop/epoch_1000.pth')['model_state'])\n",
    "\n",
    "# Train student controller via teacher-student setup (open-loop)\n",
    "net_name = 'open_loop'\n",
    "loss2 = train_rnn_teacher_student(\n",
    "    net_num=net_name,\n",
    "    path='full_rnn_non_linear',\n",
    "    controller_teacher=controller_teacher,\n",
    "    controller=controller2,\n",
    "    system=system,\n",
    "    x_target=x_target,\n",
    "    white_noise=1,\n",
    "    LR=lr,\n",
    "    num_epochs=1001,\n",
    "    opt=opt,\n",
    "    w_grad_clip=True,\n",
    "    batch_size=100,\n",
    "    num_steps=50,\n",
    "    dt=dt,\n",
    "    reg_U=0,\n",
    "    save_model=save_model,\n",
    "    log=False\n",
    ")\n",
    "\n",
    "# Plot training loss\n",
    "plt.plot(loss2)\n",
    "plt.title(net_name)\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Simulate student controller performance\n",
    "system_state, control_u, hidden_state, state_loss, control_loss = simulate_rnn(\n",
    "    controller=controller2,\n",
    "    system=system,\n",
    "    k=system.k,\n",
    "    num_steps=50,\n",
    "    x_target=np.zeros(system.k),\n",
    "    batch_size=1,\n",
    "    init_con=torch.rand(system.k, 1) * 2 - 1\n",
    ")\n",
    "\n",
    "# Plot system state and control output\n",
    "fig, ax = plt.subplots(1, system.k + 1, figsize=(14, 3))\n",
    "\n",
    "# State variables\n",
    "for i in range(system.k):\n",
    "    ax[i].plot(system_state[:, i].squeeze(), color=sns.color_palette('tab10')[i])\n",
    "    ax[i].axhline(y=x_target[i], color='k', ls='--')\n",
    "    ax[i].set_xlabel('Time Steps', size=14)\n",
    "    ax[i].set_ylabel(f'$x_{i}$', size=14)\n",
    "    ax[i].tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "# Control output\n",
    "for j in range(M):\n",
    "    ax[system.k].plot(control_u[:, 0, j], color=sns.color_palette('tab10')[system.k + j])\n",
    "    ax[system.k].set_xlabel('Time Steps', size=14)\n",
    "    ax[system.k].set_ylabel('Control', size=14)\n",
    "    ax[system.k].tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Print total state cost\n",
    "print(f'State cost: {np.sum(state_loss):.4f}')\n",
    "\n",
    "# Plot eigenvalues of the trained recurrent matrix\n",
    "e_vals, _ = np.linalg.eig(controller2.Whh.detach().numpy())\n",
    "plt.scatter(e_vals.real, e_vals.imag)\n",
    "plt.title('Eigenvalue Spectrum of $\\mathbf{W}_{hh}$')\n",
    "plt.xlabel('Re(λ)')\n",
    "plt.ylabel('Im(λ)')\n",
    "plt.axis('equal')\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14d102c6-3ba0-4bd7-8039-7da781c45430",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_paths = [\n",
    "        \n",
    "            f'models/full_rnn_non_linear/closed_loop',\n",
    "            f'models/full_rnn_non_linear/open_loop',\n",
    "            \n",
    "]\n",
    "\n",
    "all_loss = np.zeros((len(all_paths),1000))\n",
    "all_k = np.zeros((len(all_paths),1000,2))\n",
    "num_steps = int(50/dt) \n",
    "batch_size = 10\n",
    "x_t = torch.zeros(system.k, batch_size)\n",
    "DROP = 1\n",
    "\n",
    "x_t[0, :] = torch.rand(batch_size) * 2 - 1\n",
    "x_t[1, :] = torch.rand(batch_size) * 2 - 1 \n",
    "\n",
    "for n,path in enumerate(all_paths):\n",
    "\n",
    "    # hid_size = all_hidden_size[n]\n",
    "\n",
    "    for ep in range(1000):\n",
    "        controller = ct_rnn_controller(system.c, hidden_size, system.m, phi=phi, dt=dt)\n",
    "        controller.load_state_dict(torch.load(f'{path}/epoch_{ep}.pth')['model_state'])\n",
    "        p = create_p_matrix_from_rnn(system,controller)\n",
    "\n",
    "        system_state, control_u, hidden_state, state_loss, control_loss = simulate_rnn(\n",
    "                                                                controller,\n",
    "                                                                system,\n",
    "                                                                k=system.k,\n",
    "                                                                num_steps=num_steps,\n",
    "                                                                x_target=np.zeros(system.k),\n",
    "                                                                batch_size=batch_size,\n",
    "                                                                init_con=x_t,)\n",
    "\n",
    "        all_loss[n,ep] = np.mean(state_loss.squeeze())\n",
    "        XX = system_state[DROP:].swapaxes(1, 2).reshape((num_steps-DROP)*batch_size,2)\n",
    "        y = control_u[DROP:].flatten()\n",
    "        \n",
    "        reg = LinearRegression(fit_intercept=False)\n",
    "        reg.fit(XX, y)\n",
    "        all_k[n,ep] = reg.coef_\n",
    "\n",
    "\n",
    "colors = [sns.color_palette('deep')[0],sns.color_palette('deep')[3],]\n",
    "names = ['Closed-loop','Open-loop'] \n",
    "for i in range(len(all_paths)):\n",
    "    plt.plot((all_loss[i,:]),color=colors[i],label=names[i],lw=3)\n",
    "\n",
    "plt.xscale('log')\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Test loss')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "sns.despine()\n",
    "plt.show()\n",
    "\n",
    "# System parameters\n",
    "dt = 1\n",
    "x_target = np.array([0.0, 0.0])  \n",
    "A = np.array([[1, dt],\n",
    "              [0, 1]])  \n",
    "B = np.array([[0],\n",
    "              [dt]])  \n",
    "\n",
    "Q = np.eye(2)*1\n",
    "R = np.array([[1]]) * 0\n",
    "P = scipy.linalg.solve_discrete_are(A, B, Q, R)\n",
    "K = np.linalg.inv(B.T @ P @ B + R) @ (B.T @ P @ A)\n",
    "K_LQR = K.copy()\n",
    "all_k0 = np.linspace(-1.2, 0.5, 50)\n",
    "all_k1 = np.linspace(-1.2, 0.5, 50)\n",
    "\n",
    "kk0=[]\n",
    "kk1=[]\n",
    "zzz=[]\n",
    "\n",
    "for k0 in all_k0:\n",
    "    for k1 in all_k1:   \n",
    "        \n",
    "        K[0][0] = k0\n",
    "        K[0][1] = k1\n",
    "        num_steps = int(50/dt)\n",
    "        N = 1\n",
    "        state_loss_total = 0\n",
    "        control_loss_total = 0\n",
    "        \n",
    "        for n in range(1):\n",
    "            x = np.array([0.0, 0.0])  # Initialize the system state\n",
    "            x[0] = 1  # Random start position\n",
    "            x[1] = 1  # Random start velocity  \n",
    "            \n",
    "            # Simulation loop\n",
    "            for t in range(num_steps):\n",
    "                \n",
    "                # Compute control input using the LQR feedback law\n",
    "                u = K @ (x - x_target)\n",
    "                \n",
    "                # Apply control input to the system\n",
    "                w_t = np.random.randn(2) * 0.00  # Optional noise\n",
    "                x = A @ x + B.flatten() * u + w_t * np.sqrt(dt)\n",
    "        \n",
    "                # Compute the state and control losses\n",
    "                state_loss = np.sum((x-x_target)**2)\n",
    "                control_loss = u**2\n",
    "                \n",
    "                # Accumulate the losses\n",
    "                state_loss_total += state_loss.item()*dt\n",
    "                control_loss_total += control_loss.item()*dt\n",
    "\n",
    "        kk0.append(k0)\n",
    "        kk1.append(k1)\n",
    "        z = np.log( np.array(state_loss_total) ) # + np.array(control_loss_total)\n",
    "        zzz.append(z)\n",
    "\n",
    "\n",
    "from scipy.stats import binned_statistic_2d\n",
    "\n",
    "def sim_p_reduce(P, num_steps=5, init_con_x=[1, 1]):\n",
    "    traj = []\n",
    "    x0, x1 = init_con_x\n",
    "    traj.append((x0, x1))\n",
    "    \n",
    "    for _ in range(num_steps):\n",
    "        x0, x1 = P @ np.array([x0, x1])\n",
    "        traj.append((x0, x1))\n",
    "    \n",
    "    return np.array(traj)\n",
    "\n",
    "# Define grid for k1, k2\n",
    "all_a = np.linspace(-1.2, .5, 200)\n",
    "all_b = np.linspace(-1.2, .5, 200)\n",
    "\n",
    "all_stability = []\n",
    "\n",
    "for a in all_a:\n",
    "    for b in all_b:\n",
    "        # Define system matrix\n",
    "        p_red = np.array([[1, dt],\n",
    "                          [dt*a, 1 + dt*b]])\n",
    "        \n",
    "        eigvals, _ = np.linalg.eig(p_red)\n",
    "        \n",
    "        # Stability classification\n",
    "        if np.all(np.abs(eigvals) < 1):\n",
    "            stability_label = 2  # Stable\n",
    "        else:\n",
    "            complex_mask = (np.abs(eigvals) > 1) & (np.imag(eigvals) != 0)\n",
    "            stability_label = 1 if np.any(complex_mask) else 0  # Unstable (oscillatory or real)\n",
    "\n",
    "        all_stability.append(stability_label)\n",
    "\n",
    "# Reshape for plotting\n",
    "X, Y = np.meshgrid(all_a, all_b)\n",
    "Z = np.array(all_stability).reshape(len(all_b), len(all_a)).T\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.imshow(np.array(zzz).reshape(50,50).T, extent=[-1, 1, -2, 1], origin=\"lower\", aspect=\"auto\", cmap=\"mako_r\")\n",
    "plt.colorbar(label=\"Log(Loss)\")\n",
    "\n",
    "for i,name in enumerate(all_paths):\n",
    "    plt.plot(all_k[i,:,0],all_k[i,:,1], lw=2 , linestyle='-', marker='.',color=colors[i])\n",
    "\n",
    "\n",
    "plt.xlabel(\"K1\")\n",
    "plt.ylabel(\"K2\")\n",
    "plt.show()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
