{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "from AbstractModels.SpikingModel import SpikingModel\n",
    "from AbstractModels.util.decode import decode_mean\n",
    "\n",
    "from SNN.Encoders import CopyEncoder\n",
    "from SNN.models.classification import ITLIFResNet19\n",
    "\n",
    "from torchvision.datasets import CIFAR100, CIFAR10\n",
    "\n",
    "import torch\n",
    "\n",
    "from torchvision.transforms import v2 as transforms\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from SNN.util.energy_consumption import approximate_energy_consumption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ITLIFResNet19(\n",
    "    encoder=CopyEncoder(seq_length=2),\n",
    "    decoder=decode_mean,\n",
    "    num_classes=10\n",
    ")\n",
    "\n",
    "dataset = CIFAR10(\n",
    "    root='../../data/',\n",
    "    train=False,\n",
    "    download=True,\n",
    "    # CIFAR-10\n",
    "    transform=transforms.Compose([\n",
    "        transforms.ToImage(),\n",
    "        transforms.ToDtype(torch.float32, scale=True),\n",
    "        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])\n",
    "    ])\n",
    "    # CIFAR-100\n",
    "    # transform=transforms.Compose([\n",
    "    #     transforms.ToImage(),\n",
    "    #     transforms.ToDtype(torch.float, scale=True),\n",
    "    #     transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])\n",
    "    # ])\n",
    ")\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR = ''\n",
    "LNM_weights = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_spike_rate(model: SpikingModel, dataloader, device) -> list:\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    for data, _ in dataloader:\n",
    "        data = data.float().to(device)\n",
    "        model(data)\n",
    "    \n",
    "    return np.array(model.get_spike_rate()) / 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_by_path(f'{DIR}{LNM_weights}/{LNM_weights}', None, None)\n",
    "lif_spike_rate = get_spike_rate(model, dataloader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lif_energy = approximate_energy_consumption(model, dataloader, 2, lif_spike_rate)\n",
    "print(f'LIF energy: {lif_energy} uJ')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SNN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
