{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f28edf0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import time\n",
    "\n",
    "from FIT_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b607667f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get cpu or gpu device\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49e560fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Grab the required dataset - can be changed as required:\n",
    "\n",
    "def get_imnet_loaders(train_batch_size=200, imsize=224):\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        torchvision.datasets.ImageFolder(\n",
    "            root='path/to/dataset',\n",
    "            transform=transforms.Compose([\n",
    "                    transforms.Resize(imsize),\n",
    "                    transforms.CenterCrop(imsize),\n",
    "                    transforms.ToTensor(),\n",
    "                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "                ])),\n",
    "        batch_size=train_batch_size,\n",
    "        shuffle=True,\n",
    "        num_workers=4)\n",
    "\n",
    "    return train_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edc53c09",
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmarking_Hess(model, device, params, criterion, data_loader, iterations, datapoints_per_iteration):\n",
    "    ''' Used to generate the convergence statistics for the Hessian\n",
    "    Args:\n",
    "        model\n",
    "        device\n",
    "        params - list of accumulated model parameters to generate placeholders\n",
    "        criterion - loss function used to compute the Hessian\n",
    "        data_loader\n",
    "        iterations - total number of Estimator iterations to sum over\n",
    "        datapoints_per_iteration - min(batch size, datapoints_per_iteration) used to compute each estimate\n",
    "    Returns:\n",
    "        H_average - computed Hessian trace\n",
    "        estimator_accumulation - accumulation of each individual hutchinson estimator\n",
    "        estimator_mean_accumulation - accumulation of the Hutchinson trace estimator over iterations\n",
    "    '''\n",
    "    \n",
    "    ## Defines the Rademacher generation\n",
    "    def rademacher():\n",
    "        v = [torch.randint_like(p, high=2, device=device) for p in params]\n",
    "        for v_i in v:\n",
    "            v_i[v_i == 0] = -1\n",
    "        return v\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    # accumulate hutchinson estimator\n",
    "    estimator_accumulation = []\n",
    "    estimator_mean_accumulation = []\n",
    "\n",
    "    iteration = 0\n",
    "    iteration_batch = 0\n",
    "    \n",
    "    v = rademacher()\n",
    "    \n",
    "    while(iteration < iterations):\n",
    "\n",
    "        THv = [torch.zeros(p.size()).to(device) for p in params]\n",
    "        \n",
    "        for i, data in enumerate(data_loader, 1):\n",
    "            model.zero_grad()\n",
    "            \n",
    "            inputs, labels = data[0].to(device), data[1].to(device)\n",
    "            batch_size = inputs.size(0)\n",
    "            \n",
    "            outputs = model(inputs)\n",
    "            \n",
    "            loss = criterion(outputs, labels)\n",
    "            \n",
    "            loss.backward(create_graph=True)\n",
    "\n",
    "            paramsH = []\n",
    "            gradsH = []\n",
    "            for paramH in model.parameters():\n",
    "                if not paramH.collect:\n",
    "                    continue\n",
    "                paramsH.append(paramH)\n",
    "                gradsH.append(0. if paramH.grad is None else paramH.grad + 0.)\n",
    "            \n",
    "            Hv = torch.autograd.grad(gradsH, paramsH, grad_outputs=v,only_inputs=True,retain_graph=False)\n",
    "            \n",
    "            THv = [THv_ + Hv_ + 0. for THv_, Hv_ in zip(THv, Hv)]\n",
    "            \n",
    "            iteration_batch += 1\n",
    "            \n",
    "            if iteration_batch*batch_size >= datapoints_per_iteration:\n",
    "                \n",
    "                THv = [THv_ / float(iteration_batch) for THv_ in THv] # normalise to the number of batches\n",
    "                \n",
    "                vHv = [torch.sum(x * y) for (x, y) in zip(THv, v)] # compute the Hutchinson estimator\n",
    "                \n",
    "                vHv_c = np.array([i.cpu().numpy() for i in vHv])\n",
    "                \n",
    "                estimator_accumulation.append(vHv_c) # accumulate the estimator\n",
    "                \n",
    "                H_average = np.mean(estimator_accumulation, axis=0)\n",
    "\n",
    "                estimator_mean_accumulation.append(H_average)\n",
    "                \n",
    "                print(f'Iteration {iteration}')\n",
    "                \n",
    "                # Reset the hutchinson estimator variables\n",
    "                v = rademacher()\n",
    "                THv = [torch.zeros(p.size()).to(device) for p in params]  # accumulate result\n",
    "                iteration_batch = 0\n",
    "                iteration += 1\n",
    "                \n",
    "                if iteration >= iterations:\n",
    "                    break\n",
    "                    \n",
    "    return H_average, estimator_accumulation, estimator_mean_accumulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "516177a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmarking_Fish(model, device, params, criterion, data_loader,iterations):\n",
    "    ''' Used to generate the convergence statistics for the Empirical Fisher\n",
    "    Args:\n",
    "        model\n",
    "        device\n",
    "        params - list of accumulated model parameters to generate placeholders\n",
    "        criterion - loss function used to compute the EF\n",
    "        data_loader\n",
    "        iterations - total number of Estimator iterations to sum over\n",
    "    Returns:\n",
    "        F_average - computed EF trace\n",
    "        estimator_accumulation - accumulation of each individual EF estimator\n",
    "        estimator_mean_accumulation - accumulation of the EF trace estimator over iterations\n",
    "    '''\n",
    "    model.eval()\n",
    "    \n",
    "    # accumulate hutchinson estimator\n",
    "    estimator_accumulation = []\n",
    "\n",
    "    estimator_mean_accumulation = []\n",
    "    \n",
    "    iteration = 0\n",
    "    total_processed = 0\n",
    "    batches = 0\n",
    "    \n",
    "    while(iteration < iterations):\n",
    "\n",
    "        TFv = [torch.zeros(p.size()).to(device) for p in params]  # accumulate iteration up to datapoints_per_iteration\n",
    "        \n",
    "        for i, data in enumerate(data_loader, 1):\n",
    "            model.zero_grad()\n",
    "            \n",
    "            inputs, labels = data[0].to(device), data[1].to(device)\n",
    "            batch_size = inputs.size(0)\n",
    "            \n",
    "            outputs = model(inputs)\n",
    "            \n",
    "            loss = criterion(outputs, labels)\n",
    "            \n",
    "            loss.backward()\n",
    "\n",
    "            paramsH = []\n",
    "            gradsH = []\n",
    "            for paramH in model.parameters():\n",
    "                if not paramH.collect:\n",
    "                    continue\n",
    "                paramsH.append(paramH)\n",
    "                gradsH.append(0. if paramH.grad is None else paramH.grad + 0.)\n",
    "            \n",
    "            # Fisher Accumulation\n",
    "            G2 = []\n",
    "            for g in gradsH:\n",
    "                G2.append(batch_size*g*g)\n",
    "                \n",
    "            TFv = [TFv_ + G2_ + 0. for TFv_, G2_ in zip(TFv, G2)]\n",
    "            \n",
    "            total_processed += 1\n",
    "            \n",
    "            TFv_normed = [TFv_ / float(total_processed) for TFv_ in TFv]\n",
    "\n",
    "            vFv = [torch.sum(x) for x in TFv_normed]\n",
    "            \n",
    "            indiv = np.array([torch.sum(x).detach().cpu().numpy() for x in G2])\n",
    "            estimator_accumulation.append(indiv)\n",
    "\n",
    "            vFv_c = np.array([i.detach().cpu().numpy() for i in vFv])\n",
    "            \n",
    "            F_average = vFv_c\n",
    "\n",
    "            print(f'Iteration {iteration}')\n",
    "                    \n",
    "            estimator_mean_accumulation.append(F_average)\n",
    "            \n",
    "            iteration += 1\n",
    "            \n",
    "            if iteration >= iterations:\n",
    "                break\n",
    "\n",
    "    return F_average, estimator_accumulation, estimator_mean_accumulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30de90b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "## experiment setup\n",
    "batch_size = 32\n",
    "\n",
    "train_loader = get_imnet_loaders(batch_size, 224)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "\n",
    "model = models.resnet18(pretrained=True)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "fit_computer = FIT(model, device, (3, 224,224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e2e743a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "H, Ha, Hma = benchmarking_Hess(model, device, fit_computer.params, criterion, train_loader, \n",
    "                      iterations = 200,  \n",
    "                      datapoints_per_iteration = batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "367f78e6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "F, Fa, Fma = benchmarking_Fish(model, device, fit_computer.params, criterion, train_loader, \n",
    "                      iterations = 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9cfc202",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_normed = (Ha - np.mean(Ha, axis=0))/np.mean(Ha, axis=0)\n",
    "# plt.yscale('log')\n",
    "print(f'Hessian variance: {np.var(mean_normed)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f613e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_normed = (Fa - np.mean(Fa, axis=0))/np.mean(Fa, axis=0)\n",
    "# plt.yscale('log')\n",
    "print(f'EF variance: {np.var(mean_normed)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daf82109",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Trace')\n",
    "plt.plot(H/fit_computer.param_nums, label='H')\n",
    "plt.plot(F/fit_computer.param_nums, label='EF')\n",
    "plt.legend()\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e07ef974",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Hessian Convergence')\n",
    "plt.plot(Hma)\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e94a5eda",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('EF Convergence')\n",
    "plt.plot(Fma)\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1be299d8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
