{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "\n",
    "from source.constants import RESULTS_PATH, CIFAR_MEAN, CIFAR_STD\n",
    "from source.networks.resnet import get_resnet18\n",
    "from source.networks.densenet import get_densenet169\n",
    "from source.networks.regnet import get_regnet_y_800mf\n",
    "from source.networks.cnn import SimpleCNN\n",
    "from utils import load_test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "svhn regnet\n"
     ]
    }
   ],
   "source": [
    "seed = 42\n",
    "\n",
    "dataset_names = [\"cifar10\", \"cifar100\", \"svhn\", \"tin\", \"lsun\"]\n",
    "n_classes = [10, 100, 10, 200, 10]\n",
    "models = [\"resnet18\", \"densenet169\", \"regnet\", \"cnn\"]\n",
    "\n",
    "train_dataset = dataset_names[2]    # select dataset\n",
    "model = models[2]                   # select model\n",
    "\n",
    "# infer number of classes from dataset\n",
    "n_class = n_classes[dataset_names.index(train_dataset)]\n",
    "\n",
    "device = \"cuda:0\"\n",
    "batch_size = 2048 \n",
    "\n",
    "print(train_dataset, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def evaluate(networks: List, ds: Dataset):\n",
    "    probits = list()\n",
    "    for network in tqdm(networks):\n",
    "        network.eval()\n",
    "        probits_, ys_ = list(), list()\n",
    "        for x, y in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):\n",
    "            x = x.to(device)\n",
    "\n",
    "            probits_.append(torch.softmax(network.forward(x), dim=1).cpu())\n",
    "            ys_.append(y.cpu())\n",
    "        probits_ = torch.concat(probits_, dim=0)\n",
    "        probits.append(probits_)\n",
    "    return torch.stack(probits, dim=1), torch.cat(ys_, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading models:   0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading models: 100%|██████████| 50/50 [00:08<00:00,  6.13it/s]\n"
     ]
    }
   ],
   "source": [
    "path = os.path.join(RESULTS_PATH, f\"{train_dataset}_{model}_seed{seed}\")\n",
    "\n",
    "# load networks\n",
    "networks = list()\n",
    "model_files = glob.glob(os.path.join(path, \"models\", \"*.pt\"))\n",
    "for m in tqdm(range(len(model_files)), desc=\"Loading models\"):\n",
    "    model_file = os.path.join(path, \"models\", f\"model_{m}.pt\")\n",
    "\n",
    "    if model == \"resnet18\":\n",
    "        network = get_resnet18(num_classes=n_class) \n",
    "    elif model == \"densenet169\":\n",
    "        network = get_densenet169(num_classes=n_class)\n",
    "    elif model == \"regnet\":\n",
    "        network = get_regnet_y_800mf(num_classes=n_class)\n",
    "    elif model == \"cnn\":\n",
    "        network = SimpleCNN(n_class)\n",
    "\n",
    "    network.load_state_dict(torch.load(model_file, map_location=device))\n",
    "    network.to(device)\n",
    "    networks.append(copy.deepcopy(network))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate on test datasets\n",
    "for dataset_name in dataset_names:\n",
    "    print(f\"> Evaluating on {dataset_name}\")\n",
    "\n",
    "    dataset = load_test_dataset(dataset_name)\n",
    "\n",
    "    # evaluate\n",
    "    probits, ys = evaluate(networks, dataset)\n",
    "    torch.save(probits.to(torch.float16), os.path.join(path, f\"{dataset_name}_probits.pt\"))\n",
    "    torch.save(ys, os.path.join(path, f\"{dataset_name}_ys.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No adversarial examples found, use `generate_adversarial_examples.py` to generate adversarial examples.\n"
     ]
    }
   ],
   "source": [
    "# evaluate on adversarial samples if available\n",
    "adv_path = os.path.join(path, \"adversarial_examples\")\n",
    "\n",
    "runs = 5    #! number of runs\n",
    "\n",
    "if os.path.exists(adv_path):\n",
    "    # get all adversarial datasets\n",
    "    adv_datasets = sorted(glob.glob(os.path.join(adv_path, \"*.pt\")))\n",
    "\n",
    "    adv_ds_dict = {os.path.basename(adv_dataset).split(\"_\")[0]: \n",
    "                    [int(os.path.basename(adv_ds).split(\"_\")[1].strip(\".pt\")) for adv_ds in adv_datasets \n",
    "                     if os.path.basename(adv_ds).split(\"_\")[0] == os.path.basename(adv_dataset).split(\"_\")[0] and \"probits\" not in os.path.basename(adv_ds)] \n",
    "                     for adv_dataset in adv_datasets}\n",
    "\n",
    "    for adv_atk in adv_ds_dict.keys():\n",
    "        print(f\"> Evaluating on {adv_atk}\")\n",
    "\n",
    "        probits = list()\n",
    "\n",
    "        for runid in adv_ds_dict[adv_atk]:\n",
    "            # load adversarial samples and map to [0, 1]\n",
    "            images = torch.load(os.path.join(adv_path, f\"{adv_atk}_{runid}.pt\"), map_location=\"cpu\").to(torch.float32) / 255\n",
    "\n",
    "            # apply cifar normalization\n",
    "            images = (images - torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1)) / torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1)\n",
    "\n",
    "            dataset = TensorDataset(images, torch.zeros(size=(len(images), )).long())\n",
    "\n",
    "            # evaluate\n",
    "            probits_, _ = evaluate(networks[runid * (len(networks) // runs):((runid + 1) * (len(networks) // runs))], dataset)\n",
    "            probits.append(probits_)\n",
    "        probits = torch.cat(probits, dim=1)\n",
    "        print(probits.shape)\n",
    "        torch.save(probits.to(torch.float16), os.path.join(adv_path, f\"{adv_atk}_probits.pt\"))\n",
    "else:\n",
    "    print(\"No adversarial examples found, use `generate_adversarial_examples.py` to generate adversarial examples.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
