{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from adaptive_latents import datasets\n",
    "from sim_stim import make_srs\n",
    "import pathlib\n",
    "import matplotlib.pyplot as plt\n",
    "from adaptive_latents.plotting_functions import AnimationManager\n",
    "from tqdm.autonotebook import tqdm\n",
    "from IPython.display import display, clear_output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(0)\n",
    "d = datasets.Zong22Dataset()\n",
    "data = d.neural_data\n",
    "\n",
    "srs = make_srs(data, rng, comparison_preset='visualization', n_runs=1, show_tqdm=True)\n",
    "\n",
    "\n",
    "i= 40\n",
    "sr = srs['learning from stim'][0]\n",
    "\n",
    "fig, axs = plt.subplots(ncols=2, figsize=(10,4), sharex=False, sharey=False, layout='constrained')\n",
    "\n",
    "latents = sr.log['latents'].slice_by_time(slice(30,None))\n",
    "axs[0].plot(latents[:, 0], latents[:, 1], alpha=.1, color='k')\n",
    "stim_s = sr.log['stim_intended_samples'].t - latents.dt\n",
    "\n",
    "l = 1\n",
    "r = 5.1\n",
    "ax_n = 0\n",
    "center_t = sr.log['stim_intended_samples'].t[i]\n",
    "latents = sr.log['latents'].slice_by_time(slice(center_t-l,center_t+r))\n",
    "line = axs[ax_n].plot(latents[:, 0], latents[:, 1])\n",
    "stim_s = sr.log['stim_intended_samples'].slice_by_time(slice(center_t-l,center_t+r)).t - latents.dt\n",
    "latents_s = latents.slice_by_time(stim_s).reshape((-1, latents.shape[1]))\n",
    "axs[ax_n].plot(latents_s[:, 0], latents_s[:, 1], '.', color='r')\n",
    "\n",
    "for arrow_index in [17, 50]:\n",
    "    axs[0].annotate('',\n",
    "                    xytext=(latents[arrow_index, 0], latents[arrow_index, 1]),\n",
    "                    xy=(latents[arrow_index+1, 0], latents[arrow_index+1, 1]),\n",
    "                    arrowprops=dict(arrowstyle=\"simple\", color='C0'),\n",
    "                    size=11\n",
    "                    )\n",
    "\n",
    "\n",
    "u = sr.stim_designer.log[i]['u']\n",
    "idx = np.argsort(np.abs(u))[::-1]\n",
    "print(np.linalg.norm(u,ord=0))\n",
    "\n",
    "high_d = sr.log['high_d_with_stim'].slice_by_time(slice(center_t-l,center_t+r))\n",
    "axs[1].plot(high_d.t, high_d[:,idx[:int(np.linalg.norm(u,ord=0))]]);\n",
    "for stim_t in stim_s:\n",
    "    axs[1].axvline(stim_t, color='r')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "\n",
    "\n",
    "latents = sr.log['latents'].slice_by_time(slice(30,None))\n",
    "ax.plot(latents[:, 0], latents[:, 1], alpha=.1, color='k')\n",
    "\n",
    "i=40\n",
    "l = 1\n",
    "r = 5.1\n",
    "center_t = sr.log['stim_intended_samples'].t[i]\n",
    "latents = sr.log['latents'].slice_by_time(slice(center_t-l,center_t+r))\n",
    "line = ax.plot(latents[:, 0], latents[:, 1])\n",
    "stim_s = sr.log['stim_intended_samples'].slice_by_time(slice(center_t-l,center_t+r)).t - latents.dt\n",
    "latents_s = latents.slice_by_time(stim_s).reshape((-1, latents.shape[1]))\n",
    "ax.plot(latents_s[:, 0], latents_s[:, 1], '.', color='r')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.animation import FFMpegWriter, PillowWriter, HTMLWriter, ImageMagickWriter\n",
    "\n",
    "latents = sr.log['latents'].slice_by_time(slice(30,None))\n",
    "\n",
    "\n",
    "# am = AnimationManager(filename_stem='stim_video', outdir='.', filetype='gif', dpi=400)\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(1, 1, figsize=(10,10), layout='constrained', squeeze=False)\n",
    "\n",
    "# movie_writer = FFMpegWriter(codec='libvpx-vp9', fps=20, bitrate=-1)\n",
    "movie_writer = ImageMagickWriter(fps=20, bitrate=-1)\n",
    "movie_writer.setup(fig, 'zong_stim.gif', dpi=100)\n",
    "\n",
    "fig.patch.set_alpha(0.)\n",
    "for t in tqdm(np.linspace(300, 311, 200+1)):\n",
    "    ax = axs[0,0]\n",
    "    ax.cla()\n",
    "    ax.patch.set_alpha(0.)\n",
    "\n",
    "    ax.plot(latents[:, 0], latents[:, 1], alpha=.5, color='k', lw=1)\n",
    "\n",
    "    sl = slice(t-2,t)\n",
    "    sub_latents = latents.slice_by_time(sl)\n",
    "    ax.plot(sub_latents[:, 0], sub_latents[:, 1], color='k', lw=6)\n",
    "\n",
    "    stim_s = sr.log['stim_intended_samples'].slice_by_time(sl).t - latents.dt\n",
    "    latents_s = latents.slice_by_time(stim_s).reshape((-1, latents.shape[1]))\n",
    "    ax.scatter(latents_s[:, 0], latents_s[:, 1], s=75, color='r', zorder=1000)\n",
    "\n",
    "\n",
    "    movie_writer.grab_frame()\n",
    "\n",
    "movie_writer.finish()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
