{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NM9ZLM6Gr73r"
      },
      "source": [
        "# Preparation"
      ],
      "id": "NM9ZLM6Gr73r"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download & Extract Adaptive Exposures"
      ],
      "metadata": {
        "id": "WMTrhdZi4nIM"
      },
      "id": "WMTrhdZi4nIM"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RUTv-4TqoNTp"
      },
      "outputs": [],
      "source": [
        "import gdown\n",
        "url = \"https://drive.google.com/drive/folders/1zXEwDiB1EUhG4-n4x3LYhbffOKYCbiEc?usp=sharing\"\n",
        "gdown.download_folder(url, quiet=True, use_cookies=False)"
      ],
      "id": "RUTv-4TqoNTp"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download OOD Datasets"
      ],
      "metadata": {
        "id": "K_SchgWV4qQ9"
      },
      "id": "K_SchgWV4qQ9"
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "!mkdir data\n",
        "# Places\n",
        "!wget https://dl.dropboxusercontent.com/s/3pwqsyv33f6if3z/val_256.tar\n",
        "!tar -xf val_256.tar -C ./data\n",
        "%cd data\n",
        "!wget https://dl.dropboxusercontent.com/s/gaf1ygpdnkhzyjo/places365_val.txt\n",
        "!wget https://dl.dropboxusercontent.com/s/enr71zpolzi1xzm/categories_places365.txt\n",
        "%cd ..\n",
        "# COIL\n",
        "!mkdir data/coil\n",
        "!wget http://www.cs.columbia.edu/CAVE/databases/SLAM_coil-20_coil-100/coil-100/coil-100.zip\n",
        "!unzip coil-100.zip -d ./data\n",
        "!mkdir data/coil\n",
        "!cp -r data/coil-100 data/coil\n",
        "# LSUN\n",
        "!wget https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz\n",
        "!tar -xf LSUN_resize.tar.gz -C ./data\n",
        "%cd data\n",
        "# iSUN\n",
        "!wget https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz\n",
        "!tar -xf iSUN.tar.gz\n",
        "# Birds\n",
        "!wget https://www.dropbox.com/s/yc6kz6ld56q836c/images.tgz\n",
        "!tar -xf images.tgz\n",
        "# Flowers\n",
        "!wget https://dl.dropboxusercontent.com/s/hbt8e7wjiplryoo/102flowers.tgz\n",
        "!tar -xf 102flowers.tgz\n",
        "!mv jpg flowers\n",
        "!mkdir flowers/fld\n",
        "import os\n",
        "import shutil\n",
        "\n",
        "# Source and destination folder paths\n",
        "src_folder = './flowers'\n",
        "dst_folder = './flowers/fld'\n",
        "\n",
        "# Copy all files from the source folder to the destination folder\n",
        "for filename in os.listdir(src_folder):\n",
        "    # Construct the full file paths\n",
        "    src_file = os.path.join(src_folder, filename)\n",
        "    dst_file = os.path.join(dst_folder, filename)\n",
        "\n",
        "    # Copy the file to the destination folder if it's a file (not a folder)\n",
        "    if os.path.isfile(src_file):\n",
        "        shutil.copy(src_file, dst_file)\n",
        "# Tiny Image Net\n",
        "!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip\n",
        "!unzip tiny-imagenet-200.zip\n",
        "%cd .."
      ],
      "metadata": {
        "id": "MrM6BzbH4tH0"
      },
      "id": "MrM6BzbH4tH0",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Use CUDA"
      ],
      "metadata": {
        "id": "Ejrqp7GUG-vz"
      },
      "id": "Ejrqp7GUG-vz"
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi"
      ],
      "metadata": {
        "id": "Yvbj5ntSHAHJ"
      },
      "id": "Yvbj5ntSHAHJ",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
        "print('Using device:', device)"
      ],
      "metadata": {
        "id": "JYl4z9zPHCYC"
      },
      "id": "JYl4z9zPHCYC",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3fda854f"
      },
      "source": [
        "# Configurations"
      ],
      "id": "3fda854f"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dc34608a"
      },
      "outputs": [],
      "source": [
        "from torchvision import transforms\n",
        "\n",
        "in_dataset = ['cifar10', 'cifar100'][0]\n",
        "\n",
        "epochs = 10\n",
        "optim = \"adam\"\n",
        "\n",
        "lr = 0.001\n",
        "\n",
        "avialable_datasets = ['cifar10', 'cifar100', 'mnist', 'places', 'coil', 'LSUN', 'iSUN', 'flowers', 'birds', 'tiny_imagenet']\n",
        "out_dataset = avialable_datasets[1]\n",
        "\n",
        "attack_eps = 8/255\n",
        "attack_steps = 10\n",
        "attack_alpha = 2.5 * attack_eps / attack_steps\n",
        "num_classes = {\n",
        "    'cifar10': 10,\n",
        "    'cifar100': 20\n",
        "}[in_dataset]\n",
        "all_num_classes = num_classes\n",
        "\n",
        "batch_size = 128"
      ],
      "id": "dc34608a"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "44925dca"
      },
      "source": [
        "# Model"
      ],
      "id": "44925dca"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "86692533"
      },
      "source": [
        "## Wide Resnet"
      ],
      "id": "86692533"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3960f82e"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "\n",
        "class BasicBlock(nn.Module):\n",
        "    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):\n",
        "        super(BasicBlock, self).__init__()\n",
        "        self.bn1 = nn.BatchNorm2d(in_planes)\n",
        "        self.relu1 = nn.ReLU(inplace=True)\n",
        "        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
        "                               padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(out_planes)\n",
        "        self.relu2 = nn.ReLU(inplace=True)\n",
        "        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,\n",
        "                               padding=1, bias=False)\n",
        "        self.droprate = dropRate\n",
        "        self.equalInOut = (in_planes == out_planes)\n",
        "        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,\n",
        "                               padding=0, bias=False) or None\n",
        "    def forward(self, x):\n",
        "        if not self.equalInOut:\n",
        "            x = self.relu1(self.bn1(x))\n",
        "        else:\n",
        "            out = self.relu1(self.bn1(x))\n",
        "        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))\n",
        "        if self.droprate > 0:\n",
        "            out = torch.nn.functional.dropout(out, p=self.droprate, training=self.training)\n",
        "        out = self.conv2(out)\n",
        "        return torch.add(x if self.equalInOut else self.convShortcut(x), out)\n",
        "\n",
        "class NetworkBlock(nn.Module):\n",
        "    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):\n",
        "        super(NetworkBlock, self).__init__()\n",
        "        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)\n",
        "    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):\n",
        "        layers = []\n",
        "        for i in range(int(nb_layers)):\n",
        "            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))\n",
        "        return nn.Sequential(*layers)\n",
        "    def forward(self, x):\n",
        "        return self.layer(x)\n",
        "\n",
        "class WideResNet(nn.Module):\n",
        "    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):\n",
        "        super(WideResNet, self).__init__()\n",
        "        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]\n",
        "        assert((depth - 4) % 6 == 0)\n",
        "        n = (depth - 4) / 6\n",
        "        block = BasicBlock\n",
        "        # 1st conv before any network block\n",
        "        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,\n",
        "                               padding=1, bias=False)\n",
        "        # 1st block\n",
        "        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)\n",
        "        # 2nd block\n",
        "        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)\n",
        "        # 3rd block\n",
        "        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)\n",
        "        # global average pooling and classifier\n",
        "        self.bn1 = nn.BatchNorm2d(nChannels[3])\n",
        "        self.relu = nn.ReLU(inplace=True)\n",
        "        self.fc = nn.Linear(nChannels[3], num_classes)\n",
        "        self.nChannels = nChannels[3]\n",
        "\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Conv2d):\n",
        "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
        "            elif isinstance(m, nn.BatchNorm2d):\n",
        "                m.weight.data.fill_(1)\n",
        "                m.bias.data.zero_()\n",
        "            elif isinstance(m, nn.Linear):\n",
        "                m.bias.data.zero_()\n",
        "    def forward(self, x):\n",
        "        out = self.conv1(x)\n",
        "        out = self.block1(out)\n",
        "        out = self.block2(out)\n",
        "        out = self.block3(out)\n",
        "        out = self.relu(self.bn1(out))\n",
        "        out = torch.nn.functional.avg_pool2d(out, 8)\n",
        "        out = out.view(-1, self.nChannels)\n",
        "        return self.fc(out)"
      ],
      "id": "3960f82e"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "16e9cfbc"
      },
      "source": [
        "# Attack"
      ],
      "id": "16e9cfbc"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d0cdf735"
      },
      "source": [
        "## Base Attack"
      ],
      "id": "d0cdf735"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a4a9862a"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import logging\n",
        "from collections import OrderedDict\n",
        "from collections.abc import Iterable\n",
        "\n",
        "import torch\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "\n",
        "\n",
        "def wrapper_method(func):\n",
        "    def wrapper_func(self, *args, **kwargs):\n",
        "        result = func(self, *args, **kwargs)\n",
        "        for atk in self.__dict__.get('_attacks').values():\n",
        "            eval(\"atk.\"+func.__name__+\"(*args, **kwargs)\")\n",
        "        return result\n",
        "    return wrapper_func\n",
        "\n",
        "\n",
        "class Attack(object):\n",
        "    r\"\"\"\n",
        "    Base class for all attacks.\n",
        "\n",
        "    .. note::\n",
        "        It automatically set device to the device where given model is.\n",
        "        It basically changes training mode to eval during attack process.\n",
        "        To change this, please see `set_model_training_mode`.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, name, model):\n",
        "        r\"\"\"\n",
        "        Initializes internal attack state.\n",
        "\n",
        "        Arguments:\n",
        "            name (str): name of attack.\n",
        "            model (torch.nn.Module): model to attack.\n",
        "        \"\"\"\n",
        "\n",
        "        self.attack = name\n",
        "        self._attacks = OrderedDict()\n",
        "\n",
        "        self.set_model(model)\n",
        "        self.device = next(model.parameters()).device\n",
        "\n",
        "        # Controls attack mode.\n",
        "        self.attack_mode = 'default'\n",
        "        self.supported_mode = ['default']\n",
        "        self.targeted = False\n",
        "        self._target_map_function = None\n",
        "\n",
        "        # Controls when normalization is used.\n",
        "        self.normalization_used = {}\n",
        "        self._normalization_applied = False\n",
        "        self._set_auto_normalization_used(model)\n",
        "\n",
        "        # Controls model mode during attack.\n",
        "        self._model_training = False\n",
        "        self._batchnorm_training = False\n",
        "        self._dropout_training = False\n",
        "\n",
        "    def forward(self, inputs, labels=None, *args, **kwargs):\n",
        "        r\"\"\"\n",
        "        It defines the computation performed at every call.\n",
        "        Should be overridden by all subclasses.\n",
        "        \"\"\"\n",
        "        raise NotImplementedError\n",
        "\n",
        "    def _check_inputs(self, images):\n",
        "        tol = 1e-4\n",
        "        if self._normalization_applied:\n",
        "            images = self.inverse_normalize(images)\n",
        "        if torch.max(images) > 1+tol or torch.min(images) < 0-tol:\n",
        "            raise ValueError('Input must have a range [0, 1] (max: {}, min: {})'.format(\n",
        "                torch.max(images), torch.min(images)))\n",
        "        return images\n",
        "\n",
        "    def _check_outputs(self, images):\n",
        "        if self._normalization_applied:\n",
        "            images = self.normalize(images)\n",
        "        return images\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_model(self, model):\n",
        "        self.model = model\n",
        "        self.model_name = model.__class__.__name__\n",
        "\n",
        "    def get_logits(self, inputs, labels=None, *args, **kwargs):\n",
        "        if self._normalization_applied:\n",
        "            inputs = self.normalize(inputs)\n",
        "        logits = self.model(inputs)\n",
        "        return logits\n",
        "\n",
        "    @wrapper_method\n",
        "    def _set_normalization_applied(self, flag):\n",
        "        self._normalization_applied = flag\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_device(self, device):\n",
        "        self.device = device\n",
        "\n",
        "    @wrapper_method\n",
        "    def _set_auto_normalization_used(self, model):\n",
        "        if model.__class__.__name__ == 'RobModel':\n",
        "            mean = getattr(model, 'mean', None)\n",
        "            std = getattr(model, 'std', None)\n",
        "            if (mean is not None) and (std is not None):\n",
        "                if isinstance(mean, torch.Tensor):\n",
        "                    mean = mean.cpu().numpy()\n",
        "                if isinstance(std, torch.Tensor):\n",
        "                    std = std.cpu().numpy()\n",
        "                if (mean != 0).all() or (std != 1).all():\n",
        "                    self.set_normalization_used(mean, std)\n",
        "    #                 logging.info(\"Normalization automatically loaded from `model.mean` and `model.std`.\")\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_normalization_used(self, mean, std):\n",
        "        n_channels = len(mean)\n",
        "        mean = torch.tensor(mean).reshape(1, n_channels, 1, 1)\n",
        "        std = torch.tensor(std).reshape(1, n_channels, 1, 1)\n",
        "        self.normalization_used['mean'] = mean\n",
        "        self.normalization_used['std'] = std\n",
        "        self._normalization_applied = True\n",
        "\n",
        "    def normalize(self, inputs):\n",
        "        mean = self.normalization_used['mean'].to(inputs.device)\n",
        "        std = self.normalization_used['std'].to(inputs.device)\n",
        "        return (inputs - mean) / std\n",
        "\n",
        "    def inverse_normalize(self, inputs):\n",
        "        mean = self.normalization_used['mean'].to(inputs.device)\n",
        "        std = self.normalization_used['std'].to(inputs.device)\n",
        "        return inputs*std + mean\n",
        "\n",
        "    def get_mode(self):\n",
        "        r\"\"\"\n",
        "        Get attack mode.\n",
        "\n",
        "        \"\"\"\n",
        "        return self.attack_mode\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_mode_default(self):\n",
        "        r\"\"\"\n",
        "        Set attack mode as default mode.\n",
        "\n",
        "        \"\"\"\n",
        "        self.attack_mode = 'default'\n",
        "        self.targeted = False\n",
        "        print(\"Attack mode is changed to 'default.'\")\n",
        "\n",
        "    @wrapper_method\n",
        "    def _set_mode_targeted(self, mode, quiet):\n",
        "        if \"targeted\" not in self.supported_mode:\n",
        "            raise ValueError(\"Targeted mode is not supported.\")\n",
        "        self.targeted = True\n",
        "        self.attack_mode = mode\n",
        "        if not quiet:\n",
        "            print(\"Attack mode is changed to '%s'.\" % mode)\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_mode_targeted_by_function(self, target_map_function, quiet=False):\n",
        "        r\"\"\"\n",
        "        Set attack mode as targeted.\n",
        "\n",
        "        Arguments:\n",
        "            target_map_function (function): Label mapping function.\n",
        "                e.g. lambda inputs, labels:(labels+1)%10.\n",
        "                None for using input labels as targeted labels. (Default)\n",
        "\n",
        "        \"\"\"\n",
        "        self._set_mode_targeted('targeted(custom)', quiet)\n",
        "        self._target_map_function = target_map_function\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_mode_targeted_random(self, quiet=False):\n",
        "        r\"\"\"\n",
        "        Set attack mode as targeted with random labels.\n",
        "\n",
        "        Arguments:\n",
        "            num_classses (str): number of classes.\n",
        "\n",
        "        \"\"\"\n",
        "        self._set_mode_targeted('targeted(random)', quiet)\n",
        "        self._target_map_function = self.get_random_target_label\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_mode_targeted_least_likely(self, kth_min=1, quiet=False):\n",
        "        r\"\"\"\n",
        "        Set attack mode as targeted with least likely labels.\n",
        "\n",
        "        Arguments:\n",
        "            kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1)\n",
        "\n",
        "        \"\"\"\n",
        "        self._set_mode_targeted('targeted(least-likely)', quiet)\n",
        "        assert (kth_min > 0)\n",
        "        self._kth_min = kth_min\n",
        "        self._target_map_function = self.get_least_likely_label\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_mode_targeted_by_label(self, quiet=False):\n",
        "        r\"\"\"\n",
        "        Set attack mode as targeted.\n",
        "\n",
        "        .. note::\n",
        "            Use user-supplied labels as target labels.\n",
        "        \"\"\"\n",
        "        self._set_mode_targeted('targeted(label)', quiet)\n",
        "        self._target_map_function = 'function is a string'\n",
        "\n",
        "    @wrapper_method\n",
        "    def set_model_training_mode(self, model_training=False, batchnorm_training=False, dropout_training=False):\n",
        "        r\"\"\"\n",
        "        Set training mode during attack process.\n",
        "\n",
        "        Arguments:\n",
        "            model_training (bool): True for using training mode for the entire model during attack process.\n",
        "            batchnorm_training (bool): True for using training mode for batchnorms during attack process.\n",
        "            dropout_training (bool): True for using training mode for dropouts during attack process.\n",
        "\n",
        "        .. note::\n",
        "            For RNN-based models, we cannot calculate gradients with eval mode.\n",
        "            Thus, it should be changed to the training mode during the attack.\n",
        "        \"\"\"\n",
        "        self._model_training = model_training\n",
        "        self._batchnorm_training = batchnorm_training\n",
        "        self._dropout_training = dropout_training\n",
        "\n",
        "    @wrapper_method\n",
        "    def _change_model_mode(self, given_training):\n",
        "        if self._model_training:\n",
        "            self.model.train()\n",
        "            for _, m in self.model.named_modules():\n",
        "                if not self._batchnorm_training:\n",
        "                    if 'BatchNorm' in m.__class__.__name__:\n",
        "                        m = m.eval()\n",
        "                if not self._dropout_training:\n",
        "                    if 'Dropout' in m.__class__.__name__:\n",
        "                        m = m.eval()\n",
        "        else:\n",
        "            self.model.eval()\n",
        "\n",
        "    @wrapper_method\n",
        "    def _recover_model_mode(self, given_training):\n",
        "        if given_training:\n",
        "            self.model.train()\n",
        "\n",
        "    def save(self, data_loader, save_path=None, verbose=True, return_verbose=False,\n",
        "             save_predictions=False, save_clean_inputs=False, save_type='float'):\n",
        "        r\"\"\"\n",
        "        Save adversarial inputs as torch.tensor from given torch.utils.data.DataLoader.\n",
        "\n",
        "        Arguments:\n",
        "            save_path (str): save_path.\n",
        "            data_loader (torch.utils.data.DataLoader): data loader.\n",
        "            verbose (bool): True for displaying detailed information. (Default: True)\n",
        "            return_verbose (bool): True for returning detailed information. (Default: False)\n",
        "            save_predictions (bool): True for saving predicted labels (Default: False)\n",
        "            save_clean_inputs (bool): True for saving clean inputs (Default: False)\n",
        "\n",
        "        \"\"\"\n",
        "        if save_path is not None:\n",
        "            adv_input_list = []\n",
        "            label_list = []\n",
        "            if save_predictions:\n",
        "                pred_list = []\n",
        "            if save_clean_inputs:\n",
        "                input_list = []\n",
        "\n",
        "        correct = 0\n",
        "        total = 0\n",
        "        l2_distance = []\n",
        "\n",
        "        total_batch = len(data_loader)\n",
        "        given_training = self.model.training\n",
        "\n",
        "        for step, (inputs, labels) in enumerate(data_loader):\n",
        "            start = time.time()\n",
        "            adv_inputs = self.__call__(inputs, labels)\n",
        "            batch_size = len(inputs)\n",
        "\n",
        "            if verbose or return_verbose:\n",
        "                with torch.no_grad():\n",
        "                    outputs = self.get_output_with_eval_nograd(adv_inputs)\n",
        "\n",
        "                    # Calculate robust accuracy\n",
        "                    _, pred = torch.max(outputs.data, 1)\n",
        "                    total += labels.size(0)\n",
        "                    right_idx = (pred == labels.to(self.device))\n",
        "                    correct += right_idx.sum()\n",
        "                    rob_acc = 100 * float(correct) / total\n",
        "\n",
        "                    # Calculate l2 distance\n",
        "                    delta = (adv_inputs - inputs.to(self.device)).view(batch_size, -1)  # nopep8\n",
        "                    l2_distance.append(torch.norm(delta[~right_idx], p=2, dim=1))  # nopep8\n",
        "                    l2 = torch.cat(l2_distance).mean().item()\n",
        "\n",
        "                    # Calculate time computation\n",
        "                    progress = (step+1)/total_batch*100\n",
        "                    end = time.time()\n",
        "                    elapsed_time = end-start\n",
        "\n",
        "                    if verbose:\n",
        "                        self._save_print(progress, rob_acc, l2, elapsed_time, end='\\r')  # nopep8\n",
        "\n",
        "            if save_path is not None:\n",
        "                adv_input_list.append(adv_inputs.detach().cpu())\n",
        "                label_list.append(labels.detach().cpu())\n",
        "\n",
        "                adv_input_list_cat = torch.cat(adv_input_list, 0)\n",
        "                label_list_cat = torch.cat(label_list, 0)\n",
        "\n",
        "                save_dict = {'adv_inputs': adv_input_list_cat, 'labels': label_list_cat}  # nopep8\n",
        "\n",
        "                if save_predictions:\n",
        "                    pred_list.append(pred.detach().cpu())\n",
        "                    pred_list_cat = torch.cat(pred_list, 0)\n",
        "                    save_dict['preds'] = pred_list_cat\n",
        "\n",
        "                if save_clean_inputs:\n",
        "                    input_list.append(inputs.detach().cpu())\n",
        "                    input_list_cat = torch.cat(input_list, 0)\n",
        "                    save_dict['clean_inputs'] = input_list_cat\n",
        "\n",
        "                if self.normalization_used is not None:\n",
        "                    save_dict['adv_inputs'] = self.inverse_normalize(save_dict['adv_inputs'])  # nopep8\n",
        "                    if save_clean_inputs:\n",
        "                        save_dict['clean_inputs'] = self.inverse_normalize(save_dict['clean_inputs'])  # nopep8\n",
        "\n",
        "                if save_type == 'int':\n",
        "                    save_dict['adv_inputs'] = self.to_type(save_dict['adv_inputs'], 'int')  # nopep8\n",
        "                    if save_clean_inputs:\n",
        "                        save_dict['clean_inputs'] = self.to_type(save_dict['clean_inputs'], 'int')  # nopep8\n",
        "\n",
        "                save_dict['save_type'] = save_type\n",
        "                torch.save(save_dict, save_path)\n",
        "\n",
        "        # To avoid erasing the printed information.\n",
        "        if verbose:\n",
        "            self._save_print(progress, rob_acc, l2, elapsed_time, end='\\n')\n",
        "\n",
        "        if given_training:\n",
        "            self.model.train()\n",
        "\n",
        "        if return_verbose:\n",
        "            return rob_acc, l2, elapsed_time\n",
        "\n",
        "    @staticmethod\n",
        "    def to_type(inputs, type):\n",
        "        r\"\"\"\n",
        "        Return inputs as int if float is given.\n",
        "        \"\"\"\n",
        "        if type == 'int':\n",
        "            if isinstance(inputs, torch.FloatTensor) or isinstance(inputs, torch.cuda.FloatTensor):\n",
        "                return (inputs*255).type(torch.uint8)\n",
        "        elif type == 'float':\n",
        "            if isinstance(inputs, torch.ByteTensor) or isinstance(inputs, torch.cuda.ByteTensor):\n",
        "                return inputs.float()/255\n",
        "        else:\n",
        "            raise ValueError(\n",
        "                type + \" is not a valid type. [Options: float, int]\")\n",
        "        return inputs\n",
        "\n",
        "    @staticmethod\n",
        "    def _save_print(progress, rob_acc, l2, elapsed_time, end):\n",
        "        print('- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \\t'\n",
        "              % (progress, rob_acc, l2, elapsed_time), end=end)\n",
        "\n",
        "    @staticmethod\n",
        "    def load(load_path, batch_size=128, shuffle=False, normalize=None,\n",
        "             load_predictions=False, load_clean_inputs=False):\n",
        "        save_dict = torch.load(load_path)\n",
        "        keys = ['adv_inputs', 'labels']\n",
        "\n",
        "        if load_predictions:\n",
        "            keys.append('preds')\n",
        "        if load_clean_inputs:\n",
        "            keys.append('clean_inputs')\n",
        "\n",
        "        if save_dict['save_type'] == 'int':\n",
        "            save_dict['adv_inputs'] = save_dict['adv_inputs'].float()/255\n",
        "            if load_clean_inputs:\n",
        "                save_dict['clean_inputs'] = save_dict['clean_inputs'].float() / 255  # nopep8\n",
        "\n",
        "        if normalize is not None:\n",
        "            n_channels = len(normalize['mean'])\n",
        "            mean = torch.tensor(normalize['mean']).reshape(1, n_channels, 1, 1)\n",
        "            std = torch.tensor(normalize['std']).reshape(1, n_channels, 1, 1)\n",
        "            save_dict['adv_inputs'] = (save_dict['adv_inputs'] - mean) / std\n",
        "            if load_clean_inputs:\n",
        "                save_dict['clean_inputs'] = (save_dict['clean_inputs'] - mean) / std  # nopep8\n",
        "\n",
        "        adv_data = TensorDataset(*[save_dict[key] for key in keys])\n",
        "        adv_loader = DataLoader(\n",
        "            adv_data, batch_size=batch_size, shuffle=shuffle)\n",
        "        print(\"Data is loaded in the following order: [%s]\" % (\", \".join(keys)))  # nopep8\n",
        "        return adv_loader\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_output_with_eval_nograd(self, inputs):\n",
        "        given_training = self.model.training\n",
        "        if given_training:\n",
        "            self.model.eval()\n",
        "        outputs = self.get_logits(inputs)\n",
        "        if given_training:\n",
        "            self.model.train()\n",
        "        return outputs\n",
        "\n",
        "    def get_target_label(self, inputs, labels=None):\n",
        "        r\"\"\"\n",
        "        Function for changing the attack mode.\n",
        "        Return input labels.\n",
        "        \"\"\"\n",
        "        if self._target_map_function is None:\n",
        "            raise ValueError(\n",
        "                'target_map_function is not initialized by set_mode_targeted.')\n",
        "        if self.attack_mode == 'targeted(label)':\n",
        "            target_labels = labels\n",
        "        else:\n",
        "            target_labels = self._target_map_function(inputs, labels)\n",
        "        return target_labels\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_least_likely_label(self, inputs, labels=None):\n",
        "        outputs = self.get_output_with_eval_nograd(inputs)\n",
        "        if labels is None:\n",
        "            _, labels = torch.max(outputs, dim=1)\n",
        "        n_classses = outputs.shape[-1]\n",
        "\n",
        "        target_labels = torch.zeros_like(labels)\n",
        "        for counter in range(labels.shape[0]):\n",
        "            l = list(range(n_classses))\n",
        "            l.remove(labels[counter])\n",
        "            _, t = torch.kthvalue(outputs[counter][l], self._kth_min)\n",
        "            target_labels[counter] = l[t]\n",
        "\n",
        "        return target_labels.long().to(self.device)\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_random_target_label(self, inputs, labels=None):\n",
        "        outputs = self.get_output_with_eval_nograd(inputs)\n",
        "        if labels is None:\n",
        "            _, labels = torch.max(outputs, dim=1)\n",
        "        n_classses = outputs.shape[-1]\n",
        "\n",
        "        target_labels = torch.zeros_like(labels)\n",
        "        for counter in range(labels.shape[0]):\n",
        "            l = list(range(n_classses))\n",
        "            l.remove(labels[counter])\n",
        "            t = (len(l)*torch.rand([1])).long().to(self.device)\n",
        "            target_labels[counter] = l[t]\n",
        "\n",
        "        return target_labels.long().to(self.device)\n",
        "\n",
        "    def __call__(self, images, labels=None, *args, **kwargs):\n",
        "        given_training = self.model.training\n",
        "        self._change_model_mode(given_training)\n",
        "        images = self._check_inputs(images)\n",
        "        adv_images = self.forward(images, labels, *args, **kwargs)\n",
        "        adv_images = self._check_outputs(adv_images)\n",
        "        self._recover_model_mode(given_training)\n",
        "        return adv_images\n",
        "\n",
        "    def __repr__(self):\n",
        "        info = self.__dict__.copy()\n",
        "\n",
        "        del_keys = ['model', 'attack', 'supported_mode']\n",
        "\n",
        "        for key in info.keys():\n",
        "            if key[0] == \"_\":\n",
        "                del_keys.append(key)\n",
        "\n",
        "        for key in del_keys:\n",
        "            del info[key]\n",
        "\n",
        "        info['attack_mode'] = self.attack_mode\n",
        "        info['normalization_used'] = True if len(self.normalization_used) > 0 else False  # nopep8\n",
        "\n",
        "        return self.attack + \"(\" + ', '.join('{}={}'.format(key, val) for key, val in info.items()) + \")\"\n",
        "\n",
        "    def __setattr__(self, name, value):\n",
        "        object.__setattr__(self, name, value)\n",
        "\n",
        "        attacks = self.__dict__.get('_attacks')\n",
        "\n",
        "        # Get all items in iterable items.\n",
        "        def get_all_values(items, stack=[]):\n",
        "            if (items not in stack):\n",
        "                stack.append(items)\n",
        "                if isinstance(items, list) or isinstance(items, dict):\n",
        "                    if isinstance(items, dict):\n",
        "                        items = (list(items.keys())+list(items.values()))\n",
        "                    for item in items:\n",
        "                        yield from get_all_values(item, stack)\n",
        "                else:\n",
        "                    if isinstance(items, Attack):\n",
        "                        yield items\n",
        "            else:\n",
        "                if isinstance(items, Attack):\n",
        "                    yield items\n",
        "\n",
        "        for num, value in enumerate(get_all_values(value)):\n",
        "            attacks[name+\".\"+str(num)] = value\n",
        "            for subname, subvalue in value.__dict__.get('_attacks').items():\n",
        "                attacks[name+\".\"+subname] = subvalue"
      ],
      "id": "a4a9862a"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "36ce6a09"
      },
      "source": [
        "## PGD"
      ],
      "id": "36ce6a09"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1f6aaa80"
      },
      "outputs": [],
      "source": [
        "\n",
        "class PGD_CLS(Attack):\n",
        "    r\"\"\"\n",
        "    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'\n",
        "    [https://arxiv.org/abs/1706.06083]\n",
        "    Distance Measure : Linf\n",
        "    Arguments:\n",
        "        model (nn.Module): model to attack.\n",
        "        eps (float): maximum perturbation. (Default: 8/255)\n",
        "        alpha (float): step size. (Default: 2/255)\n",
        "        steps (int): number of steps. (Default: 10)\n",
        "        random_start (bool): using random initialization of delta. (Default: True)\n",
        "    Shape:\n",
        "        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].\n",
        "        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \\leq y_i \\leq` `number of labels`.\n",
        "        - output: :math:`(N, C, H, W)`.\n",
        "    Examples::\n",
        "        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True)\n",
        "        >>> adv_images = attack(images, labels)\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True):\n",
        "        super().__init__(\"PGD\", model)\n",
        "        self.eps = eps\n",
        "        self.alpha = alpha\n",
        "        self.steps = steps\n",
        "        self.random_start = random_start\n",
        "        self.supported_mode = ['default', 'targeted']\n",
        "\n",
        "    def forward(self, images, labels):\n",
        "        r\"\"\"\n",
        "        Overridden.\n",
        "        \"\"\"\n",
        "\n",
        "        images = images.clone().detach().to(self.device)\n",
        "        labels = labels.clone().detach().to(self.device)\n",
        "\n",
        "        if self.targeted:\n",
        "            target_labels = self.get_target_label(images, labels)\n",
        "\n",
        "        loss = nn.CrossEntropyLoss()\n",
        "        adv_images = images.clone().detach()\n",
        "\n",
        "        if self.random_start:\n",
        "            # Starting at a uniformly random point\n",
        "            adv_images = adv_images + \\\n",
        "                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n",
        "            adv_images = torch.clamp(adv_images, min=0, max=1).detach()\n",
        "\n",
        "        for _ in range(self.steps):\n",
        "            adv_images.requires_grad = True\n",
        "            outputs = self.get_logits(adv_images)\n",
        "\n",
        "            # Calculate loss\n",
        "            if self.targeted:\n",
        "                cost = -loss(outputs, target_labels)\n",
        "            else:\n",
        "                cost = loss(outputs, labels)\n",
        "\n",
        "            # Update adversarial images\n",
        "            grad = torch.autograd.grad(cost, adv_images,\n",
        "                                       retain_graph=False, create_graph=False)[0]\n",
        "\n",
        "            adv_images = adv_images.detach() + self.alpha*grad.sign()\n",
        "            delta = torch.clamp(adv_images - images,\n",
        "                                min=-self.eps, max=self.eps)\n",
        "            adv_images = torch.clamp(images + delta, min=0, max=1).detach()\n",
        "\n",
        "        return  adv_images"
      ],
      "id": "1f6aaa80"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "636ff745"
      },
      "outputs": [],
      "source": [
        "\n",
        "class PGD_TEST(Attack):\n",
        "    r\"\"\"\n",
        "    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'\n",
        "    [https://arxiv.org/abs/1706.06083]\n",
        "\n",
        "    Distance Measure : Linf\n",
        "\n",
        "    Arguments:\n",
        "        model (nn.Module): model to attack.\n",
        "        eps (float): maximum perturbation. (Default: 8/255)\n",
        "        alpha (float): step size. (Default: 2/255)\n",
        "        steps (int): number of steps. (Default: 10)\n",
        "        random_start (bool): using random initialization of delta. (Default: True)\n",
        "\n",
        "    Shape:\n",
        "        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].\n",
        "        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \\leq y_i \\leq` `number of labels`.\n",
        "        - output: :math:`(N, C, H, W)`.\n",
        "\n",
        "    Examples::\n",
        "        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True)\n",
        "        >>> adv_images = attack(images, labels)\n",
        "\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True, num_classes=10):\n",
        "        super().__init__(\"PGD\", model)\n",
        "        self.eps = eps\n",
        "        self.alpha = alpha\n",
        "        self.steps = steps\n",
        "        self.random_start = random_start\n",
        "        self.supported_mode = ['default', 'targeted']\n",
        "        self.num_classes = num_classes\n",
        "\n",
        "    def forward(self, images, labels):\n",
        "        r\"\"\"\n",
        "        Overridden.\n",
        "        \"\"\"\n",
        "\n",
        "        images = images.clone().detach().to(self.device)\n",
        "        labels = labels.clone().detach().to(self.device)\n",
        "\n",
        "        ones = torch.ones_like(labels)\n",
        "        multipliers = -1 * (ones - 2 * ones * (labels == self.num_classes))\n",
        "\n",
        "        adv_images = images.clone().detach()\n",
        "\n",
        "        if self.random_start:\n",
        "            # Starting at a uniformly random point\n",
        "            adv_images = adv_images + \\\n",
        "                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n",
        "            adv_images = torch.clamp(adv_images, min=0, max=1).detach()\n",
        "\n",
        "        for _ in range(self.steps):\n",
        "            adv_images.requires_grad = True\n",
        "            outputs = self.get_logits(adv_images)\n",
        "\n",
        "            target_labels = torch.full_like(labels, self.num_classes)\n",
        "            cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')\n",
        "            losses = cross_entropy_loss(outputs, target_labels)\n",
        "\n",
        "            cost = torch.mean(losses * multipliers)\n",
        "\n",
        "            # Update adversarial images\n",
        "            grad = torch.autograd.grad(cost, adv_images,\n",
        "                                       retain_graph=False, create_graph=False)[0]\n",
        "\n",
        "            adv_images = adv_images.detach() + self.alpha * grad.sign()\n",
        "            delta = torch.clamp(adv_images - images,\n",
        "                                min=-self.eps, max=self.eps)\n",
        "            adv_images = torch.clamp(images + delta, min=0, max=1).detach()\n",
        "\n",
        "        return adv_images\n"
      ],
      "id": "636ff745"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9b84987e"
      },
      "source": [
        "# Data"
      ],
      "id": "9b84987e"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "47dd7587"
      },
      "source": [
        "## Datasets"
      ],
      "id": "47dd7587"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ce3ce7bc"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torchvision import transforms\n",
        "from glob import glob\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from torchvision.transforms import functional as F\n",
        "import random\n",
        "\n",
        "class AdaptiveOutliers(torch.utils.data.Dataset):    \n",
        "    def __init__(self, filepath='./Generated Outliers/cifar10-glide-ood/cifar10_glide_ood.npy', size=32, transform=transforms.Compose([transforms.Resize(32), transforms.ToTensor()])):\n",
        "        \n",
        "        self.data = [x for x in torch.from_numpy(np.load(filepath))]\n",
        "        self.data = [transform(F.to_pil_image(x)) for x in self.data]\n",
        "        self.targets = np.zeros(len(self.data))\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        image = self.data[index]\n",
        "        target = self.targets[index]\n",
        "        \n",
        "        return image, int(target)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "    \n",
        "class MergedDataset(torch.utils.data.Dataset):    \n",
        "    def __init__(self, in_dataset, out_dataset, num_classes, root='.', size=32):\n",
        "        self.data = [x for x in in_dataset.data] + [x for x in out_dataset.data]\n",
        "        self.targets = [y for y in in_dataset.targets] + [num_classes] * len(out_dataset)\n",
        "        self.transform = transforms.Compose([transforms.Resize(size), transforms.ToTensor()])\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        image = self.data[index]\n",
        "        target = self.targets[index]\n",
        "\n",
        "        if self.transform:\n",
        "            image = self.transform(F.to_pil_image(image))\n",
        "        \n",
        "        return image, int(target)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)"
      ],
      "id": "ce3ce7bc"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## OOD Datasets Loader"
      ],
      "metadata": {
        "id": "3GGrWbpv4Muc"
      },
      "id": "3GGrWbpv4Muc"
    },
    {
      "cell_type": "code",
      "source": [
        "def get_out_testing_datasets(out_name):\n",
        "\n",
        "    if name == 'mnist':\n",
        "      mnist = torchvision.datasets.MNIST(root='./data', train = False, download = True, transform=transforms.Compose([transforms.ToTensor(),\n",
        "                                                                                            transforms.Resize(32),\n",
        "                                                                                            transforms.Lambda(lambda x : x.repeat(3, 1, 1)),\n",
        "                                                                                            ]))\n",
        "      return mnist\n",
        "    \n",
        "    elif name == 'tiny_imagenet':\n",
        "      tiny_imagenet = torchvision.datasets.ImageFolder(root = 'data/tiny-imagenet-200/test', transform=transforms.Compose([transforms.ToTensor(),\n",
        "                                                                                                    transforms.Resize(32)]))\n",
        "      \n",
        "      return tiny_imagenet\n",
        "    \n",
        "    elif name == 'places':\n",
        "      places365 = torchvision.datasets.Places365(root = 'data/', split = 'val', small = True, download = False, transform=transforms.Compose([transforms.ToTensor(),\n",
        "                                                                                                    transforms.Resize(32)]))\n",
        "\n",
        "      return places\n",
        "    \n",
        "    elif name == 'LSUN':\n",
        "      LSUN = torchvision.datasets.ImageFolder(root = 'data/LSUN_resize/', transform = transforms.ToTensor())\n",
        "\n",
        "      return LSUN\n",
        "\n",
        "    elif name == 'iSUN':\n",
        "      iSUN = torchvision.datasets.ImageFolder(root = 'data/iSUN/', transform = transforms.ToTensor())\n",
        "\n",
        "      return iSUN\n",
        "      \n",
        "    elif name == 'birds': \n",
        "      birds = torchvision.datasets.ImageFolder(root = 'data/images/', loader=bird_loader, transform = transforms.ToTensor())\n",
        "\n",
        "      return birds\n",
        "    \n",
        "    elif name == 'flowers':\n",
        "      flowers = torchvision.datasets.ImageFolder(root = 'data/flowers/', loader=flower_loader, transform = transforms.ToTensor())\n",
        "\n",
        "      return flowers\n",
        "    \n",
        "    elif name == 'coil':\n",
        "      coil_100 = torchvision.datasets.ImageFolder(root = 'data/coil/', transform=transforms.Compose([transforms.ToTensor(),\n",
        "                                                                                          transforms.Resize(32)]))\n",
        "      \n",
        "      return coil100\n",
        "    \n",
        "    else:\n",
        "      raise ValueError(\"Invalid OOD Dataset\")\n",
        "    "
      ],
      "metadata": {
        "id": "qwxSJfEd4P1M"
      },
      "id": "qwxSJfEd4P1M",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## CIFAR100 superclasses"
      ],
      "metadata": {
        "id": "Nn1KV2er4RwT"
      },
      "id": "Nn1KV2er4RwT"
    },
    {
      "cell_type": "code",
      "source": [
        "def sparse2coarse(targets):\n",
        "    \"\"\"Convert Pytorch CIFAR100 sparse targets to coarse targets.\n",
        "    Usage:\n",
        "        trainset = torchvision.datasets.CIFAR100(path)\n",
        "        trainset.targets = sparse2coarse(trainset.targets)\n",
        "    \"\"\"\n",
        "    coarse_labels = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3,\n",
        "                              3, 14, 9, 18, 7, 11, 3, 9, 7, 11,\n",
        "                              6, 11, 5, 10, 7, 6, 13, 15, 3, 15,\n",
        "                              0, 11, 1, 10, 12, 14, 16, 9, 11, 5,\n",
        "                              5, 19, 8, 8, 15, 13, 14, 17, 18, 10,\n",
        "                              16, 4, 17, 4, 2, 0, 17, 4, 18, 17,\n",
        "                              10, 3, 2, 12, 12, 16, 12, 1, 9, 19,\n",
        "                              2, 10, 0, 1, 16, 12, 9, 13, 15, 13,\n",
        "                              16, 19, 2, 4, 6, 19, 5, 5, 8, 19,\n",
        "                              18, 1, 2, 15, 6, 0, 17, 8, 14, 13])\n",
        "    return coarse_labels[targets]"
      ],
      "metadata": {
        "id": "Jgd5odFa4WIz"
      },
      "id": "Jgd5odFa4WIz",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EctvGJQBNNhr"
      },
      "source": [
        "## DataLoaders"
      ],
      "id": "EctvGJQBNNhr"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c18e8a01"
      },
      "outputs": [],
      "source": [
        "import torchvision\n",
        "from torchvision.datasets import CIFAR10, CIFAR100, MNIST\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import transforms\n",
        "import random\n",
        "\n",
        "dataset_class = {\n",
        "    'cifar10': CIFAR10,\n",
        "    'cifar100': CIFAR100,\n",
        "    'mnist': MNIST\n",
        "}\n",
        "\n",
        "def get_loaders():\n",
        "    \n",
        "    in_train_dataset = dataset_class[in_dataset](root='.', download=True, train=True)\n",
        "    \n",
        "    if in_dataset == 'cifar100':\n",
        "        in_train_dataset.targets = sparse2coarse(in_train_dataset.targets)\n",
        "\n",
        "    exposure_transformations = transforms.Compose([transforms.Resize([32, 32]),\n",
        "                           transforms.RandomHorizontalFlip(),\n",
        "                           transforms.RandomGrayscale(),\n",
        "                           transforms.RandomChoice(\n",
        "                           [transforms.RandomApply([transforms.RandomAffine(90, translate=(0.15, 0.15), scale=(0.85, 1), shear=None)], p=0.6),\n",
        "                           transforms.RandomApply([transforms.RandomAffine(0, translate=None, scale=(0.5, 0.75), shear=30)], p=0.6),\n",
        "                           transforms.RandomApply([transforms.AutoAugment()], p=0.9),]),\n",
        "                           transforms.ToTensor()])\n",
        "      \n",
        "    out_train_dataset = AdaptiveOutliers(transform=exposure_transformations)\n",
        "    \n",
        "    train_dataset = MergedDataset(in_train_dataset, out_train_dataset, num_classes=num_classes)\n",
        "\n",
        "    in_test_dataset = dataset_class[in_dataset](root='.', download=True, train=False)\n",
        "    \n",
        "    in_test_dataset.data = [x for x in in_test_dataset.data]\n",
        "    \n",
        "    if in_dataset == 'cifar100':\n",
        "        in_test_dataset.targets = sparse2coarse(in_test_dataset.targets)\n",
        "    \n",
        "    if out_dataset in ['cifar10', 'cifar100', 'mnist']:\n",
        "        out_test_dataset = dataset_class[out_dataset](root='.', download=True, train=False, transform=\n",
        "                                                      transforms.Compose([transforms.Grayscale(3), transforms.Resize(32), transforms.ToTensor()]) if out_dataset == 'mnist' else transforms.ToTensor())\n",
        "    else:\n",
        "        out_test_dataset = get_out_testing_datasets([out_dataset])[1][0]\n",
        "      \n",
        "    out_test_dataset.data = [x for x, _ in out_test_dataset]\n",
        "    \n",
        "    test_dataset = MergedDataset(in_test_dataset, out_test_dataset, num_classes=num_classes)\n",
        "    \n",
        "    trainloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)\n",
        "    testloader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size)\n",
        "\n",
        "\n",
        "    print(\"Length of In train dataset:\", len(in_train_dataset))\n",
        "    print(\"Length of Out train dataset:\", len(out_train_dataset))\n",
        "    \n",
        "    \n",
        "    print(\"Length of In test dataset:\", len(in_test_dataset))\n",
        "    print(\"Length of Out test dataset:\", len(out_test_dataset))\n",
        "    \n",
        "    print(f\"Length of train dataset: {len(train_dataset)}\")\n",
        "    print(f\"Length of test dataset: {len(test_dataset)}\")\n",
        "    \n",
        "    return trainloader, testloader"
      ],
      "id": "c18e8a01"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "96b77c1b"
      },
      "source": [
        "# Utils"
      ],
      "id": "96b77c1b"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3c92884e"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "from sklearn.metrics import roc_auc_score, accuracy_score\n",
        "\n",
        "def auc_softmax_adversarial(model, test_loader, test_attack, epoch:int, device, num_classes):\n",
        "\n",
        "    is_train = model.training\n",
        "    model.eval()\n",
        "\n",
        "    soft = torch.nn.Softmax(dim=1)\n",
        "    anomaly_scores = []\n",
        "    preds = []\n",
        "    test_labels = []\n",
        "\n",
        "    with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
        "        torch.cuda.empty_cache()\n",
        "        for i, (data, target) in enumerate(tepoch):\n",
        "            data, target = data.to(device), target.to(device)\n",
        "\n",
        "            adv_data = test_attack(data, target)\n",
        "            output = model(adv_data)\n",
        "\n",
        "            predictions = output.argmax(dim=1, keepdim=True).squeeze()\n",
        "            preds += predictions.detach().cpu().numpy().tolist()\n",
        "\n",
        "            probs = soft(output).squeeze()\n",
        "            anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist()\n",
        "\n",
        "            target = target == num_classes\n",
        "            \n",
        "            test_labels += target.detach().cpu().numpy().tolist()\n",
        "\n",
        "    auc = roc_auc_score(test_labels, anomaly_scores)\n",
        "    accuracy = accuracy_score(test_labels, preds, normalize=True)\n",
        "\n",
        "    if is_train:\n",
        "        model.train()\n",
        "    else:\n",
        "        model.eval()\n",
        "\n",
        "    return auc, accuracy\n",
        "\n",
        "def auc_softmax(model, test_loader, epoch:int, device, num_classes):\n",
        "\n",
        "    is_train = model.training\n",
        "    model.eval()\n",
        "\n",
        "    soft = torch.nn.Softmax(dim=1)\n",
        "    anomaly_scores = []\n",
        "    preds = []\n",
        "    test_labels = []\n",
        "    \n",
        "    with torch.no_grad():\n",
        "        with tqdm(test_loader, unit=\"batch\") as tepoch:\n",
        "            torch.cuda.empty_cache()\n",
        "            for i, (data, target) in enumerate(tepoch):\n",
        "                data, target = data.to(device), target.to(device)\n",
        "                output = model(data)\n",
        "\n",
        "                predictions = output.argmax(dim=1, keepdim=True).squeeze()\n",
        "                preds += predictions.detach().cpu().numpy().tolist()\n",
        "\n",
        "                probs = soft(output).squeeze()\n",
        "                anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist()\n",
        "\n",
        "                target = (target == num_classes).long()\n",
        "                \n",
        "                test_labels += target.detach().cpu().numpy().tolist()\n",
        "\n",
        "    auc = roc_auc_score(test_labels, anomaly_scores)\n",
        "    accuracy = accuracy_score(test_labels, preds, normalize=True)\n",
        "\n",
        "    if is_train:\n",
        "        model.train()\n",
        "    else:\n",
        "        model.eval()\n",
        "\n",
        "    return auc, accuracy\n",
        "\n",
        "lr_schedule = lambda learning_rate, t, max_epochs: np.interp([t], [0, max_epochs // 3, max_epochs * 2 // 3, max_epochs], [learning_rate, learning_rate/10, learning_rate / 100, 0])[0]\n",
        "    \n",
        "    \n",
        "def run(model, train_attack, test_attack, trainloader, testloader, test_step:int, max_epochs:int, device, loss_threshold=1e-3, num_classes=10, lr=0.01, optim=None):\n",
        "\n",
        "    if optim == \"adam\":\n",
        "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)\n",
        "    elif optim == \"sgd\":\n",
        "        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)\n",
        "    \n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    init_epoch = 0\n",
        "\n",
        "    clean_aucs = []\n",
        "    adv_aucs = []\n",
        "    \n",
        "    print(f'Starting Run from epoch {init_epoch}')\n",
        "    \n",
        "    train_loss = 0\n",
        "    \n",
        "    for epoch in range(init_epoch, max_epochs):\n",
        "\n",
        "        \n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "        \n",
        "        print(f'====== Starting Training on epoch {epoch}')\n",
        "        train_accuracy, train_loss = train_one_epoch(epoch=epoch,\\\n",
        "                                                                max_epochs=max_epochs, \\\n",
        "                                                                model=model,\\\n",
        "                                                                optimizer=optimizer,\n",
        "                                                                criterion=criterion,\\\n",
        "                                                                trainloader=trainloader,\\\n",
        "                                                                train_attack=train_attack,\\\n",
        "                                                                lr=lr,\\\n",
        "                                                                device=device)\n",
        "\n",
        "        print(\"train accuracy is \", train_accuracy)\n",
        "        print(\"train loss is \", train_loss)\n",
        "        \n",
        "        \n",
        "        if (epoch + 1)%1 == 0:\n",
        "            save_model_checkpoint(model, train_loss, f'./{epoch}_model_{in_dataset}.pt', optimizer)\n",
        "        \n",
        "        if epoch % test_step == 0 :\n",
        "\n",
        "            test_auc = {}\n",
        "            test_accuracy = {}\n",
        "\n",
        "            print(f'AUC & Accuracy Vanila - Started...')\n",
        "            clean_auc, clean_accuracy  = auc_softmax(model=model, epoch=epoch, test_loader=testloader, device=device, num_classes=num_classes)\n",
        "            test_auc['Clean'], test_accuracy['Clean'] = clean_auc, clean_accuracy\n",
        "            print(f'AUC Vanila - score on epoch {epoch} is: {clean_auc * 100}')\n",
        "            print(f'Accuracy Vanila -  score on epoch {epoch} is: {clean_accuracy * 100}')\n",
        "\n",
        "            attack_name = 'PGD-10'\n",
        "            attack = test_attack\n",
        "            print(f'AUC & Accuracy Adversarial - {attack_name} - Started...')\n",
        "            adv_auc, adv_accuracy = auc_softmax_adversarial(model=model, epoch=epoch, test_loader=testloader, test_attack=attack, device=device, num_classes=num_classes)\n",
        "            print(f'AUC Adversairal {attack_name} - score on epoch {epoch} is: {adv_auc * 100}')\n",
        "            print(f'Accuracy Adversairal {attack_name} -  score on epoch {epoch} is: {adv_accuracy * 100}')\n",
        "\n",
        "        \n",
        "        if train_loss < loss_threshold:\n",
        "            break\n",
        "\n",
        "\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "        clean_aucs.append(clean_auc)\n",
        "        adv_aucs.append(adv_auc)\n",
        "\n",
        "\n",
        "    save_model_checkpoint(model, train_loss, f'./last_model_{in_dataset}.pt', optimizer)\n",
        "\n",
        "    return clean_aucs, adv_aucs\n",
        "\n",
        "\n",
        "\n",
        "def train_one_epoch(epoch, max_epochs, model, optimizer, criterion, trainloader, train_attack, lr, device):\n",
        "\n",
        "    soft = torch.nn.Softmax(dim=1)\n",
        "\n",
        "    preds = []\n",
        "    true_labels = []\n",
        "    running_loss = 0\n",
        "    accuracy = 0\n",
        "\n",
        "    model.train()\n",
        "    with tqdm(trainloader, unit=\"batch\") as tepoch:\n",
        "        torch.cuda.empty_cache()\n",
        "        for i, (data, target) in enumerate(tepoch):\n",
        "            tepoch.set_description(f\"Epoch {epoch + 1}/{max_epochs}\")\n",
        "            updated_lr = lr_schedule(learning_rate=lr, t=epoch + (i + 1) / len(tepoch), max_epochs=max_epochs)\n",
        "            optimizer.param_groups[0].update(lr=updated_lr)\n",
        "\n",
        "            data, target = data.to(device), target.to(device)\n",
        "            target = target.type(torch.LongTensor).cuda()\n",
        "\n",
        "            # Adversarial attack on every batch\n",
        "            data = train_attack(data, target)\n",
        "\n",
        "            # Zero gradients for every batch\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "        \n",
        "            output = model(data)\n",
        "\n",
        "            # Compute the loss and its gradients\n",
        "            loss = criterion(output, target)\n",
        "            loss.backward()\n",
        "\n",
        "            # Adjust learning weights\n",
        "            optimizer.step()\n",
        "\n",
        "            true_labels += target.detach().cpu().numpy().tolist()\n",
        "\n",
        "            predictions = output.argmax(dim=1, keepdim=True).squeeze()\n",
        "            preds += predictions.detach().cpu().numpy().tolist()\n",
        "            correct = (torch.tensor(preds) == torch.tensor(true_labels)).sum().item()\n",
        "            accuracy = correct / len(preds)\n",
        "\n",
        "            probs = soft(output).squeeze()\n",
        "\n",
        "            running_loss += loss.item() * data.size(0)\n",
        "\n",
        "            tepoch.set_postfix(loss=running_loss / len(preds), accuracy=100. * accuracy)\n",
        "\n",
        "    return  accuracy_score(true_labels, preds, normalize=True), \\\n",
        "            running_loss / len(preds)\n",
        "\n",
        "\n",
        "def save_model_checkpoint(model, loss, path, optimizer):\n",
        "    try:\n",
        "        torch.save({\n",
        "                'model_state_dict': model.state_dict(),\n",
        "                'optimizer_state_dict': optimizer.state_dict(),\n",
        "                'loss': loss,\n",
        "        }, path)\n",
        "    except:\n",
        "        raise ValueError('Saving model checkpoint failed!')\n",
        "\n",
        "def load_model_checkpoint(model, optimizer, path):\n",
        "    try:\n",
        "        checkpoint = torch.load(path)\n",
        "        model.load_state_dict(checkpoint['model_state_dict'])\n",
        "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
        "        epoch = checkpoint['epoch']\n",
        "        loss = checkpoint['loss']\n",
        "        return model, optimizer, epoch, loss\n",
        "    except:\n",
        "        return None\n"
      ],
      "id": "3c92884e"
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Visualization"
      ],
      "metadata": {
        "id": "mFnChyAKI_Ba"
      },
      "id": "mFnChyAKI_Ba"
    },
    {
      "cell_type": "code",
      "source": [
        "from matplotlib import pyplot as plt\n",
        "\n",
        "def visualize_samples(dataloader, n, title=\"Sample\"):\n",
        "    normal_samples = []\n",
        "    abnormal_samples = []\n",
        "\n",
        "    def to_3_channels(image):\n",
        "        if image.shape[0] == 1:\n",
        "            return image.repeat(3, 1, 1)\n",
        "        return image\n",
        "\n",
        "    # Collect n x n samples\n",
        "    for images, labels in dataloader:\n",
        "        for i, l in enumerate(labels):\n",
        "            image = to_3_channels(images[i])\n",
        "            if len(normal_samples) < n * n and l == 0:\n",
        "                normal_samples.append(image)\n",
        "            elif len(abnormal_samples) < n * n and l != 0:\n",
        "                abnormal_samples.append(image)\n",
        "            if len(normal_samples) == n * n and len(abnormal_samples) == n * n:\n",
        "                break\n",
        "        if len(normal_samples) == n * n and len(abnormal_samples) == n * n:\n",
        "            break\n",
        "\n",
        "    normal_grid = torchvision.utils.make_grid(normal_samples, nrow=n)\n",
        "    abnormal_grid = torchvision.utils.make_grid(abnormal_samples, nrow=n)\n",
        "\n",
        "\n",
        "    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(18, 8))\n",
        "    fig.patch.set_alpha(0)\n",
        "    fig.suptitle(title, fontsize=16)\n",
        "\n",
        "    axs[0].imshow(normal_grid.permute(1, 2, 0))\n",
        "    axs[0].set_title('Normal', fontsize=14)\n",
        "    axs[0].axis('off')\n",
        "\n",
        "    axs[1].imshow(abnormal_grid.permute(1, 2, 0))\n",
        "    axs[1].set_title('Abnormal', fontsize=14)\n",
        "    axs[1].axis('off')\n",
        "\n",
        "    plt.show()"
      ],
      "metadata": {
        "id": "6_VE6ssFJCDy"
      },
      "id": "6_VE6ssFJCDy",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6d657bcc"
      },
      "source": [
        "# Training"
      ],
      "id": "6d657bcc"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b044bb07"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "device = 'cuda:0'\n",
        "\n",
        "trainloader, testloader = get_loaders()\n",
        "model = WideResNet(40, num_classes+1, 4,  dropRate=0.0).to(device)\n",
        "\n",
        "train_attack1 = PGD_CLS(model, eps=attack_eps, steps=10, alpha=attack_alpha)\n",
        "test_attack = PGD_TEST(model, eps=attack_eps, steps=10, alpha=attack_alpha, num_classes=num_classes)\n",
        "\n",
        "device = torch.device(f\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "clean_aucs, adv_aucs = run(model, train_attack1, test_attack, trainloader, testloader, 1, epochs, device, loss_threshold=1e-3, num_classes=num_classes,lr=lr, optim=optim)"
      ],
      "id": "b044bb07"
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Testing"
      ],
      "metadata": {
        "id": "mjTkKpnYCZvT"
      },
      "id": "mjTkKpnYCZvT"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4204da52"
      },
      "outputs": [],
      "source": [
        "attack = PGD_TEST(model, eps=attack_eps, steps=100, alpha=attack_alpha, num_classes=num_classes)\n",
        "\n",
        "adv_sm = 0\n",
        "clean_sm = 0\n",
        "\n",
        "for out_dataset in avialable_datasets:\n",
        "\n",
        "    if out_dataset == in_dataset:\n",
        "        continue\n",
        "\n",
        "    trainloader, testloader = get_loaders()\n",
        "    \n",
        "    visualize_samples(testloader, 8, out_dataset)\n",
        "    \n",
        "    clean_auc, clean_accuracy  = auc_softmax(model=model, epoch=epochs, test_loader=testloader, device=device, num_classes=num_classes)\n",
        "    print(f\"Clean AUC for (In={in_dataset}) and (Out={out_dataset}) is {int(clean_auc * 10000)/100}\")\n",
        "    adv_auc, adv_accuracy = auc_softmax_adversarial(model=model, epoch=epochs, test_loader=testloader, test_attack=attack, device=device, num_classes=num_classes)\n",
        "    print(f\"PGD-10 Adversarial AUC for (In={in_dataset}) and (Out={out_dataset}) is {int(adv_auc * 10000) / 100}\")\n",
        "    adv_sm += adv_auc\n",
        "    clean_sm += clean_auc\n",
        "\n",
        "print(\"Average mean of clean AUC:\", clean_sm / (len(avialable_datasets) - 1))\n",
        "print(\"Average mean of Adversarial AUC:\", adv_sm / (len(avialable_datasets) - 1))"
      ],
      "id": "4204da52"
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "WMTrhdZi4nIM",
        "K_SchgWV4qQ9",
        "Ejrqp7GUG-vz",
        "3fda854f",
        "44925dca",
        "86692533",
        "16e9cfbc",
        "d0cdf735",
        "47dd7587",
        "3GGrWbpv4Muc",
        "Nn1KV2er4RwT",
        "EctvGJQBNNhr",
        "mFnChyAKI_Ba"
      ],
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.12"
    },
    "papermill": {
      "default_parameters": {},
      "duration": 15695.269558,
      "end_time": "2023-05-14T17:24:48.838846",
      "environment_variables": {},
      "exception": null,
      "input_path": "__notebook__.ipynb",
      "output_path": "__notebook__.ipynb",
      "parameters": {},
      "start_time": "2023-05-14T13:03:13.569288",
      "version": "2.4.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}