{
 "cells": [
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 403: Forbidden\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9912422/9912422 [00:01<00:00, 9127065.96it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 403: Forbidden\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 28881/28881 [00:00<00:00, 167985.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 403: Forbidden\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1648877/1648877 [00:00<00:00, 3238324.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 403: Forbidden\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4542/4542 [00:00<00:00, 4827807.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from torch import rand\n",
    "from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential\n",
    "\n",
    "from backpack import backpack, extend\n",
    "from backpack.extensions import (\n",
    "    GGNMP,\n",
    "    HMP,\n",
    "    KFAC,\n",
    "    KFLR,\n",
    "    KFRA,\n",
    "    PCHMP,\n",
    "    BatchDiagGGNExact,\n",
    "    BatchDiagGGNMC,\n",
    "    BatchDiagHessian,\n",
    "    BatchGrad,\n",
    "    BatchL2Grad,\n",
    "    DiagGGNExact,\n",
    "    DiagGGNMC,\n",
    "    DiagHessian,\n",
    "    SqrtGGNExact,\n",
    "    SqrtGGNMC,\n",
    "    SumGradSquared,\n",
    "    Variance,\n",
    ")\n",
    "from backpack.utils.examples import load_one_batch_mnist\n",
    "\n",
    "X, y = load_one_batch_mnist(batch_size=512)\n",
    "\n",
    "model = Sequential(Flatten(), Linear(784, 10))\n",
    "lossfunc = CrossEntropyLoss()\n",
    "\n",
    "model = extend(model)\n",
    "lossfunc = extend(lossfunc)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-18T18:50:10.057651Z",
     "start_time": "2024-07-18T18:50:01.667943Z"
    }
   },
   "id": "97444720fe980ccc",
   "execution_count": 34
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch.functional as F\n",
    "import torch"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "195cc19e34fb07e2"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def hessian(x, logits):\n",
    "    batch_size, d = x.shape  # Shape: [batch_size, d]\n",
    "    num_classes = logits.shape[1]  # Number of classes\n",
    "    dC = num_classes * d  # Total number of parameters in the flattened gradient\n",
    "    p = F.softmax(logits, dim=1)  # Shape: [batch_size, num_classes]\n",
    "\n",
    "    # Compute p_k(1-p_k) for diagonal blocks and -p_k*p_l for off-diagonal blocks\n",
    "    # Diagonal part\n",
    "    p_diag = p * (1 - p)  # Shape: [batch_size, num_classes]\n",
    "    # Off-diagonal part\n",
    "    p_off_diag = -p.unsqueeze(2) * p.unsqueeze(1)  # Shape: [batch_size, num_classes, num_classes]\n",
    "\n",
    "    # Fill the diagonal part in off-diagonal tensor\n",
    "    indices = torch.arange(num_classes)\n",
    "    p_off_diag[:, indices, indices] = p_diag\n",
    "    # Outer product of x\n",
    "    X_outer = torch.einsum('bi,bj->bij', x, x)  # Shape: [batch_size, d, d]\n",
    "\n",
    "    H2 = torch.einsum('bkl,bij->bklij', p_off_diag, X_outer)\n",
    "    H2 = H2.sum(0).reshape(dC, dC)  # Shape: [dC, dC]\n",
    "\n",
    "    H2 /= batch_size\n",
    "    # breakpoint()\n",
    "    # H2 /= dC\n",
    "    H2 /= num_classes\n",
    "    return H2"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-18T23:13:19.549293Z",
     "start_time": "2024-07-18T23:13:19.545201Z"
    }
   },
   "id": "9021e26366c02fd5",
   "execution_count": 126
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def hessian_diagonal(x, logits):\n",
    "    batch_size, d = x.shape  # Shape: [batch_size, d]\n",
    "    num_classes = logits.shape[1]  # Number of classes\n",
    "    dC = num_classes * d  # Total number of parameters in the flattened gradient\n",
    "    p = F.softmax(logits, dim=1)  # Shape: [batch_size, num_classes]\n",
    "\n",
    "    # Compute p_k(1-p_k) for diagonal blocks\n",
    "    p_diag = p * (1 - p)  # Shape: [batch_size, num_classes]\n",
    "\n",
    "    # Outer product of x, but only considering the diagonal part\n",
    "    x_squared = x ** 2  # Shape: [batch_size, d]\n",
    "\n",
    "    # Compute the diagonal of the Hessian matrix\n",
    "    H2_diag = torch.einsum('bk,bi->bki', p_diag, x_squared)  # Shape: [batch_size, num_classes, d]\n",
    "\n",
    "    # Sum across the batch dimension\n",
    "    H2_diag = H2_diag.sum(0)  # Shape: [num_classes, d]\n",
    "\n",
    "    # Normalize the result\n",
    "    H2_diag /= batch_size\n",
    "    H2_diag /= num_classes\n",
    "\n",
    "    # Reshape the result to match the diagonal of the Hessian matrix\n",
    "    H2_diag_flat = H2_diag.flatten()  # Shape: [dC]\n",
    "\n",
    "    return H2_diag_flat\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-18T23:13:19.861589Z",
     "start_time": "2024-07-18T23:13:19.858157Z"
    }
   },
   "id": "86641559eb7b5882",
   "execution_count": 127
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def hessian_diag_backpack(x, y, logits, model, loss_fn):\n",
    "    model.zero_grad()\n",
    "    loss = loss_fn(logits, y)\n",
    "    \n",
    "    \n",
    "    with backpack(DiagHessian()):\n",
    "        loss.backward()\n",
    "    # loss.backward()\n",
    "    \n",
    "    hessian_diag = []\n",
    "    for param in model.parameters():\n",
    "        hessian_diag.append(param.diag_h.flatten())\n",
    "    \n",
    "    hessian_diag = torch.cat(hessian_diag)/logits.shape[1]\n",
    "    return hessian_diag"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-18T23:25:50.058915Z",
     "start_time": "2024-07-18T23:25:50.055701Z"
    }
   },
   "id": "8b01284569083dae",
   "execution_count": 149
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0454, 0.0181, 0.0749, 0.0616, 0.0376, 0.0587, 0.0660, 0.0224, 0.0712,\n",
      "        0.0523, 0.0382, 0.0547])\n"
     ]
    }
   ],
   "source": [
    "\n",
    "class Algorithm:\n",
    "    def __init__(self, classifier):\n",
    "        self.classifier = classifier\n",
    "\n",
    "    def hessian_diag_backpack(self, logits, y, model, loss_fn):\n",
    "        # Ensure that the model and loss function are extended\n",
    "\n",
    "        with backpack(DiagHessian()):\n",
    "            loss.backward()\n",
    "\n",
    "        hessian_diag = []\n",
    "        for name, param in model.named_parameters():\n",
    "            if hasattr(param, 'diag_h'):\n",
    "                hessian_diag.append(param.diag_h.flatten())\n",
    "            else:\n",
    "                raise AttributeError(f\"Parameter {name} has no attribute 'diag_h'\")\n",
    "\n",
    "        hessian_diag = torch.cat(hessian_diag) / logits.shape[1]\n",
    "        return hessian_diag\n",
    "\n",
    "# Extend each layer of the model with BackPACK and check extension\n",
    "def extend_model(model):\n",
    "    for module in model.modules():\n",
    "        if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.BatchNorm1d)):\n",
    "            extend(module)\n",
    "            print(f\"Extended module: {module}\")  # Debugging print\n",
    "            # Check if the extension is applied\n",
    "            dummy_input = torch.randn(1, module.in_features) if hasattr(module, 'in_features') else None\n",
    "            if dummy_input is not None:\n",
    "                dummy_output = module(dummy_input)\n",
    "                dummy_loss = dummy_output.sum()\n",
    "                with backpack(DiagHessian()):\n",
    "                    dummy_loss.backward()\n",
    "                if hasattr(module.weight, 'diag_h'):\n",
    "                    print(f\"{module}: has 'diag_h' attribute.\")\n",
    "                else:\n",
    "                    print(f\"{module}: missing 'diag_h' attribute.\")\n",
    "    return model\n",
    "\n",
    "# Example Classifier definition\n",
    "def Classifier(in_features, out_features, is_nonlinear=False):\n",
    "    if is_nonlinear:\n",
    "        return nn.Sequential(\n",
    "            nn.Linear(in_features, in_features // 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(in_features // 2, in_features // 4),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(in_features // 4, out_features))\n",
    "    else:\n",
    "        return nn.Linear(in_features, out_features, bias=False)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# # Generate random data\n",
    "# # Test the functions\n",
    "input_dim = 3\n",
    "output_dim = 4\n",
    "batch_size = 5\n",
    "is_nonlinear = False\n",
    "# classifier = Classifier(input_dim, output_dim, is_nonlinear)\n",
    "# classifier  = nn.Linear(input_dim, output_dim, bias=False)\n",
    "classifier = Classifier(input_dim, output_dim, is_nonlinear)\n",
    "classifier = extend(extend(classifier))\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "loss_fn = extend(loss_fn)\n",
    "x = torch.randn(batch_size, input_dim)\n",
    "y = torch.randint(0, output_dim, (batch_size,))\n",
    "logits = classifier(x)\n",
    "\n",
    "\n",
    "\n",
    "# Call the function\n",
    "hessian_diag = hessian_diag_backpack(x, y, logits, classifier,loss_fn)\n",
    "print(hessian_diag)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-18T23:41:33.979156Z",
     "start_time": "2024-07-18T23:41:33.963329Z"
    }
   },
   "id": "335e1bdf25e36349",
   "execution_count": 171
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[168], line 4\u001B[0m\n\u001B[1;32m      2\u001B[0m loss \u001B[38;5;241m=\u001B[39m loss_fn(logits, y)\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m backpack(DiagHessian()):\n\u001B[0;32m----> 4\u001B[0m     \u001B[43mloss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m      6\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m name, param \u001B[38;5;129;01min\u001B[39;00m classifier\u001B[38;5;241m.\u001B[39mnamed_parameters():\n\u001B[1;32m      7\u001B[0m     \u001B[38;5;28mprint\u001B[39m(name)\n",
      "File \u001B[0;32m~/anaconda3/envs/fishr/lib/python3.9/site-packages/torch/_tensor.py:525\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    515\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    516\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    517\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    518\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    523\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    524\u001B[0m     )\n\u001B[0;32m--> 525\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    526\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    527\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/anaconda3/envs/fishr/lib/python3.9/site-packages/torch/autograd/__init__.py:267\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    262\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    264\u001B[0m \u001B[38;5;66;03m# The reason we repeat the same comment below is that\u001B[39;00m\n\u001B[1;32m    265\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    266\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 267\u001B[0m \u001B[43m_engine_run_backward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    268\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    269\u001B[0m \u001B[43m    \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    270\u001B[0m \u001B[43m    \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    271\u001B[0m \u001B[43m    \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    272\u001B[0m \u001B[43m    \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    273\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\n\u001B[1;32m    274\u001B[0m \u001B[43m    \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\n\u001B[1;32m    275\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/anaconda3/envs/fishr/lib/python3.9/site-packages/torch/autograd/graph.py:744\u001B[0m, in \u001B[0;36m_engine_run_backward\u001B[0;34m(t_outputs, *args, **kwargs)\u001B[0m\n\u001B[1;32m    742\u001B[0m     unregister_hooks \u001B[38;5;241m=\u001B[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001B[1;32m    743\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 744\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    745\u001B[0m \u001B[43m        \u001B[49m\u001B[43mt_outputs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\n\u001B[1;32m    746\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m  \u001B[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001B[39;00m\n\u001B[1;32m    747\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[1;32m    748\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m attach_logging_hooks:\n",
      "\u001B[0;31mRuntimeError\u001B[0m: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward."
     ]
    }
   ],
   "source": [
    "classifier.eval()\n",
    "loss = loss_fn(logits, y)\n",
    "with backpack(DiagHessian()):\n",
    "    loss.backward()\n",
    "\n",
    "for name, param in classifier.named_parameters():\n",
    "    print(name)\n",
    "    print(param)\n",
    "    print(param.diag_h)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-18T23:37:11.584873Z",
     "start_time": "2024-07-18T23:37:11.556295Z"
    }
   },
   "id": "578ef0d338337363",
   "execution_count": 168
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full Hessian diagonal:\n",
      "tensor([ 0.0409, -0.0037, -0.0067, -0.0024, -0.0026,  0.0033,  0.0022,  0.0166,\n",
      "         0.0013, -0.0068, -0.0110,  0.0294], grad_fn=<DiagonalBackward0_copy>)\n",
      "Diagonal function output:\n",
      "tensor([0.0409, 0.0259, 0.0220, 0.0197, 0.0125, 0.0128, 0.0478, 0.0216, 0.0221,\n",
      "        0.0451, 0.0315, 0.0294], grad_fn=<ViewBackward0>)\n",
      "BackPACK Hessian diagonal output:\n",
      "tensor([0.0409, 0.0259, 0.0220, 0.0197, 0.0125, 0.0128, 0.0478, 0.0216, 0.0221,\n",
      "        0.0451, 0.0315, 0.0294, 0.0428, 0.0255, 0.0522, 0.0546])\n",
      "Diagonals match (original vs diagonal function): False\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (12) must match the size of tensor b (16) at non-singleton dimension 0",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[143], line 33\u001B[0m\n\u001B[1;32m     31\u001B[0m \u001B[38;5;28mprint\u001B[39m(H_backpack_diag)\n\u001B[1;32m     32\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDiagonals match (original vs diagonal function):\u001B[39m\u001B[38;5;124m\"\u001B[39m, torch\u001B[38;5;241m.\u001B[39mallclose(H_full_diag, H_diag))\n\u001B[0;32m---> 33\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDiagonals match (original vs BackPACK):\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mallclose\u001B[49m\u001B[43m(\u001B[49m\u001B[43mH_full_diag\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mH_backpack_diag\u001B[49m\u001B[43m)\u001B[49m)\n",
      "\u001B[0;31mRuntimeError\u001B[0m: The size of tensor a (12) must match the size of tensor b (16) at non-singleton dimension 0"
     ]
    }
   ],
   "source": [
    "# Compute Hessian diagonal using original function and BackPACK\n",
    "H_full = hessian(x, logits)\n",
    "H_diag = hessian_diagonal(x, logits)\n",
    "H_backpack_diag = hessian_diag_backpack(x, y, logits, model, lossfunc)\n",
    "\n",
    "# Extract the diagonal from the full Hessian\n",
    "H_full_diag = H_full.diag()\n",
    "\n",
    "# Print and compare the results\n",
    "print(\"Full Hessian diagonal:\")\n",
    "print(H_full_diag)\n",
    "print(\"Diagonal function output:\")\n",
    "print(H_diag)\n",
    "print(\"BackPACK Hessian diagonal output:\")\n",
    "print(H_backpack_diag)\n",
    "print(\"Diagonals match (original vs diagonal function):\", torch.allclose(H_full_diag, H_diag))\n",
    "print(\"Diagonals match (original vs BackPACK):\", torch.allclose(H_full_diag, H_backpack_diag))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-18T23:19:59.624912Z",
     "start_time": "2024-07-18T23:19:59.596780Z"
    }
   },
   "id": "e674d904275e9c3b",
   "execution_count": 143
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "data": {
      "text/plain": "tensor(0.1421, grad_fn=<LinalgVectorNormBackward0>)"
     },
     "execution_count": 180,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(H_full, p='fro')\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T04:14:07.115784Z",
     "start_time": "2024-07-19T04:14:07.105431Z"
    }
   },
   "id": "617e20986e4b1253",
   "execution_count": 180
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "data": {
      "text/plain": "tensor(0.0554, grad_fn=<LinalgVectorNormBackward0>)"
     },
     "execution_count": 181,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(H_full_diag, p='fro')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T04:14:08.660301Z",
     "start_time": "2024-07-19T04:14:08.657132Z"
    }
   },
   "id": "4be9cf62920fedb5",
   "execution_count": 181
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([12, 12])"
     },
     "execution_count": 175,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "H_full.shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T04:13:44.328741Z",
     "start_time": "2024-07-19T04:13:44.325917Z"
    }
   },
   "id": "b912557eafa907d9",
   "execution_count": 175
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([12])"
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "H_full_diag.shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T04:13:57.150492Z",
     "start_time": "2024-07-19T04:13:57.147549Z"
    }
   },
   "id": "9c919395822cdbc3",
   "execution_count": 177
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "c95028e867e0c4b1"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
