{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax\n",
    "from matplotlib import pyplot as plt\n",
    "from adaptive_latents.input_sources.autoregressor import AR_K\n",
    "from tqdm.notebook import tqdm\n",
    "import pandas as pd\n",
    "from typing import Literal\n",
    "\n",
    "jax.config.update('jax_enable_x64', True)\n",
    "import adaptive_latents as al\n",
    "rng = np.random.default_rng()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = al.datasets.Naumann24uDataset(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = d.get_rectangular_block(70)\n",
    "\n",
    "targets = np.sort(np.unique(d.opto_stimulations.target_neuron))\n",
    "opto_stims = np.zeros((a.shape[0], targets.size))\n",
    "for idx, row in d.opto_stimulations.iterrows():\n",
    "    assert (a.t == row.time).sum() == 1\n",
    "    neuron_index = np.nonzero((targets == row.target_neuron))[0][0]\n",
    "    opto_stims[(a.t == row.time), neuron_index] = 1\n",
    "\n",
    "\n",
    "angles = np.sort(np.unique(d.visual_stimuli.l_angle))\n",
    "visual_stimuli = np.zeros((a.shape[0], angles.size))\n",
    "for idx, row in d.visual_stimuli.iterrows():\n",
    "    assert (a.t == row.time).sum() == 1\n",
    "    angle_index = np.nonzero((angles == row.l_angle))[0][0]\n",
    "    visual_stimuli[(a.t == row.time), angle_index] = 1\n",
    "\n",
    "stims = al.ArrayWithTime(np.hstack([opto_stims, visual_stimuli]), a.t)\n",
    "\n",
    "ar = AR_K(k=20, rank_limit=None)\n",
    "ar.fit(a, stims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(ar):\n",
    "    window_time = 13 # seconds\n",
    "    n_steps = int(window_time//a.dt) + 1\n",
    "    errors = []\n",
    "\n",
    "    for targets_index in range(12):\n",
    "        for idx, start_t in enumerate(d.opto_stimulations[d.opto_stimulations.target_neuron == targets[targets_index]].time):\n",
    "            trial = a.slice_by_time(start_t, start_t + window_time)\n",
    "\n",
    "            pre_trial = a.slice_by_time(start_t - ar.k * a.dt - 1, start_t)\n",
    "            new_stims = np.zeros((n_steps+ar.k, stims.shape[1]))\n",
    "            new_stims[ar.k, targets_index] = 1\n",
    "            starting_state = pre_trial[-ar.k:]\n",
    "            prediction = ar.predict(starting_state, new_stims, n_steps=n_steps)\n",
    "\n",
    "            errors.append(trial - prediction)\n",
    "    return ((np.array(errors)**2).mean())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ks = np.arange(1, 30)\n",
    "k_results = []\n",
    "\n",
    "for k in tqdm(ks):\n",
    "    ar = AR_K(k=k, rank_limit=None)\n",
    "    ar.fit(a, stims)\n",
    "    k_results.append(evaluate(ar))\n",
    "\n",
    "\n",
    "best_k = ks[np.argmin(k_results)]\n",
    "ar = AR_K(k=best_k, rank_limit=None)\n",
    "ar.fit(a, stims)\n",
    "full_rank_baseline = evaluate(ar)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig,ax = plt.subplots()\n",
    "\n",
    "\n",
    "ax.plot(ks, k_results)\n",
    "ax.axhline(y=full_rank_baseline, color='k')\n",
    "ax.set_xlabel('number of autoregression steps')\n",
    "ax.set_ylabel('mse')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "rank_limits = np.arange(1, 30)\n",
    "rl_results = []\n",
    "\n",
    "for rl in tqdm(rank_limits):\n",
    "    ar = AR_K(k=best_k, rank_limit=rl)\n",
    "    ar.fit(a, stims)\n",
    "    rl_results.append(evaluate(ar))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig,ax = plt.subplots()\n",
    "ax.plot(rank_limits, rl_results)\n",
    "\n",
    "ax.axhline(y=full_rank_baseline, color='k')\n",
    "ax.set_xlabel('rank constraint')\n",
    "ax.set_ylabel('mse')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ar = AR_K(k=30, rank_limit=None)\n",
    "ar.fit(a, stims)\n",
    "\n",
    "fig, axs = plt.subplots(nrows=5, ncols=2, figsize=(10, 10), tight_layout=True, sharey=True)\n",
    "window_time = 13 # seconds\n",
    "n_steps = int(window_time//a.dt) + 1\n",
    "targets_index = 7\n",
    "errors = []\n",
    "truths = []\n",
    "\n",
    "for idx, start_t in enumerate(d.opto_stimulations[d.opto_stimulations.target_neuron == targets[targets_index]].time):\n",
    "    trial = a.slice_by_time(start_t, start_t + window_time) \n",
    "\n",
    "    pre_trial = a.slice_by_time(start_t - ar.k * a.dt - 1, start_t)\n",
    "    new_stims = np.zeros((n_steps+ar.k, stims.shape[1]))\n",
    "    new_stims[ar.k, targets_index] = 1\n",
    "    starting_state = pre_trial[-ar.k:]\n",
    "    prediction = ar.predict(starting_state, new_stims, n_steps=n_steps)\n",
    "\n",
    "\n",
    "    axs[idx,0].plot(trial)\n",
    "    axs[idx,1].plot(prediction)\n",
    "\n",
    "    truths.append(trial)\n",
    "    errors.append(prediction - trial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "window_time = 13\n",
    "\n",
    "_groups = Literal['opto', 'vis', 'rand']\n",
    "def get_responses(n=None, group:_groups='opto', neuron=0, estimated=False):\n",
    "    if group == 'opto':\n",
    "        times = d.opto_stimulations[d.opto_stimulations.target_neuron == targets[n]].time\n",
    "    elif group == 'vis':\n",
    "        times = d.visual_stimuli[d.visual_stimuli.l_angle == angles[n]].time\n",
    "    elif group == 'rand':\n",
    "        times = rng.uniform(low=a.t[0] + ar.k * a.dt, high=a.t[-1] - window_time, size=5)\n",
    "    else:\n",
    "        raise ValueError(f'Unknown group {group}')\n",
    "\n",
    "    n_steps = np.floor(window_time/a.dt).astype(int)\n",
    "    if estimated:\n",
    "        ret = []\n",
    "        for t in times:\n",
    "            try:\n",
    "                idx = int(np.nonzero(a.t == t)[0][0])\n",
    "            except IndexError:\n",
    "                idx = np.argmin(np.abs(a.t - t))\n",
    "            pre_trial = a.slice(idx-ar.k, idx)\n",
    "            new_stims = stims.slice(idx-ar.k, int(idx + window_time//a.dt))\n",
    "            starting_state = pre_trial[-ar.k:]\n",
    "            prediction = ar.predict(starting_state, new_stims, n_steps=n_steps)\n",
    "            ret.append(prediction[:,neuron])\n",
    "        ret = np.column_stack(ret)\n",
    "    else:\n",
    "        ret = []\n",
    "        for t in times:\n",
    "            ret.append(a.slice_by_time(t, t + window_time)[:n_steps,neuron])\n",
    "        ret = np.column_stack(ret)\n",
    "        \n",
    "    return ret\n",
    "\n",
    "def modulation_statistic(x):\n",
    "    differences = []\n",
    "    for split in range(6,20):\n",
    "        differences.append(x[split:].mean() - x[:split].mean())\n",
    "    return max(differences)\n",
    "\n",
    "\n",
    "def get_modulation(group:_groups='opto', neuron=0, estimated=False):\n",
    "    tgts = {'opto':targets, 'vis':angles}.get(group)\n",
    "    def f(g):\n",
    "        return max([modulation_statistic(get_responses(n=i, group=g, neuron=neuron, estimated=estimated)) for i in range(len(tgts))])\n",
    "    stat = f(group)\n",
    "    null_samples = []\n",
    "    for _ in range(200):\n",
    "        null_samples.append(f('rand'))\n",
    "    return (np.array(null_samples) < stat).mean()\n",
    "\n",
    "get_modulation(group='opto', neuron=31, estimated=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "modulations = []\n",
    "for neuron in tqdm(range(a.shape[1])):\n",
    "    modulations.append([])\n",
    "    for group in ['opto', 'vis']:\n",
    "        for estimated in [True, False]:\n",
    "            m = get_modulation(group=group, neuron=neuron, estimated=estimated)\n",
    "            modulations[-1].append(m)\n",
    "\n",
    "modulations = np.array(modulations)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({\n",
    "    'opto_est': modulations[:, 0],\n",
    "    'opto_real': modulations[:, 1],\n",
    "    'vis_est': modulations[:, 2],\n",
    "    'vis_real': modulations[:, 3],\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "plt.hist(df.opto_est, bins=20);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "fig, axs = plt.subplots(ncols=2, figsize=(10,5))\n",
    "axs[0].scatter(df.opto_real, df.vis_real)\n",
    "axs[0].set_xlabel('modulation to opto stimuli')\n",
    "axs[0].set_ylabel('modulation to visual stimuli')\n",
    "axs[0].set_title(\"'real' modulations\")\n",
    "\n",
    "axs[1].scatter(df.opto_est, df.vis_est)\n",
    "axs[1].set_xlabel('modulation to opto stimuli')\n",
    "axs[1].set_title('modulations from simulations')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "fig, ax = plt.subplots(figsize=(5,5))\n",
    "\n",
    "ax.scatter(df.opto_real, df.vis_real)\n",
    "ax.scatter(df.opto_est, df.vis_est)\n",
    "ax.plot(np.vstack([df.opto_real,df.opto_est]), np.vstack([df.vis_real,df.vis_est]), color='k', alpha=.1);\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15,3))\n",
    "ax.plot(a.t, a[:,np.argsort(df.opto_real)[:5]])\n",
    "for t in d.opto_stimulations.time:\n",
    "    ax.axvline(t, c='k')\n",
    "\n",
    "for t in d.visual_stimuli.time:\n",
    "    ax.axvline(t, c='gray')\n",
    "\n",
    "ax.set_xlim([890, 1510])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_errors(group:_groups='opto'):\n",
    "    l = []\n",
    "    tgts = {'opto':targets, 'vis':angles, 'rand': range(3*max(len(targets), len(angles)))}.get(group)\n",
    "\n",
    "    for neuron in range(a.shape[1]):\n",
    "        l.append([])\n",
    "        for n in range(len(tgts)):\n",
    "            estimate = get_responses(n, group=group, neuron=neuron, estimated=True) \n",
    "            observed = get_responses(n, group=group, neuron=neuron, estimated=False)\n",
    "            denominator = max(((observed.mean() - observed)**2).mean(), 0.01)\n",
    "            l[-1].append(((estimate - observed)**2).mean()/denominator)\n",
    "        \n",
    "    l = np.array(l)\n",
    "    l = np.log(l.mean(axis=1))\n",
    "    return l\n",
    "\n",
    "df['opto_errors'] = get_errors(group='opto')\n",
    "df['vis_errors']= get_errors(group='vis')\n",
    "df['rand_errors'] = get_errors(group='rand')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(df.rand_errors, df.opto_errors)\n",
    "plt.xlabel('error over random times')\n",
    "plt.ylabel('error for stim trials')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(df.opto_real, df.vis_real)\n",
    "plt.xlabel('modulation for optogenetic stimulations')\n",
    "plt.ylabel('prediction error for random times')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
