{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c879f336-e027-4056-8f88-a2835de260b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os \n",
    "from pathlib import Path\n",
    "\n",
    "from utils import *\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8a62c6e-6589-437d-b32e-fe5cc5d2618c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------------------\n",
    "# Generate references (r_x, r_y)\n",
    "# ----------------------------------------\n",
    "def generate_sum_of_sinusoids(t, amps, freqs, phases):\n",
    "    r = np.zeros_like(t)\n",
    "    for a, f, phi in zip(amps, freqs, phases):\n",
    "        r += a * np.cos(2 * np.pi * f * t + phi)\n",
    "    return r\n",
    "\n",
    "def generate_sum_of_sinusoids_batch(t, amps, freqs, phases):\n",
    "    all_r = np.zeros((phases.shape[0],len(t)))\n",
    "    for i,cur_phases in enumerate(phases):\n",
    "        r = np.zeros_like(t)\n",
    "        for a, f, phi in zip(amps, freqs, cur_phases):\n",
    "            r += a * np.cos(2 * np.pi * f * t + phi)\n",
    "        all_r[i] = r\n",
    "    return all_r\n",
    "\n",
    "def ramp_function(t, ramp_duration=5.0):\n",
    "    return np.clip(t / ramp_duration, 0, 1)\n",
    "\n",
    "total_time = 30.0    # total time\n",
    "sampling_rate = 10 # 10   # Hz\n",
    "n_samples = int(total_time * sampling_rate)\n",
    "\n",
    "t = np.linspace(0, total_time, n_samples, endpoint=False)\n",
    "\n",
    "# For the x-axis\n",
    "amps_x = np.array([2.31, 2.31,])\n",
    "freqs_x = np.array([0.10,  0.30,])\n",
    "\n",
    "# For the y-axis\n",
    "amps_y = np.array([2.31, 2.31,])\n",
    "freqs_y = np.array([0.20, 0.40,])\n",
    "\n",
    "# ---------------------------------------------\n",
    "# Instantiate RNN controller and the plant\n",
    "# ---------------------------------------------\n",
    "dt = 0.1\n",
    "g  = 0.2\n",
    "rnn = ct_rnn_controller(\n",
    "    input_size=4,\n",
    "    hidden_size=100,  \n",
    "    output_size=2,\n",
    "    phi='linear',\n",
    "    manually_initialize=True,\n",
    "    with_bias=False,\n",
    "    train_Wih=False,     \n",
    "    train_Whh=True,      \n",
    "    train_Who=True,      \n",
    "    small_scale=False,   \n",
    "    Wih_start_big=False, \n",
    "    g=g, \n",
    "    dt=dt, \n",
    "    tracking_task=True,\n",
    ")\n",
    "\n",
    "plant = plant_2D_torch(dt=dt, noise=0.00, clamp=True)\n",
    "params_to_train = [p for p in rnn.parameters() if p.requires_grad]\n",
    "optimizer = optim.Adam(params_to_train, lr=1e-3)\n",
    "w_grad_clip = 0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d87fd0-48a8-44f6-878a-0b0bcfefc8a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ------------------------------\n",
    "# Training Loop\n",
    "# ------------------------------\n",
    "num_epochs = 5_001\n",
    "batch_size = 100\n",
    "train_losses = []\n",
    "\n",
    "save_model = 1\n",
    "net_name = 'closed_4_freq_linear_adam'\n",
    "model_dir = Path(f'models/{net_name}')\n",
    "\n",
    "if save_model:\n",
    "    model_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "\n",
    "    # Initialize hidden state and system state\n",
    "    h = torch.zeros((batch_size, rnn.hidden_size), dtype=torch.float32, requires_grad=True)\n",
    "    x_current = torch.zeros((batch_size, 4), dtype=torch.float32, requires_grad=True)  # [x, vx, y, vy]\n",
    "\n",
    "    # Generate random phases and sinusoidal reference\n",
    "    phases_x = np.random.uniform(-np.pi, np.pi, size=(batch_size, len(amps_x)))\n",
    "    phases_y = np.random.uniform(-np.pi, np.pi, size=(batch_size, len(amps_y)))\n",
    "\n",
    "    r_x_full = generate_sum_of_sinusoids_batch(t, amps_x, freqs_x, phases_x)\n",
    "    r_y_full = generate_sum_of_sinusoids_batch(t, amps_y, freqs_y, phases_y)\n",
    "\n",
    "    ramp = ramp_function(t, ramp_duration=1.0)\n",
    "    r_x = r_x_full * ramp\n",
    "    r_y = r_y_full * ramp\n",
    "\n",
    "    r_x_torch = torch.tensor(r_x, dtype=torch.float32)\n",
    "    r_y_torch = torch.tensor(r_y, dtype=torch.float32)\n",
    "\n",
    "    # Rollout and accumulate loss\n",
    "    total_loss = 0.0\n",
    "    for i in range(n_samples):\n",
    "        input_rnn = torch.stack((x_current[:, 0], x_current[:, 2], r_x_torch[:, i], r_y_torch[:, i]), dim=1)\n",
    "        u, h, _ = rnn(input_rnn, h)\n",
    "        x_next = plant.step(x_current.T, u.T)\n",
    "\n",
    "        mse_x = torch.mean((x_next[0, :] - r_x_torch[:, i]) ** 2)\n",
    "        mse_y = torch.mean((x_next[2, :] - r_y_torch[:, i]) ** 2)\n",
    "        total_loss += (mse_x + mse_y) * dt\n",
    "\n",
    "        x_current = x_next.T\n",
    "\n",
    "    loss = total_loss / n_samples\n",
    "\n",
    "    # Backward and optimize\n",
    "    if epoch > 0:\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        if w_grad_clip:\n",
    "            torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1.0)\n",
    "        optimizer.step()\n",
    "\n",
    "    train_losses.append(loss.item())\n",
    "\n",
    "    if (epoch + 1) % 500 == 0:\n",
    "        print(f\"[Epoch {epoch + 1}/{num_epochs}] Loss = {loss.item():.4f}\")\n",
    "\n",
    "    # Save model checkpoint\n",
    "    if save_model:\n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state': rnn.state_dict(),\n",
    "            'loss': loss.item()\n",
    "        }, model_dir / f'epoch_{epoch}.pth')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4c42650-1549-4ef7-95c6-bebe59d72d34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ------------------------------\n",
    "# Load Training Losses\n",
    "# ------------------------------\n",
    "num_epochs = 5_000\n",
    "net_name = 'closed_4_freq_linear_adam'\n",
    "epochs = np.arange(0, num_epochs)\n",
    "losses = []\n",
    "\n",
    "for ep in epochs:\n",
    "    checkpoint = torch.load(f'models/{net_name}/epoch_{ep}.pth')\n",
    "    losses.append(checkpoint['loss'])\n",
    "\n",
    "losses = np.array(losses)\n",
    "\n",
    "# ------------------------------\n",
    "# Plot Training Loss\n",
    "# ------------------------------\n",
    "plt.figure(figsize=(5, 4))\n",
    "plt.plot(np.log(losses), label='Training loss', color='royalblue', lw=3)\n",
    "\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Log Loss')\n",
    "plt.title('Training Loss')\n",
    "plt.legend()\n",
    "\n",
    "# Mark selected epochs\n",
    "highlight_epochs = [0,200,700,1800,4500]\n",
    "for ep in highlight_epochs:\n",
    "    plt.axvline(x=ep, ls='--', color='k', lw=1)\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "phases_x = np.random.uniform(-np.pi, np.pi, size=len(amps_x))\n",
    "phases_y = np.random.uniform(-np.pi, np.pi, size=len(amps_y))\n",
    "\n",
    "# Define 4 stimulus conditions: single x or y frequency at a time\n",
    "configs = [\n",
    "    {\"a_x\": amps_x[:1], \"f_x\": freqs_x[:1], \"a_y\": np.zeros(1), \"f_y\": np.zeros(1)},\n",
    "    {\"a_x\": np.zeros(1), \"f_x\": np.zeros(1), \"a_y\": amps_y[:1], \"f_y\": freqs_y[:1]},\n",
    "    {\"a_x\": amps_x[1:], \"f_x\": freqs_x[1:], \"a_y\": np.zeros(1), \"f_y\": np.zeros(1)},\n",
    "    {\"a_x\": np.zeros(1), \"f_x\": np.zeros(1), \"a_y\": amps_y[1:], \"f_y\": freqs_y[1:]}\n",
    "]\n",
    "\n",
    "colors = {'ref': sns.color_palette('deep', 4)[0], 'rnn': 'k'}\n",
    "net_name = 'closed_4_freq_linear_adam'\n",
    "\n",
    "for ep in highlight_epochs:\n",
    "    fig, axes = plt.subplots(1, 6, figsize=(16, 2.5))\n",
    "    ax0, ax1, ax2, ax3, ax4, ax5 = axes\n",
    "\n",
    "    # Load model at this epoch\n",
    "    rnn.load_state_dict(torch.load(f'models/{net_name}/epoch_{ep}.pth')['model_state'])\n",
    "\n",
    "    for i, config in enumerate(configs):\n",
    "        # Generate references\n",
    "        r_x_full = generate_sum_of_sinusoids(t, config[\"a_x\"], config[\"f_x\"], phases_x)\n",
    "        r_y_full = generate_sum_of_sinusoids(t, config[\"a_y\"], config[\"f_y\"], phases_y)\n",
    "        ramp = ramp_function(t, ramp_duration=1.0)\n",
    "        r_x = r_x_full * ramp\n",
    "        r_y = r_y_full * ramp\n",
    "        r_x_torch = torch.tensor(r_x, dtype=torch.float32)\n",
    "        r_y_torch = torch.tensor(r_y, dtype=torch.float32)\n",
    "\n",
    "        # Rollout trajectory\n",
    "        with torch.no_grad():\n",
    "            h = torch.zeros((1, rnn.hidden_size), dtype=torch.float32)\n",
    "            x = torch.zeros(4, dtype=torch.float32)\n",
    "            traj_x, traj_y = [], []\n",
    "\n",
    "            for n in range(n_samples):\n",
    "                input_t = torch.tensor([[x[0], x[2], r_x_torch[n], r_y_torch[n]]], dtype=torch.float32)\n",
    "                u, h, _ = rnn(input_t, h)\n",
    "                x = plant.step(x, u[0])\n",
    "                traj_x.append(x[0].item())\n",
    "                traj_y.append(x[2].item())\n",
    "\n",
    "        # Plotting\n",
    "        if i == 0:\n",
    "            ax0.set_title(f'Ep: {ep}')\n",
    "            ax0.plot(losses, color='k', lw=3)\n",
    "            ax0.set_yscale('log')\n",
    "            ax0.set_xlabel('Epoch')\n",
    "            ax0.set_ylabel('Loss')\n",
    "            ax0.axvline(x=ep, ls='--', color=colors['ref'])\n",
    "\n",
    "            eigvals, _ = np.linalg.eig(-np.eye(100) + rnn.Whh.detach().numpy())\n",
    "            ax1.scatter(eigvals.real, eigvals.imag, s=10)\n",
    "            theta = np.linspace(0, 2 * np.pi, 200)\n",
    "            ax1.plot(np.cos(theta) - 1, np.sin(theta), '--', color='black', lw=1.5, alpha=0.8)\n",
    "            ax1.axvline(1, ls='--', color='k')\n",
    "            ax1.axhline(0, color='black', lw=1.5)\n",
    "            ax1.axvline(0, color='black', lw=1.5)\n",
    "            ax1.set_xlim(-3, 3)\n",
    "            ax1.set_ylim(-3, 3)\n",
    "            ax1.set_xlabel('Real')\n",
    "            ax1.set_ylabel('Imag')\n",
    "            for omega in np.concatenate((freqs_x * 2 * np.pi, freqs_y * 2 * np.pi)):\n",
    "                ax1.axhline(omega, color='k', alpha=0.5, lw=0.5)\n",
    "                ax1.text(-2.5, omega, r'$\\omega=$' + f'{omega:.2f}', size=8)\n",
    "\n",
    "        # Select correct axis for each case\n",
    "        axis_map = {0: ax2, 1: ax3, 2: ax4, 3: ax5}\n",
    "        ax = axis_map[i]\n",
    "        if i in [0, 2]:  # x trajectory\n",
    "            ax.plot(t, r_x, color=colors['ref'], ls='-', label='Ref.')\n",
    "            ax.plot(t, traj_x, color=colors['rnn'], ls='--', label='RNN')\n",
    "            ax.set_ylabel(\"X\")\n",
    "            omega = config['f_x'][0] * 2 * np.pi\n",
    "        else:  # y trajectory\n",
    "            ax.plot(t, r_y, color=colors['ref'], ls='-', label='Ref.')\n",
    "            ax.plot(t, traj_y, color=colors['rnn'], ls='--', label='RNN')\n",
    "            ax.set_ylabel(\"Y\")\n",
    "            omega = config['f_y'][0] * 2 * np.pi\n",
    "\n",
    "        ax.set_xlabel(\"Time\")\n",
    "        ax.set_title(r'$\\omega=$' + f'{omega:.2f}')\n",
    "        if i == 0:\n",
    "            ax.legend()\n",
    "\n",
    "    sns.despine()\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "986be7d5-4f92-4673-a556-3881047857c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ------------------------------\n",
    "# Config and Initialization\n",
    "# ------------------------------\n",
    "phases_x = np.random.uniform(-np.pi, np.pi, size=len(amps_x))\n",
    "phases_y = np.random.uniform(-np.pi, np.pi, size=len(amps_y))\n",
    "\n",
    "# Predefined stimulus configurations for 4 probe conditions\n",
    "configs = [\n",
    "    {\"a_x\": amps_x[:1], \"f_x\": freqs_x[:1], \"a_y\": np.zeros(1), \"f_y\": np.zeros(1)},\n",
    "    {\"a_x\": np.zeros(1), \"f_x\": np.zeros(1), \"a_y\": amps_y[:1], \"f_y\": freqs_y[:1]},\n",
    "    {\"a_x\": amps_x[1:], \"f_x\": freqs_x[1:], \"a_y\": np.zeros(1), \"f_y\": np.zeros(1)},\n",
    "    {\"a_x\": np.zeros(1), \"f_x\": np.zeros(1), \"a_y\": amps_y[1:], \"f_y\": freqs_y[1:]}\n",
    "]\n",
    "\n",
    "# Allocate memory to record training stats\n",
    "loss_matrix = np.zeros((num_epochs, 4))\n",
    "W_in_all = np.zeros((num_epochs, 4, 100))    # shape: [epoch, input_dim, hidden_dim]\n",
    "W_out_all = np.zeros((num_epochs, 100, 2))   # shape: [epoch, hidden_dim, output_dim]\n",
    "W_out_norm = np.zeros((num_epochs,))\n",
    "\n",
    "# ------------------------------\n",
    "# Evaluation over Training Epochs\n",
    "# ------------------------------\n",
    "for ep in range(num_epochs):\n",
    "\n",
    "    # Load model weights\n",
    "    checkpoint = torch.load(f'models/{net_name}/epoch_{ep}.pth')\n",
    "    rnn.load_state_dict(checkpoint['model_state'])\n",
    "\n",
    "    # Log input and output weights\n",
    "    W_in_all[ep] = rnn.Wih.detach().numpy()\n",
    "    W_out_all[ep] = rnn.Who.detach().numpy()\n",
    "    W_out_norm[ep] = np.linalg.norm(W_out_all[ep])\n",
    "\n",
    "    # Loop over each probe config\n",
    "    for i, config in enumerate(configs):\n",
    "        # Generate ramped reference trajectory\n",
    "        r_x = generate_sum_of_sinusoids(t, config[\"a_x\"], config[\"f_x\"], phases_x) * ramp_function(t, 1.0)\n",
    "        r_y = generate_sum_of_sinusoids(t, config[\"a_y\"], config[\"f_y\"], phases_y) * ramp_function(t, 1.0)\n",
    "        r_x_torch = torch.tensor(r_x, dtype=torch.float32)\n",
    "        r_y_torch = torch.tensor(r_y, dtype=torch.float32)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            h = torch.zeros((1, rnn.hidden_size), dtype=torch.float32)\n",
    "            x = torch.zeros(4, dtype=torch.float32)\n",
    "            episode_loss = []\n",
    "\n",
    "            for n in range(n_samples):\n",
    "                input_t = torch.tensor([[x[0], x[2], r_x_torch[n], r_y_torch[n]]], dtype=torch.float32)\n",
    "                u, h, _ = rnn(input_t, h)\n",
    "                x = plant.step(x, u[0])\n",
    "\n",
    "                err = (x[0] - r_x_torch[n])**2 + (x[2] - r_y_torch[n])**2\n",
    "                episode_loss.append(err.item())\n",
    "\n",
    "        loss_matrix[ep, i] = np.mean(episode_loss)\n",
    "\n",
    "# ------------------------------\n",
    "# Plotting Loss Curves by Frequency\n",
    "# ------------------------------\n",
    "omega_labels = [r'$\\omega=0.62$', r'$\\omega=0.94$', r'$\\omega=1.57$', r'$\\omega=2.20$']\n",
    "plt.figure(figsize=(6, 4))\n",
    "\n",
    "for i in range(4):\n",
    "    plt.plot(loss_matrix[:, i], label=omega_labels[i])\n",
    "\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Tracking Error for Each Frequency')\n",
    "plt.legend()\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a207c7ca-4b82-4d87-9876-ffb9f4e6f0fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Setup ===\n",
    "dt = 0.1\n",
    "g = 0.2\n",
    "num_epochs = 5001\n",
    "batch_size = 100\n",
    "save_model = True\n",
    "net_name = 'open_4_freq_linear_adam'\n",
    "cutoff_interval = (num_epochs - 1) // 1000\n",
    "train_losses = []\n",
    "\n",
    "# === Define Student Network ===\n",
    "student_net = ct_rnn_controller(\n",
    "    input_size=4,\n",
    "    hidden_size=100,\n",
    "    output_size=2,\n",
    "    phi='linear',\n",
    "    manually_initialize=True,\n",
    "    with_bias=False,\n",
    "    train_Wih=True,\n",
    "    train_Whh=True,\n",
    "    train_Who=True,\n",
    "    small_scale=False,\n",
    "    Wih_start_big=False,\n",
    "    g=g,\n",
    "    dt=dt,\n",
    "    tracking_task=True,\n",
    ")\n",
    "student_net.load_state_dict(torch.load(f'models/closed_4_freq_linear_adam/epoch_0.pth')['model_state'])\n",
    "\n",
    "# === Define Teacher Network ===\n",
    "teacher_net = ct_rnn_controller(\n",
    "    input_size=4,\n",
    "    hidden_size=100,\n",
    "    output_size=2,\n",
    "    phi='linear',\n",
    "    manually_initialize=True,\n",
    "    with_bias=False,\n",
    "    train_Wih=False,  \n",
    "    train_Whh=True,\n",
    "    train_Who=True,\n",
    "    small_scale=False,\n",
    "    Wih_start_big=False,\n",
    "    g=g,\n",
    "    dt=dt,\n",
    "    tracking_task=True,\n",
    ")\n",
    "teacher_net.load_state_dict(torch.load(f'models/closed_4_freq_linear_adam/epoch_5000.pth')['model_state'])\n",
    "\n",
    "# === Environment ===\n",
    "plant = plant_2D_torch(dt=dt, noise=0.0, clamp=True)\n",
    "\n",
    "# === Optimizer ===\n",
    "optimizer = optim.Adam([p for p in student_net.parameters() if p.requires_grad], lr=1e-3)\n",
    "clip_grad = False\n",
    "\n",
    "# === Create model save directory ===\n",
    "model_dir = f'models/{net_name}'\n",
    "if save_model and not os.path.exists(model_dir):\n",
    "    os.makedirs(model_dir)\n",
    "\n",
    "# === Training Loop ===\n",
    "for epoch in range(num_epochs):\n",
    "\n",
    "    h_student = torch.zeros(batch_size, student_net.hidden_size)\n",
    "    h_teacher = torch.zeros(batch_size, teacher_net.hidden_size)\n",
    "    x_current = torch.zeros(batch_size, 4, dtype=torch.float32, requires_grad=True)\n",
    "\n",
    "    # === Generate sinusoidal reference ===\n",
    "    phases_x = np.random.uniform(-np.pi, np.pi, size=(batch_size, len(amps_x)))\n",
    "    phases_y = np.random.uniform(-np.pi, np.pi, size=(batch_size, len(amps_y)))\n",
    "\n",
    "    r_x = generate_sum_of_sinusoids_batch(t, amps_x, freqs_x, phases_x) * ramp_function(t, ramp_duration=1.0)\n",
    "    r_y = generate_sum_of_sinusoids_batch(t, amps_y, freqs_y, phases_y) * ramp_function(t, ramp_duration=1.0)\n",
    "\n",
    "    r_x = torch.tensor(r_x, dtype=torch.float32)\n",
    "    r_y = torch.tensor(r_y, dtype=torch.float32)\n",
    "\n",
    "\n",
    "    total_loss = 0.0\n",
    "\n",
    "    # === Time loop ===\n",
    "    for i in range(n_samples):\n",
    "        input_rnn = torch.stack([\n",
    "            x_current[:, 0],   # x\n",
    "            x_current[:, 2],   # y\n",
    "            r_x[:, i],         # x_ref\n",
    "            r_y[:, i]          # y_ref\n",
    "        ], dim=-1)\n",
    "\n",
    "        u_teacher, h_teacher, _ = teacher_net(input_rnn, h_teacher)\n",
    "        u_student, h_student, _ = student_net(input_rnn, h_student)\n",
    "\n",
    "        loss_step = torch.mean((u_teacher - u_student) ** 2)\n",
    "        total_loss += loss_step * dt\n",
    "\n",
    "        # update environment\n",
    "        x_next = plant.step(x_current.T, u_teacher.T)\n",
    "        x_current = x_next.T\n",
    "\n",
    "    loss = total_loss / n_samples\n",
    "\n",
    "    # === Backward ===\n",
    "    if epoch > 0:\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        if clip_grad:\n",
    "            torch.nn.utils.clip_grad_norm_(student_net.parameters(), max_norm=1.0)\n",
    "        optimizer.step()\n",
    "\n",
    "    # === Logging ===\n",
    "    train_losses.append(loss.item())\n",
    "    if (epoch + 1) % 500 == 0:\n",
    "        print(f\"[Epoch {epoch + 1}/{num_epochs}] Loss = {loss.item():.6f}\")\n",
    "\n",
    "    # === Save model checkpoint ===\n",
    "    if save_model:\n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state': student_net.state_dict(),\n",
    "            'loss': loss.item()\n",
    "        }, f'{model_dir}/epoch_{epoch}.pth')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bbda1c8-08b5-4fa2-88c2-576f8e9d2b3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Config ===\n",
    "net_name = 'open_4_freq_linear_adam'\n",
    "train_batch_size = 100\n",
    "test_batch_size = 10\n",
    "epochs = np.arange(num_epochs)\n",
    "\n",
    "# === Load training loss from checkpoints ===\n",
    "train_losses = [\n",
    "    torch.load(f'models/{net_name}/epoch_{ep}.pth')['loss']\n",
    "    for ep in epochs\n",
    "]\n",
    "\n",
    "# === Evaluate test loss over time ===\n",
    "test_losses = []\n",
    "\n",
    "for ep in epochs:\n",
    "    # Load model\n",
    "    rnn.load_state_dict(torch.load(f'models/{net_name}/epoch_{ep}.pth')['model_state'])\n",
    "\n",
    "    # Initialize hidden state and system state\n",
    "    h = torch.zeros(test_batch_size, rnn.hidden_size, dtype=torch.float32)\n",
    "    x = torch.zeros(test_batch_size, 4, dtype=torch.float32)  # [x, vx, y, vy]\n",
    "\n",
    "    # Generate reference trajectories\n",
    "    phases_x = np.random.uniform(-np.pi, np.pi, size=(test_batch_size, len(amps_x)))\n",
    "    phases_y = np.random.uniform(-np.pi, np.pi, size=(test_batch_size, len(amps_y)))\n",
    "    \n",
    "    r_x = generate_sum_of_sinusoids_batch(t, amps_x, freqs_x, phases_x)\n",
    "    r_y = generate_sum_of_sinusoids_batch(t, amps_y, freqs_y, phases_y)\n",
    "    ramp = ramp_function(t, ramp_duration=1.0)\n",
    "\n",
    "    r_x = torch.tensor(r_x * ramp, dtype=torch.float32)\n",
    "    r_y = torch.tensor(r_y * ramp, dtype=torch.float32)\n",
    "\n",
    "    loss = 0.0\n",
    "    for i in range(n_samples):\n",
    "        input_t = torch.stack((x[:, 0], x[:, 2], r_x[:, i], r_y[:, i]), dim=1)\n",
    "        u, h, _ = rnn(input_t, h)\n",
    "\n",
    "        x_next = plant.step(x.T, u.T)\n",
    "        err_x = (x_next[0, :] - r_x[:, i]) ** 2\n",
    "        err_y = (x_next[2, :] - r_y[:, i]) ** 2\n",
    "        loss += (err_x.mean() + err_y.mean()) * dt\n",
    "        x = x_next.T\n",
    "\n",
    "    test_losses.append((loss / n_samples).item())\n",
    "\n",
    "# === Plot losses ===\n",
    "fig, ax = plt.subplots(1, 2, figsize=(8, 4))\n",
    "\n",
    "ax[0].plot(np.log(train_losses), lw=2, color='royalblue')\n",
    "ax[0].set_title('Train Loss')\n",
    "ax[0].set_xlabel('Epoch')\n",
    "ax[0].set_ylabel('Log Loss')\n",
    "\n",
    "ax[1].plot(np.log(test_losses), lw=2, color='royalblue')\n",
    "ax[1].set_title('Test Loss')\n",
    "ax[1].set_xlabel('Epoch')\n",
    "ax[1].set_ylabel('Log Loss')\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
