{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import adaptive_latents as al\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import NMF\n",
    "from adaptive_latents.prediction_regression_run import pred_reg_run, defaults_per_dataset\n",
    "import copy\n",
    "\n",
    "rng = np.random.default_rng()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = al.datasets.Odoherty21Dataset(bin_width=0.06, drop_third_coord=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "neural_data = d.neural_data.copy()\n",
    "\n",
    "\n",
    "u,s,vh = np.linalg.svd(neural_data - neural_data.mean(axis=0), full_matrices=False)\n",
    "\n",
    "nmf = NMF(n_components=2)\n",
    "nmf.fit(neural_data)\n",
    "\n",
    "def make_pc_ok(pc):\n",
    "    return pc / np.abs(pc).max() * 5\n",
    "\n",
    "\n",
    "response_time = np.arange(6)\n",
    "response_decay = np.exp(-response_time/1.5) \n",
    "\n",
    "response_directions = {\n",
    "    '99th quantile': np.quantile(neural_data, q=.99, axis=0) * .5,\n",
    "    'mean': np.mean(neural_data, axis=0) * 5,\n",
    "    'std': np.std(neural_data, axis=0) * 2,\n",
    "    'pc1': make_pc_ok(vh[0]),\n",
    "    'pc2': make_pc_ok(vh[1]),\n",
    "    'pc15': make_pc_ok(vh[15]),\n",
    "    'pc-1': make_pc_ok(vh[-1]),\n",
    "    'nmf1': make_pc_ok(nmf.components_[0]),\n",
    "    'nmf2': make_pc_ok(nmf.components_[1]),\n",
    "}\n",
    "\n",
    "stim_times = []\n",
    "stim_batch_times = [80, 120, 160]\n",
    "stim_batch_magnitudes = [2, 1, -1]\n",
    "stim_batch_samples = [neural_data.time_to_sample(t) for t in stim_batch_times]\n",
    "\n",
    "for batch_start, magnitude in zip(stim_batch_samples, stim_batch_magnitudes):\n",
    "    stim_times.append([])\n",
    "    for i, (k, response_direction) in enumerate(response_directions.items()):\n",
    "        start = batch_start + i * len(response_time) * 2\n",
    "        response = response_decay[:,None] @ response_direction[None,:]\n",
    "\n",
    "        neural_data[start + response_time ,:] += response * magnitude\n",
    "        stim_times[-1].append(neural_data.t[start])\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots()\n",
    "stim_batch = 0\n",
    "cutout = neural_data.slice_by_time(stim_batch_times[stim_batch]-1, stim_batch_times[stim_batch]+6)\n",
    "axs.matshow(cutout, extent=[0,cutout.shape[1],cutout.t[-1],cutout.t[0]], vmin=0, vmax=d.neural_data.max()*1.4, aspect='auto', origin='upper')\n",
    "axs.set_xlabel(\"neuron #\")\n",
    "axs.set_ylabel(\"time\")\n",
    "axs.set_title(\"stimuli visualized in firing rate matrix\")\n",
    "axs.set_yticks(stim_times[stim_batch])\n",
    "axs.set_yticklabels(list(response_directions.keys()));\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "rarr = np.vstack(list(response_directions.values()))\n",
    "plt.matshow(rarr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = pred_reg_run(neural_data, d.behavioral_data[:,:0], d.neural_data, dim_red_method='sjpca', **defaults_per_dataset['odoherty21'])\n",
    "run2 = pred_reg_run(d.neural_data, d.behavioral_data[:,:0], d.neural_data, dim_red_method='sjpca', **defaults_per_dataset['odoherty21'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, axs = plt.subplots(nrows=3, layout='tight')\n",
    "\n",
    "latents = run.dim_reduced_data \n",
    "latents: al.ArrayWithTime\n",
    "\n",
    "flat_stim_times = [leaf for tree in stim_times for leaf in tree]\n",
    "\n",
    "for i, ax in enumerate(axs):\n",
    "    ax.plot(latents.t, latents[:,:])\n",
    "    ax.set_xticks(flat_stim_times)\n",
    "    for t in flat_stim_times:\n",
    "        ax.axvline(t, color='k', linestyle='--', alpha=.5)\n",
    "    ax.set_xticklabels(list(response_directions.keys()) *3 )\n",
    "    ax.set_xlim(np.array([-1, 7]) + stim_batch_times[i])\n",
    "    ax.set_title(f\"magnitude={stim_batch_magnitudes[i]}\")\n",
    "e = {k:al.utils.column_space_distance(v[:,None]/np.linalg.norm(v), run.pipeline.steps[-3].Q) * 180/np.pi for k,v in response_directions.items()}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "window_width = 7\n",
    "\n",
    "ws = []\n",
    "\n",
    "for i in range(3):\n",
    "    sss = latents.time_to_sample(stim_times[i][1])\n",
    "    derived_response = (latents[sss:sss+window_width] - latents[sss-1])\n",
    "    w, _, _, _ = np.linalg.lstsq(response_decay[:,None],derived_response.T)\n",
    "    ws.append(w)\n",
    "\n",
    "plt.plot(np.squeeze(ws).T)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "ax.bar(e.keys(), np.array(list(e.values())))\n",
    "ax.set_xlabel('method')\n",
    "ax.set_ylabel('angle from proSVD space')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
