{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Follow the instructions in the `README.md` at the root of the project code to install the `conda` environment `sgdchaotic`. \n",
    "Then, install pytorch-hessian-eigenthings by running in your shell\n",
    "\n",
    "```conda activate sgdchaotic```\n",
    "\n",
    "followed by\n",
    "\n",
    "```pip install --upgrade git+https://github.com/noahgolmant/pytorch-hessian-eigenthings.git@master#egg=hessian-eigenthings```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "# Replace <HOME> with the path to the code root folder (parent of experiments/, o2grad/ and notebooks/)\n",
    "sys.path.append('<HOME>/experiments')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from hessian_eigenthings import compute_hessian_eigenthings\n",
    "import hydra\n",
    "from hydra.core.global_hydra import GlobalHydra\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader, Dataset, TensorDataset\n",
    "from torch.nn.functional import one_hot\n",
    "from torchsummary import summary\n",
    "from o2grad.backprop.o2model import O2Model\n",
    "from scipy.linalg import eigh\n",
    "\n",
    "\n",
    "from models.torch import get_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GlobalHydra.instance().clear()\n",
    "hydra.initialize('../experiments/config')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_eigenthings(m, x, y):\n",
    "    m.clear_cache()\n",
    "    m.zero_grad()\n",
    "    y_hat = m.forward(x)\n",
    "    loss = m.criterion(y_hat, y)\n",
    "    loss.backward()\n",
    "    hessian = m.get_hessian()\n",
    "    eigvecs, eigvals = eigh(hessian)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparison: MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp_config = hydra.compose(config_name='model/mlp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp = get_model(mlp_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary(mlp, (16, 16), device='cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preparing artificial inputs and model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = torch.rand(8, 16, 16), torch.arange(0, 8) % 10\n",
    "x, y = x.to('cuda:0'), y.to('cuda:0')\n",
    "x.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "mlp = mlp.to('cuda:0')\n",
    "o2mlp = O2Model(mlp, criterion)\n",
    "o2mlp.disable_progressbar()\n",
    "o2mlp.enable_o2backprop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "get_eigenthings(o2mlp, x, y)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp_ = get_model(mlp_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_eigenthings = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TupleDataset(Dataset):\n",
    "    \n",
    "    def __init__(self, x: torch.Tensor, y: torch.Tensor):\n",
    "        assert x.shape[0] == y.shape[0]\n",
    "        self.n = x.shape[0]\n",
    "        self.x = x.reshape(self.n, -1)\n",
    "        self.y = y.reshape(self.n, -1)\n",
    "        self.x_shape = x.shape\n",
    "        self.y_shape = y.shape\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.x)\n",
    "        \n",
    "    def __getitem__(self, idx):\n",
    "        item_x = self.x[idx, :].reshape(1, *self.x_shape[1:])\n",
    "        item_y = self.y[idx, :].reshape(1, *self.y_shape[1:])\n",
    "        return (item_x, item_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _collate_fn(inputs):\n",
    "    x_in = [x[0] for x in inputs]\n",
    "    y_in = [y[1] for y in inputs]\n",
    "    x_out = torch.cat(x_in, dim=0)\n",
    "    y_out = torch.cat(y_in, dim=0).reshape(-1)\n",
    "    return x_out, y_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TupleDataset(x, y)    \n",
    "dataloader = DataLoader(dataset, batch_size=8, collate_fn=_collate_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "eigvals, eigvecs = compute_hessian_eigenthings(\n",
    "    mlp_, \n",
    "    dataloader,\n",
    "    nn.CrossEntropyLoss(),\n",
    "    mode='lanczos',\n",
    "    num_eigenthings=num_eigenthings\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that this result is for just 1000 eigenpairs, but the model has 5350."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparison: CNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn_config = hydra.compose(config_name='model/cnn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn = get_model(cnn_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary(cnn, (1, 16, 16), device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = torch.rand(4, 1, 16, 16), torch.arange(0, 4) % 10\n",
    "x, y = x.to('cuda'), y.to('cuda:0')\n",
    "x.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "cnn = cnn.to('cuda:0')\n",
    "o2cnn = O2Model(cnn, criterion)\n",
    "o2cnn.disable_progressbar()\n",
    "o2cnn.enable_o2backprop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "get_eigenthings(o2cnn, x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn_ = get_model(cnn_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = torch.rand(4, 1, 16, 16), torch.arange(0, 4) % 10\n",
    "x, y = x.to('cuda'), y.to('cuda:0')\n",
    "x.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TupleDataset(x, y)\n",
    "dataloader = DataLoader(dataset, batch_size=4, collate_fn=_collate_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "eigvals, eigvecs = compute_hessian_eigenthings(\n",
    "    cnn_, \n",
    "    dataloader, \n",
    "    nn.CrossEntropyLoss(), \n",
    "    mode='lanczos',\n",
    "    num_eigenthings=1156\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('test_demos')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2840db3c517012e2cd12fb69127985a8118c878abd095bfc35895e29dad92dbb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
