{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "import numpy as np\n",
    "from torch.linalg import slogdet\n",
    "\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import adaptive_latents\n",
    "from adaptive_latents import StreamingKalmanFilter, Bubblewrap, ArrayWithTime, Pipeline, proSVD, Tee, VJF, CenteringTransformer\n",
    "from adaptive_latents.regressions import BaseKNearestNeighborRegressor\n",
    "from adaptive_latents.stim_regressor import StimRegressor\n",
    "from tqdm.autonotebook import tqdm\n",
    "\n",
    "rng = np.random.default_rng(0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = adaptive_latents.datasets.Odoherty21Dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def high_d_S(low_d_point, high_d_stim):\n",
    "    return high_d_stim\n",
    "\n",
    "def S(low_d_point, high_d_stim, pro):\n",
    "    return pro.transform(high_d_S(low_d_point, high_d_stim)[None,:])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_experiment(input_arrays, decay_rate = .9, stim_scale=1, rng=None, s_inputs_to_evaluate_on=None):\n",
    "    if rng is None:\n",
    "        rng = np.random.default_rng(0)\n",
    "    centerer = CenteringTransformer()\n",
    "\n",
    "    pro = proSVD(k=10)\n",
    "\n",
    "    sr = StimRegressor(\n",
    "        autoreg=StreamingKalmanFilter(),\n",
    "        stim_reg=BaseKNearestNeighborRegressor(k=1, maxlen=1000),\n",
    "        attempt_correction=True\n",
    "    )\n",
    "\n",
    "    stims = []\n",
    "    predictions = []\n",
    "    latents = []\n",
    "    dt_X = []\n",
    "    s_hat_evals = []\n",
    "    s_eval = []\n",
    "\n",
    "    s_inputs_evaluated_on = []\n",
    "\n",
    "    to_add = np.zeros(input_arrays[0].shape[1])\n",
    "\n",
    "    for input_array in input_arrays:\n",
    "        pbar = tqdm(total=round(input_array.t[-1],2))\n",
    "        for data in Pipeline().streaming_run_on(input_array):\n",
    "\n",
    "            latent_location = pro.transform(centerer.transform(data))\n",
    "\n",
    "            stim = np.zeros(data.shape[1])\n",
    "            if rng.random() < .05 and pro.is_initialized:\n",
    "                stim = pro.Q[:,0] * stim_scale\n",
    "                to_add += high_d_S(latent_location, stim)\n",
    "\n",
    "            if s_inputs_to_evaluate_on == 'record' and np.any(stim):\n",
    "                s_inputs_evaluated_on.append(copy.deepcopy((latent_location, stim, pro)))\n",
    "                s_eval.append(S(latent_location, stim, pro))\n",
    "\n",
    "            data = data + to_add\n",
    "            to_add = to_add * decay_rate\n",
    "\n",
    "            data = centerer.partial_fit_transform(data)\n",
    "            data = pro.partial_fit_transform(data)\n",
    "            latents.append(data)\n",
    "\n",
    "            qX = ArrayWithTime([[1]], data.t)\n",
    "            sr.partial_fit_transform(np.array([[stim]]), stream='stim')\n",
    "            prediction = sr.partial_fit_transform(qX, stream='dt_X')\n",
    "            sr.partial_fit_transform(data, stream='X')\n",
    "\n",
    "\n",
    "\n",
    "            stims.append(ArrayWithTime(stim, data.t))\n",
    "            predictions.append(prediction)\n",
    "            dt_X.append(qX)\n",
    "\n",
    "            if not isinstance(s_inputs_to_evaluate_on, str) and s_inputs_to_evaluate_on is not None and np.any(stim):\n",
    "                point_evals = []\n",
    "                for latent_location, stim, _ in s_inputs_to_evaluate_on:\n",
    "                    stim_reg_input = np.hstack([latent_location.flatten(), stim.flatten()])\n",
    "                    diff = sr.stim_reg.predict(stim_reg_input)\n",
    "                    point_evals.append(diff.flatten())\n",
    "                s_hat_evals.append(ArrayWithTime(point_evals,data.t))\n",
    "\n",
    "                if not s_eval:\n",
    "                    for inputs in s_inputs_to_evaluate_on:\n",
    "                        s_eval.append(S(*inputs))\n",
    "\n",
    "\n",
    "            if not sr.autoreg.parameter_fitting:\n",
    "                sr.autoreg.toggle_parameter_fitting(True)\n",
    "\n",
    "            pbar.update(round(data.t,2) - pbar.n)\n",
    "        sr.autoreg.toggle_parameter_fitting(False)\n",
    "\n",
    "\n",
    "    stims = ArrayWithTime.from_list(stims)\n",
    "    predictions = ArrayWithTime.from_list(predictions, drop_early_nans=True, squeeze_type='to_2d')\n",
    "    latents = ArrayWithTime.from_list(latents, squeeze_type='to_2d')\n",
    "    s_hat_evals = ArrayWithTime.from_list(s_hat_evals, drop_early_nans=True)\n",
    "    s_eval = np.squeeze(s_eval)\n",
    "    return (predictions, latents, stims), (s_hat_evals, s_eval, s_inputs_evaluated_on)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "stim_scale = 20\n",
    "(record_predictions, record_latents, record_stims), (record_s_hat_evals, record_s_eval, record_s_inputs_evaluated_on) = do_experiment([d.neural_data], stim_scale=stim_scale, s_inputs_to_evaluate_on='record')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "(eval_predictions, eval_latents, eval_stims), (eval_s_hat_evals, eval_s_eval, eval_s_inputs_evaluated_on) = do_experiment([d.neural_data], stim_scale=stim_scale, s_inputs_to_evaluate_on=record_s_inputs_evaluated_on)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "assert np.array_equal(record_predictions, eval_predictions, equal_nan=True)\n",
    "assert np.array_equal(record_latents, eval_latents, equal_nan=True)\n",
    "assert np.array_equal(record_stims, eval_stims, equal_nan=True)\n",
    "assert np.array_equal(record_s_eval, eval_s_eval, equal_nan=True)\n",
    "assert record_s_hat_evals.size == 0\n",
    "assert len(eval_s_inputs_evaluated_on) == 0\n",
    "predictions, latents, stims, s_hat_evals, s_eval = eval_predictions, eval_latents, eval_stims, eval_s_hat_evals, record_s_eval\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"evaluated on {s_hat_evals.shape[0]} timepoints; {s_hat_evals.shape[1]} test points with {s_hat_evals.shape[2]}-d returns\")\n",
    "error = np.linalg.norm(s_hat_evals - s_eval, axis=2).mean(axis=1)\n",
    "error = ArrayWithTime.from_transformed_data(error, s_hat_evals)\n",
    "pred_zero_error = np.linalg.norm(s_hat_evals*0 - s_eval, axis=2).mean(axis=1)[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "rig, ax = plt.subplots()\n",
    "ax.scatter(error.t, error, s=2)\n",
    "ax.axhline(pred_zero_error, linestyle='--', color='k')\n",
    "ax.set_ylim([0, ax.get_ylim()[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "plt.matshow(stims[stims.any(axis=1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "i = 0\n",
    "ax.plot(latents.t, latents[:,i])\n",
    "ax.plot(predictions.t, predictions[:,0,i])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
