{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "import torchvision as tv\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from foolbox import PyTorchModel, accuracy, samples\n",
    "from foolbox.attacks import FGM, FGSM, L1PGD, L2PGD, LinfPGD, L2CarliniWagnerAttack\n",
    "from foolbox.criteria import Misclassification\n",
    "\n",
    "from source.constants import RESULTS_PATH, CIFAR_MEAN, CIFAR_STD\n",
    "from source.utils.train_utils import evaluate\n",
    "from source.utils.metrics import Accuracy\n",
    "from source.networks.resnet import get_resnet18, get_resnet34, get_resnet50\n",
    "from utils import load_test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "\n",
    "dataset_names = [\"cifar10\", \"cifar100\", \"svhn\", \"tin\", \"lsun\"]\n",
    "n_classes = [10, 100, 10, 10, 200]\n",
    "models = [\"resnet18\", \"resnet34\", \"resnet50\"]\n",
    "\n",
    "dataset_name = dataset_names[0]    # select dataset\n",
    "model = models[0]                   # select model\n",
    "\n",
    "# infer number of classes from dataset\n",
    "n_class = n_classes[dataset_names.index(dataset_name)]\n",
    "\n",
    "device = \"cuda:0\"\n",
    "batch_size = 2048 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join(RESULTS_PATH, f\"{dataset_name}_{model}_seed{seed}\")\n",
    "\n",
    "model_files = glob.glob(os.path.join(path, \"models\", \"*.pt\"))\n",
    "model_file = sorted(model_files)[0]\n",
    "\n",
    "print(model_file)\n",
    "\n",
    "if model == \"resnet18\":\n",
    "    network = get_resnet18(num_classes=n_class) \n",
    "elif model == \"resnet34\":\n",
    "    network = get_resnet34(num_classes=n_class) \n",
    "elif model == \"resnet50\":\n",
    "    network = get_resnet50(num_classes=n_class) \n",
    "\n",
    "network.load_state_dict(torch.load(model_file, map_location=device))\n",
    "network.eval()\n",
    "network.to(device)\n",
    "\n",
    "dataset = load_test_dataset(dataset_name, custom_transform=tv.transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Clean Accuracy:  0.94140625\n"
     ]
    }
   ],
   "source": [
    "preprocessing = dict(mean=CIFAR_MEAN, std=CIFAR_STD, axis=-3)\n",
    "fnet = PyTorchModel(network, bounds=(0, 1), preprocessing=preprocessing)\n",
    "\n",
    "test_loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=4)\n",
    "\n",
    "images, labels = next(iter(test_loader))\n",
    "images, labels = images.to(device), labels.to(device)\n",
    "\n",
    "clean_acc = accuracy(fnet, images, labels)\n",
    "print(\"Clean Accuracy: \", clean_acc)\n",
    "# print(\"Clean Accuracy: \", evaluate(network, test_loader, Accuracy()))\n",
    "\n",
    "attack = LinfPGD()\n",
    "\n",
    "epsilons = [\n",
    "        0.0,\n",
    "        0.0002,\n",
    "        0.0005,\n",
    "        0.0008,\n",
    "        0.001,\n",
    "        0.0015,\n",
    "        0.002,\n",
    "        0.003,\n",
    "        0.01,\n",
    "        0.1,\n",
    "        0.3,\n",
    "        0.5,\n",
    "        1.0,\n",
    "    ]\n",
    "\n",
    "raw_advs, clipped_advs, success = attack(fnet, images, epsilons=epsilons, criterion=Misclassification(labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "robust accuracy for perturbations with\n",
      "  Linf norm ≤ 0.0   : 94.1 %\n",
      "  Linf norm ≤ 0.0002: 93.4 %\n",
      "  Linf norm ≤ 0.0005: 91.8 %\n",
      "  Linf norm ≤ 0.0008: 89.8 %\n",
      "  Linf norm ≤ 0.001 : 88.3 %\n",
      "  Linf norm ≤ 0.0015: 83.6 %\n",
      "  Linf norm ≤ 0.002 : 79.7 %\n",
      "  Linf norm ≤ 0.003 : 68.0 %\n",
      "  Linf norm ≤ 0.01  : 14.8 %\n",
      "  Linf norm ≤ 0.1   :  0.0 %\n",
      "  Linf norm ≤ 0.3   :  0.0 %\n",
      "  Linf norm ≤ 0.5   :  0.0 %\n",
      "  Linf norm ≤ 1.0   :  0.0 %\n"
     ]
    }
   ],
   "source": [
    "robust_accuracy = 1 - success.float().mean(axis=-1)\n",
    "print(\"robust accuracy for perturbations with\")\n",
    "\n",
    "for eps, acc in zip(epsilons, robust_accuracy):\n",
    "    print(f\"  Linf norm ≤ {eps:<6}: {acc.item() * 100:4.1f} %\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([256, 3, 32, 32])\n"
     ]
    }
   ],
   "source": [
    "print(raw_advs[0].shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
