{"cells":[{"cell_type":"markdown","metadata":{"id":"bU_TDFG-0jx4"},"source":["# Preparation"]},{"cell_type":"markdown","metadata":{"id":"R34pJ7an0pXz"},"source":["## Download & Extract Adaptive Exposures"]},{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:52:10.638316Z","iopub.status.busy":"2023-05-24T19:52:10.638024Z","iopub.status.idle":"2023-05-24T19:52:24.538756Z","shell.execute_reply":"2023-05-24T19:52:24.537506Z","shell.execute_reply.started":"2023-05-24T19:52:10.638291Z"},"id":"WJUVw7gG2ics","trusted":true},"outputs":[],"source":["%%capture\n","!pip install gdown"]},{"cell_type":"markdown","metadata":{"id":"qQFhV1tP6Gou"},"source":["## Download Generated Outliers"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:52:24.541460Z","iopub.status.busy":"2023-05-24T19:52:24.541167Z","iopub.status.idle":"2023-05-24T19:53:00.281121Z","shell.execute_reply":"2023-05-24T19:53:00.280073Z","shell.execute_reply.started":"2023-05-24T19:52:24.541434Z"},"id":"AbTeiWZf0s0L","trusted":true},"outputs":[],"source":["import gdown\n","url = \"https://drive.google.com/drive/folders/1zXEwDiB1EUhG4-n4x3LYhbffOKYCbiEc?usp=sharing\"\n","gdown.download_folder(url, quiet=True, use_cookies=False)"]},{"cell_type":"markdown","metadata":{"id":"xljjSpsRzADQ"},"source":["# Use CUDA"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:00.283183Z","iopub.status.busy":"2023-05-24T19:53:00.282345Z","iopub.status.idle":"2023-05-24T19:53:01.286738Z","shell.execute_reply":"2023-05-24T19:53:01.285642Z","shell.execute_reply.started":"2023-05-24T19:53:00.283143Z"},"id":"snXeoPIUzB9t","trusted":true},"outputs":[],"source":["!nvidia-smi"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:01.290684Z","iopub.status.busy":"2023-05-24T19:53:01.290300Z","iopub.status.idle":"2023-05-24T19:53:04.361114Z","shell.execute_reply":"2023-05-24T19:53:04.360138Z","shell.execute_reply.started":"2023-05-24T19:53:01.290647Z"},"id":"gC3uwQLMzDE9","trusted":true},"outputs":[],"source":["import torch\n","device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n","print('Using device:', device)"]},{"cell_type":"markdown","metadata":{"id":"fc44c55a","papermill":{"duration":0.00891,"end_time":"2023-04-23T16:48:37.807071","exception":false,"start_time":"2023-04-23T16:48:37.798161","status":"completed"},"tags":[]},"source":["# Configurations"]},{"cell_type":"code","execution_count":5,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2023-05-24T19:53:04.363748Z","iopub.status.busy":"2023-05-24T19:53:04.362731Z","iopub.status.idle":"2023-05-24T19:53:04.369593Z","shell.execute_reply":"2023-05-24T19:53:04.368497Z","shell.execute_reply.started":"2023-05-24T19:53:04.363708Z"},"id":"892b9031","papermill":{"duration":0.023102,"end_time":"2023-04-23T16:48:37.837607","exception":false,"start_time":"2023-04-23T16:48:37.814505","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["no_trials = 5\n","in_dataset = 'MNIST'\n","fakes_paths_reg = f'./Generated Outliers/{in_dataset.lower()}-glide-osr/*.npy'\n","attack_eps = 8/255\n","attack_steps = 10\n","attack_alpha = 2.5 * attack_eps / attack_steps\n","num_classes = 6\n","batch_size = 128"]},{"cell_type":"markdown","metadata":{"id":"f974ece4","papermill":{"duration":0.014191,"end_time":"2023-04-23T16:50:00.592373","exception":false,"start_time":"2023-04-23T16:50:00.578182","status":"completed"},"tags":[]},"source":["# Model"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:04.373071Z","iopub.status.busy":"2023-05-24T19:53:04.372491Z","iopub.status.idle":"2023-05-24T19:53:04.396165Z","shell.execute_reply":"2023-05-24T19:53:04.395520Z","shell.execute_reply.started":"2023-05-24T19:53:04.373039Z"},"id":"59cab483","papermill":{"duration":2.413818,"end_time":"2023-04-23T16:50:03.020519","exception":false,"start_time":"2023-04-23T16:50:00.606701","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import math\n","import torch\n","import torch.nn as nn\n","\n","\n","class BasicBlock(nn.Module):\n","    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):\n","        super(BasicBlock, self).__init__()\n","        self.bn1 = nn.BatchNorm2d(in_planes)\n","        self.relu1 = nn.ReLU(inplace=True)\n","        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n","                               padding=1, bias=False)\n","        self.bn2 = nn.BatchNorm2d(out_planes)\n","        self.relu2 = nn.ReLU(inplace=True)\n","        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,\n","                               padding=1, bias=False)\n","        self.droprate = dropRate\n","        self.equalInOut = (in_planes == out_planes)\n","        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,\n","                               padding=0, bias=False) or None\n","    def forward(self, x):\n","        if not self.equalInOut:\n","            x = self.relu1(self.bn1(x))\n","        else:\n","            out = self.relu1(self.bn1(x))\n","        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))\n","        if self.droprate > 0:\n","            out = torch.nn.functional.dropout(out, p=self.droprate, training=self.training)\n","        out = self.conv2(out)\n","        return torch.add(x if self.equalInOut else self.convShortcut(x), out)\n","\n","class NetworkBlock(nn.Module):\n","    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):\n","        super(NetworkBlock, self).__init__()\n","        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)\n","    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):\n","        layers = []\n","        for i in range(int(nb_layers)):\n","            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))\n","        return nn.Sequential(*layers)\n","    def forward(self, x):\n","        return self.layer(x)\n","\n","class WideResNet(nn.Module):\n","    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):\n","        super(WideResNet, self).__init__()\n","        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]\n","        assert((depth - 4) % 6 == 0)\n","        n = (depth - 4) / 6\n","        block = BasicBlock\n","        # 1st conv before any network block\n","        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,\n","                               padding=1, bias=False)\n","        # 1st block\n","        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)\n","        # 2nd block\n","        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)\n","        # 3rd block\n","        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)\n","        # global average pooling and classifier\n","        self.bn1 = nn.BatchNorm2d(nChannels[3])\n","        self.relu = nn.ReLU(inplace=True)\n","        self.fc = nn.Linear(nChannels[3], num_classes)\n","        self.nChannels = nChannels[3]\n","\n","        for m in self.modules():\n","            if isinstance(m, nn.Conv2d):\n","                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n","            elif isinstance(m, nn.BatchNorm2d):\n","                m.weight.data.fill_(1)\n","                m.bias.data.zero_()\n","            elif isinstance(m, nn.Linear):\n","                m.bias.data.zero_()\n","    def forward(self, x):\n","        out = self.conv1(x)\n","        out = self.block1(out)\n","        out = self.block2(out)\n","        out = self.block3(out)\n","        out = self.relu(self.bn1(out))\n","        out = torch.nn.functional.avg_pool2d(out, 8)\n","        out = out.view(-1, self.nChannels)\n","        return self.fc(out)"]},{"cell_type":"markdown","metadata":{"id":"64e94984","papermill":{"duration":0.033855,"end_time":"2023-04-23T16:50:03.092172","exception":false,"start_time":"2023-04-23T16:50:03.058317","status":"completed"},"tags":[]},"source":["# Attack"]},{"cell_type":"markdown","metadata":{"id":"2caba216","papermill":{"duration":0.013859,"end_time":"2023-04-23T16:50:03.121315","exception":false,"start_time":"2023-04-23T16:50:03.107456","status":"completed"},"tags":[]},"source":["## Base Attack"]},{"cell_type":"code","execution_count":7,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:04.398182Z","iopub.status.busy":"2023-05-24T19:53:04.397588Z","iopub.status.idle":"2023-05-24T19:53:04.470914Z","shell.execute_reply":"2023-05-24T19:53:04.470104Z","shell.execute_reply.started":"2023-05-24T19:53:04.398148Z"},"id":"073103a5","jupyter":{"source_hidden":true},"papermill":{"duration":0.081538,"end_time":"2023-04-23T16:50:03.216903","exception":false,"start_time":"2023-04-23T16:50:03.135365","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import time\n","import logging\n","from collections import OrderedDict\n","from collections.abc import Iterable\n","\n","import torch\n","from torch.utils.data import DataLoader, TensorDataset\n","\n","\n","def wrapper_method(func):\n","    def wrapper_func(self, *args, **kwargs):\n","        result = func(self, *args, **kwargs)\n","        for atk in self.__dict__.get('_attacks').values():\n","            eval(\"atk.\"+func.__name__+\"(*args, **kwargs)\")\n","        return result\n","    return wrapper_func\n","\n","\n","class Attack(object):\n","    r\"\"\"\n","    Base class for all attacks.\n","\n","    .. note::\n","        It automatically set device to the device where given model is.\n","        It basically changes training mode to eval during attack process.\n","        To change this, please see `set_model_training_mode`.\n","    \"\"\"\n","\n","    def __init__(self, name, model):\n","        r\"\"\"\n","        Initializes internal attack state.\n","\n","        Arguments:\n","            name (str): name of attack.\n","            model (torch.nn.Module): model to attack.\n","        \"\"\"\n","\n","        self.attack = name\n","        self._attacks = OrderedDict()\n","\n","        self.set_model(model)\n","        self.device = next(model.parameters()).device\n","\n","        # Controls attack mode.\n","        self.attack_mode = 'default'\n","        self.supported_mode = ['default']\n","        self.targeted = False\n","        self._target_map_function = None\n","\n","        # Controls when normalization is used.\n","        self.normalization_used = {}\n","        self._normalization_applied = False\n","        self._set_auto_normalization_used(model)\n","\n","        # Controls model mode during attack.\n","        self._model_training = False\n","        self._batchnorm_training = False\n","        self._dropout_training = False\n","\n","    def forward(self, inputs, labels=None, *args, **kwargs):\n","        r\"\"\"\n","        It defines the computation performed at every call.\n","        Should be overridden by all subclasses.\n","        \"\"\"\n","        raise NotImplementedError\n","\n","    def _check_inputs(self, images):\n","        tol = 1e-4\n","        if self._normalization_applied:\n","            images = self.inverse_normalize(images)\n","        if torch.max(images) > 1+tol or torch.min(images) < 0-tol:\n","            raise ValueError('Input must have a range [0, 1] (max: {}, min: {})'.format(\n","                torch.max(images), torch.min(images)))\n","        return images\n","\n","    def _check_outputs(self, images):\n","        if self._normalization_applied:\n","            images = self.normalize(images)\n","        return images\n","\n","    @wrapper_method\n","    def set_model(self, model):\n","        self.model = model\n","        self.model_name = model.__class__.__name__\n","\n","    def get_logits(self, inputs, labels=None, *args, **kwargs):\n","        if self._normalization_applied:\n","            inputs = self.normalize(inputs)\n","        logits = self.model(inputs)\n","        return logits\n","\n","    @wrapper_method\n","    def _set_normalization_applied(self, flag):\n","        self._normalization_applied = flag\n","\n","    @wrapper_method\n","    def set_device(self, device):\n","        self.device = device\n","\n","    @wrapper_method\n","    def _set_auto_normalization_used(self, model):\n","        if model.__class__.__name__ == 'RobModel':\n","            mean = getattr(model, 'mean', None)\n","            std = getattr(model, 'std', None)\n","            if (mean is not None) and (std is not None):\n","                if isinstance(mean, torch.Tensor):\n","                    mean = mean.cpu().numpy()\n","                if isinstance(std, torch.Tensor):\n","                    std = std.cpu().numpy()\n","                if (mean != 0).all() or (std != 1).all():\n","                    self.set_normalization_used(mean, std)\n","    #                 logging.info(\"Normalization automatically loaded from `model.mean` and `model.std`.\")\n","\n","    @wrapper_method\n","    def set_normalization_used(self, mean, std):\n","        n_channels = len(mean)\n","        mean = torch.tensor(mean).reshape(1, n_channels, 1, 1)\n","        std = torch.tensor(std).reshape(1, n_channels, 1, 1)\n","        self.normalization_used['mean'] = mean\n","        self.normalization_used['std'] = std\n","        self._normalization_applied = True\n","\n","    def normalize(self, inputs):\n","        mean = self.normalization_used['mean'].to(inputs.device)\n","        std = self.normalization_used['std'].to(inputs.device)\n","        return (inputs - mean) / std\n","\n","    def inverse_normalize(self, inputs):\n","        mean = self.normalization_used['mean'].to(inputs.device)\n","        std = self.normalization_used['std'].to(inputs.device)\n","        return inputs*std + mean\n","\n","    def get_mode(self):\n","        r\"\"\"\n","        Get attack mode.\n","\n","        \"\"\"\n","        return self.attack_mode\n","\n","    @wrapper_method\n","    def set_mode_default(self):\n","        r\"\"\"\n","        Set attack mode as default mode.\n","\n","        \"\"\"\n","        self.attack_mode = 'default'\n","        self.targeted = False\n","        print(\"Attack mode is changed to 'default.'\")\n","\n","    @wrapper_method\n","    def _set_mode_targeted(self, mode, quiet):\n","        if \"targeted\" not in self.supported_mode:\n","            raise ValueError(\"Targeted mode is not supported.\")\n","        self.targeted = True\n","        self.attack_mode = mode\n","        if not quiet:\n","            print(\"Attack mode is changed to '%s'.\" % mode)\n","\n","    @wrapper_method\n","    def set_mode_targeted_by_function(self, target_map_function, quiet=False):\n","        r\"\"\"\n","        Set attack mode as targeted.\n","\n","        Arguments:\n","            target_map_function (function): Label mapping function.\n","                e.g. lambda inputs, labels:(labels+1)%10.\n","                None for using input labels as targeted labels. (Default)\n","\n","        \"\"\"\n","        self._set_mode_targeted('targeted(custom)', quiet)\n","        self._target_map_function = target_map_function\n","\n","    @wrapper_method\n","    def set_mode_targeted_random(self, quiet=False):\n","        r\"\"\"\n","        Set attack mode as targeted with random labels.\n","\n","        Arguments:\n","            num_classses (str): number of classes.\n","\n","        \"\"\"\n","        self._set_mode_targeted('targeted(random)', quiet)\n","        self._target_map_function = self.get_random_target_label\n","\n","    @wrapper_method\n","    def set_mode_targeted_least_likely(self, kth_min=1, quiet=False):\n","        r\"\"\"\n","        Set attack mode as targeted with least likely labels.\n","\n","        Arguments:\n","            kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1)\n","\n","        \"\"\"\n","        self._set_mode_targeted('targeted(least-likely)', quiet)\n","        assert (kth_min > 0)\n","        self._kth_min = kth_min\n","        self._target_map_function = self.get_least_likely_label\n","\n","    @wrapper_method\n","    def set_mode_targeted_by_label(self, quiet=False):\n","        r\"\"\"\n","        Set attack mode as targeted.\n","\n","        .. note::\n","            Use user-supplied labels as target labels.\n","        \"\"\"\n","        self._set_mode_targeted('targeted(label)', quiet)\n","        self._target_map_function = 'function is a string'\n","\n","    @wrapper_method\n","    def set_model_training_mode(self, model_training=False, batchnorm_training=False, dropout_training=False):\n","        r\"\"\"\n","        Set training mode during attack process.\n","\n","        Arguments:\n","            model_training (bool): True for using training mode for the entire model during attack process.\n","            batchnorm_training (bool): True for using training mode for batchnorms during attack process.\n","            dropout_training (bool): True for using training mode for dropouts during attack process.\n","\n","        .. note::\n","            For RNN-based models, we cannot calculate gradients with eval mode.\n","            Thus, it should be changed to the training mode during the attack.\n","        \"\"\"\n","        self._model_training = model_training\n","        self._batchnorm_training = batchnorm_training\n","        self._dropout_training = dropout_training\n","\n","    @wrapper_method\n","    def _change_model_mode(self, given_training):\n","        if self._model_training:\n","            self.model.train()\n","            for _, m in self.model.named_modules():\n","                if not self._batchnorm_training:\n","                    if 'BatchNorm' in m.__class__.__name__:\n","                        m = m.eval()\n","                if not self._dropout_training:\n","                    if 'Dropout' in m.__class__.__name__:\n","                        m = m.eval()\n","        else:\n","            self.model.eval()\n","\n","    @wrapper_method\n","    def _recover_model_mode(self, given_training):\n","        if given_training:\n","            self.model.train()\n","\n","    def save(self, data_loader, save_path=None, verbose=True, return_verbose=False,\n","             save_predictions=False, save_clean_inputs=False, save_type='float'):\n","        r\"\"\"\n","        Save adversarial inputs as torch.tensor from given torch.utils.data.DataLoader.\n","\n","        Arguments:\n","            save_path (str): save_path.\n","            data_loader (torch.utils.data.DataLoader): data loader.\n","            verbose (bool): True for displaying detailed information. (Default: True)\n","            return_verbose (bool): True for returning detailed information. (Default: False)\n","            save_predictions (bool): True for saving predicted labels (Default: False)\n","            save_clean_inputs (bool): True for saving clean inputs (Default: False)\n","\n","        \"\"\"\n","        if save_path is not None:\n","            adv_input_list = []\n","            label_list = []\n","            if save_predictions:\n","                pred_list = []\n","            if save_clean_inputs:\n","                input_list = []\n","\n","        correct = 0\n","        total = 0\n","        l2_distance = []\n","\n","        total_batch = len(data_loader)\n","        given_training = self.model.training\n","\n","        for step, (inputs, labels) in enumerate(data_loader):\n","            start = time.time()\n","            adv_inputs = self.__call__(inputs, labels)\n","            batch_size = len(inputs)\n","\n","            if verbose or return_verbose:\n","                with torch.no_grad():\n","                    outputs = self.get_output_with_eval_nograd(adv_inputs)\n","\n","                    # Calculate robust accuracy\n","                    _, pred = torch.max(outputs.data, 1)\n","                    total += labels.size(0)\n","                    right_idx = (pred == labels.to(self.device))\n","                    correct += right_idx.sum()\n","                    rob_acc = 100 * float(correct) / total\n","\n","                    # Calculate l2 distance\n","                    delta = (adv_inputs - inputs.to(self.device)).view(batch_size, -1)  # nopep8\n","                    l2_distance.append(torch.norm(delta[~right_idx], p=2, dim=1))  # nopep8\n","                    l2 = torch.cat(l2_distance).mean().item()\n","\n","                    # Calculate time computation\n","                    progress = (step+1)/total_batch*100\n","                    end = time.time()\n","                    elapsed_time = end-start\n","\n","                    if verbose:\n","                        self._save_print(progress, rob_acc, l2, elapsed_time, end='\\r')  # nopep8\n","\n","            if save_path is not None:\n","                adv_input_list.append(adv_inputs.detach().cpu())\n","                label_list.append(labels.detach().cpu())\n","\n","                adv_input_list_cat = torch.cat(adv_input_list, 0)\n","                label_list_cat = torch.cat(label_list, 0)\n","\n","                save_dict = {'adv_inputs': adv_input_list_cat, 'labels': label_list_cat}  # nopep8\n","\n","                if save_predictions:\n","                    pred_list.append(pred.detach().cpu())\n","                    pred_list_cat = torch.cat(pred_list, 0)\n","                    save_dict['preds'] = pred_list_cat\n","\n","                if save_clean_inputs:\n","                    input_list.append(inputs.detach().cpu())\n","                    input_list_cat = torch.cat(input_list, 0)\n","                    save_dict['clean_inputs'] = input_list_cat\n","\n","                if self.normalization_used is not None:\n","                    save_dict['adv_inputs'] = self.inverse_normalize(save_dict['adv_inputs'])  # nopep8\n","                    if save_clean_inputs:\n","                        save_dict['clean_inputs'] = self.inverse_normalize(save_dict['clean_inputs'])  # nopep8\n","\n","                if save_type == 'int':\n","                    save_dict['adv_inputs'] = self.to_type(save_dict['adv_inputs'], 'int')  # nopep8\n","                    if save_clean_inputs:\n","                        save_dict['clean_inputs'] = self.to_type(save_dict['clean_inputs'], 'int')  # nopep8\n","\n","                save_dict['save_type'] = save_type\n","                torch.save(save_dict, save_path)\n","\n","        # To avoid erasing the printed information.\n","        if verbose:\n","            self._save_print(progress, rob_acc, l2, elapsed_time, end='\\n')\n","\n","        if given_training:\n","            self.model.train()\n","\n","        if return_verbose:\n","            return rob_acc, l2, elapsed_time\n","\n","    @staticmethod\n","    def to_type(inputs, type):\n","        r\"\"\"\n","        Return inputs as int if float is given.\n","        \"\"\"\n","        if type == 'int':\n","            if isinstance(inputs, torch.FloatTensor) or isinstance(inputs, torch.cuda.FloatTensor):\n","                return (inputs*255).type(torch.uint8)\n","        elif type == 'float':\n","            if isinstance(inputs, torch.ByteTensor) or isinstance(inputs, torch.cuda.ByteTensor):\n","                return inputs.float()/255\n","        else:\n","            raise ValueError(\n","                type + \" is not a valid type. [Options: float, int]\")\n","        return inputs\n","\n","    @staticmethod\n","    def _save_print(progress, rob_acc, l2, elapsed_time, end):\n","        print('- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \\t'\n","              % (progress, rob_acc, l2, elapsed_time), end=end)\n","\n","    @staticmethod\n","    def load(load_path, batch_size=128, shuffle=False, normalize=None,\n","             load_predictions=False, load_clean_inputs=False):\n","        save_dict = torch.load(load_path)\n","        keys = ['adv_inputs', 'labels']\n","\n","        if load_predictions:\n","            keys.append('preds')\n","        if load_clean_inputs:\n","            keys.append('clean_inputs')\n","\n","        if save_dict['save_type'] == 'int':\n","            save_dict['adv_inputs'] = save_dict['adv_inputs'].float()/255\n","            if load_clean_inputs:\n","                save_dict['clean_inputs'] = save_dict['clean_inputs'].float() / 255  # nopep8\n","\n","        if normalize is not None:\n","            n_channels = len(normalize['mean'])\n","            mean = torch.tensor(normalize['mean']).reshape(1, n_channels, 1, 1)\n","            std = torch.tensor(normalize['std']).reshape(1, n_channels, 1, 1)\n","            save_dict['adv_inputs'] = (save_dict['adv_inputs'] - mean) / std\n","            if load_clean_inputs:\n","                save_dict['clean_inputs'] = (save_dict['clean_inputs'] - mean) / std  # nopep8\n","\n","        adv_data = TensorDataset(*[save_dict[key] for key in keys])\n","        adv_loader = DataLoader(\n","            adv_data, batch_size=batch_size, shuffle=shuffle)\n","        print(\"Data is loaded in the following order: [%s]\" % (\", \".join(keys)))  # nopep8\n","        return adv_loader\n","\n","    @torch.no_grad()\n","    def get_output_with_eval_nograd(self, inputs):\n","        given_training = self.model.training\n","        if given_training:\n","            self.model.eval()\n","        outputs = self.get_logits(inputs)\n","        if given_training:\n","            self.model.train()\n","        return outputs\n","\n","    def get_target_label(self, inputs, labels=None):\n","        r\"\"\"\n","        Function for changing the attack mode.\n","        Return input labels.\n","        \"\"\"\n","        if self._target_map_function is None:\n","            raise ValueError(\n","                'target_map_function is not initialized by set_mode_targeted.')\n","        if self.attack_mode == 'targeted(label)':\n","            target_labels = labels\n","        else:\n","            target_labels = self._target_map_function(inputs, labels)\n","        return target_labels\n","\n","    @torch.no_grad()\n","    def get_least_likely_label(self, inputs, labels=None):\n","        outputs = self.get_output_with_eval_nograd(inputs)\n","        if labels is None:\n","            _, labels = torch.max(outputs, dim=1)\n","        n_classses = outputs.shape[-1]\n","\n","        target_labels = torch.zeros_like(labels)\n","        for counter in range(labels.shape[0]):\n","            l = list(range(n_classses))\n","            l.remove(labels[counter])\n","            _, t = torch.kthvalue(outputs[counter][l], self._kth_min)\n","            target_labels[counter] = l[t]\n","\n","        return target_labels.long().to(self.device)\n","\n","    @torch.no_grad()\n","    def get_random_target_label(self, inputs, labels=None):\n","        outputs = self.get_output_with_eval_nograd(inputs)\n","        if labels is None:\n","            _, labels = torch.max(outputs, dim=1)\n","        n_classses = outputs.shape[-1]\n","\n","        target_labels = torch.zeros_like(labels)\n","        for counter in range(labels.shape[0]):\n","            l = list(range(n_classses))\n","            l.remove(labels[counter])\n","            t = (len(l)*torch.rand([1])).long().to(self.device)\n","            target_labels[counter] = l[t]\n","\n","        return target_labels.long().to(self.device)\n","\n","    def __call__(self, images, labels=None, *args, **kwargs):\n","        given_training = self.model.training\n","        self._change_model_mode(given_training)\n","        images = self._check_inputs(images)\n","        adv_images = self.forward(images, labels, *args, **kwargs)\n","        adv_images = self._check_outputs(adv_images)\n","        self._recover_model_mode(given_training)\n","        return adv_images\n","\n","    def __repr__(self):\n","        info = self.__dict__.copy()\n","\n","        del_keys = ['model', 'attack', 'supported_mode']\n","\n","        for key in info.keys():\n","            if key[0] == \"_\":\n","                del_keys.append(key)\n","\n","        for key in del_keys:\n","            del info[key]\n","\n","        info['attack_mode'] = self.attack_mode\n","        info['normalization_used'] = True if len(self.normalization_used) > 0 else False  # nopep8\n","\n","        return self.attack + \"(\" + ', '.join('{}={}'.format(key, val) for key, val in info.items()) + \")\"\n","\n","    def __setattr__(self, name, value):\n","        object.__setattr__(self, name, value)\n","\n","        attacks = self.__dict__.get('_attacks')\n","\n","        # Get all items in iterable items.\n","        def get_all_values(items, stack=[]):\n","            if (items not in stack):\n","                stack.append(items)\n","                if isinstance(items, list) or isinstance(items, dict):\n","                    if isinstance(items, dict):\n","                        items = (list(items.keys())+list(items.values()))\n","                    for item in items:\n","                        yield from get_all_values(item, stack)\n","                else:\n","                    if isinstance(items, Attack):\n","                        yield items\n","            else:\n","                if isinstance(items, Attack):\n","                    yield items\n","\n","        for num, value in enumerate(get_all_values(value)):\n","            attacks[name+\".\"+str(num)] = value\n","            for subname, subvalue in value.__dict__.get('_attacks').items():\n","                attacks[name+\".\"+subname] = subvalue"]},{"cell_type":"markdown","metadata":{"id":"0681d390","papermill":{"duration":0.014023,"end_time":"2023-04-23T16:50:03.245178","exception":false,"start_time":"2023-04-23T16:50:03.231155","status":"completed"},"tags":[]},"source":["## PGD"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:04.472867Z","iopub.status.busy":"2023-05-24T19:53:04.472484Z","iopub.status.idle":"2023-05-24T19:53:04.488770Z","shell.execute_reply":"2023-05-24T19:53:04.487833Z","shell.execute_reply.started":"2023-05-24T19:53:04.472833Z"},"id":"0f76f4a1","papermill":{"duration":0.029206,"end_time":"2023-04-23T16:50:03.332562","exception":false,"start_time":"2023-04-23T16:50:03.303356","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["\n","class PGD_TEST(Attack):\n","    r\"\"\"\n","    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'\n","    [https://arxiv.org/abs/1706.06083]\n","\n","    Distance Measure : Linf\n","\n","    Arguments:\n","        model (nn.Module): model to attack.\n","        eps (float): maximum perturbation. (Default: 8/255)\n","        alpha (float): step size. (Default: 2/255)\n","        steps (int): number of steps. (Default: 10)\n","        random_start (bool): using random initialization of delta. (Default: True)\n","\n","    Shape:\n","        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].\n","        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \\leq y_i \\leq` `number of labels`.\n","        - output: :math:`(N, C, H, W)`.\n","\n","    Examples::\n","        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True)\n","        >>> adv_images = attack(images, labels)\n","\n","    \"\"\"\n","\n","    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True, num_classes=10):\n","        super().__init__(\"PGD\", model)\n","        self.eps = eps\n","        self.alpha = alpha\n","        self.steps = steps\n","        self.random_start = random_start\n","        self.supported_mode = ['default', 'targeted']\n","        self.num_classes = num_classes\n","\n","    def forward(self, images, labels):\n","        r\"\"\"\n","        Overridden.\n","        \"\"\"\n","\n","        images = images.clone().detach().to(self.device)\n","        labels = labels.clone().detach().to(self.device)\n","\n","        ones = torch.ones_like(labels)\n","        multipliers = -1 * (ones - 2 * ones * (labels == self.num_classes))\n","\n","        adv_images = images.clone().detach()\n","\n","        if self.random_start:\n","            # Starting at a uniformly random point\n","            adv_images = adv_images + \\\n","                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n","            adv_images = torch.clamp(adv_images, min=0, max=1).detach()\n","\n","        for _ in range(self.steps):\n","            adv_images.requires_grad = True\n","            outputs = self.get_logits(adv_images)\n","\n","            target_labels = torch.full_like(labels, self.num_classes)\n","            cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')\n","            losses = cross_entropy_loss(outputs, target_labels)\n","\n","            cost = torch.mean(losses * multipliers)\n","\n","            # Update adversarial images\n","            grad = torch.autograd.grad(cost, adv_images,\n","                                       retain_graph=False, create_graph=False)[0]\n","\n","            adv_images = adv_images.detach() + self.alpha * grad.sign()\n","            delta = torch.clamp(adv_images - images,\n","                                min=-self.eps, max=self.eps)\n","            adv_images = torch.clamp(images + delta, min=0, max=1).detach()\n","\n","        return adv_images\n"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:04.491907Z","iopub.status.busy":"2023-05-24T19:53:04.491216Z","iopub.status.idle":"2023-05-24T19:53:04.505299Z","shell.execute_reply":"2023-05-24T19:53:04.504295Z","shell.execute_reply.started":"2023-05-24T19:53:04.491876Z"},"id":"d90c131e","papermill":{"duration":0.028675,"end_time":"2023-04-23T16:50:03.375215","exception":false,"start_time":"2023-04-23T16:50:03.346540","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["\n","class PGD_CLS(Attack):\n","    r\"\"\"\n","    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'\n","    [https://arxiv.org/abs/1706.06083]\n","    Distance Measure : Linf\n","    Arguments:\n","        model (nn.Module): model to attack.\n","        eps (float): maximum perturbation. (Default: 8/255)\n","        alpha (float): step size. (Default: 2/255)\n","        steps (int): number of steps. (Default: 10)\n","        random_start (bool): using random initialization of delta. (Default: True)\n","    Shape:\n","        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].\n","        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \\leq y_i \\leq` `number of labels`.\n","        - output: :math:`(N, C, H, W)`.\n","    Examples::\n","        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True)\n","        >>> adv_images = attack(images, labels)\n","    \"\"\"\n","\n","    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True):\n","        super().__init__(\"PGD\", model)\n","        self.eps = eps\n","        self.alpha = alpha\n","        self.steps = steps\n","        self.random_start = random_start\n","        self.supported_mode = ['default', 'targeted']\n","\n","    def forward(self, images, labels):\n","        r\"\"\"\n","        Overridden.\n","        \"\"\"\n","\n","        images = images.clone().detach().to(self.device)\n","        labels = labels.clone().detach().to(self.device)\n","\n","        if self.targeted:\n","            target_labels = self.get_target_label(images, labels)\n","\n","        loss = nn.CrossEntropyLoss()\n","        adv_images = images.clone().detach()\n","\n","        if self.random_start:\n","            # Starting at a uniformly random point\n","            adv_images = adv_images + \\\n","                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)\n","            adv_images = torch.clamp(adv_images, min=0, max=1).detach()\n","\n","        for _ in range(self.steps):\n","            adv_images.requires_grad = True\n","            outputs = self.get_logits(adv_images)\n","\n","            # Calculate loss\n","            if self.targeted:\n","                cost = -loss(outputs, target_labels)\n","            else:\n","                cost = loss(outputs, labels)\n","\n","            # Update adversarial images\n","            grad = torch.autograd.grad(cost, adv_images,\n","                                       retain_graph=False, create_graph=False)[0]\n","\n","            adv_images = adv_images.detach() + self.alpha*grad.sign()\n","            delta = torch.clamp(adv_images - images,\n","                                min=-self.eps, max=self.eps)\n","            adv_images = torch.clamp(images + delta, min=0, max=1).detach()\n","\n","        return  adv_images"]},{"cell_type":"markdown","metadata":{"id":"297a06e2","papermill":{"duration":0.014316,"end_time":"2023-04-23T16:50:03.403765","exception":false,"start_time":"2023-04-23T16:50:03.389449","status":"completed"},"tags":[]},"source":["# Data"]},{"cell_type":"markdown","metadata":{"id":"4QMJvQ247TrM"},"source":["## Datasets"]},{"cell_type":"code","execution_count":10,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:04.509564Z","iopub.status.busy":"2023-05-24T19:53:04.508919Z","iopub.status.idle":"2023-05-24T19:53:04.519802Z","shell.execute_reply":"2023-05-24T19:53:04.518720Z","shell.execute_reply.started":"2023-05-24T19:53:04.509528Z"},"id":"k_75a4Vr7aNY","trusted":true},"outputs":[],"source":["class ImageNetDataset(torch.utils.data.Dataset):\n","    def __init__(self, data, label, transform=None):\n","        self.data = data\n","        self.label = label\n","        self.transform = transform\n","\n","    def __len__(self):\n","        return len(self.data)\n","\n","    def __getitem__(self, idx):\n","        img = self.data[idx]\n","        if self.transform:\n","            img = self.transform(img)\n","        return img, self.label\n","\n","class FakesDataset(torch.utils.data.Dataset):\n","    def __init__(self, path, label, transform=None):\n","        samples = np.load(path)\n","        self.data = [Image.fromarray(sample[0]) for sample in samples]\n","        self.label = label\n","        self.transform = transform\n","\n","    def __len__(self):\n","        return len(self.data)\n","\n","    def __getitem__(self, idx):\n","        img = self.data[idx]\n","        if self.transform:\n","            img = self.transform(img)\n","        return img, self.label\n"]},{"cell_type":"markdown","metadata":{"id":"RnKwcR-t7VVb"},"source":["## DataLoader"]},{"cell_type":"code","execution_count":11,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:04.521850Z","iopub.status.busy":"2023-05-24T19:53:04.521457Z","iopub.status.idle":"2023-05-24T19:53:04.777306Z","shell.execute_reply":"2023-05-24T19:53:04.776392Z","shell.execute_reply.started":"2023-05-24T19:53:04.521820Z"},"id":"a144fab1","papermill":{"duration":0.041167,"end_time":"2023-04-23T16:50:03.526834","exception":false,"start_time":"2023-04-23T16:50:03.485667","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import os\n","import random\n","import torch\n","import torchvision\n","from torchvision.datasets import CIFAR100\n","from torchvision.datasets import CIFAR10\n","from torchvision import transforms\n","from torchvision.transforms import functional as F\n","from glob import glob\n","import numpy as np\n","from PIL import Image\n","\n","    \n","def get_selected_indices_from_path(fakes_path):\n","    filename = os.path.basename(fakes_path)\n","    parts = filename.split(\" \")\n","    in_indices = [parts[0][-1]] + parts[1:-1] + [parts[-1][0]]\n","    in_indices = [int(x) for x in in_indices]\n","    return in_indices\n","\n","    \n","def get_loaders(fakes_path=None):\n","    train_data_transform = [transforms.Resize(32), transforms.Grayscale(num_output_channels=3), transforms.ToTensor()]\n","    transform_fakes = [transforms.Resize(32), transforms.Grayscale(num_output_channels=3), transforms.ToTensor()]\n","\n","    train_dataset = eval(f'torchvision.datasets.{in_dataset}(\"./{in_dataset}\", train=True, download=True, transform=transforms.Compose(train_data_transform))')\n","    test_dataset = eval(f'torchvision.datasets.{in_dataset}(\"./{in_dataset}\", train=False, download=True, transform=transforms.Compose(train_data_transform))')\n","\n","    # Get Number of Unique Labels\n","    unique_train_labels = set([t.item() for t in train_dataset.targets])\n","    num_unique_train_labels = len(unique_train_labels)\n","\n","    # Create a dataset with the selected ImageNet images and the last label (num_classes - 1)\n","    \n","    fakes_label = num_classes\n","    exposure_train_data = FakesDataset(fakes_path, fakes_label, transforms.Compose(transform_fakes))\n","    \n","    # Pick in-distribution classes as indicated in the file path\n","    normal_classes = get_selected_indices_from_path(fakes_path)\n","    abnormal_classes = [c for c in range(10) if c not in normal_classes]\n","\n","    print(f'Normal Classes: {normal_classes}')\n","    print(f'Abnormal Classes: {abnormal_classes}')\n","\n","    # Create a mapping from the original labels to the new labels (0-5 for in-distribution, 6 for out-of-distribution)\n","    label_mapping = {c: i for i, c in enumerate(normal_classes)}\n","    label_mapping.update({c: num_classes for c in abnormal_classes})\n","\n","    # Modify the targets of the CIFAR-10 train and test datasets\n","    train_dataset.targets = [label_mapping[target.item()] for target in train_dataset.targets]\n","    test_dataset.targets = [label_mapping[target.item()] for target in test_dataset.targets]\n","\n","    # Remove abnormal classes from the train dataset\n","    normal_indices = [i for i, target in enumerate(train_dataset.targets) if target != num_classes]\n","    train_dataset.data = train_dataset.data[normal_indices]\n","    train_dataset.targets = [target for target in train_dataset.targets if target != num_classes]\n","\n","    trainset = torch.utils.data.ConcatDataset([train_dataset, exposure_train_data])\n","    testset = test_dataset\n","\n","    trainloader = DataLoader(trainset, shuffle=True, batch_size=batch_size)\n","    testloader = DataLoader(testset, shuffle=False, batch_size=batch_size//2)\n","    shuffled_testloader = DataLoader(testset, shuffle=True, batch_size=batch_size//2)\n","\n","    # just for double check!\n","    from collections import Counter\n","\n","    def count_labels(dataset):\n","        return Counter(label for _, label in dataset)\n","\n","    # Count and print the number of samples for each label in each dataset\n","    train_label_counts = count_labels(trainset)\n","    test_label_counts = count_labels(testset)\n","\n","    print(\"Train dataset label counts:\")\n","    for label, count in train_label_counts.items():\n","        print(f\"Label {label}: {count}\")\n","\n","    print(\"Test dataset label counts:\")\n","    for label, count in test_label_counts.items():\n","        print(f\"Label {label}: {count}\")\n","    \n","    print(f\"Length of train dataset: {len(trainset)}\")\n","    print(f\"Length of test dataset: {len(testset)}\")\n","    \n","    return trainloader, testloader"]},{"cell_type":"markdown","metadata":{"id":"8a73c2c8","papermill":{"duration":0.013948,"end_time":"2023-04-23T16:50:03.667409","exception":false,"start_time":"2023-04-23T16:50:03.653461","status":"completed"},"tags":[]},"source":["# Utils"]},{"cell_type":"code","execution_count":15,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:48.000075Z","iopub.status.busy":"2023-05-24T19:53:47.999724Z","iopub.status.idle":"2023-05-24T19:53:48.030824Z","shell.execute_reply":"2023-05-24T19:53:48.029839Z","shell.execute_reply.started":"2023-05-24T19:53:48.000045Z"},"id":"25ec0fae","papermill":{"duration":0.665452,"end_time":"2023-04-23T16:50:04.346979","exception":false,"start_time":"2023-04-23T16:50:03.681527","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["from tqdm import tqdm\n","from sklearn.metrics import roc_auc_score\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import os\n","import csv\n","\n","def auc_softmax_adversarial(model, test_loader, test_attack,  device, num_classes):\n","\n","    is_train = model.training\n","    model.eval()\n","\n","    soft = torch.nn.Softmax(dim=1)\n","    anomaly_scores = []\n","    preds = []\n","    test_labels = []\n","\n","    with tqdm(test_loader, unit=\"batch\") as tepoch:\n","        torch.cuda.empty_cache()\n","        for i, (data, target) in enumerate(tepoch):\n","            data, target = data.to(device), target.to(device)\n","\n","            adv_data = test_attack(data, target)\n","            output = model(adv_data)\n","\n","            predictions = output.argmax(dim=1, keepdim=True).squeeze()\n","            preds += predictions.detach().cpu().numpy().tolist()\n","\n","            probs = soft(output).squeeze()\n","            \n","            target = target == num_classes\n","            \n","            test_labels += target.detach().cpu().numpy().tolist()\n","\n","    auc = roc_auc_score(test_labels, anomaly_scores)\n","\n","    if is_train:\n","        model.train()\n","    else:\n","        model.eval()\n","\n","    return auc\n","\n","def auc_softmax(model, test_loader, device, num_classes):\n","\n","    is_train = model.training\n","    model.eval()\n","\n","    soft = torch.nn.Softmax(dim=1)\n","    anomaly_scores = []\n","    preds = []\n","    test_labels = []\n","\n","    with torch.no_grad():\n","        with tqdm(test_loader, unit=\"batch\") as tepoch:\n","            torch.cuda.empty_cache()\n","            for i, (data, target) in enumerate(tepoch):\n","                data, target = data.to(device), target.to(device)\n","                output = model(data)\n","\n","                predictions = output.argmax(dim=1, keepdim=True).squeeze()\n","                preds += predictions.detach().cpu().numpy().tolist()\n","\n","                probs = soft(output).squeeze()\n","                anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist()\n","\n","                target = target == num_classes\n","                \n","                test_labels += target.detach().cpu().numpy().tolist()\n","\n","    auc = roc_auc_score(test_labels, anomaly_scores)\n","\n","    if is_train:\n","        model.train()\n","    else:\n","        model.eval()\n","\n","    return auc\n","\n","\n","\n","lr_schedule = lambda learning_rate, t, max_epochs: np.interp([t], [0, max_epochs // 3, max_epochs * 2 // 3, max_epochs], [learning_rate, learning_rate/10, learning_rate / 100, 0])[0]\n","\n","def run(model, train_attack, test_attack, trainloader, testloader, test_step:int, max_epochs:int, device, loss_threshold=1e-3, num_classes=10, learning_rate=0.1):\n","    \n","    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)\n","    \n","    criterion = nn.CrossEntropyLoss()\n","\n","    init_epoch = 0\n","\n","    print(f'Starting Run from epoch {init_epoch}')\n","    \n","    for epoch in range(init_epoch, max_epochs):\n","    \n","        torch.cuda.empty_cache()\n","\n","        print(f'====== Starting Training on epoch {epoch}')\n","        train_loss = train_one_epoch(epoch=epoch,\\\n","                                    max_epochs=max_epochs, \\\n","                                    model=model,\\\n","                                    optimizer=optimizer,\n","                                    criterion=criterion,\\\n","                                    trainloader=trainloader,\\\n","                                    train_attack=train_attack,\\\n","                                    lr=learning_rate,\\\n","                                    device=device)\n","        \n","        print(\"train loss is \", train_loss)\n","        \n","        if train_loss < loss_threshold:\n","            break\n","\n","        if epoch % test_step == 0 :\n","\n","            print(f'AUC Vanila - Started...')\n","            clean_auc  = auc_softmax(model=model,  test_loader=testloader, device=device, num_classes=num_classes)\n","            print(f'AUC Vanila - score on epoch {epoch} is: {clean_auc * 100}')\n","\n","            attack_name = 'PGD-10'\n","            print(f'AUC Adversarial - {attack_name} - Started...')\n","            adv_auc = auc_softmax_adversarial(model=model, test_loader=testloader, test_attack=test_attack, device=device, num_classes=num_classes)\n","            print(f'AUC Adversairal {attack_name} - score on epoch {epoch} is: {adv_auc * 100}')\n","\n","            torch.cuda.empty_cache()\n","\n","\n","def train_one_epoch(epoch, max_epochs, model, optimizer, criterion, trainloader, train_attack, lr, device):\n","\n","    soft = torch.nn.Softmax(dim=1)\n","\n","    preds = []\n","    true_labels = []\n","    running_loss = 0\n","    accuracy = 0\n","\n","    model.train()\n","    with tqdm(trainloader, unit=\"batch\") as tepoch:\n","        torch.cuda.empty_cache()\n","        for i, (data, target) in enumerate(tepoch):\n","\n","            tepoch.set_description(f\"Epoch {epoch + 1}/{max_epochs}\")\n","            updated_lr = lr_schedule(learning_rate=lr, t=epoch + (i + 1) / len(tepoch), max_epochs=max_epochs)\n","            optimizer.param_groups[0].update(lr=updated_lr)\n","\n","            data, target = data.to(device), target.to(device)\n","            target = target.type(torch.LongTensor).cuda()\n","\n","            # Adversarial attack on every batch\n","            data = train_attack(data, target)\n","\n","            # Zero gradients for every batch\n","            optimizer.zero_grad()\n","\n","            output = model(data)\n","\n","            # Compute the loss and its gradients\n","            loss = criterion(output, target)\n","            loss.backward()\n","\n","            # Adjust learning weights\n","            optimizer.step()\n","\n","            true_labels += target.detach().cpu().numpy().tolist()\n","\n","            predictions = output.argmax(dim=1, keepdim=True).squeeze()\n","            preds += predictions.detach().cpu().numpy().tolist()\n","            correct = (torch.tensor(preds) == torch.tensor(true_labels)).sum().item()\n","            accuracy = correct / len(preds)\n","\n","            probs = soft(output).squeeze()\n","\n","            running_loss += loss.item() * data.size(0)\n","\n","            tepoch.set_postfix(loss=running_loss / len(preds), accuracy=100. * accuracy)\n","\n","    return  running_loss / len(preds)"]},{"cell_type":"markdown","metadata":{"id":"0RCQvoIREg2y"},"source":["# Visualization"]},{"cell_type":"code","execution_count":13,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:05.480300Z","iopub.status.busy":"2023-05-24T19:53:05.479614Z","iopub.status.idle":"2023-05-24T19:53:05.492918Z","shell.execute_reply":"2023-05-24T19:53:05.492102Z","shell.execute_reply.started":"2023-05-24T19:53:05.480267Z"},"id":"llX2fIEZEkAy","trusted":true},"outputs":[],"source":["from matplotlib import pyplot as plt\n","\n","def visualize_samples(dataloader, n, title=\"Sample\"):\n","    normal_samples = []\n","    abnormal_samples = []\n","\n","    def to_3_channels(image):\n","        if image.shape[0] == 1:\n","            return image.repeat(3, 1, 1)\n","        return image\n","\n","    # Collect n x n samples\n","    for images, labels in dataloader:\n","        for i, l in enumerate(labels):\n","            image = to_3_channels(images[i])\n","            if len(normal_samples) < n * n and l == 0:\n","                normal_samples.append(image)\n","            elif len(abnormal_samples) < n * n and l != 0:\n","                abnormal_samples.append(image)\n","            if len(normal_samples) == n * n and len(abnormal_samples) == n * n:\n","                break\n","        if len(normal_samples) == n * n and len(abnormal_samples) == n * n:\n","            break\n","\n","    normal_grid = torchvision.utils.make_grid(normal_samples, nrow=n)\n","    abnormal_grid = torchvision.utils.make_grid(abnormal_samples, nrow=n)\n","\n","\n","    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(18, 8))\n","    fig.patch.set_alpha(0)\n","    fig.suptitle(title, fontsize=16)\n","\n","    axs[0].imshow(normal_grid.permute(1, 2, 0))\n","    axs[0].set_title('Normal', fontsize=14)\n","    axs[0].axis('off')\n","\n","    axs[1].imshow(abnormal_grid.permute(1, 2, 0))\n","    axs[1].set_title('Abnormal', fontsize=14)\n","    axs[1].axis('off')\n","\n","    plt.show()"]},{"cell_type":"markdown","metadata":{"id":"8091738d","papermill":{"duration":0.014219,"end_time":"2023-04-23T16:50:04.375730","exception":false,"start_time":"2023-04-23T16:50:04.361511","status":"completed"},"tags":[]},"source":["# Training and Testing"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-05-24T19:53:50.935342Z","iopub.status.busy":"2023-05-24T19:53:50.934929Z","iopub.status.idle":"2023-05-24T19:54:29.385687Z","shell.execute_reply":"2023-05-24T19:54:29.384359Z","shell.execute_reply.started":"2023-05-24T19:53:50.935310Z"},"id":"0fcd36ef","papermill":{"duration":5286.540449,"end_time":"2023-04-23T18:18:34.746187","exception":false,"start_time":"2023-04-23T16:50:28.205738","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import torch\n","from torchvision.models import resnet18\n","from glob import glob\n","\n","fakes_paths = glob(fakes_paths_reg)\n","\n","for i in range(no_trials):\n","    print('run', i, \"started\")\n","    \n","    model = WideResNet(40, num_classes+1, 4,  dropRate=0.0).to(device)\n","    \n","    trainloader, testloader = get_loaders(fakes_path=fakes_paths[i])\n","\n","    # Visualization\n","    visualize_samples(trainloader, 8, \"Trainset\")\n","    visualize_samples(testloader, 8, \"Testset\")\n","    \n","    train_attack = PGD_CLS(model, eps=attack_eps, steps=10, alpha=attack_alpha)\n","    test_attack = PGD_TEST(model, eps=attack_eps, steps=100, alpha=attack_alpha/10, num_classes=num_classes)\n","\n","    run(model, train_attack, test_attack, trainloader, testloader, 1, 10, device, num_classes=num_classes)"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["TkGzHEBA6KaT","xljjSpsRzADQ","f974ece4","64e94984","2caba216","0681d390","RnKwcR-t7VVb","0RCQvoIREg2y"],"gpuType":"T4","provenance":[]},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.10"},"papermill":{"default_parameters":{},"duration":5453.965645,"end_time":"2023-04-23T18:19:22.922648","environment_variables":{},"exception":null,"input_path":"__notebook__.ipynb","output_path":"__notebook__.ipynb","parameters":{},"start_time":"2023-04-23T16:48:28.957003","version":"2.4.0"}},"nbformat":4,"nbformat_minor":5}
