{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f101b2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable\n",
    "import pandas as pd\n",
    "from scipy.io import loadmat\n",
    "from scipy.stats import binned_statistic\n",
    "\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import model, utils, inference\n",
    "\n",
    "from importlib import reload\n",
    "reload(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ac6c35c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data():\n",
    "    temp = loadmat(f'SpTimesRGC.mat', squeeze_me=False, struct_as_record=False)['SpTimes'][0]\n",
    "    n_time_bins = 20 * 60 * 120 # 20 min * 119.9820 Hz\n",
    "    time_bins = np.linspace(1, n_time_bins, n_time_bins)\n",
    "    n_neurons = 27\n",
    "    spikes = np.zeros((n_time_bins, n_neurons))\n",
    "    for i in range(n_neurons):\n",
    "        spikes[:, i] = binned_statistic(temp[i][:, 0], None, bins=np.hstack(([0], time_bins)), statistic='count')[0].T\n",
    "    return spikes\n",
    "\n",
    "spikes = load_data()\n",
    "\n",
    "## hyper-parameters\n",
    "decay = 5\n",
    "# dt = 1/120\n",
    "dt = 0.05\n",
    "window_size = 1\n",
    "n_vis_neurons = spikes.shape[1]\n",
    "n_neurons = n_vis_neurons\n",
    "basis = utils.exp_basis(decay, window_size, dt*window_size)\n",
    "\n",
    "\n",
    "vis_spikes_list_train, vis_spikes_list_test = torch.from_numpy(spikes[:96000].reshape(960, 100, -1)).to(torch.float32), torch.from_numpy(spikes[96000:].reshape(480, 100, -1)).to(torch.float32)\n",
    "convolved_vis_spikes_list_train = utils.convolve_spikes_with_basis(vis_spikes_list_train, basis)\n",
    "convolved_vis_spikes_list_test = utils.convolve_spikes_with_basis(vis_spikes_list_test, basis)\n",
    "train_dataset = TensorDataset(vis_spikes_list_train, convolved_vis_spikes_list_train)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "inf_model = model.POGLM(n_neurons, n_vis_neurons, basis)\n",
    "with torch.no_grad():\n",
    "    inf_model.linear.weight.data = torch.zeros((n_neurons, n_neurons))\n",
    "    inf_model.linear.bias.data = torch.zeros((n_neurons, ))\n",
    "    \n",
    "inf_optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.01)\n",
    "\n",
    "\n",
    "n_epochs = 10\n",
    "print_freq = 1\n",
    "\n",
    "epoch_loss_list = torch.zeros(n_epochs)\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    for spikes_list, convolved_spikes_list in train_dataloader:\n",
    "        batch_size = spikes_list.shape[0]\n",
    "        loss = 0\n",
    "        for sample in range(batch_size):\n",
    "            spikes = spikes_list[sample]\n",
    "            convolved_spikes = convolved_spikes_list[sample]\n",
    "            \n",
    "            hid_spikes_list = spikes[None, :, n_vis_neurons:]\n",
    "            convolved_hid_spikes_list = convolved_spikes[None, :, n_vis_neurons:]\n",
    "            vis_spikes = spikes[:, :n_vis_neurons]\n",
    "            convolved_vis_spikes = convolved_spikes[:, :n_vis_neurons]\n",
    "            loss -= inf_model.complete_log_likelihood(hid_spikes_list, convolved_hid_spikes_list, vis_spikes, convolved_vis_spikes)[0]\n",
    "        \n",
    "        loss /= batch_size\n",
    "        loss.backward()\n",
    "        inf_optimizer.step()\n",
    "        inf_optimizer.zero_grad()\n",
    "        \n",
    "        epoch_loss_list[epoch] += loss.item()\n",
    "    epoch_loss_list[epoch] /= len(train_dataloader)\n",
    "    \n",
    "    if epoch % print_freq == 0:\n",
    "        with torch.no_grad():\n",
    "            print(epoch, epoch_loss_list[epoch], flush=True)\n",
    "\n",
    "            \n",
    "def evaluate_rgc_0(inf_model, spikes_list, convolved_spikes_list, seed: int = 0):\n",
    "    n_samples = spikes_list.shape[0]\n",
    "    df = pd.DataFrame(index=np.arange(n_samples), columns=['marginal log-likelihood', 'ELBO'])\n",
    "    \n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for sample in range(n_samples):\n",
    "            spikes = spikes_list[sample]\n",
    "            convolved_spikes = convolved_spikes_list[sample]\n",
    "            \n",
    "            hid_spikes_list = spikes[None, :, n_vis_neurons:]\n",
    "            convolved_hid_spikes_list = convolved_spikes[None, :, n_vis_neurons:]\n",
    "            vis_spikes = spikes[:, :n_vis_neurons]\n",
    "            convolved_vis_spikes = convolved_spikes[:, :n_vis_neurons]\n",
    "            df.at[sample, 'marginal log-likelihood'] = inf_model.complete_log_likelihood(hid_spikes_list, convolved_hid_spikes_list, vis_spikes, convolved_vis_spikes)[0]\n",
    "            df.at[sample, 'ELBO'] = np.nan\n",
    "            \n",
    "    return df\n",
    "\n",
    "\n",
    "df = evaluate_rgc_0(inf_model, vis_spikes_list_test, convolved_vis_spikes_list_test).mean().to_frame().T\n",
    "df['time'] = np.nan\n",
    "df.to_csv(f'csv/rgc_0.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-cli726]",
   "language": "python",
   "name": "conda-env-.conda-cli726-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
