{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from adaptive_latents import ArrayWithTime, Bubblewrap\n",
    "import adaptive_latents as al\n",
    "import pandas as pd\n",
    "import itertools\n",
    "\n",
    "from adaptive_latents.transformer import DecoupledTransformer, StreamingTransformer, Concatenator\n",
    "from adaptive_latents.regressions import BaseVanillaOnlineRegressor, BaseKNearestNeighborRegressor, OnlineRegressor\n",
    "\n",
    "\n",
    "rng = np.random.default_rng()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "## StimRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "## Trivial manual example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "X = np.array([\n",
    "    [0,0], \n",
    "    [1,1], \n",
    "    [2,2], \n",
    "    [3,3], \n",
    "    [4,4], \n",
    "    [5,7], # stim\n",
    "    [6,8], \n",
    "    [7,11], # stim\n",
    "    [8,12], \n",
    "    [9,15], # stim\n",
    "    [10,16],\n",
    "])[:,1:2]\n",
    "stim = np.zeros(shape=(X.shape[0], 1))\n",
    "stim[5] = 1\n",
    "stim[7] = 1\n",
    "stim[9] = 1\n",
    "\n",
    "\n",
    "X = al.ArrayWithTime(X, np.arange(X.shape[0]))\n",
    "stim = al.ArrayWithTime(stim, np.arange(stim.shape[0]) - 0.001)\n",
    "\n",
    "s = StimRegressor(input_streams={1:'X', 0:'stim'}, spatial_stim_response=True)\n",
    "s.offline_run_on([stim, X]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def operation(x):\n",
    "    return al.ArrayWithTime.from_list(x, drop_early_nans=True, squeeze_type='to_2d')[:, slice(-1,None)]\n",
    "\n",
    "pred = operation(s.predictions)\n",
    "p1 = operation(s.auto_pred)\n",
    "p2 = operation(s.stim_pred[1:])\n",
    "true = operation(X.slice(-pred.shape[0],None))\n",
    "st = stim.slice(2,None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df = pd.DataFrame(np.hstack([st, p1, p2, pred, true]))\n",
    "d = true.shape[-1]\n",
    "df.columns = ['stim'] + ['p1']*d + ['p2']*d + ['pred']*d + ['true']*d\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(squeeze=False)\n",
    "axs[0,0].plot(X.t, X[:,-1], '.-')\n",
    "# axs[0,0].plot(stim.t, stim,'.')\n",
    "axs[0,0].plot(pred.t, pred, '.-')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "## Nest example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "dt = 0.05\n",
    "\n",
    "A = np.array([\n",
    "    [np.cos(dt),  -np.sin(dt), 0],\n",
    "    [np.sin(dt),   np.cos(dt), 0],\n",
    "    [         0,            0, .99]\n",
    "])\n",
    "\n",
    "def C(x,y):\n",
    "    return y * x[0] / np.linalg.norm(x[:2])\n",
    "\n",
    "ts = np.arange(0, 500*2.1*np.pi, dt)\n",
    "\n",
    "stim = (ts * 0).reshape(-1, 1)\n",
    "\n",
    "stim[np.random.default_rng(0).choice(stim.shape[0], size=(ts.max()*.1).astype(int), replace=False)] = 1\n",
    "\n",
    "X_true = np.zeros((ts.size, 3))\n",
    "X_true[0,1] = 3\n",
    "\n",
    "for i, t in enumerate(ts):\n",
    "    if i == 0:\n",
    "        continue\n",
    "    X_true[i] = A @ X_true[i-1]\n",
    "    X_true[i,2] += C(X_true[i-1], stim[i])\n",
    "    X_true[i] += rng.normal(0, 0.01, X_true[i].shape)\n",
    "\n",
    "X = X_true + rng.normal(0, 0.01, X_true.shape)\n",
    "\n",
    "X = ArrayWithTime(X, ts)\n",
    "stim = ArrayWithTime(stim, ts-1e-8)\n",
    "\n",
    "fig, ax = plt.subplots(subplot_kw={'projection': '3d'})\n",
    "ax.plot(X[:,0], X[:,1], X[:,2])\n",
    "ax.axis('equal')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = StimRegressor(input_streams={0:'X', 1:'stim'}, spatial_stim_response=True)\n",
    "s1.offline_run_on([X, stim], show_tqdm=True)\n",
    "\n",
    "s2 = StimRegressor(input_streams={0:'X', 1:'stim'}, spatial_stim_response=False)\n",
    "s2.offline_run_on([X, stim], show_tqdm=True);\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "_, axs = plt.subplots(nrows=2,squeeze=False, sharex=True)\n",
    "\n",
    "for i,s in enumerate([s2, s1]):\n",
    "    pred = operation(s.predictions)\n",
    "    axs[i,0].plot(X.t, X[:,-1])\n",
    "    axs[i,0].plot(pred.t, pred[:,-1])\n",
    "    axs[i,0].legend(['true system', 'predicted'])\n",
    "    axs[i,0].set_xlim([2511.5,2513])\n",
    "    axs[i,0].set_ylim([-1,2])\n",
    "    \n",
    "axs[0,0].set_title('no stim regression')\n",
    "axs[1,0].set_title('with stim regression')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "_, axs = plt.subplots(squeeze=False)\n",
    "\n",
    "for s in [s2,s1]:\n",
    "    pred = operation(s.predictions)\n",
    "    sl = stim[-pred.shape[0]:,0]==1\n",
    "    error = (pred - X[-pred.shape[0]:])\n",
    "    axs[0,0].plot(error.t[sl], error[sl,-1], '.')\n",
    "axs[0,0].legend(['no stim regression', 'with stim regression'])\n",
    "axs[0,0].set_title('Error comparison for timepoints with stimulation')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, ax = plt.subplots(subplot_kw={'projection': '3d'})\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-6,6,10), np.linspace(-6,6,10))\n",
    "Z = 0 * X\n",
    "\n",
    "for i_x, i_y in itertools.product(range(10), range(10)):\n",
    "    Z[i_x, i_y] = C([X[i_x,i_y], Y[i_x,i_y], None], 1)\n",
    "\n",
    "ax.plot_surface(X, Y, Z, zorder=10)\n",
    "plt.plot(s1.stim_reg.history[:,0], s1.stim_reg.history[:,1], s1.stim_reg.history[:,5], '.', zorder=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.unique(stim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
