{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "from AbstractModels.SpikingModel import SpikingModel\n",
    "from AbstractModels.util.decode import decode_mean\n",
    "\n",
    "from SNN.Encoders import IdentityEncoder\n",
    "from SNN.models.classification import LNMResNet19\n",
    "\n",
    "from Datasets.CIFAR10DVS import CIFAR10DVS\n",
    "\n",
    "import torch\n",
    "\n",
    "from torchvision.transforms import v2 as transforms\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from SNN.util.energy_consumption import approximate_energy_consumption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LNMResNet19(\n",
    "    encoder=IdentityEncoder(seq_length=10),\n",
    "    decoder=decode_mean,\n",
    "    num_classes=10\n",
    ")\n",
    "\n",
    "dataset = CIFAR10DVS(\n",
    "    root='../../data/',\n",
    "    train=False,\n",
    "    transform=True\n",
    ")\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR = ''\n",
    "LNM_weights = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_spike_rate(model: SpikingModel, dataloader, device) -> list:\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    for data, _ in dataloader:\n",
    "        data = data.float().to(device)\n",
    "        model(data)\n",
    "    \n",
    "    return np.array(model.get_spike_rate()) / 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_by_path(f'{DIR}{LNM_weights}/{LNM_weights}', None, None)\n",
    "lnm_spike_rate = get_spike_rate(model, dataloader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lnm_energy = approximate_energy_consumption(model, dataloader, 10, lnm_spike_rate)\n",
    "print(f'LNM energy: {lnm_energy} uJ')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SNN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
