{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "04e3cebb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-25T22:19:09.287250Z",
     "start_time": "2023-09-25T22:15:23.222612Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'inference' from '/storage/coda1/p-awu36/0/cli726/vis/pohawkes/inference.py'>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from tqdm import trange\n",
    "\n",
    "import model, utils, inference\n",
    "\n",
    "from importlib import reload\n",
    "reload(model)\n",
    "reload(utils)\n",
    "reload(inference)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dabf85cf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-25T22:19:30.262392Z",
     "start_time": "2023-09-25T22:19:30.258606Z"
    }
   },
   "outputs": [],
   "source": [
    "decay = 5\n",
    "dt = 0.05\n",
    "window_size = 5\n",
    "n_neurons = 5\n",
    "n_vis_neurons = 3\n",
    "basis = utils.exp_basis(decay, window_size, dt*window_size)\n",
    "T = 5\n",
    "n_time_bins = int(T/dt)\n",
    "n_samples = 60 # same number of sequences for train and test\n",
    "n_samples_train = 40\n",
    "n_trials = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "19c6b120",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-25T22:28:34.110727Z",
     "start_time": "2023-09-25T22:28:32.126091Z"
    }
   },
   "outputs": [],
   "source": [
    "df = pd.DataFrame(index=np.arange(n_trials), columns=['gen_model', 'spikes_list_train', 'convolved_spikes_list_train', 'firing_rates_list_train', 'spikes_list_test', 'convolved_spikes_list_test', 'firing_rates_list_test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "274cd348",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-25T22:28:35.558599Z",
     "start_time": "2023-09-25T22:28:35.295621Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 41.05it/s]\n"
     ]
    }
   ],
   "source": [
    "for trial in trange(n_trials):\n",
    "    torch.manual_seed(trial)\n",
    "    \n",
    "    gen_model = model.POGLM(n_neurons, n_vis_neurons, basis)\n",
    "    with torch.no_grad():\n",
    "        gen_model.linear.weight.data = torch.randn((n_neurons, n_neurons)) * 0.1\n",
    "        gen_model.bias = torch.randn(n_neurons) * 0.1\n",
    "    \n",
    "    df.at[trial, 'gen_model'] = gen_model.state_dict()\n",
    "    \n",
    "    spikes_list, convolved_spikes_list, firing_rates_list = gen_model.sample(n_time_bins, n_samples)\n",
    "    df.at[trial, 'spikes_list_train'] = spikes_list[:n_samples_train]\n",
    "    df.at[trial, 'convolved_spikes_list_train'] = convolved_spikes_list[:n_samples_train]\n",
    "    df.at[trial, 'firing_rates_list_train'] = firing_rates_list[:n_samples_train]\n",
    "    df.at[trial, 'spikes_list_test'] = spikes_list[n_samples_train:]\n",
    "    df.at[trial, 'convolved_spikes_list_test'] = convolved_spikes_list[n_samples_train:]\n",
    "    df.at[trial, 'firing_rates_list_test'] = firing_rates_list[n_samples_train:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "883c02b2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-25T22:28:37.734301Z",
     "start_time": "2023-09-25T22:28:37.662962Z"
    }
   },
   "outputs": [],
   "source": [
    "df.to_pickle('data.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c035188",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-cli726]",
   "language": "python",
   "name": "conda-env-.conda-cli726-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
