{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "\n",
    "from source.utils.seeding import fix_seeds\n",
    "from source.constants import RESULTS_PATH, CIFAR_MEAN, CIFAR_STD\n",
    "from source.networks.resnet import get_resnet18_d\n",
    "from utils import load_test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tin resnet18\n"
     ]
    }
   ],
   "source": [
    "seed = 42\n",
    "\n",
    "dataset_names = [\"cifar10\", \"cifar100\", \"svhn\", \"tin\", \"lsun\"]\n",
    "n_classes = [10, 100, 10, 200, 10]\n",
    "models = [\"resnet18\"]\n",
    "\n",
    "train_dataset = dataset_names[3]    # select dataset\n",
    "model = models[0]                   # select model\n",
    "\n",
    "p_drop = 0.2                        # dropout probability   \n",
    "n_models = 10                     # number of dropout models\n",
    "\n",
    "# infer number of classes from dataset\n",
    "n_class = n_classes[dataset_names.index(train_dataset)]\n",
    "\n",
    "device = \"cuda:0\"\n",
    "batch_size = 2048 \n",
    "\n",
    "path = os.path.join(RESULTS_PATH, f\"{train_dataset}_{model}_dropout{p_drop}_seed{seed}\")\n",
    "\n",
    "print(train_dataset, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def evaluate(networks: List, ds: Dataset, n_models: int = n_models, seed: int = seed):\n",
    "    probits = list()\n",
    "    for network in tqdm(networks):\n",
    "        probits_, ys_ = list(), list()\n",
    "        for x, y in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):\n",
    "            x = x.to(device)\n",
    "\n",
    "            # for consistent dropout masks\n",
    "            fix_seeds(seed)\n",
    "            \n",
    "            probits__ = list()\n",
    "            for n in range(n_models):\n",
    "                \n",
    "                # first model is normal model, others are dropout models\n",
    "                if n == 0:\n",
    "                    network.eval()\n",
    "                else:\n",
    "                    network.train()\n",
    "                    # put bachnorm in eval mode\n",
    "                    for m in network.modules():\n",
    "                        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):\n",
    "                            m.eval()\n",
    "\n",
    "                probits__.append(torch.softmax(network.forward(x), dim=1).cpu())\n",
    "            probits__ = torch.stack(probits__, dim=1)\n",
    "            probits_.append(probits__)\n",
    "            ys_.append(y.cpu())\n",
    "\n",
    "        probits_ = torch.concat(probits_, dim=0)\n",
    "        probits.append(probits_)\n",
    "    return torch.stack(probits, dim=1), torch.cat(ys_, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading models: 5it [00:00,  7.22it/s]\n"
     ]
    }
   ],
   "source": [
    "# load networks\n",
    "networks = list()\n",
    "model_files = glob.glob(os.path.join(path, \"models\", \"*.pt\"))\n",
    "for m, model_file in tqdm(enumerate(sorted(model_files)), desc=\"Loading models\"):\n",
    "\n",
    "    if model == \"resnet18\":\n",
    "        network = get_resnet18_d(num_classes=n_class, p_drop=p_drop) \n",
    "    else:\n",
    "        raise NotImplementedError(\"Model not supported\")\n",
    "\n",
    "    network.load_state_dict(torch.load(model_file, map_location=device))\n",
    "    network.to(device)\n",
    "    networks.append(copy.deepcopy(network))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(f\"> Evaluating on {dataset_name}\")\n",
    "\n",
    "    dataset = load_test_dataset(dataset_name)\n",
    "\n",
    "    # evaluate\n",
    "    probits, ys = evaluate(networks, dataset)\n",
    "    torch.save(probits.to(torch.float16), os.path.join(path, f\"{dataset_name}_probits.pt\"))\n",
    "    torch.save(ys, os.path.join(path, f\"{dataset_name}_ys.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Evaluating on fgsm\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:06<00:00,  6.04s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.06s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.02s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.01s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10000, 5, 10, 200])\n",
      "> Evaluating on linfpgd\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:06<00:00,  6.02s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.03s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.09s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.05s/it]\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.05s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10000, 5, 10, 200])\n"
     ]
    }
   ],
   "source": [
    "# evaluate on adversarial samples if available\n",
    "adv_path = os.path.join(path, \"adversarial_examples\")\n",
    "\n",
    "runs = 5    #! number of runs\n",
    "\n",
    "if os.path.exists(adv_path):\n",
    "    # get all adversarial datasets\n",
    "    adv_datasets = sorted(glob.glob(os.path.join(adv_path, \"*.pt\")))\n",
    "\n",
    "    adv_ds_dict = {os.path.basename(adv_dataset).split(\"_\")[0]: \n",
    "                    [int(os.path.basename(adv_ds).split(\"_\")[1].strip(\".pt\")) for adv_ds in adv_datasets \n",
    "                     if os.path.basename(adv_ds).split(\"_\")[0] == os.path.basename(adv_dataset).split(\"_\")[0] and \"probits\" not in os.path.basename(adv_ds)] \n",
    "                     for adv_dataset in adv_datasets}\n",
    "\n",
    "    for adv_atk in adv_ds_dict.keys():\n",
    "        print(f\"> Evaluating on {adv_atk}\")\n",
    "\n",
    "        probits = list()\n",
    "\n",
    "        for runid in adv_ds_dict[adv_atk]:\n",
    "            # load adversarial samples and map to [0, 1]\n",
    "            images = torch.load(os.path.join(adv_path, f\"{adv_atk}_{runid}.pt\"), map_location=\"cpu\").to(torch.float32) / 255\n",
    "\n",
    "            # apply cifar normalization\n",
    "            images = (images - torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1)) / torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1)\n",
    "\n",
    "            dataset = TensorDataset(images, torch.zeros(size=(len(images), )).long())\n",
    "\n",
    "            # evaluate\n",
    "            probits_, _ = evaluate(networks[runid * (len(networks) // runs):((runid + 1) * (len(networks) // runs))], dataset)\n",
    "            probits.append(probits_)\n",
    "        probits = torch.cat(probits, dim=1)\n",
    "        print(probits.shape)\n",
    "        torch.save(probits.to(torch.float16), os.path.join(adv_path, f\"{adv_atk}_probits.pt\"))\n",
    "else:\n",
    "    print(\"No adversarial examples found, use `generate_adversarial_examples.py` to generate adversarial examples.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
