{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This Notebook provides a minimal example for using LFP to train a simple LeNet on MNIST.\n",
    "\n",
    "For more complex examples, refer to the experiment notebooks in ./nbs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lweber/.cache/pypoetry/virtualenvs/lfp-KukTaqIE-py3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as tnn\n",
    "import torcheval.metrics\n",
    "import torchvision.datasets as tvisiondata\n",
    "import torchvision.transforms as T\n",
    "from tqdm import tqdm\n",
    "\n",
    "from lfprop.propagation import (\n",
    "    propagator_lxt as propagator,\n",
    ")  # LFP propagator. Alternatively, use propagator_zennit\n",
    "from lfprop.rewards import reward_functions as rewards  # Reward Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "savepath = \"./minimal-example-data\"\n",
    "os.makedirs(savepath, exist_ok=True)\n",
    "\n",
    "batch_size = 128\n",
    "n_channels = 1\n",
    "n_outputs = 10\n",
    "\n",
    "lr = 0.1\n",
    "momentum = 0.9\n",
    "epochs = 10\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])\n",
    "training_data = tvisiondata.MNIST(\n",
    "    root=savepath,\n",
    "    transform=transform,\n",
    "    download=True,\n",
    "    train=True,\n",
    ")\n",
    "\n",
    "validation_data = tvisiondata.MNIST(\n",
    "    root=savepath,\n",
    "    transform=transform,\n",
    "    download=True,\n",
    "    train=False,\n",
    ")\n",
    "\n",
    "training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)\n",
    "validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LeNet(\n",
       "  (features): Sequential(\n",
       "    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (3): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))\n",
       "    (4): ReLU()\n",
       "    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=256, out_features=120, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Dropout(p=0.5, inplace=False)\n",
       "    (3): Linear(in_features=120, out_features=84, bias=True)\n",
       "    (4): ReLU()\n",
       "    (5): Dropout(p=0.5, inplace=False)\n",
       "    (6): Linear(in_features=84, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class LeNet(tnn.Module):\n",
    "    \"\"\"\n",
    "    Small LeNet\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, n_channels, n_outputs, activation=tnn.ReLU):\n",
    "        super().__init__()\n",
    "\n",
    "        # Feature extractor\n",
    "        self.features = tnn.Sequential(\n",
    "            tnn.Conv2d(n_channels, 16, 5),\n",
    "            activation(),\n",
    "            tnn.MaxPool2d(2, 2),\n",
    "            tnn.Conv2d(16, 16, 5),\n",
    "            activation(),\n",
    "            tnn.MaxPool2d(2, 2),\n",
    "        )\n",
    "\n",
    "        # Classifier\n",
    "        self.classifier = tnn.Sequential(\n",
    "            tnn.Linear(256 if n_channels == 1 else 400, 120),\n",
    "            activation(),\n",
    "            tnn.Dropout(),\n",
    "            tnn.Linear(120, 84),\n",
    "            activation(),\n",
    "            tnn.Dropout(),\n",
    "            tnn.Linear(84, n_outputs),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        forwards input through network\n",
    "        \"\"\"\n",
    "\n",
    "        # Forward through network\n",
    "        x = self.features(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.classifier(x)\n",
    "\n",
    "        # Return output\n",
    "        return x\n",
    "\n",
    "\n",
    "model = LeNet(\n",
    "    n_channels=n_channels,\n",
    "    n_outputs=n_outputs,\n",
    "    activation=torch.nn.ReLU,\n",
    ")\n",
    "\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Up LFP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the LFP Composite (cf. \"composites\" in zennit or lxt).\n",
    "# This call is the same whether the lxt or zennit backend is used (propagator_lxt and propagator_zennit).\n",
    "# Currently, only LFP-Epsilon is implemented. More composites may be added in the future.\n",
    "propagation_composite = propagator.LFPEpsilonComposite()\n",
    "\n",
    "# Initialize the Reward Function.\n",
    "# Here we use the Reward Function suggested in the LFP-Paper, but check out other reward functions in ./lfp/rewards/reward_functions.py\n",
    "reward_func = rewards.SoftmaxLossReward(device)\n",
    "\n",
    "# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Up Simple Evaluation using torcheval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_model(loader):\n",
    "    \"\"\"\n",
    "    Evaluates the model on a single dataset\n",
    "    \"\"\"\n",
    "    eval_metrics = {\n",
    "        \"reward\": torcheval.metrics.Mean(device=device),\n",
    "        \"accuracy\": torcheval.metrics.MulticlassAccuracy(average=\"micro\", num_classes=10, k=1, device=device),\n",
    "    }\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    # Iterate over Data Loader\n",
    "    for index, (inputs, labels) in enumerate(loader):\n",
    "        inputs = inputs.to(device)\n",
    "        labels = torch.tensor(labels).to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            # Get model predictions\n",
    "            outputs = model(inputs)\n",
    "\n",
    "        with torch.set_grad_enabled(True):\n",
    "            # Get rewards\n",
    "            reward = reward_func(outputs, labels)\n",
    "\n",
    "        for k, v in eval_metrics.items():\n",
    "            if k == \"reward\":\n",
    "                eval_metrics[k].update(reward)\n",
    "            else:\n",
    "                eval_metrics[k].update(outputs, labels)\n",
    "\n",
    "    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}\n",
    "\n",
    "    # Return evaluation\n",
    "    return return_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/469 [00:00<?, ?it/s]/tmp/ipykernel_126343/2783185813.py:51: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  labels = torch.tensor(labels).to(device)\n",
      "/home/lweber/.cache/pypoetry/virtualenvs/lfp-KukTaqIE-py3.11/lib/python3.11/site-packages/lxt/core.py:345: UserWarning: This functionality is not yet fully tested. Please check the model after removing the composite.\n",
      "  warn(\"This functionality is not yet fully tested. Please check the model after removing the composite.\")\n",
      "100%|██████████| 469/469 [00:18<00:00, 25.01it/s]\n",
      "/tmp/ipykernel_126343/1047920224.py:17: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  labels = torch.tensor(labels).to(device)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10: (Train Reward) -0.00; (Train Accuracy) 0.94; (Val Reward) 0.00; (Val Accuracy) 0.95\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:18<00:00, 25.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:18<00:00, 25.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/10: (Train Reward) 0.00; (Train Accuracy) 0.96; (Val Reward) 0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:18<00:00, 25.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:18<00:00, 25.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/10: (Train Reward) -0.00; (Train Accuracy) 0.95; (Val Reward) 0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:18<00:00, 24.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:18<00:00, 25.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:17<00:00, 26.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/10: (Train Reward) -0.00; (Train Accuracy) 0.95; (Val Reward) -0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:17<00:00, 26.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 469/469 [00:17<00:00, 27.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) 0.00; (Val Accuracy) 0.96\n"
     ]
    }
   ],
   "source": [
    "def lfp_step(inputs, labels):\n",
    "    \"\"\"\n",
    "    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.\n",
    "    \"\"\"\n",
    "    # Set Model to training mode\n",
    "    model.train()\n",
    "\n",
    "    with torch.enable_grad():\n",
    "        # Zero Optimizer\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # This applies LFP Hooks/Functions (which depends on whether lxt or zennit backend is used)\n",
    "        with propagation_composite.context(model) as modified:\n",
    "            inputs = inputs.detach().requires_grad_(True)\n",
    "            outputs = modified(inputs)\n",
    "\n",
    "            # Calculate reward\n",
    "            # Do like this to avoid tensors being kept in memory\n",
    "            reward = torch.from_numpy(reward_func(outputs, labels).detach().cpu().numpy()).to(device)\n",
    "\n",
    "            # Calculate LFP and write into .feedback attribute of parameters\n",
    "            torch.autograd.grad((outputs,), (inputs,), grad_outputs=(reward,), retain_graph=False)[0]\n",
    "\n",
    "            # Write LFP Values into .grad attributes. Note the negative sign: LFP requires maximization instead of minimization like gradient descent\n",
    "            for name, param in model.named_parameters():\n",
    "                param.grad = -param.feedback\n",
    "\n",
    "            # Update Clipping. Training may become unstable otherwise, especially in small models with large learning rates.\n",
    "            # In larger models (e.g., VGG, ResNet), where smaller learning rates are generally utilized, not clipping updates may result in better performance.\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)\n",
    "\n",
    "            # Optimization step\n",
    "            optimizer.step()\n",
    "\n",
    "    # Set Model back to eval mode\n",
    "    model.eval()\n",
    "\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(epochs):\n",
    "    with tqdm(total=len(training_loader)) as pbar:\n",
    "        # Iterate over Data Loader\n",
    "        for index, (inputs, labels) in enumerate(training_loader):\n",
    "            inputs = inputs.to(device)\n",
    "            labels = torch.tensor(labels).to(device)\n",
    "\n",
    "            # Perform Update Step\n",
    "            lfp_step(inputs, labels)\n",
    "\n",
    "            # Update Progress Bar\n",
    "            pbar.update(1)\n",
    "\n",
    "    # Evaluate and print performance after every epoch\n",
    "    eval_stats_train = eval_model(training_loader)\n",
    "    eval_stats_val = eval_model(validation_loader)\n",
    "    print(\n",
    "        \"Epoch {}/{}: (Train Reward) {:.2f}; (Train Accuracy) {:.2f}; (Val Reward) {:.2f}; (Val Accuracy) {:.2f}\".format(\n",
    "            epoch + 1,\n",
    "            epochs,\n",
    "            float(np.mean(eval_stats_train[\"reward\"])),\n",
    "            float(eval_stats_train[\"accuracy\"]),\n",
    "            float(np.mean(eval_stats_val[\"reward\"])),\n",
    "            float(eval_stats_val[\"accuracy\"]),\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lfp-KukTaqIE-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
