{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "from AbstractModels.util.decode import decode_mean\n",
    "\n",
    "from SNN.Encoders import CopyEncoder\n",
    "from SNN.models.classification import ITLIFResNet19\n",
    "from SNN.Layers import LNM\n",
    "from SNN.Layers import NeuronConfig\n",
    "from SNN.LearnableMembrane import LearnableMembrane\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "from scipy.stats import norm\n",
    "\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ITLIFResNet19(\n",
    "    encoder=CopyEncoder(seq_length=2),\n",
    "    decoder=decode_mean,\n",
    "    num_classes=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR = ''\n",
    "\n",
    "LNM_weights = ''\n",
    "model.load_by_path(f\"{DIR}{LNM_weights}/{LNM_weights}\", None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "membrane_updates = [module for module in model.modules() if isinstance(module, LearnableMembrane)]\n",
    "neuron_models = [module for module in model.modules() if isinstance(module, LNM)]\n",
    "plots = []\n",
    "for membrane_update in membrane_updates:\n",
    "    plots.append(membrane_update.plot())\n",
    "    membrane_update.print_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rows = int(np.ceil(len(plots) ** 0.5))\n",
    "# cols = int(np.ceil(len(plots) / rows))\n",
    "rows = 3\n",
    "cols = 6\n",
    "fig, axes = plt.subplots(rows, cols, figsize=(12, 5))\n",
    "for i, plot in enumerate(plots):\n",
    "    ax = axes[i // cols, i % cols]\n",
    "    ax.plot(plot[0], plot[1])\n",
    "    ax.set_xticks([-1.0, 0.0, 1.0])\n",
    "    ax.set_title(f'Layer {i + 1}')\n",
    "    if i % cols == 0:\n",
    "        ax.set_ylabel(r'$f_{\\theta}(u(t))$', fontsize=12)\n",
    "    # ax.set_ylabel(\"u(t+1)\")\n",
    "    if i // cols == rows - 1:\n",
    "        ax.set_xlabel(r'$u(t)$', fontsize=12)\n",
    "\n",
    "axes.flatten()[-1].set_xlabel(r'$u(t)$', fontsize=12)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "fig.savefig(f\"./membrane_updates.pdf\", dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SNN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
