{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "from AbstractModels.SpikingModel import SpikingModel\n",
    "from AbstractModels.util.decode import decode_mean\n",
    "\n",
    "from SNN.Encoders import CopyEncoder\n",
    "from SNN.models.classification import ITLIFResNet34\n",
    "from SNN.Layers import LNM\n",
    "from SNN.LearnableMembrane import LearnableMembrane\n",
    "\n",
    "import torch\n",
    "from torch import distributed as dist\n",
    "\n",
    "from torchvision.transforms import v2 as transforms\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "import numpy as np\n",
    "from scipy.stats import norm\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import os\n",
    "\n",
    "from SNN.util.energy_consumption import approximate_energy_consumption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ITLIFResNet34(\n",
    "    encoder=CopyEncoder(seq_length=4),\n",
    "    decoder=decode_mean,\n",
    "    num_classes=1000\n",
    ")\n",
    "\n",
    "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
    "os.environ[\"MASTER_PORT\"] = \"12355\"\n",
    "\n",
    "torch.cuda.set_device(0)\n",
    "dist.init_process_group(\"nccl\", world_size=1, rank=0)\n",
    "rank = dist.get_rank()\n",
    "\n",
    "model.cuda()\n",
    "\n",
    "model.model = torch.nn.parallel.DistributedDataParallel(model.model, device_ids=[rank])\n",
    "\n",
    "dataset = load_dataset(\"imagenet-1k\", split='validation', streaming=True)\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToImage(),\n",
    "    transforms.ToDtype(torch.float, scale=True),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "def apply_transform(example):\n",
    "    example['image'] = transform(example['image'])\n",
    "    return example\n",
    "\n",
    "dataset = dataset.map(apply_transform)\n",
    "\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR = ''\n",
    "LNM_weights = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_spike_rate(model: SpikingModel, dataloader, device) -> list:\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    for data in dataloader:\n",
    "        \n",
    "        data = data['image'].float().to(device)\n",
    "        model(data)\n",
    "    \n",
    "    return np.array(model.get_spike_rate()) / 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_by_path(f\"{DIR}{LNM_weights}/{LNM_weights}\", None, None)\n",
    "lnm_spike_rate = get_spike_rate(model, dataloader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lnm_energy = approximate_energy_consumption(model, dataloader, 10, lnm_spike_rate)\n",
    "\n",
    "print(f'LNM energy: {lnm_energy} uJ')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SNN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
