{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import logging\n",
    "\n",
    "# Suppress warnings that might reveal local file paths for privacy\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
    "warnings.filterwarnings(\"ignore\", message=\".*IProgress not found.*\")\n",
    "\n",
    "print(\"Privacy protection enabled: Warnings suppressed.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ODE Model Demonstration: ODEbench & Ground Truth Examples\n",
    "\n",
    "This notebook demonstrates the model's ability to learn and predict vector fields for various dynamical systems from **ODEbench**, comparing predictions against **Ground Truth** equations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from model_lib import OdeonEval\n",
    "from scipy.integrate import solve_ivp\n",
    "\n",
    "# Set device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = Path(\"model/checkpoints\")\n",
    "\n",
    "try:\n",
    "    evaluator = OdeonEval(model_path)\n",
    "    print(\"Model loaded successfully!\")\n",
    "except Exception as e:\n",
    "    print(f\"Error loading model: {e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Helper Functions\n",
    "\n",
    "We define helpers to generate context data (trajectories) and visualize the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_context_from_ode(ode_func, y0, t_span, n_points=50):\n",
    "    \"\"\"Generates a trajectory to serve as context for the model.\"\"\"\n",
    "    t_eval = np.linspace(t_span[0], t_span[1], n_points)\n",
    "    sol = solve_ivp(ode_func, t_span, y0, t_eval=t_eval)\n",
    "    \n",
    "    # Pad with zero if 2D to match model's 3D expectation\n",
    "    traj_2d = sol.y.T\n",
    "    if traj_2d.shape[1] == 2:\n",
    "        traj_3d = np.hstack([traj_2d, np.zeros((n_points, 1))])\n",
    "    else:\n",
    "        traj_3d = traj_2d\n",
    "        \n",
    "    traj_torch = torch.from_numpy(traj_3d).float().unsqueeze(0).unsqueeze(0) # (B, T, N, D)\n",
    "    times_torch = torch.from_numpy(sol.t).float().view(1, 1, n_points, 1)    # (B, T, N, 1)\n",
    "    \n",
    "    return traj_torch, times_torch, sol.y.T\n",
    "\n",
    "def plot_comparison(title, locations, true_drift, pred_drift, gt_traj=None, pred_traj=None):\n",
    "    \"\"\"Plots Ground Truth vs Predicted vector fields and trajectories (2 Panels).\"\"\"\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
    "    \n",
    "    # 1. Ground Truth\n",
    "    ax = axes[0]\n",
    "    ax.quiver(locations[:, 0], locations[:, 1], true_drift[:, 0], true_drift[:, 1], color='blue', alpha=0.3, label='GT Field')\n",
    "    if gt_traj is not None:\n",
    "        ax.plot(gt_traj[:, 0], gt_traj[:, 1], 'k--', label='GT Trajectory', linewidth=2)\n",
    "    ax.set_title(f\"{title} - Ground Truth\")\n",
    "    ax.legend()\n",
    "\n",
    "    # 2. Model Prediction\n",
    "    ax = axes[1]\n",
    "    ax.quiver(locations[:, 0], locations[:, 1], pred_drift[:, 0], pred_drift[:, 1], color='green', alpha=0.3, label='Model Field')\n",
    "    if pred_traj is not None:\n",
    "        ax.plot(pred_traj[:, 0], pred_traj[:, 1], 'b-', label='Model Prediction', linewidth=2)\n",
    "        ax.scatter(pred_traj[0, 0], pred_traj[0, 1], color='blue', s=40, label='Start')\n",
    "    ax.set_title(f\"{title} - Prediction\")\n",
    "    ax.legend()\n",
    "\n",
    "    for ax in axes:\n",
    "        ax.set_xlabel(\"x\")\n",
    "        ax.set_ylabel(\"y\")\n",
    "        ax.grid(True)\n",
    "        ax.set_xlim(locations[:,0].min()-0.5, locations[:,0].max()+0.5)\n",
    "        ax.set_ylim(locations[:,1].min()-0.5, locations[:,1].max()+0.5)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "def run_example(name, ode_func, y0, t_span, grid_range):\n",
    "    \"\"\"Runs a full prediction cycle for a given ODE system.\"\"\"\n",
    "    # 1. Generate Context\n",
    "    c_traj, c_times, c_traj_np = get_context_from_ode(ode_func, y0, t_span)\n",
    "    \n",
    "    # 2. Setup Query Grid\n",
    "    xx, yy = np.meshgrid(np.linspace(*grid_range[0], 10), np.linspace(*grid_range[1], 10))\n",
    "    locs = np.stack([xx.flatten(), yy.flatten(), np.zeros_like(xx.flatten())], axis=-1)\n",
    "    locs_torch = torch.from_numpy(locs).float().to(device).unsqueeze(0)\n",
    "    \n",
    "    # 3. Ground Truth Drift\n",
    "    true_drift = np.array([ode_func(0, l[:2]) for l in locs])\n",
    "    \n",
    "    # 4. Model Prediction\n",
    "    with torch.no_grad():\n",
    "        pred_drift = evaluator.predict(c_traj, c_times, locs_torch)[0].cpu().numpy()\n",
    "        \n",
    "    # 5. Integration Comparison\n",
    "    t_int = [t_span[0], t_span[0] + (t_span[1]-t_span[0])*0.5]\n",
    "    t_eval_int = np.linspace(t_int[0], t_int[1], 50)\n",
    "    \n",
    "    gt_sol = solve_ivp(ode_func, t_int, y0, t_eval=t_eval_int)\n",
    "    \n",
    "    def pred_ode_func(t, y):\n",
    "        y_query = torch.tensor([[[y[0], y[1], 0.0]]], dtype=torch.float32).to(device)\n",
    "        with torch.no_grad():\n",
    "            drift = evaluator.predict(c_traj, c_times, y_query)[0, 0, :2].cpu().numpy()\n",
    "        return drift\n",
    "\n",
    "    pred_sol = solve_ivp(pred_ode_func, t_int, y0, t_eval=t_eval_int)\n",
    "    \n",
    "    # 6. Visualize\n",
    "    plot_comparison(name, locs, true_drift, pred_drift, \n",
    "                    gt_traj=gt_sol.y.T, pred_traj=pred_sol.y.T)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Example 1: Spiral (Linear System)\n",
    "\n",
    "Equation: $\\dot{x} = -0.1x + 2y$, $\\dot{y} = -2x - 0.1y$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spiral_ode(t, y):\n",
    "    x, y = y\n",
    "    return [-0.1*x + 2.0*y, -2.0*x - 0.1*y]\n",
    "\n",
    "run_example(\"Spiral\", spiral_ode, [1.0, 0.0], [0, 10], [(-1.5, 1.5), (-1.5, 1.5)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Example 2: Van der Pol Oscillator (ODEbench ID: 130)\n",
    "\n",
    "Standard non-linear oscillator used as a benchmark for ODE solvers and learners.\n",
    "\n",
    "Equation: $\\dot{x} = y$, $\\dot{y} = -x + \\mu(1-x^2)y$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vdp_ode(t, y):\n",
    "    mu = 1.0\n",
    "    x, y = y\n",
    "    return [y, -x + mu * (1 - x**2) * y]\n",
    "\n",
    "run_example(\"Van der Pol\", vdp_ode, [0.5, 0.0], [0, 20], [(-2.5, 2.5), (-3, 3)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Example 3: Lotka-Volterra (Predator-Prey)\n",
    "\n",
    "A classic model of biological competition.\n",
    "\n",
    "Equation: $\\dot{x} = x(1.84 - 1.45y)$, $\\dot{y} = -y(3.0 - 1.62x)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lotka_volterra_ode(t, y):\n",
    "    x, y = y\n",
    "    return [x * (1.84 - 1.45 * y), -y * (3.0 - 1.62 * x)]\n",
    "\n",
    "run_example(\"Lotka-Volterra\", lotka_volterra_ode, [2.0, 1.0], [0, 5], [(0, 4), (0, 3)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Example 4: FitzHugh-Nagumo Model\n",
    "\n",
    "Simplified model of neuron activation.\n",
    "\n",
    "Equation: $\\dot{x} = 3(x - x^3/3 + y)$, $\\dot{y} = (0.2 - 3x - 0.2y)/3$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fhn_ode(t, y):\n",
    "    x1, x2 = y[0], y[1]\n",
    "    dx1dt = 3 * (x1 - (x1**3) / 3 + x2)\n",
    "    dx2dt = (0.2 - 3 * x1 - 0.2 * x2) / 3\n",
    "    return [dx1dt, dx2dt]\n",
    "\n",
    "run_example(\"FitzHugh-Nagumo\", fhn_ode, [-1.0, 1.0], [0, 10], [(-2.5, 2.5), (-2, 2)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}