{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This Notebook provides a minimal example for using LFP to train a simple MLP-Spiking Neural Network (SNN) on MNIST.\n",
    "\n",
    "For more complex examples, refer to the experiment notebooks in ./nbs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import snntorch as snn\n",
    "    from snntorch import utils as snnutils\n",
    "except ImportError:\n",
    "    print(\n",
    "        \"The SNN functionality of this package requires extra dependencies \",\n",
    "        \"which can be installed via pip install lfprop[snn] (or lfprop[full] for all dependencies).\",\n",
    "    )\n",
    "    raise ImportError(\"snntorch required; reinstall lfprop with option `snn` (pip install lfprop[snn])\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/berend/code/repos/layerwise-feedback-propagation/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torcheval.metrics\n",
    "import torchvision.datasets as tvisiondata\n",
    "import torchvision.transforms as T\n",
    "from tqdm import tqdm\n",
    "\n",
    "from lfprop.rewards import reward_functions as rewards  # Reward Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "savepath = \"./minimal-example-data\"\n",
    "os.makedirs(savepath, exist_ok=True)\n",
    "\n",
    "batch_size = 128  # 128\n",
    "n_channels = 784\n",
    "n_outputs = 10\n",
    "n_steps = 15\n",
    "lr = 0.02\n",
    "momentum = 0.9\n",
    "epochs = 3\n",
    "model_name = \"smalllifmlp\"\n",
    "lif_kwargs = {\"beta\": 0.9, \"minmem\": None}\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])\n",
    "training_data = tvisiondata.MNIST(\n",
    "    root=savepath,\n",
    "    transform=transform,\n",
    "    download=True,\n",
    "    train=True,\n",
    ")\n",
    "\n",
    "validation_data = tvisiondata.MNIST(\n",
    "    root=savepath,\n",
    "    transform=transform,\n",
    "    download=True,\n",
    "    train=False,\n",
    ")\n",
    "\n",
    "# [DEBUG] overfit to a small dataset\n",
    "# training_data = torch.utils.data.Subset(training_data, list(range(0, len(training_data) // 2)))\n",
    "# validation_data = torch.utils.data.Subset(validation_data, list(range(0, 10)) * 100)\n",
    "\n",
    "training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)\n",
    "validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SmallLifMLP(\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=784, out_features=1000, bias=True)\n",
       "    (1): CustomLeaky()\n",
       "    (2): Linear(in_features=1000, out_features=10, bias=True)\n",
       "    (3): CustomLeaky()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from lfprop.model.spiking_networks import get_model\n",
    "\n",
    "model = get_model(model_name=model_name, n_channels=n_channels, n_outputs=n_outputs, device=device, **lif_kwargs)\n",
    "model.reset()\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Up LFP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the SNN-Propagator\n",
    "from lfprop.propagation.propagator_snn import LRPRewardPropagator\n",
    "\n",
    "snn_propagator = LRPRewardPropagator(model, norm_backward=True)\n",
    "\n",
    "# Initialize the Reward Function.\n",
    "reward_func = rewards.SnnCorrectClassRewardSpikesRateCoded(device)\n",
    "\n",
    "# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Up Simple Evaluation using torcheval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_model(loader, n_steps: int = 15):\n",
    "    \"\"\"\n",
    "    Evaluates the model on a single dataset\n",
    "    \"\"\"\n",
    "    eval_metrics = {\n",
    "        \"reward\": torcheval.metrics.Mean(device=device),\n",
    "        \"accuracy\": torcheval.metrics.MulticlassAccuracy(average=\"micro\", num_classes=10, k=1, device=device),\n",
    "    }\n",
    "\n",
    "    model.eval()\n",
    "    model.reset()\n",
    "\n",
    "    # Iterate over Data Loader\n",
    "    for index, (inputs, labels) in tqdm(enumerate(loader), desc=\"Evaluating\", total=len(loader)):\n",
    "        inputs = inputs.to(device)\n",
    "        labels = (labels).to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            # Get model predictions\n",
    "            u_rec, spk_rec = [], []\n",
    "            for step in tqdm(range(n_steps), disable=True):  # [ ] move this into the fwd method of the model?\n",
    "                y = model(inputs)\n",
    "                spk_out, u_out = y\n",
    "                u_rec.append(u_out)\n",
    "                spk_rec.append(spk_out)\n",
    "\n",
    "            spikes = torch.stack(spk_rec, dim=0)\n",
    "            membrane_potential = torch.stack(u_rec, dim=0)\n",
    "\n",
    "            # Get rewards\n",
    "            reward = reward_func(spikes=spikes, potentials=membrane_potential, labels=labels)\n",
    "            outputs = reward_func.get_predictions(spikes=spikes, potentials=membrane_potential)\n",
    "\n",
    "        for k, v in eval_metrics.items():\n",
    "            if k == \"reward\":\n",
    "                eval_metrics[k].update(reward)\n",
    "            else:\n",
    "                eval_metrics[k].update(outputs, labels)\n",
    "\n",
    "    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}\n",
    "    model.reset()\n",
    "    # Return evaluation\n",
    "    return return_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [01:16<00:00,  6.10it/s]\n",
      "Evaluating: 100%|██████████| 469/469 [00:15<00:00, 29.71it/s]\n",
      "Evaluating: 100%|██████████| 79/79 [00:02<00:00, 29.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3: (Train Reward) 0.00; (Train Accuracy) 0.87; (Val Reward) -0.00; (Val Accuracy) 0.88\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [02:17<00:00,  3.42it/s]\n",
      "Evaluating: 100%|██████████| 469/469 [00:15<00:00, 29.92it/s]\n",
      "Evaluating: 100%|██████████| 79/79 [00:02<00:00, 29.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/3: (Train Reward) -0.00; (Train Accuracy) 0.89; (Val Reward) -0.00; (Val Accuracy) 0.89\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [03:26<00:00,  2.27it/s]\n",
      "Evaluating: 100%|██████████| 469/469 [00:15<00:00, 29.62it/s]\n",
      "Evaluating: 100%|██████████| 79/79 [00:02<00:00, 29.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/3: (Train Reward) 0.00; (Train Accuracy) 0.89; (Val Reward) 0.00; (Val Accuracy) 0.90\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from lfprop.model.spiking_networks import clip_gradients\n",
    "\n",
    "\n",
    "def lfp_step(inputs, labels, n_steps: int = 15):\n",
    "    \"\"\"\n",
    "    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.\n",
    "    \"\"\"\n",
    "    # Set Model to training mode\n",
    "    model.train()\n",
    "    model.reset()\n",
    "\n",
    "    with torch.enable_grad():\n",
    "        # Zero Optimizer\n",
    "        optimizer.zero_grad()\n",
    "        # Get model predictions\n",
    "        u_rec, spk_rec = [], []\n",
    "        for step in tqdm(range(n_steps), disable=True):  # [ ] move this into the fwd method of the model?\n",
    "            y = model(inputs)\n",
    "            spk_out, u_out = y\n",
    "            u_rec.append(u_out)\n",
    "            spk_rec.append(spk_out)\n",
    "\n",
    "        spikes = torch.stack(spk_rec, dim=0)\n",
    "        membrane_potential = torch.stack(u_rec, dim=0)\n",
    "\n",
    "        reward = reward_func(spikes=spikes, potentials=membrane_potential, labels=labels)\n",
    "        reward_func.get_predictions(spikes=spikes, potentials=membrane_potential)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # go backwards through sequence and write reward into accumulated_feedback param attr\n",
    "        for step in range(n_steps):\n",
    "            snn_propagator.propagate(iteration_feedback=reward[-(step + 1)], iteration_idx=step)\n",
    "\n",
    "    for name, param in model.named_parameters():\n",
    "        if hasattr(param, \"accumulated_feedback\"):\n",
    "            # overwrite grad with lfp-signal\n",
    "            param.grad = -param.accumulated_feedback\n",
    "        else:\n",
    "            print(\"!\", name)  # [ ] check or remove\n",
    "\n",
    "    # clip feedback to avoid exploding gradients\n",
    "    clip_gradients(model, True, 0.6)\n",
    "\n",
    "    # update parameters\n",
    "    optimizer.step()\n",
    "\n",
    "    # Set Model back to eval mode\n",
    "    snn_propagator.reset()  # deletes stored feedback [ ] unify this somehow\n",
    "    model.reset()  # necessary to free the internal state of the model\n",
    "    model.eval()\n",
    "\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(epochs):\n",
    "    with tqdm(total=len(training_loader)) as pbar:\n",
    "        # Iterate over Data Loader\n",
    "        for index, (inputs, labels) in enumerate(training_loader):\n",
    "            inputs = inputs.to(device)\n",
    "            labels = (labels).to(device)\n",
    "\n",
    "            # Perform Update Step\n",
    "            lfp_step(inputs, labels, n_steps=n_steps)\n",
    "\n",
    "            # Update Progress Bar\n",
    "            pbar.update(1)\n",
    "            # if index >= 30:\n",
    "            # break\n",
    "\n",
    "    # Evaluate and print performance after every epoch\n",
    "    eval_stats_train = eval_model(training_loader, n_steps=n_steps)\n",
    "    eval_stats_val = eval_model(validation_loader, n_steps=n_steps)\n",
    "    print(\n",
    "        \"Epoch {}/{}: (Train Reward) {:.2f}; (Train Accuracy) {:.2f}; (Val Reward) {:.2f}; (Val Accuracy) {:.2f}\".format(\n",
    "            epoch + 1,\n",
    "            epochs,\n",
    "            float(np.mean(eval_stats_train[\"reward\"])),\n",
    "            float(eval_stats_train[\"accuracy\"]),\n",
    "            float(np.mean(eval_stats_val[\"reward\"])),\n",
    "            float(eval_stats_val[\"accuracy\"]),\n",
    "        )\n",
    "    )\n",
    "\n",
    "# training takes approx. 5 min"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
