{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Toy example for stimulation "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Van der Pol oscillator example\n",
    "\n",
    "This code is a modified version of the one in https://github.com/pearsonlab/Bubblewrap/blob/main/datagen.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.integrate import solve_ivp\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the Van der Pol oscillator function\n",
    "def vdp(t, f, mu=1.0):\n",
    "    x, y = f\n",
    "    x_dot = y\n",
    "    y_dot = mu * (1 - x**2) * y - x\n",
    "    return [x_dot, y_dot]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the random projection function\n",
    "def random_proj(initial_dim: int, dim: int, seed=4):\n",
    "    rand = np.random.default_rng(seed)\n",
    "    t = rand.normal(0, 1, size=(dim, initial_dim))\n",
    "    return (t / np.sum(t, axis=0)).T\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the data generation function\n",
    "def gen_data_diffeq(f, projection, t, x0, dim, ivp_kwargs=None, proj_kwargs=None, noise=None, noise_kwargs=None, seed=41):\n",
    "    if ivp_kwargs is None:\n",
    "        ivp_kwargs = {}\n",
    "    if proj_kwargs is None:\n",
    "        proj_kwargs = {}\n",
    "    if noise_kwargs is None:\n",
    "        noise_kwargs = {}\n",
    "        \n",
    "    ivp = solve_ivp(f, t, x0, rtol=1e-6, **ivp_kwargs)\n",
    "    \n",
    "    y = ivp['y'].T\n",
    "    proj = projection(y.shape[1], dim, **proj_kwargs)\n",
    "    projed = y @ proj\n",
    "    \n",
    "    if noise is not None:\n",
    "        rand = np.random.default_rng(seed)\n",
    "        projed += getattr(rand, noise)(**noise_kwargs, size=projed.shape)\n",
    "    \n",
    "    return ivp['t'], y, projed\n",
    "\n",
    "# Define the function to generate datasets\n",
    "def make_dataset(f, x0, num_trajectories, num_dim, begin, end, noise):\n",
    "    xx = []\n",
    "    projeds = []\n",
    "    rng = np.random.RandomState(39)\n",
    "    \n",
    "    for _ in trange(num_trajectories):\n",
    "        t, x, projed = gen_data_diffeq(f, random_proj,\n",
    "                                       t=(0, 125), x0=x0 + 0.01 * rng.randn(*x0.shape), dim=num_dim, noise=\"normal\",\n",
    "                                       ivp_kwargs={'max_step': 0.05},\n",
    "                                       noise_kwargs={\"loc\": 0, \"scale\": noise},)\n",
    "        t = t[begin:end]\n",
    "        xx.append(x[begin:end])\n",
    "        projeds.append(projed[begin:end])\n",
    "    \n",
    "    xx = np.stack(xx, axis=0)\n",
    "    projeds = np.stack(projeds, axis=0)\n",
    "    \n",
    "    xs = xx\n",
    "    ys = projeds\n",
    "    us = np.zeros((xx.shape[0], xx.shape[1], 1))\n",
    "    \n",
    "    return xs, ys\n",
    "\n",
    "# Generate data using the Van der Pol oscillator\n",
    "x0 = np.array([0.1, 0.1])\n",
    "num_trajectories = 1\n",
    "num_dim = 2\n",
    "begin, end = 500, 1050\n",
    "noise = 0.05\n",
    "\n",
    "xs, ys = make_dataset(vdp, x0, num_trajectories, num_dim, begin, end, noise)\n",
    "\n",
    "# Plot the results\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "\"\"\"\n",
    "# Plot the original Van der Pol states\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(xs[0, :, 0], xs[0, :, 1], label='Van der Pol Oscillator')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "plt.title('Van der Pol Oscillator States')\n",
    "plt.legend()\n",
    "\"\"\"\n",
    "# Plot the projected states\n",
    "plt.subplot(1, 1, 1)\n",
    "plt.plot(ys[0, :, 0], ys[0, :, 1], label='Projected Data')\n",
    "plt.xlabel('Dimension 1')\n",
    "plt.ylabel('Dimension 2')\n",
    "plt.title('Projected Van der Pol Oscillator States')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the type of impulse stimulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add a funtion for stimulation\n",
    "# an impulse function that starts at start_time and decays exponentialy with decay_rate\n",
    "def stimulation(t, start_time, amplitude, decay_rate):\n",
    "    if t < start_time:\n",
    "        return 0\n",
    "    return amplitude * np.exp(-decay_rate * (t - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Visualize the stimulation function\n",
    "# Parameters for the stimulation\n",
    "start_time = 5.0\n",
    "amplitude = 10.0\n",
    "decay_rate = 1.0\n",
    "time_duration = 20.0\n",
    "\n",
    "# Generate time points\n",
    "t_values = np.linspace(0, time_duration, 1000)\n",
    "\n",
    "# Compute stimulation values\n",
    "stim_values = [stimulation(t, start_time, amplitude, decay_rate) for t in t_values]\n",
    "\n",
    "# Plot the stimulation function\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(t_values, stim_values, label='Stimulation', color='red')\n",
    "plt.axvline(x=start_time, color='black', linestyle='--', label='Stimulation Start Time')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Stimulation Amplitude')\n",
    "plt.title('Stimulation Function')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement the stimulation within the generated data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the data generation function with stimulation\n",
    "def gen_data_diffeq(f, projection, t, x0, dim, ivp_kwargs=None, proj_kwargs=None, noise=None, noise_kwargs=None, seed=41,\n",
    "                    stim_start_time=0, stim_amplitude=1.0, stim_decay_rate=0.1):\n",
    "    if ivp_kwargs is None:\n",
    "        ivp_kwargs = {}\n",
    "    if proj_kwargs is None:\n",
    "        proj_kwargs = {}\n",
    "    if noise_kwargs is None:\n",
    "        noise_kwargs = {}\n",
    "        \n",
    "    ivp = solve_ivp(f, t, x0, rtol=1e-6, **ivp_kwargs)\n",
    "    \n",
    "    y = ivp['y'].T\n",
    "    proj = projection(y.shape[1], dim, **proj_kwargs)\n",
    "    projed = y @ proj\n",
    "    \n",
    "    if noise is not None:\n",
    "        rand = np.random.default_rng(seed)\n",
    "        projed += getattr(rand, noise)(**noise_kwargs, size=projed.shape)\n",
    "    \n",
    "    # Apply stimulation to the projected data\n",
    "    stim = np.array([stimulation(ti, stim_start_time, stim_amplitude, stim_decay_rate) for ti in ivp['t']])\n",
    "    projed[:, 0] += stim  # Assuming the stimulation affects the first dimension\n",
    "    \n",
    "    return ivp['t'], y, projed\n",
    "\n",
    "# Define the function to generate datasets\n",
    "def make_dataset(f, x0, num_trajectories, num_dim, begin, end, noise, stim_timestep=600, stim_amplitude=1.0, stim_decay_rate=0.1):\n",
    "    xx = []\n",
    "    projeds = []\n",
    "    rng = np.random.RandomState(39)\n",
    "    \n",
    "    for _ in trange(num_trajectories):\n",
    "        t, x, projed = gen_data_diffeq(f, random_proj,\n",
    "                                       t=(0, 125), x0=x0 + 0.01 * rng.randn(*x0.shape), dim=num_dim, noise=\"normal\",\n",
    "                                       ivp_kwargs={'max_step': 0.05},\n",
    "                                       noise_kwargs={\"loc\": 0, \"scale\": noise},\n",
    "                                       stim_start_time=0, stim_amplitude=0, stim_decay_rate=0)\n",
    "        t = t[begin:end]\n",
    "        # Calculate the actual start time for stimulation\n",
    "        stim_start_time = t[stim_timestep - begin]\n",
    "        # Regenerate the data with the correct stimulation start time\n",
    "        t, x, projed = gen_data_diffeq(f, random_proj,\n",
    "                                       t=(0, 125), x0=x0 + 0.01 * rng.randn(*x0.shape), dim=num_dim, noise=\"normal\",\n",
    "                                       ivp_kwargs={'max_step': 0.05},\n",
    "                                       noise_kwargs={\"loc\": 0, \"scale\": noise},\n",
    "                                       stim_start_time=stim_start_time, stim_amplitude=stim_amplitude, stim_decay_rate=stim_decay_rate)\n",
    "        t = t[begin:end]\n",
    "        xx.append(x[begin:end])\n",
    "        projeds.append(projed[begin:end])\n",
    "    \n",
    "    xx = np.stack(xx, axis=0)\n",
    "    projeds = np.stack(projeds, axis=0)\n",
    "    \n",
    "    xs = xx\n",
    "    ys = projeds\n",
    "    us = np.zeros((xx.shape[0], xx.shape[1], 1))\n",
    "    \n",
    "    return xs, ys\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate data using the Van der Pol oscillator with stimulation\n",
    "x0 = np.array([0.1, 0.1])\n",
    "num_trajectories = 1\n",
    "num_dim = 2\n",
    "begin, end = 500, 1100\n",
    "noise = 0.05\n",
    "stim_timestep = 900\n",
    "stim_amplitude = 1.0\n",
    "stim_decay_rate = 0.1\n",
    "\n",
    "xs, ys = make_dataset(vdp, x0, num_trajectories, num_dim, begin, end, noise, stim_timestep, stim_amplitude, stim_decay_rate)\n",
    "\n",
    "# Plot the results\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "# Find the index where stimulation starts\n",
    "stim_index = stim_timestep - begin\n",
    "\n",
    "# Plot the projected states\n",
    "plt.plot(ys[0, :stim_index, 0], ys[0, :stim_index, 1], label='Projected Data (Before Stimulation)')\n",
    "plt.plot(ys[0, stim_index:, 0], ys[0, stim_index:, 1], label='Projected Data (After Stimulation)', color='red')\n",
    "\n",
    "# Plot a green dot where the red plot starts\n",
    "plt.scatter(ys[0, stim_index, 0], ys[0, stim_index, 1], color='green', s=100, label='Stimulation Start')\n",
    "\n",
    "plt.xlabel('Dimension 1')\n",
    "plt.ylabel('Dimension 2')\n",
    "plt.title('Projected Van der Pol Oscillator States')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate data using the Van der Pol oscillator with stimulation\n",
    "x0 = np.array([0.1, 0.1])\n",
    "num_trajectories = 1\n",
    "num_dim = 2\n",
    "begin, end = 500, 1100\n",
    "noise = 0.05\n",
    "stim_timestep = 900\n",
    "stim_amplitude = 1.0\n",
    "stim_decay_rate = .5\n",
    "\n",
    "xs, ys = make_dataset(vdp, x0, num_trajectories, num_dim, begin, end, noise, stim_timestep, stim_amplitude, stim_decay_rate)\n",
    "\n",
    "# Plot the results\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "# Find the index where stimulation starts\n",
    "stim_index = stim_timestep - begin\n",
    "\n",
    "# Plot the projected states\n",
    "plt.plot(ys[0, :stim_index, 0], ys[0, :stim_index, 1], label='Projected Data (Before Stimulation)')\n",
    "plt.plot(ys[0, stim_index:, 0], ys[0, stim_index:, 1], label='Projected Data (After Stimulation)', color='red')\n",
    "\n",
    "# Plot a green dot where the red plot starts\n",
    "plt.scatter(ys[0, stim_index, 0], ys[0, stim_index, 1], color='green', s=100, label='Stimulation Start')\n",
    "\n",
    "plt.xlabel('Dimension 1')\n",
    "plt.ylabel('Dimension 2')\n",
    "plt.title('Projected Van der Pol Oscillator States')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "adaptive_latents",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
