{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import adaptive_latents as al\n",
    "from adaptive_latents import NumpyTimedDataSource, Bubblewrap, AnimationManager\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "rng = np.random.default_rng()\n",
    "from tqdm import trange\n",
    "from scipy.integrate import solve_ivp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the Van der Pol oscillator function\n",
    "def vdp(t, f, mu=1.0):\n",
    "    x, y = f\n",
    "    x_dot = y\n",
    "    y_dot = mu * (1 - x**2) * y - x\n",
    "    return [x_dot, y_dot]\n",
    "\n",
    "# Define the random projection function\n",
    "def random_proj(initial_dim: int, dim: int, seed=4):\n",
    "    rand = np.random.default_rng(seed)\n",
    "    t = rand.normal(0, 1, size=(dim, initial_dim))\n",
    "    return (t / np.sum(t, axis=0)).T\n",
    "\n",
    "# Define the stimulation function\n",
    "def stimulation(t, start_time, amplitude, decay_rate):\n",
    "    if t < start_time:\n",
    "        return 0\n",
    "    return amplitude * np.exp(-decay_rate * (t - start_time))\n",
    "\n",
    "# Define the data generation function with stimulation\n",
    "def gen_data_diffeq(f, projection, t, x0, dim, ivp_kwargs=None, proj_kwargs=None, noise=None, noise_kwargs=None, seed=41,\n",
    "                    stim_start_time=0, stim_amplitude=1.0, stim_decay_rate=0.1):\n",
    "    if ivp_kwargs is None:\n",
    "        ivp_kwargs = {}\n",
    "    if proj_kwargs is None:\n",
    "        proj_kwargs = {}\n",
    "    if noise_kwargs is None:\n",
    "        noise_kwargs = {}\n",
    "        \n",
    "    ivp = solve_ivp(f, t, x0, rtol=1e-6, **ivp_kwargs)\n",
    "    \n",
    "    y = ivp['y'].T\n",
    "    proj = projection(y.shape[1], dim, **proj_kwargs)\n",
    "    projed = y @ proj\n",
    "    \n",
    "    if noise is not None:\n",
    "        rand = np.random.default_rng(seed)\n",
    "        projed += getattr(rand, noise)(**noise_kwargs, size=projed.shape)\n",
    "    \n",
    "    # Apply stimulation to the projected data\n",
    "    stim = np.array([stimulation(ti, stim_start_time, stim_amplitude, stim_decay_rate) for ti in ivp['t']])\n",
    "    projed[:, 0] += stim  # Assuming the stimulation affects the first dimension\n",
    "    \n",
    "    return ivp['t'], y, projed\n",
    "\n",
    "# Define the function to generate datasets\n",
    "def make_dataset(f, x0, num_trajectories, num_dim, begin, end, noise, stim_timestep=600, stim_amplitude=1.0, stim_decay_rate=0.1):\n",
    "    xx = []\n",
    "    projeds = []\n",
    "    rng = np.random.RandomState(39)\n",
    "    \n",
    "    for _ in trange(num_trajectories):\n",
    "        t, x, projed = gen_data_diffeq(f, random_proj,\n",
    "                                       t=(0, 125), x0=x0 + 0.01 * rng.randn(*x0.shape), dim=num_dim, noise=\"normal\",\n",
    "                                       ivp_kwargs={'max_step': 0.05},\n",
    "                                       noise_kwargs={\"loc\": 0, \"scale\": noise},\n",
    "                                       stim_start_time=0, stim_amplitude=0, stim_decay_rate=0)\n",
    "        t = t[begin:end]\n",
    "        # Calculate the actual start time for stimulation\n",
    "        stim_start_time = t[stim_timestep - begin]\n",
    "        # Regenerate the data with the correct stimulation start time\n",
    "        t, x, projed = gen_data_diffeq(f, random_proj,\n",
    "                                       t=(0, 125), x0=x0 + 0.01 * rng.randn(*x0.shape), dim=num_dim, noise=\"normal\",\n",
    "                                       ivp_kwargs={'max_step': 0.05},\n",
    "                                       noise_kwargs={\"loc\": 0, \"scale\": noise},\n",
    "                                       stim_start_time=stim_start_time, stim_amplitude=stim_amplitude, stim_decay_rate=stim_decay_rate)\n",
    "        t = t[begin:end]\n",
    "        xx.append(x[begin:end])\n",
    "        projeds.append(projed[begin:end])\n",
    "    \n",
    "    xx = np.stack(xx, axis=0)\n",
    "    projeds = np.stack(projeds, axis=0)\n",
    "    \n",
    "    xs = xx\n",
    "    ys = projeds\n",
    "    us = np.zeros((xx.shape[0], xx.shape[1], 1))\n",
    "    \n",
    "    return xs, ys\n",
    "\n",
    "# Generate data using the Van der Pol oscillator with stimulation\n",
    "x0 = np.array([0.1, 0.1])\n",
    "num_trajectories = 1\n",
    "num_dim = 2\n",
    "begin, end = 500, 1100\n",
    "noise = 0.05\n",
    "stim_timestep = 900\n",
    "stim_amplitude = 1.0\n",
    "stim_decay_rate = 0.1\n",
    "\n",
    "xs, ys = make_dataset(vdp, x0, num_trajectories, num_dim, begin, end, noise, stim_timestep, stim_amplitude, stim_decay_rate)\n",
    "\n",
    "# Plot the results\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "# Find the index where stimulation starts\n",
    "stim_index = stim_timestep - begin\n",
    "\n",
    "# Plot the projected states\n",
    "plt.plot(ys[0, :stim_index, 0], ys[0, :stim_index, 1], label='Projected Data (Before Stimulation)')\n",
    "plt.plot(ys[0, stim_index:, 0], ys[0, stim_index:, 1], label='Projected Data (After Stimulation)', color='red')\n",
    "\n",
    "# Plot a green dot where the red plot starts\n",
    "plt.scatter(ys[0, stim_index, 0], ys[0, stim_index, 1], color='green', s=100, label='Stimulation Start')\n",
    "\n",
    "plt.xlabel('Dimension 1')\n",
    "plt.ylabel('Dimension 2')\n",
    "plt.title('Projected Van der Pol Oscillator States')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bw = Bubblewrap(M=100, num=20, B_thresh=-5, num_grad_q=4)\n",
    "al.Pipeline([bw])\n",
    "\n",
    "sources = [ys[0]]\n",
    "\n",
    "with al.plotting_functions.AnimationManager(n_rows=1, n_cols=2, fps=20) as am:\n",
    "    for idx, y in enumerate(ys[0]):\n",
    "        bw.partial_fit_transform([y], stream=0)\n",
    "        \n",
    "        if bw.is_initialized:\n",
    "            for ax in am.axs.flatten():\n",
    "                ax.cla()\n",
    "\n",
    "            am.axs[0,0].scatter(ys[0][:,0], ys[0][:,1], alpha=0)  # good for setting the x and y lims\n",
    "            \n",
    "            bw.show_bubbles_2d(am.axs[0,0])\n",
    "            am.axs[0,0].scatter(ys[0][:idx,0], ys[0][:idx,1], color='gray', s=5)\n",
    "            bw.show_nstep_pdf(am.axs[0,1], other_axis=am.axs[0,0], fig=am.fig, show_colorbar=False)\n",
    "\n",
    "            for ax in am.axs.flatten():\n",
    "                ax.axis(\"off\")\n",
    "            \n",
    "            am.grab_frame()\n",
    "        \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "adaptive_latents",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
