{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import glob\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Subset, DataLoader\n",
    "\n",
    "from laplace import Laplace\n",
    "\n",
    "from source.constants import RESULTS_PATH\n",
    "from source.networks.resnet import get_resnet18\n",
    "from source.data.cifar10_c import get_cifar10_c, corruptions\n",
    "from utils import load_train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "resnet18\n"
     ]
    }
   ],
   "source": [
    "seed = 42\n",
    "\n",
    "n_class = 10\n",
    "models = [\"resnet18\"]\n",
    "\n",
    "model = models[0]                   # select model\n",
    "\n",
    "device = \"cuda:0\"\n",
    "batch_size = 256\n",
    "num_workers = 4 \n",
    "\n",
    "n_models = 10\n",
    "\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading models: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading models: 5it [00:14,  2.83s/it]\n"
     ]
    }
   ],
   "source": [
    "path = os.path.join(RESULTS_PATH, f\"cifar10_{model}_seed{seed}_laplace\")\n",
    "\n",
    "results_path = os.path.join(path, \"corruptions\")\n",
    "os.makedirs(results_path, exist_ok=True)\n",
    "\n",
    "# load networks\n",
    "networks = list()\n",
    "model_files = glob.glob(os.path.join(path, \"models\", \"*.pt\"))\n",
    "for m, model_file in tqdm(enumerate(sorted(model_files)), desc=\"Loading models\"):\n",
    "\n",
    "    if model == \"resnet18\":\n",
    "        network = get_resnet18(num_classes=n_class) \n",
    "\n",
    "    network.load_state_dict(torch.load(model_file, map_location=device))\n",
    "    network.to(device)\n",
    "    networks.append(copy.deepcopy(network))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "41667 8333\n"
     ]
    }
   ],
   "source": [
    "full_train, _ = load_train_dataset(\"cifar10\")\n",
    "\n",
    "val_inds = torch.load(os.path.join(path, \"val_inds.pt\"))\n",
    "train_inds = np.delete(np.arange(len(full_train)), (val_inds))\n",
    "\n",
    "print(len(train_inds), len(val_inds))\n",
    "\n",
    "# for training just train and val datasets necessary\n",
    "train_ds = Subset(full_train, indices=train_inds)\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_probits = list()\n",
    "\n",
    "for n, network in enumerate(networks):\n",
    "    print(\"Laplace Approximation for model\", n)\n",
    "\n",
    "    network.train()\n",
    "    # define laplace approximation\n",
    "    la = Laplace(network, \n",
    "                likelihood='classification',\n",
    "                subset_of_weights='last_layer',\n",
    "                hessian_structure='kron')\n",
    "\n",
    "    la.fit(train_loader)\n",
    "    la.optimize_prior_precision(method=\"marglik\")\n",
    "    print(la.prior_precision, torch.log(la.prior_precision))\n",
    "    \n",
    "    network.eval()\n",
    "    la.model.eval()\n",
    "\n",
    "    for c in tqdm(range(15)):\n",
    "        for s in range(1, 6):\n",
    "            cs = c * 5 + (s - 1)\n",
    "\n",
    "            dataset = get_cifar10_c(corruption=c, severity=s)\n",
    "            test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "\n",
    "            # predict\n",
    "            probits_ = list()\n",
    "            for x, _ in test_loader:\n",
    "                x = x.to(device)\n",
    "                probs_model = torch.softmax(network.forward(x), dim=1).detach().cpu().unsqueeze(1)\n",
    "                probs = la._nn_predictive_samples(x, n_samples=n_models - 1).permute(1,0,2).cpu()\n",
    "                probits_.append(torch.cat([probs_model, probs], dim=1))\n",
    "            probits_ = torch.cat(probits_, dim=0)\n",
    "            if len(ds_probits) == cs:\n",
    "                ds_probits.append(list())\n",
    "            ds_probits[cs].append(probits_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in range(15):\n",
    "    for s in range(1, 6):\n",
    "        cs = c * 5 + (s - 1)\n",
    "        dataset_name = f\"{corruptions[c]}_{s}\"\n",
    "        probits = torch.stack(ds_probits[cs], dim=1)\n",
    "        torch.save(probits.to(torch.float16), os.path.join(results_path, f\"{dataset_name}_probits.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
