{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from adaptive_latents import ArrayWithTime\n",
    "from adaptive_latents.input_sources import LDS\n",
    "import matplotlib.pyplot as plt\n",
    "from adaptive_latents import VJF, Bubblewrap, StreamingKalmanFilter\n",
    "from adaptive_latents.regressions import BaseKNearestNeighborRegressor\n",
    "from adaptive_latents.stim_regressor import StimRegressor\n",
    "import numpy as np\n",
    "import itertools\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(1)\n",
    "\n",
    "_, X, stim = LDS.run_nest_dynamical_system(500, transitions_per_rotation=60 + 1 / np.pi, u_function='curvy', rng=rng, early_shift=1e-8)\n",
    "\n",
    "stim.t = stim.t "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "fig, axs = plt.subplots(ncols=2, subplot_kw={'projection': '3d'})\n",
    "for ax in axs:\n",
    "    ax.plot(X[:,0], X[:,1], X[:,2])\n",
    "    ax.axis('equal')\n",
    "axs[0].view_init(elev=90, azim=0, roll=0)\n",
    "axs[1].view_init(elev=0, azim=0, roll=0)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "StimRegressor().test_if_api_compatible();\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 1\n",
    "qX = ArrayWithTime(n_steps * np.ones((stim.shape[0] - 1, 1)), (stim.t[1:] + X.t[1:])/2)\n",
    "\n",
    "s1 = StimRegressor(autoreg=StreamingKalmanFilter(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=True)\n",
    "o1 = s1.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)\n",
    "\n",
    "s2 = StimRegressor(autoreg=Bubblewrap(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=True)\n",
    "o2 = s2.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)\n",
    "\n",
    "s3 = StimRegressor(autoreg=VJF(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=True)\n",
    "o3 = s3.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)\n",
    "\n",
    "s4 = StimRegressor(autoreg=StreamingKalmanFilter(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=False)\n",
    "o4 = s4.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "_, axs = plt.subplots(squeeze=False)\n",
    "\n",
    "order = list(reversed(['kalman filter', 'bubblewrap', 'vjf', 'kalman filter (no correction)']))\n",
    "\n",
    "test_dim = 0\n",
    "on_stim = 0\n",
    "s = (stim.t[stim.flatten() == on_stim], test_dim)\n",
    "for o, name in zip([o4, o3, o2, o1], order):\n",
    "    error = ArrayWithTime.subtract_aligned_indices(o, X)\n",
    "    to_plot = error.slice_by_time(*s, all_axes=True)\n",
    "    true = X.slice_by_time(*s, all_axes=True)\n",
    "    \n",
    "    print(f\"{name: <30} {np.nanstd(to_plot)/ np.nanstd(true):.2f}\")\n",
    "    axs[0,0].plot(to_plot.t, to_plot, '.')\n",
    "\n",
    "\n",
    "\n",
    "axs[0,0].legend(order)\n",
    "axs[0,0].set_title(f'comparison of d{test_dim} on {\"stim\" if on_stim else \"non-stim\"} samples')\n",
    "axs[0,0].set_ylabel('prediction error')\n",
    "axs[0,0].set_xlabel('time (a.u., technically rotations)');\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, ax = plt.subplots(subplot_kw={'projection': '3d'})\n",
    "\n",
    "XX, YY = np.meshgrid(np.linspace(-10,10,14), np.linspace(-10,10,12))\n",
    "Z = 0 * XX\n",
    "\n",
    "def S(state, s):\n",
    "    return s * state[0] / np.linalg.norm(state[:2])\n",
    "\n",
    "for i_x, i_y in itertools.product(range(XX.shape[0]), range(XX.shape[1])):\n",
    "    Z[i_x, i_y] = S([XX[i_x,i_y], YY[i_x,i_y], None], 1)\n",
    "\n",
    "ax.plot_surface(XX, YY, Z, zorder=10)\n",
    "ax.plot(s1.stim_reg.history[:,0], s1.stim_reg.history[:,1], s1.stim_reg.history[:,5], '.', zorder=10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, ax = plt.subplots(subplot_kw={'projection': '3d'})\n",
    "\n",
    "n_points = 100\n",
    "\n",
    "XX, YY = np.meshgrid(np.linspace(-10,10,n_points), np.linspace(-10,10,n_points))\n",
    "Z = 0 * XX\n",
    "\n",
    "def S(state, s):\n",
    "    return s * state[0] / np.linalg.norm(state[:2])\n",
    "\n",
    "for i_x, i_y in itertools.product(range(XX.shape[0]), range(XX.shape[1])):\n",
    "    Z[i_x, i_y] = s1.stim_reg.predict(np.array([XX[i_x, i_y], YY[i_x, i_y], 0]))[-1] #  - S([XX[i_x,i_y], YY[i_x,i_y], None], 1)\n",
    "\n",
    "ax.plot_surface(XX, YY, Z, zorder=10)\n",
    "print(Z.mean())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
