{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "930f56f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from tqdm import trange\n",
    "\n",
    "sys.path.insert(1, \"/\".join(os.path.abspath(\"\").split(\"/\")[0:-1]))\n",
    "\n",
    "import models\n",
    "\n",
    "from importlib import reload\n",
    "\n",
    "reload(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5973e7c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_vis_neurons = 3\n",
    "n_hid_neurons = 2\n",
    "kernel_size = 3\n",
    "n_time_bins = 100\n",
    "n_samples_train = 40\n",
    "n_samples_test = 20\n",
    "n_trials = 10\n",
    "n_neurons = n_vis_neurons + n_hid_neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a77b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(\n",
    "    index=np.arange(n_trials),\n",
    "    columns=[\n",
    "        \"model\",\n",
    "        \"y_train\",\n",
    "        \"rate_train\",\n",
    "        \"y_test\",\n",
    "        \"rate_test\",\n",
    "    ],\n",
    ")\n",
    "\n",
    "for trial in trange(n_trials):\n",
    "    torch.manual_seed(trial)\n",
    "    poglm = models.POGLM(\n",
    "        n_vis_neurons=n_vis_neurons,\n",
    "        n_hid_neurons=n_hid_neurons,\n",
    "        kernel_size=kernel_size,\n",
    "    )\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        poglm.conv_generative.data = (\n",
    "            torch.randn(poglm.conv_generative.weight.shape) * 0.5\n",
    "        )\n",
    "        poglm.conv_generative.bias.data = (\n",
    "            torch.randn(poglm.conv_generative.bias.shape) * 0.5\n",
    "        )\n",
    "\n",
    "    rate_train, y_train = poglm.sample(n_time_bins, n_samples_train)\n",
    "    rate_test, y_test = poglm.sample(n_time_bins, n_samples_test)\n",
    "\n",
    "    df.at[trial, \"model\"] = poglm.state_dict()\n",
    "    df.at[trial, \"y_train\"] = y_train\n",
    "    df.at[trial, \"rate_train\"] = rate_train\n",
    "    df.at[trial, \"y_test\"] = y_test\n",
    "    df.at[trial, \"rate_test\"] = rate_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bfd6438",
   "metadata": {},
   "outputs": [],
   "source": [
    "trial = 1\n",
    "sample = 2\n",
    "\n",
    "rate_train = df.at[trial, \"rate_train\"]\n",
    "y_train = df.at[trial, \"y_train\"]\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    n_neurons,\n",
    "    1,\n",
    "    figsize=(8, n_neurons),\n",
    "    layout=\"constrained\",\n",
    "    sharex=True,\n",
    "    sharey=True,\n",
    ")\n",
    "for i in range(n_neurons):\n",
    "    axs[i].plot(rate_train[sample, :, i].numpy(), color=\"C0\")\n",
    "    axs[i].plot(\n",
    "        y_train[sample, :, i].numpy(),\n",
    "        color=\"C1\",\n",
    "    )\n",
    "    axs[i].set_title(f\"Neuron {i+1}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4a6fdfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, kernel_size, figsize=(12, 4), layout=\"constrained\")\n",
    "\n",
    "for k in range(kernel_size):\n",
    "    axs[k].matshow(\n",
    "        poglm.conv_generative.weight[:, :, k].detach().numpy(),\n",
    "        cmap=\"seismic\",\n",
    "        vmin=-0.5,\n",
    "        vmax=0.5,\n",
    "    )\n",
    "    axs[k].set_title(f\"Kernel slice {k+1}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "679b3cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_pickle(\"data.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "308c16f6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "csai0",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
