{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "from AbstractModels.util.decode import decode_mean\n",
    "\n",
    "from SNN.Encoders import CopyEncoder\n",
    "from SNN.models.classification import ITLIFResNet34\n",
    "from SNN.Layers import LNM\n",
    "from SNN.Layers import NeuronConfig\n",
    "from SNN.LearnableMembrane import LearnableMembrane\n",
    "\n",
    "import torch\n",
    "from torch import distributed as dist\n",
    "\n",
    "from torchvision.transforms import v2 as transforms\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "import numpy as np\n",
    "from scipy.stats import norm\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ITLIFResNet34(\n",
    "    encoder=CopyEncoder(seq_length=4),\n",
    "    decoder=decode_mean,\n",
    "    num_classes=1000\n",
    ")\n",
    "\n",
    "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
    "os.environ[\"MASTER_PORT\"] = \"12355\"\n",
    "\n",
    "torch.cuda.set_device(0)\n",
    "dist.init_process_group(\"nccl\", world_size=1, rank=0)\n",
    "rank = dist.get_rank()\n",
    "\n",
    "model.cuda()\n",
    "\n",
    "model.model = torch.nn.parallel.DistributedDataParallel(model.model, device_ids=[rank])\n",
    "\n",
    "dataset = load_dataset(\"imagenet-1k\", split='validation', streaming=True)\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToImage(),\n",
    "    transforms.ToDtype(torch.float, scale=True),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "def apply_transform(example):\n",
    "    example['image'] = transform(example['image'])\n",
    "    return example\n",
    "\n",
    "dataset = dataset.map(apply_transform)\n",
    "\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR = ''\n",
    "\n",
    "LNM_weights = ''\n",
    "model.load_by_path(f\"{DIR}{LNM_weights}/{LNM_weights}\", None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "membrane_updates = [module for module in model.modules() if isinstance(module, LearnableMembrane)]\n",
    "neuron_models = [module for module in model.modules() if isinstance(module, LNM)]\n",
    "plots = []\n",
    "for membrane_update in membrane_updates:\n",
    "    plots.append(membrane_update.plot())\n",
    "    membrane_update.print_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rows = int(np.ceil(len(plots) ** 0.5))\n",
    "# cols = int(np.ceil(len(plots) / rows))\n",
    "rows = 6\n",
    "cols = 6\n",
    "fig, axes = plt.subplots(rows, cols, figsize=(12, 10))\n",
    "for i, plot in enumerate(plots):\n",
    "    ax = axes[i // cols, i % cols]\n",
    "    ax.plot(plot[0], plot[1])\n",
    "    ax.set_xticks([-1.0, 0.0, 1.0])\n",
    "    ax.set_title(f'Layer {i + 1}')\n",
    "    if i % cols == 0:\n",
    "        ax.set_ylabel(r'$f_{\\theta}(u(t))$', fontsize=12)\n",
    "    # ax.set_ylabel(\"u(t+1)\")\n",
    "    if i in [30, 31, 32]:\n",
    "        ax.set_xlabel(r'$u(t)$', fontsize=12)\n",
    "\n",
    "axes.flatten()[-1].set_xlabel(r'$u(t)$', fontsize=12)\n",
    "axes.flatten()[-2].set_xlabel(r'$u(t)$', fontsize=12)\n",
    "axes.flatten()[-3].set_xlabel(r'$u(t)$', fontsize=12)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "fig.savefig(f\"./membrane_updates.pdf\", dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SNN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
