{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from adaptive_latents.input_sources.lds_simulation import LDS\n",
    "from adaptive_latents.input_sources.kalman_filter import StreamingKalmanFilter\n",
    "from adaptive_latents import StimRegressor, ArrayWithTime\n",
    "from adaptive_latents.regressions import BaseMultiKernelRegressor\n",
    "import numpy as np\n",
    "from adaptive_latents import ArrayWithTime\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "rng = np.random.default_rng()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_rotations = 100\n",
    "noise_variance = 0.05\n",
    "stims_per_rotation = 2\n",
    "stim_magnitude = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "aa = np.linspace(-2, 5, 7)[1:2]\n",
    "bb = np.linspace(-2, 5, 8)[1:2]\n",
    "\n",
    "cc = [None]\n",
    "n_repeats = 3\n",
    "\n",
    "\n",
    "\n",
    "mse_s = []\n",
    "traces = []\n",
    "rng = np.random.default_rng(5)\n",
    "\n",
    "with tqdm(total=len(aa)*len(bb) * len(cc) * n_repeats) as pbar:\n",
    "    for _ in range(n_repeats):\n",
    "        _, Y, stim = LDS.run_nest_dynamical_system(n_rotations, stims_per_rotation=stims_per_rotation, stim_magnitude=stim_magnitude, rng=rng, u_function='curvy flips', noise=noise_variance)\n",
    "        mse_s.append([])\n",
    "        traces.append([])\n",
    "        for a in aa:\n",
    "            mse_s[-1].append([])\n",
    "            traces[-1].append([])\n",
    "            for b in bb:\n",
    "                mse_s[-1][-1].append([])\n",
    "                traces[-1][-1].append([])\n",
    "\n",
    "                for c in cc:\n",
    "                    sr = StimRegressor(\n",
    "                        autoreg=StreamingKalmanFilter(),\n",
    "                        stim_reg=BaseMultiKernelRegressor(length_scales=[10**a,1,10**b], maxlen=500),\n",
    "                        log_level=2,\n",
    "                        check_dt=True\n",
    "                    )\n",
    "\n",
    "                    sr.offline_run_on([(Y, 'X'), (stim, 'stim')], convinient_return=False, show_tqdm=False)\n",
    "                    sr.finalize_log(stim)\n",
    "\n",
    "                    flip_time = stim.t[stim.shape[0]//2]\n",
    "                    pred_error = sr.log['pred_error'][:,2]\n",
    "                    stim_pred_error = sr.log['pred_error'].slice_by_time(stim.t[np.squeeze(stim) == 1])[:,2]\n",
    "\n",
    "                    pre_flip_error = np.nanmean(stim_pred_error.slice_by_time(slice(None, flip_time))**2)\n",
    "                    post_flip_error = np.nanmean(stim_pred_error.slice_by_time(slice(flip_time, None))**2)\n",
    "\n",
    "                    traces[-1][-1][-1].append(pred_error)\n",
    "                    mse_s[-1][-1][-1].append((pre_flip_error, post_flip_error))\n",
    "\n",
    "                    pbar.update(1)\n",
    "\n",
    "\n",
    "all_mse_s = np.array(mse_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "10**aa, 10**bb\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "s = (slice(None), slice(None), slice(None), slice(None), 1)\n",
    "idx = np.unravel_index(np.argmin(all_mse_s[s]), all_mse_s[s].shape)\n",
    "print(idx)\n",
    "print(list(map(lambda x: x[0][x[1]], zip([aa,bb,cc],idx[1:]))))\n",
    "\n",
    "fig, axs = plt.subplots(nrows=2, figsize=(10,5))\n",
    "pred_error = traces[idx[0]][idx[1]][idx[2]][idx[3]]\n",
    "axs[0].plot(pred_error)\n",
    "\n",
    "axs[1].plot(pred_error.slice_by_time(slice(flip_time, None)))\n",
    "axs[1].plot(pred_error.slice_by_time(slice(None,flip_time)))\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "for a in np.array(traces, dtype=object)[:,idx[1],idx[2],idx[3]]:\n",
    "    ax.plot(a.t, a)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "from scipy.signal import savgol_filter\n",
    "\n",
    "trs = np.squeeze(np.array(traces,dtype=object))\n",
    "lens = [len(x) for x in trs]\n",
    "t = trs[np.argmax(lens)].t\n",
    "maxlen = max(lens)\n",
    "trs = [np.hstack([np.zeros(maxlen - len(x)) * np.nan, x]) for x in trs]\n",
    "trs = np.array(trs)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "mean_mse = np.nanmean(trs**2, axis=0)\n",
    "smoothed_mean_mse = ArrayWithTime(savgol_filter(mean_mse, 40, 1),t)\n",
    "ax.plot(t, mean_mse, label='raw')\n",
    "ax.plot(t, smoothed_mean_mse, label='smoothed')\n",
    "ax.set_xlabel('time (number of rotations)')\n",
    "ax.set_ylabel('averaged MSE')\n",
    "\n",
    "threshold = float(smoothed_mean_mse.slice_by_time(slice(15,45)).max())\n",
    "ax.axhline(threshold, color='k', linestyle='--')\n",
    "\n",
    "\n",
    "initial = smoothed_mean_mse.slice_by_time(slice(0,30))\n",
    "initial_t = initial.t[np.nonzero(initial < threshold)[0][0]]\n",
    "ax.axvline(initial_t, color='k', alpha=.5)\n",
    "print(f'initial training in {initial_t :.1f} samples')\n",
    "\n",
    "recovery = smoothed_mean_mse.slice_by_time(slice(50,80))\n",
    "recovered_t = recovery.t[np.nonzero(recovery < threshold)[0][0]]\n",
    "ax.axvline(recovered_t, color='k', alpha=.5)\n",
    "print(f'recovered in {(recovered_t-50) :.1f} samples')\n",
    "\n",
    "# ax.set_xlim([50, 55])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_s = all_mse_s[:,:,:,0,:].mean(axis=0)\n",
    "\n",
    "vmin = np.log(mse_s.min())\n",
    "vmax= np.log(mse_s.max())\n",
    "\n",
    "fig, axs = plt.subplots(ncols=2, figsize=(12,5), layout='constrained')\n",
    "axs[0].pcolormesh(bb, aa, np.log(mse_s[...,0]), shading='nearest', vmin=vmin, vmax=vmax)\n",
    "axs[0].set_ylabel('log(length_scale)')\n",
    "axs[0].set_xlabel('log(time_scale)')\n",
    "axs[0].set_title('first half (pre switch)')\n",
    "axs[0].set_xticks(bb)\n",
    "axs[0].set_yticks(aa)\n",
    "\n",
    "# axs[1].pcolormesh(bb, aa, np.log(mse_s.mean(axis=-1)), shading='nearest', vmin=vmin, vmax=vmax)\n",
    "# axs[1].set_xlabel('b')\n",
    "# axs[1].set_title('average')\n",
    "\n",
    "cm = axs[1].pcolormesh(bb, aa, np.log(mse_s[...,1]), shading='nearest', vmin=vmin, vmax=vmax)\n",
    "axs[1].set_ylabel('log(length_scale)')\n",
    "axs[1].set_xlabel('log(time_scale)')\n",
    "axs[1].set_title('second half error (post switch)')\n",
    "axs[1].set_xticks(bb)\n",
    "axs[1].set_yticks(aa)\n",
    "\n",
    "fig.colorbar(mappable=cm)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_min = np.squeeze(all_mse_s.mean(axis=0)[...,1])\n",
    "plt.matshow(to_min)\n",
    "\n",
    "idx = np.unravel_index(np.argmin(to_min), to_min.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(ncols=3, figsize=(15,5))\n",
    "\n",
    "to_dice = all_mse_s[...,0].mean(axis=0)\n",
    "\n",
    "for i, values in enumerate([aa,bb,cc]):\n",
    "    bests = np.min(to_dice, axis=tuple(set(range(3)) - {i}))\n",
    "    axs[i].plot(values,bests)\n",
    "    print(values[np.argmin(bests)])\n",
    "    # axs[1].plot(bb,np.min(to_dice, axis=(0,2)))\n",
    "    # axs[2].plot(cc,np.min(to_dice, axis=(0,1)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %load_ext autoreload\n",
    "# %autoreload 2\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "_, Y, stim = LDS.run_nest_dynamical_system(n_rotations, stims_per_rotation=stims_per_rotation, stim_magnitude=stim_magnitude, rng=rng, u_function='curvy', noise=noise_variance)\n",
    "\n",
    "\n",
    "a = .9\n",
    "sr = StimRegressor(\n",
    "    autoreg=StreamingKalmanFilter(),\n",
    "    stim_reg=BaseMultiKernelRegressor(length_scales=[10**-.43,1,10**15], maxlen=500),\n",
    "    log_level=2,\n",
    "    check_dt=True\n",
    ")\n",
    "\n",
    "sr.offline_run_on([(Y, 'X'), (stim, 'stim')], convinient_return=False, show_tqdm=False)\n",
    "sr.finalize_log(stim)\n",
    "\n",
    "flip_time = stim.t[stim.shape[0]//2]\n",
    "pred_error = sr.log['pred_error'][:,2]\n",
    "stim_pred_error = sr.log['pred_error'].slice_by_time(stim.t[np.squeeze(stim) == 1])[:,2]\n",
    "\n",
    "pre_flip_error = np.nanmean(stim_pred_error.slice_by_time(slice(None, flip_time))**2)\n",
    "post_flip_error = np.nanmean(stim_pred_error.slice_by_time(slice(flip_time, None))**2)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.plot(pred_error)\n",
    "ax.set_ylim([-5,5]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "from adaptive_latents import datasets\n",
    "print(datasets.Odoherty21Dataset().neural_data.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
