{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from adaptive_latents import ArrayWithTime\n",
    "from adaptive_latents.input_sources import LDS\n",
    "from adaptive_latents import VJF, Bubblewrap, StreamingKalmanFilter\n",
    "from adaptive_latents.regressions import BaseKNearestNeighborRegressor\n",
    "from adaptive_latents.stim_regressor import StimRegressor\n",
    "import numpy as np\n",
    "from tqdm.notebook import trange\n",
    "import functools\n",
    "import copy\n",
    "\n",
    "rng = np.random.default_rng()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "transitions_per_rotation = 30 + 1/np.pi\n",
    "stims_per_rotation = 3\n",
    "n_rotations = 250\n",
    "initial_state = np.array([100,0,0], dtype=np.float64)\n",
    "assert stims_per_rotation < transitions_per_rotation/2\n",
    "\n",
    "\n",
    "lds = LDS.nest_lds(transitions_per_rotation=transitions_per_rotation, rng=rng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def S(state, stim):\n",
    "    # return stim* np.array([0,0, np.exp(-state @ state/60**2)]) * 3\n",
    "    # return stim* np.array([0,0, np.cos(state[0]/90)]) * 3\n",
    "    return stim*np.array([0, state[0]/np.linalg.norm(state[:2]), 0]) * 3\n",
    "    # return np.ones(3)*stim*10 + rng.normal(size=3)\n",
    "    # return stim*np.array([np.sin(state[0] / 5), np.sin(state[1]/5), np.sin(state[2]/5)]) * 10\n",
    "\n",
    "\n",
    "def U(lds, state, i, rng):\n",
    "    stim = float(rng.random() < stims_per_rotation/transitions_per_rotation)\n",
    "    u = S(state, stim)\n",
    "    return u\n",
    "\n",
    "alternate_histories = []\n",
    "for _ in trange(100):\n",
    "    X, _, _ = copy.deepcopy(lds).simulate(n_steps=int(n_rotations*transitions_per_rotation), initial_state=initial_state, U=U, rng=rng)\n",
    "    alternate_histories.append(X)\n",
    "alternate_histories = np.array(alternate_histories)\n",
    "originally_sampled_points = rng.choice(alternate_histories.reshape((-1,3)), size=1_000, replace=False)\n",
    "\n",
    "def evaluate(function):\n",
    "    ret = []\n",
    "    for point in originally_sampled_points:\n",
    "        ret.append(function(np.hstack([point, [1]])))\n",
    "    return np.array(ret)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "sr = StimRegressor(autoreg=StreamingKalmanFilter(), stim_reg=BaseKNearestNeighborRegressor(k=20, maxlen=1000), attempt_correction=True)\n",
    "\n",
    "state = initial_state\n",
    "\n",
    "observations = []\n",
    "stims = []\n",
    "predictions = []\n",
    "dt_X = []\n",
    "\n",
    "s_hat_evals = []\n",
    "a_errors = []\n",
    "full_iteration = None\n",
    "\n",
    "\n",
    "n_steps = int(n_rotations * transitions_per_rotation)\n",
    "for i in trange(n_steps):\n",
    "    state, observation, delivered_stim = lds.simulate_step(state, rng, u_function=U, i=i, use_state_dynamics=i != 0, add_centers=False)\n",
    "\n",
    "    stim = (delivered_stim).any()\n",
    "\n",
    "    qX = ArrayWithTime([[1]], i)\n",
    "    sr.partial_fit_transform(np.array([[stim]]), stream='stim')\n",
    "    prediction = sr.partial_fit_transform(qX, stream='dt_X')\n",
    "    sr.partial_fit_transform(observation[None,:], stream='X')\n",
    "\n",
    "\n",
    "    observations.append(ArrayWithTime(observation, i))\n",
    "    stims.append(ArrayWithTime(stim, i))\n",
    "    predictions.append(prediction)\n",
    "    dt_X.append(qX)\n",
    "\n",
    "    if stim and (sr.autoreg.get_arbitrary_dynamics_parameter() is not None):\n",
    "        e = evaluate(sr.stim_reg.predict)\n",
    "        s_hat_evals.append(ArrayWithTime(e, i))\n",
    "        a_errors.append(ArrayWithTime((sr.autoreg.A - lds.A)**2, i))\n",
    "        if not np.isnan(sr.stim_reg.history).any() and full_iteration is None:\n",
    "            full_iteration = i\n",
    "\n",
    "observations = ArrayWithTime.from_list(observations, squeeze_type='to_2d')\n",
    "stims = ArrayWithTime.from_notime(stims)\n",
    "# predictions = ArrayWithTime.from_list(predictions, drop_early_nans=True)\n",
    "dt_X = ArrayWithTime.from_list(dt_X, squeeze_type='to_2d')\n",
    "\n",
    "s_hat_evals = ArrayWithTime.from_list(s_hat_evals, drop_early_nans=True)\n",
    "a_errors = ArrayWithTime.from_list(a_errors)\n",
    "\n",
    "s_eval = evaluate(functools.partial(S, stim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, axs = plt.subplots(nrows=2, figsize=(10,5))\n",
    "ax = axs[0]\n",
    "\n",
    "s_hat_errors = s_hat_evals - s_eval\n",
    "\n",
    "mse_per_eval = (s_hat_errors**2).mean(axis=1)\n",
    "\n",
    "ax.plot(s_hat_evals.t, mse_per_eval, '.-')\n",
    "\n",
    "\n",
    "if (s_eval != 0).any():\n",
    "    zero_estimator_mse = ((0-s_eval)**2).mean(axis=0)\n",
    "    for idx, line in enumerate(zero_estimator_mse):\n",
    "        ax.axhline(line, color=f'C{idx}', alpha=.3)\n",
    "\n",
    "if full_iteration is not None:\n",
    "    ax.axvline(full_iteration, color='k', alpha=.3)\n",
    "\n",
    "ax = axs[1]\n",
    "ax.plot(a_errors.t, np.log(np.linalg.norm(a_errors, axis=(1,2))), 'C3--', alpha=.5);\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(ncols=2, figsize=(12,6), sharex=True, sharey=True)\n",
    "\n",
    "scatters = []\n",
    "\n",
    "c = [S(p, 1)[2] for p in originally_sampled_points]\n",
    "s = axs[0].scatter(originally_sampled_points[:,0], originally_sampled_points[:,1], c=c)\n",
    "scatters.append(s)\n",
    "\n",
    "s = axs[1].scatter(sr.stim_reg.history[:,0], sr.stim_reg.history[:,1], c=sr.stim_reg.history[:,5] )\n",
    "scatters.append(s)\n",
    "\n",
    "\n",
    "clim = np.array([s.get_clim() for s in scatters])\n",
    "clim = (clim[:,0].min(), clim[:,1].max())\n",
    "for s in scatters:\n",
    "    s.set_clim(clim)\n",
    "\n",
    "\n",
    "fig.colorbar(s, ax = axs);\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(observations[:,0], observations[:,1], c=observations.t, cmap='viridis')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "\n",
    "ax.plot(np.arange(alternate_histories.shape[1])/transitions_per_rotation, np.linalg.norm(alternate_histories, axis=-1).T)\n",
    "ax.set_xlabel('rotations')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
