{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import copy\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Subset, DataLoader, TensorDataset\n",
    "\n",
    "from laplace import Laplace\n",
    "\n",
    "from source.constants import RESULTS_PATH, CIFAR_MEAN, CIFAR_STD\n",
    "from source.networks.resnet import get_resnet18\n",
    "from source.networks.densenet import get_densenet169\n",
    "from utils import load_train_dataset, load_test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "svhn resnet18\n"
     ]
    }
   ],
   "source": [
    "seed = 42\n",
    "\n",
    "dataset_names = [\"cifar10\", \"cifar100\", \"svhn\", \"tin\", \"lsun\"]\n",
    "n_classes = [10, 100, 10, 200, 10]\n",
    "models = [\"resnet18\", \"densenet169\"]\n",
    "\n",
    "dataset_name = dataset_names[2]    # select dataset\n",
    "model = models[0]                   # select model\n",
    "\n",
    "# infer number of classes from dataset\n",
    "n_class = n_classes[dataset_names.index(dataset_name)]\n",
    "\n",
    "n_models = 5\n",
    "\n",
    "device = \"cuda:0\"\n",
    "batch_size = 256\n",
    "num_workers = 4 \n",
    "\n",
    "# how many samples to draw from the posterior\n",
    "n_samples = 10\n",
    "\n",
    "print(dataset_name, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create laplace results directory\n",
    "orig_path = os.path.join(RESULTS_PATH, f\"{dataset_name}_{model}_seed{seed}\")\n",
    "new_path = os.path.join(RESULTS_PATH, f\"{dataset_name}_{model}_seed{seed}_laplace\")\n",
    "os.makedirs(new_path, exist_ok=True)\n",
    "\n",
    "# copy models from original directory to new directory\n",
    "model_files = glob.glob(os.path.join(orig_path, \"models\", \"*.pt\"))\n",
    "\n",
    "os.makedirs(os.path.join(new_path, \"models\"), exist_ok=True)\n",
    "\n",
    "for run_id in range(n_models):\n",
    "    model_file = sorted(model_files)[run_id]\n",
    "    print(\"moving\", model_file)\n",
    "    os.system(f\"cp {model_file} {os.path.join(new_path, 'models', f'model_{run_id}.pt')}\")\n",
    "os.system(f\"cp {os.path.join(orig_path, 'val_inds.pt')} {os.path.join(new_path, 'val_inds.pt')}\")\n",
    "\n",
    "model_files = glob.glob(os.path.join(new_path, \"models\", \"*.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading models: 5it [00:00,  7.99it/s]\n"
     ]
    }
   ],
   "source": [
    "# load networks\n",
    "networks = list()\n",
    "for m, model_file in tqdm(enumerate(sorted(model_files)), desc=\"Loading models\"):\n",
    "\n",
    "    if model == \"resnet18\":\n",
    "        network = get_resnet18(num_classes=n_class) \n",
    "    elif model == \"densenet169\":\n",
    "        network = get_densenet169(num_classes=n_class)\n",
    "    else:\n",
    "        raise ValueError(\"Model not implemented\")\n",
    "\n",
    "    network.load_state_dict(torch.load(model_file, map_location=device))\n",
    "    network.to(device)\n",
    "    networks.append(copy.deepcopy(network))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_train, _ = load_train_dataset(dataset_name)\n",
    "\n",
    "val_inds = torch.load(os.path.join(orig_path, \"val_inds.pt\"))\n",
    "train_inds = np.delete(np.arange(len(full_train)), (val_inds))\n",
    "\n",
    "print(len(train_inds), len(val_inds))\n",
    "\n",
    "# for training just train and val datasets necessary\n",
    "train_ds = Subset(full_train, indices=train_inds)\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(la, network, test_loader):\n",
    "    # predict\n",
    "    probits_, ys_ = list(), list()\n",
    "    for x, y in tqdm(test_loader):\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        probs_model = torch.softmax(network.forward(x), dim=1).detach().cpu().unsqueeze(1)\n",
    "        probs = la._nn_predictive_samples(x, n_samples=n_samples - 1).permute(1,0,2).cpu()\n",
    "        probits_.append(torch.cat([probs_model, probs], dim=1))\n",
    "        ys_.append(y.cpu())\n",
    "    probits_ = torch.cat(probits_, dim=0)\n",
    "    ys_ = torch.cat(ys_, dim=0)\n",
    "    return probits_, ys_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_probits = list()\n",
    "atk_probits = list()\n",
    "adv_path = os.path.join(new_path, \"adversarial_examples\")\n",
    "\n",
    "for run_id, network in enumerate(networks):\n",
    "    print(\"Laplace Approximation for model\", run_id)\n",
    "\n",
    "    network.train()\n",
    "    # define laplace approximation\n",
    "    la = Laplace(network, \n",
    "                likelihood='classification',\n",
    "                subset_of_weights='last_layer',\n",
    "                hessian_structure='kron')\n",
    "\n",
    "    la.fit(train_loader)\n",
    "    la.optimize_prior_precision(method=\"marglik\")\n",
    "    print(la.prior_precision, torch.log(la.prior_precision))\n",
    "    \n",
    "    network.eval()\n",
    "    la.model.eval()\n",
    "\n",
    "    # evaluate on test datasets\n",
    "    for d, dataset_name in enumerate(dataset_names):\n",
    "        print(f\"> Evaluating on {dataset_name}\")\n",
    "\n",
    "        dataset = load_test_dataset(dataset_name)\n",
    "        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "\n",
    "        probits_, ys_ = evaluate(la, network, test_loader)\n",
    "        \n",
    "        if len(ds_probits) == d:\n",
    "            ds_probits.append(list())\n",
    "        ds_probits[d].append(probits_)\n",
    "\n",
    "        # save labels per dataset\n",
    "        if run_id == 0:\n",
    "            torch.save(ys_, os.path.join(new_path, f\"{dataset_name}_ys.pt\"))\n",
    "\n",
    "    # Evaluate adversarial examples if available\n",
    "    if os.path.exists(adv_path):\n",
    "        # get all adversarial datasets\n",
    "        adv_datasets = sorted(glob.glob(os.path.join(adv_path, \"*.pt\")))\n",
    "\n",
    "        adv_ds_dict = {os.path.basename(adv_dataset).split(\"_\")[0]: \n",
    "                [int(os.path.basename(adv_ds).split(\"_\")[1].strip(\".pt\")) for adv_ds in adv_datasets \n",
    "                    if os.path.basename(adv_ds).split(\"_\")[0] == os.path.basename(adv_dataset).split(\"_\")[0] and \"probits\" not in os.path.basename(adv_ds)] \n",
    "                    for adv_dataset in adv_datasets}\n",
    "        \n",
    "        for a, adv_atk in enumerate(adv_ds_dict.keys()):\n",
    "            print(f\"> Evaluating on {adv_atk}\")\n",
    "\n",
    "            # load adversarial samples and map to [0, 1]\n",
    "            images = torch.load(os.path.join(adv_path, f\"{adv_atk}_{run_id}.pt\"), map_location=\"cpu\").to(torch.float32) / 255\n",
    "\n",
    "            # apply cifar normalization\n",
    "            images = (images - torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1)) / torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1)\n",
    "\n",
    "            dataset = TensorDataset(images, torch.zeros(size=(len(images), )).long())\n",
    "            test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "\n",
    "            # evaluate\n",
    "            probits_, _ = evaluate(la, network, test_loader)\n",
    "            if len(atk_probits) == a:\n",
    "                atk_probits.append(list())\n",
    "            atk_probits[a].append(probits_)\n",
    "    else:\n",
    "        print(\"No adversarial examples found for dataset\", dataset_name)\n",
    "\n",
    "for d, dataset_name in enumerate(dataset_names):\n",
    "    probits = torch.stack(ds_probits[d], dim=1)\n",
    "    print(probits.shape)\n",
    "    torch.save(probits.to(torch.float16), os.path.join(new_path, f\"{dataset_name}_probits.pt\"))  \n",
    "\n",
    "if len(atk_probits) > 0:\n",
    "    for a, adv_atk in enumerate(adv_ds_dict.keys()):\n",
    "        probits = torch.stack(atk_probits[a], dim=1)\n",
    "        print(probits.shape)\n",
    "        torch.save(probits.to(torch.float16), os.path.join(adv_path, f\"{adv_atk}_probits.pt\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
