{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'torchvision'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-6dea78d8567c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mtimer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchvision\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtransforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchvision'"
     ]
    }
   ],
   "source": [
    "## Create dataset of perturbed CIFAR10 under all domains\n",
    "## In the following, a valid question (no pun intended) is: why have a validation set? The reason is that the models\n",
    "## used to generate adv attacks were trained on the same split of training/validation which may affect the distributions\n",
    "## of adversarial attacks produced from samples seen during training by the model or not (in the case of validation).\n",
    "## WARNING\n",
    "## WARNING\n",
    "## WARNING: IF YOU MODIFIED THE SEED TO TRAIN THE MODEL, AND WANT TO RUN THIS CODE, BE MINDFUL OF THE AFOREMENTIONED.\n",
    "\n",
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "timer = 0\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import advertorch.attacks as attacks\n",
    "from attacks.deepfool import DeepfoolLinfAttack\n",
    "import torch.nn as nn\n",
    "from autoattack import AutoAttack\n",
    "\n",
    "from advertorch.context import ctx_noparamgrad_and_eval\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "import json\n",
    "\n",
    "# import argparse\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 100\n",
    "\n",
    "# proportion of full training set used for validation\n",
    "valid_size = 0.2\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "transform_to_tensor = transforms.ToTensor()\n",
    "\n",
    "train_and_valid_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transform_to_tensor)\n",
    "test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transform_to_tensor)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# transform = transforms.ToTensor()\n",
    "# train_and_valid_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transform)\n",
    "# test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transform)\n",
    "\n",
    "num_valid_samples = int(np.floor(valid_size * len(train_and_valid_data)))\n",
    "num_train_samples = len(train_and_valid_data) - num_valid_samples\n",
    "train_data, valid_data = torch.utils.data.random_split(train_and_valid_data, [num_train_samples, num_valid_samples], generator=torch.Generator().manual_seed(seed))\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)\n",
    "valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(self.expansion * planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = F.relu(self.bn2(self.conv2(out)))\n",
    "        out = self.bn3(self.conv3(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, block, num_blocks, num_classes=10):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.in_planes = 64\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
    "        self.linear = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, num_blocks, stride):\n",
    "        strides = [stride] + [1] * (num_blocks - 1)\n",
    "        layers = []\n",
    "        for stride in strides:\n",
    "            layers.append(block(self.in_planes, planes, stride))\n",
    "            self.in_planes = planes * block.expansion\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "        out = F.avg_pool2d(out, 4)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.linear(out)\n",
    "#         print(x.size(), out.size())\n",
    "        return out\n",
    "\n",
    "\n",
    "def ResNet18():\n",
    "    return ResNet(BasicBlock, [2, 2, 2, 2])\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "model = ResNet18()\n",
    "model.to(device)\n",
    "\n",
    "if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "    print(\"Using DataParallel\")\n",
    "    model = torch.nn.DataParallel(model)\n",
    "# model.load_state_dict(torch.load('model_ResNet18.pt'))\n",
    "# model.to(device)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# divided by 10 eps, eps_iter and CW's lr, added as input binary_search_steps to CW attacks\n",
    "\n",
    "\n",
    "adversary_PGD_Linf_std = attacks.LinfPGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=8/255,\n",
    "    nb_iter=40, eps_iter=2/255, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "adversary_PGD_L2_std = attacks.L2PGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=0.5,\n",
    "    nb_iter=40, eps_iter=2/255, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "adversary_PGD_L1_std = attacks.L1PGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=10.,\n",
    "    nb_iter=40, eps_iter=2/255, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "adversary_CW = attacks.CarliniWagnerL2Attack(\n",
    "    model, num_classes=10, max_iterations=20, learning_rate=0.01,\n",
    "    binary_search_steps=5, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "adversary_deepfool = DeepfoolLinfAttack(\n",
    "        model, num_classes=10, nb_iter=30, eps=0.011, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "# Unseen attacks used for validation, has bigger learning rate and number of iterations. CHANGED PGD Linf eps iter to 12/255 AND CW LR to 0.0115\n",
    "adversary_CW_unseen = attacks.CarliniWagnerL2Attack(\n",
    "    model, num_classes=10, max_iterations=30, learning_rate=0.012,\n",
    "    binary_search_steps=7, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "adversary_PGD_Linf_unseen = attacks.LinfPGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=12/255,\n",
    "    nb_iter=70, eps_iter=2/255, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "adversary_deepfool_unseen = DeepfoolLinfAttack(\n",
    "        model, num_classes=10, nb_iter=50, eps=8/255, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "adversary_autoattack_unseen = AutoAttack(model, norm='Linf', eps=8/255, \n",
    "        version='standard', seed=None, verbose=False)\n",
    "\n",
    "adversary_autoattack_L2_unseen = AutoAttack(model, norm='L2', eps=0.5, \n",
    "        version='standard', seed=None, verbose=False)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def generate_domains(domain_name, data, label, batch_size=batch_size_test, bool_correct_preds_per_domain={}):\n",
    "    if len(bool_correct_preds_per_domain) == 0:\n",
    "        mask = torch.ones_like(label)\n",
    "    else:\n",
    "        mask = bool_correct_preds_per_domain[domain_name]\n",
    "    masked_data = data[mask, :, :, :]\n",
    "    masked_label = label[mask]\n",
    "\n",
    "    # All the data might have been masked. In that case return None.\n",
    "    if len(masked_data) == 0:\n",
    "        return None\n",
    "\n",
    "    if domain_name == 'clean':\n",
    "        return masked_data\n",
    "    if domain_name == 'PGD_L1_std':\n",
    "        return adversary_PGD_L1_std.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'PGD_L2_std':\n",
    "        return adversary_PGD_L2_std.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'PGD_Linf_std':\n",
    "        return adversary_PGD_Linf_std.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'Deepfool_base':\n",
    "        return adversary_deepfool.perturb(masked_data, masked_label)\n",
    "    if domain_name == \"CW_base\":\n",
    "        return adversary_CW.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'PGD_Linf_mod':\n",
    "        return adversary_PGD_Linf_unseen.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'Deepfool_mod':\n",
    "        return adversary_deepfool_unseen.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'CW_mod':\n",
    "        return adversary_CW_unseen.perturb(masked_data, masked_label)\n",
    "    if domain_name == \"Autoattack\":\n",
    "        return adversary_autoattack_unseen.run_standard_evaluation(masked_data, masked_label, bs=len(masked_label))\n",
    "    if domain_name == \"Autoattack_L2\":\n",
    "        return adversary_autoattack_L2_unseen.run_standard_evaluation(masked_data, masked_label, bs=len(masked_label))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def loss_helper(model, data_all_domains, label_all_domains, num_domains, num_correct_per_domain, tensor_list_losses_epoch):\n",
    "    list_losses = []\n",
    "    \n",
    "    for domain in range(0, num_domains):\n",
    "        preds = model(data_all_domains[domain])\n",
    "        list_losses.append(F.cross_entropy(preds, label_all_domains[domain]))\n",
    "        num_correct_per_domain[domain] += ((torch.argmax(preds, dim=1) == label_all_domains[domain]).sum().item())\n",
    "    \n",
    "    # Some spaghetti going on here between torch and lists types, as evidenced by how the loss_helper() is called in compute_loss()\n",
    "    tensor_list_losses = torch.stack(list_losses)\n",
    "    \n",
    "    ERM_term = torch.sum(tensor_list_losses) / num_domains\n",
    "    REx_variance_term = torch.var(tensor_list_losses)\n",
    "    \n",
    "    tensor_list_losses_epoch += tensor_list_losses\n",
    "    \n",
    "    return ERM_term, REx_variance_term\n",
    "\n",
    "def REx_loss(ERM_term, REx_variance_term, beta):\n",
    "    return beta * REx_variance_term + ERM_term\n",
    "\n",
    " \n",
    "def compute_loss(is_REx, beta, loss_terms, model, list_data_all_domains, list_label_all_domains, num_domains, \n",
    "                 num_train_correct_preds_per_domain, tensor_list_losses_epoch_train):\n",
    "    if is_REx:\n",
    "        ERM_term, REx_variance_term = loss_helper(model, list_data_all_domains, list_label_all_domains, num_domains, num_train_correct_preds_per_domain, tensor_list_losses_epoch_train)\n",
    "        loss_terms_temp = [ERM_term.item(), REx_variance_term.item()]\n",
    "        loss_terms += np.array(loss_terms_temp)\n",
    "        loss = REx_loss(ERM_term, REx_variance_term, beta)\n",
    "    else:\n",
    "        ERM_term, _ = loss_helper(model, list_data_all_domains, list_label_all_domains, num_domains, num_train_correct_preds_per_domain, tensor_list_losses_epoch_train)\n",
    "        loss_terms += np.array([ERM_term.item()])\n",
    "        loss = ERM_term\n",
    "    return loss\n",
    "\n",
    "\n",
    "# Keep track across restarts of which samples were still correctly predicted, for each attack\n",
    "def track_correct_pred_per_domain(model, data_all_domains, labels, domains, bool_correct_per_domain):\n",
    "    for domain in domains:\n",
    "        # Case when the mask filtered all data\n",
    "        if data_all_domains[domain] == None:\n",
    "            continue\n",
    "\n",
    "        preds = model(data_all_domains[domain])\n",
    "        # bool_correct_per_domain[domain] = torch.logical_and(bool_correct_per_domain[domain], (torch.argmax(preds, dim=1) == label_all_domains[domain]))\n",
    "\n",
    "        # Array sizes of preds and bool_correct are different because of the mask when generating the domains, so handling it manually. Maybe\n",
    "        # there is/will be a native method to handle this but gotta go fast.\n",
    "        mask = bool_correct_per_domain[domain]\n",
    "        are_preds_right = (torch.argmax(preds, dim=1) == labels[mask])\n",
    "        i = 0\n",
    "        for k in range(len(bool_correct_per_domain[domain])):\n",
    "            if bool_correct_per_domain[domain][k]:\n",
    "                bool_correct_per_domain[domain][k] = are_preds_right[i]\n",
    "                i += 1\n",
    "    return\n",
    "\n",
    "# Compute the number of correct predictions against each attack after all the restarts\n",
    "def update_num_correct_pred_per_domain(num_correct_per_domain, bool_correct_per_domain, domains):\n",
    "    for domain in domains:\n",
    "        num_correct_per_domain[domain] += bool_correct_per_domain[domain].sum().item()\n",
    "    return\n",
    "\n",
    "# Compute the number of correct predictions if the attacker was using an ensemble of all attacks. Skip the attacks in skipped_domains_worst_case from calculation.\n",
    "def get_num_correct_worst_case(bool_correct_per_domain, domains, skipped_domains_worst_case=[]):\n",
    "    # TODO WARNING\n",
    "    # TODO WARNING\n",
    "    if len(domains) == 0:\n",
    "        raise ValueError(\"No domain has been defined !\")\n",
    "    \n",
    "    bool_correct_worst_case = torch.ones_like(bool_correct_per_domain[domains[0]], dtype=torch.bool)\n",
    "    for domain in domains:\n",
    "        if domain in skipped_domains_worst_case:\n",
    "            continue\n",
    "        bool_correct_worst_case = torch.logical_and(bool_correct_worst_case, bool_correct_per_domain[domain])\n",
    "\n",
    "    return bool_correct_worst_case.sum().item()\n",
    "\n",
    "# Get which attacks were seen based on model filename. TODO consider renaming to \"seen_domains\" as clean counts here.\n",
    "def get_seen_attacks(model_name):\n",
    "    split_model_name = model_name.split('_')\n",
    "    seen_attacks = []\n",
    "    if \"MSD\" in split_model_name:\n",
    "        if \"ERM\" in split_model_name:\n",
    "            seen_attacks = ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']\n",
    "        else:\n",
    "            seen_attacks = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']\n",
    "    if \"PGDs\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']\n",
    "    if \"std\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_Linf_std', 'Deepfool_base', 'CW_base']\n",
    "    if \"clean\" in split_model_name:\n",
    "        seen_attacks = ['clean']\n",
    "    if \"L1\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_L1_std']\n",
    "    if \"L2\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_L2_std']\n",
    "    if \"Linf\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_Linf_std']\n",
    "    return seen_attacks\n",
    "\n",
    "\n",
    "\n",
    "def generate_augmented_data_from_loader(model, data_loader, data_shape, batch_size, domains):\n",
    "    adversarial_data = (torch.zeros(data_shape), torch.zeros(data_shape[0]))\n",
    "    for i_loader, (data, label) in enumerate(data_loader):\n",
    "        data, label = data.to(device), label.to(device)\n",
    "\n",
    "        # Keeps track for each sample and each domain of whether one restart succeeded in fooling the network by using logical and\n",
    "        # on (label == prediction) and bool_track_correct_pred each iteration. fb trackers are appended later in the code\n",
    "        bool_track_correct_pred_per_domain = {}\n",
    "        for domain in domains:\n",
    "            bool_track_correct_pred_per_domain[domain] = torch.ones_like(label, dtype=torch.bool)\n",
    "\n",
    "\n",
    "        with ctx_noparamgrad_and_eval(model):\n",
    "            # Clean data is a domain.\n",
    "            data_all_domains = {}\n",
    "            for domain in domains:\n",
    "                data_all_domains[domain] = generate_domains(domain, data, label, batch_size=batch_size, bool_correct_preds_per_domain=bool_track_correct_pred_per_domain)\n",
    "\n",
    "        \n",
    "        # Compute index where current batch start and ends in the entire data\n",
    "        idx_start = i_loader * batch_size\n",
    "        idx_end = idx_start + len(data)\n",
    "        for i_domains, domain in enumerate(domains):\n",
    "            adversarial_data[0][idx_start:idx_end, i_domains, :, :, :] = data_all_domains[domain].cpu()\n",
    "            adversarial_data[1][idx_start:idx_end] = label.cpu()\n",
    "\n",
    "\n",
    "    return adversarial_data\n",
    "\n",
    "def generate_augmented_data_from_loader_and_save(model, data_loader, data_shape, batch_size, domains, working_dir_of_save, data_filename):\n",
    "    adversarial_data = generate_augmented_data_from_loader(model, data_loader, data_shape, batch_size, domains)\n",
    "    if not os.path.exists(working_dir_of_save):\n",
    "        os.mkdir(working_dir_of_save)\n",
    "    torch.save(adversarial_data, working_dir_of_save + data_filename)\n",
    "    del adversarial_data\n",
    "    return\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "WORKING_DIR = \"results/CIFAR10/\"\n",
    "TRAINED_MODEL_PATH = WORKING_DIR + \"models/\"\n",
    "for root, dirs, files in os.walk(TRAINED_MODEL_PATH):\n",
    "    model_filenames = files\n",
    "    model_paths = [TRAINED_MODEL_PATH + file for file in files]\n",
    "    # checkpoint = {'current_model': model.state_dict(),\n",
    "    #                 'epoch': 0}\n",
    "    # torch.save(checkpoint, model_paths[0])\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "domains = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base',\n",
    "                'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack']\n",
    "skipped_domains_worst_case = ['CW_mod']\n",
    "\n",
    "\n",
    "num_test_batches = len(test_loader)\n",
    "\n",
    "num_domains = len(domains)\n",
    "\n",
    "    \n",
    "    \n",
    "######################    \n",
    "# test the model #\n",
    "######################\n",
    "model.eval()\n",
    "\n",
    "\n",
    "\n",
    "for model_num, model_path in enumerate(model_paths):\n",
    "    # checkpoint = torch.load(\"experiments/MNIST/MLP/pretrained_hard_PGD/REx_waterfall_lr_init_0.01/model_AIT_REx_3040.pt\")\n",
    "    checkpoint = torch.load(model_path)\n",
    "    # checkpoint = torch.load(\"model_MNIST_MSD_250.pt\")\n",
    "    starting_epoch = checkpoint['epoch']\n",
    "    # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons\n",
    "    # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even\n",
    "    # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model\n",
    "    try:\n",
    "        model.load_state_dict(checkpoint['current_model'])\n",
    "    except:\n",
    "        print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "        model.module.load_state_dict(checkpoint['current_model'])\n",
    "        print(\"Successfully loaded onto model.module.\")\n",
    "\n",
    "\n",
    "    \n",
    "    # Create tensors to be filled with perturbed data\n",
    "    adversarial_data_train_shape = [num_train_samples, num_domains, 3, 32, 32]\n",
    "    adversarial_data_valid_shape = [num_valid_samples, num_domains, 3, 32, 32]\n",
    "    adversarial_data_test_shape = [len(test_data), num_domains, 3, 32, 32]\n",
    "\n",
    "\n",
    "\n",
    "    model_name = model_filenames[model_num]\n",
    "    # Where the adversarial data will be saved\n",
    "    working_dir_of_save = \"data/adversarial-cifar10/\" + model_name + \"/\"\n",
    "\n",
    "\n",
    "    adversarial_data_train = generate_augmented_data_from_loader_and_save(model, train_loader, adversarial_data_train_shape, batch_size_train_and_valid, domains, working_dir_of_save, \"adversarial_data_train.pt\")\n",
    "    import resource\n",
    "    print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)\n",
    "    adversarial_data_valid = generate_augmented_data_from_loader_and_save(model, valid_loader, adversarial_data_valid_shape, batch_size_train_and_valid, domains, working_dir_of_save, \"adversarial_data_valid.pt\")\n",
    "    print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)\n",
    "    adversarial_data_test = generate_augmented_data_from_loader_and_save(model, test_loader, adversarial_data_test_shape, batch_size_test, domains, working_dir_of_save, \"adversarial_data_test.pt\")\n",
    "    print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)\n",
    "\n",
    "\n",
    "\n",
    "    # Save various metadata\n",
    "    metadata = {}\n",
    "    metadata[\"mapping_idx_to_transform\"] = []\n",
    "    for i, domain in enumerate(domains):\n",
    "        metadata[\"mapping_idx_to_transform\"].append(str(i) + \": \"+ domain)\n",
    "    metadata[\"model_used\"] = model_name\n",
    "    metadata[\"seed\"] = seed\n",
    "    with open(working_dir_of_save + \"metadata.json\", \"w\") as outfile:\n",
    "        json.dump(metadata, outfile)\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([40000])\n",
      "torch.Size([4014, 3, 3, 32, 32]) torch.Size([4014])\n",
      "torch.Size([10000])\n",
      "torch.Size([986, 3, 3, 32, 32]) torch.Size([986])\n",
      "The user did not set --resume to True. Training from scratch...\n",
      "Did not force a scheduler reset.\n",
      "Using DataParallel\n",
      "Epoch: 1 \tTraining Loss: 3.253968 \tValidation Loss: 3.230947\n",
      "Top 1 training accuracy: 0.395283 \tTop 1 validation accuracy: 0.407032\n",
      "Top 2 training accuracy: 0.735592 \tTop 2 validation accuracy: 0.746450\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (inf --> 3.230947). \n",
      "Epoch: 2 \tTraining Loss: 3.121182 \tValidation Loss: 3.169043\n",
      "Top 1 training accuracy: 0.450839 \tTop 1 validation accuracy: 0.434415\n",
      "Top 2 training accuracy: 0.785999 \tTop 2 validation accuracy: 0.756592\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.230947 --> 3.169043). \n",
      "Epoch: 3 \tTraining Loss: 3.061275 \tValidation Loss: 3.138187\n",
      "Top 1 training accuracy: 0.473260 \tTop 1 validation accuracy: 0.445909\n",
      "Top 2 training accuracy: 0.802275 \tTop 2 validation accuracy: 0.763354\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.169043 --> 3.138187). \n",
      "Epoch: 4 \tTraining Loss: 3.022740 \tValidation Loss: 3.118922\n",
      "Top 1 training accuracy: 0.487045 \tTop 1 validation accuracy: 0.451995\n",
      "Top 2 training accuracy: 0.809749 \tTop 2 validation accuracy: 0.774510\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.138187 --> 3.118922). \n",
      "Epoch: 5 \tTraining Loss: 2.994718 \tValidation Loss: 3.105989\n",
      "Top 1 training accuracy: 0.496097 \tTop 1 validation accuracy: 0.454699\n",
      "Top 2 training accuracy: 0.813735 \tTop 2 validation accuracy: 0.780933\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.118922 --> 3.105989). \n",
      "Epoch: 6 \tTraining Loss: 2.973235 \tValidation Loss: 3.096954\n",
      "Top 1 training accuracy: 0.501163 \tTop 1 validation accuracy: 0.457404\n",
      "Top 2 training accuracy: 0.816891 \tTop 2 validation accuracy: 0.781609\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.105989 --> 3.096954). \n",
      "Epoch: 7 \tTraining Loss: 2.956113 \tValidation Loss: 3.090620\n",
      "Top 1 training accuracy: 0.506145 \tTop 1 validation accuracy: 0.460784\n",
      "Top 2 training accuracy: 0.819963 \tTop 2 validation accuracy: 0.781609\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.096954 --> 3.090620). \n",
      "Epoch: 8 \tTraining Loss: 2.942108 \tValidation Loss: 3.086264\n",
      "Top 1 training accuracy: 0.509633 \tTop 1 validation accuracy: 0.463489\n",
      "Top 2 training accuracy: 0.822289 \tTop 2 validation accuracy: 0.782961\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.090620 --> 3.086264). \n",
      "Epoch: 9 \tTraining Loss: 2.930427 \tValidation Loss: 3.083356\n",
      "Top 1 training accuracy: 0.512789 \tTop 1 validation accuracy: 0.464165\n",
      "Top 2 training accuracy: 0.824033 \tTop 2 validation accuracy: 0.786004\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.086264 --> 3.083356). \n",
      "Epoch: 10 \tTraining Loss: 2.920532 \tValidation Loss: 3.081514\n",
      "Top 1 training accuracy: 0.515197 \tTop 1 validation accuracy: 0.465517\n",
      "Top 2 training accuracy: 0.825112 \tTop 2 validation accuracy: 0.783300\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.083356 --> 3.081514). \n",
      "Epoch: 11 \tTraining Loss: 2.912044 \tValidation Loss: 3.080477\n",
      "Top 1 training accuracy: 0.517273 \tTop 1 validation accuracy: 0.466531\n",
      "Top 2 training accuracy: 0.826524 \tTop 2 validation accuracy: 0.784652\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.081514 --> 3.080477). \n",
      "Epoch: 12 \tTraining Loss: 2.904690 \tValidation Loss: 3.080080\n",
      "Top 1 training accuracy: 0.518186 \tTop 1 validation accuracy: 0.468898\n",
      "Top 2 training accuracy: 0.826358 \tTop 2 validation accuracy: 0.787018\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Validation loss decreased (3.080477 --> 3.080080). \n",
      "Epoch: 13 \tTraining Loss: 2.898273 \tValidation Loss: 3.080227\n",
      "Top 1 training accuracy: 0.519930 \tTop 1 validation accuracy: 0.470588\n",
      "Top 2 training accuracy: 0.826939 \tTop 2 validation accuracy: 0.788371\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 14 \tTraining Loss: 2.892648 \tValidation Loss: 3.080878\n",
      "Top 1 training accuracy: 0.521591 \tTop 1 validation accuracy: 0.472617\n",
      "Top 2 training accuracy: 0.827354 \tTop 2 validation accuracy: 0.789723\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 15 \tTraining Loss: 2.887707 \tValidation Loss: 3.082044\n",
      "Top 1 training accuracy: 0.524165 \tTop 1 validation accuracy: 0.472279\n",
      "Top 2 training accuracy: 0.827188 \tTop 2 validation accuracy: 0.790061\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 16 \tTraining Loss: 2.883368 \tValidation Loss: 3.083780\n",
      "Top 1 training accuracy: 0.525909 \tTop 1 validation accuracy: 0.471940\n",
      "Top 2 training accuracy: 0.827936 \tTop 2 validation accuracy: 0.788371\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 17 \tTraining Loss: 2.879568 \tValidation Loss: 3.086186\n",
      "Top 1 training accuracy: 0.525743 \tTop 1 validation accuracy: 0.472955\n",
      "Top 2 training accuracy: 0.828102 \tTop 2 validation accuracy: 0.788709\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 18 \tTraining Loss: 2.876252 \tValidation Loss: 3.089392\n",
      "Top 1 training accuracy: 0.525909 \tTop 1 validation accuracy: 0.472279\n",
      "Top 2 training accuracy: 0.828185 \tTop 2 validation accuracy: 0.789047\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 19 \tTraining Loss: 2.873361 \tValidation Loss: 3.093503\n",
      "Top 1 training accuracy: 0.526325 \tTop 1 validation accuracy: 0.473293\n",
      "Top 2 training accuracy: 0.828517 \tTop 2 validation accuracy: 0.788371\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 20 \tTraining Loss: 2.870821 \tValidation Loss: 3.098485\n",
      "Top 1 training accuracy: 0.524581 \tTop 1 validation accuracy: 0.475997\n",
      "Top 2 training accuracy: 0.828268 \tTop 2 validation accuracy: 0.789047\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 21 \tTraining Loss: 2.868524 \tValidation Loss: 3.103991\n",
      "Top 1 training accuracy: 0.524498 \tTop 1 validation accuracy: 0.477012\n",
      "Top 2 training accuracy: 0.828932 \tTop 2 validation accuracy: 0.789385\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 22 \tTraining Loss: 2.866319 \tValidation Loss: 3.109236\n",
      "Top 1 training accuracy: 0.524415 \tTop 1 validation accuracy: 0.477012\n",
      "Top 2 training accuracy: 0.829181 \tTop 2 validation accuracy: 0.789723\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 23 \tTraining Loss: 2.864034 \tValidation Loss: 3.113182\n",
      "Top 1 training accuracy: 0.524165 \tTop 1 validation accuracy: 0.478026\n",
      "Top 2 training accuracy: 0.829181 \tTop 2 validation accuracy: 0.789047\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 24 \tTraining Loss: 2.861531 \tValidation Loss: 3.115138\n",
      "Top 1 training accuracy: 0.524415 \tTop 1 validation accuracy: 0.477350\n",
      "Top 2 training accuracy: 0.828517 \tTop 2 validation accuracy: 0.790399\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 25 \tTraining Loss: 2.858785 \tValidation Loss: 3.115270\n",
      "Top 1 training accuracy: 0.524996 \tTop 1 validation accuracy: 0.477012\n",
      "Top 2 training accuracy: 0.828019 \tTop 2 validation accuracy: 0.790399\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 26 \tTraining Loss: 2.855912 \tValidation Loss: 3.114409\n",
      "Top 1 training accuracy: 0.526491 \tTop 1 validation accuracy: 0.477350\n",
      "Top 2 training accuracy: 0.828434 \tTop 2 validation accuracy: 0.791413\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 27 \tTraining Loss: 2.853092 \tValidation Loss: 3.113403\n",
      "Top 1 training accuracy: 0.527238 \tTop 1 validation accuracy: 0.479378\n",
      "Top 2 training accuracy: 0.828600 \tTop 2 validation accuracy: 0.790399\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 28 \tTraining Loss: 2.850469 \tValidation Loss: 3.112710\n",
      "Top 1 training accuracy: 0.528401 \tTop 1 validation accuracy: 0.481068\n",
      "Top 2 training accuracy: 0.829679 \tTop 2 validation accuracy: 0.790399\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 29 \tTraining Loss: 2.848099 \tValidation Loss: 3.112440\n",
      "Top 1 training accuracy: 0.530560 \tTop 1 validation accuracy: 0.481068\n",
      "Top 2 training accuracy: 0.830427 \tTop 2 validation accuracy: 0.790399\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 30 \tTraining Loss: 2.845972 \tValidation Loss: 3.112532\n",
      "Top 1 training accuracy: 0.531141 \tTop 1 validation accuracy: 0.481744\n",
      "Top 2 training accuracy: 0.830842 \tTop 2 validation accuracy: 0.791075\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 31 \tTraining Loss: 2.844050 \tValidation Loss: 3.112891\n",
      "Top 1 training accuracy: 0.531224 \tTop 1 validation accuracy: 0.482421\n",
      "Top 2 training accuracy: 0.830842 \tTop 2 validation accuracy: 0.791075\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 32 \tTraining Loss: 2.842294 \tValidation Loss: 3.113433\n",
      "Top 1 training accuracy: 0.531971 \tTop 1 validation accuracy: 0.483097\n",
      "Top 2 training accuracy: 0.831340 \tTop 2 validation accuracy: 0.791413\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 33 \tTraining Loss: 2.840673 \tValidation Loss: 3.114100\n",
      "Top 1 training accuracy: 0.531971 \tTop 1 validation accuracy: 0.483435\n",
      "Top 2 training accuracy: 0.831839 \tTop 2 validation accuracy: 0.791075\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 34 \tTraining Loss: 2.839163 \tValidation Loss: 3.114852\n",
      "Top 1 training accuracy: 0.531888 \tTop 1 validation accuracy: 0.482759\n",
      "Top 2 training accuracy: 0.832088 \tTop 2 validation accuracy: 0.792427\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 35 \tTraining Loss: 2.837747 \tValidation Loss: 3.115663\n",
      "Top 1 training accuracy: 0.532054 \tTop 1 validation accuracy: 0.482759\n",
      "Top 2 training accuracy: 0.832088 \tTop 2 validation accuracy: 0.792765\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 36 \tTraining Loss: 2.836413 \tValidation Loss: 3.116516\n",
      "Top 1 training accuracy: 0.532719 \tTop 1 validation accuracy: 0.482082\n",
      "Top 2 training accuracy: 0.832005 \tTop 2 validation accuracy: 0.791751\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 37 \tTraining Loss: 2.835151 \tValidation Loss: 3.117399\n",
      "Top 1 training accuracy: 0.532968 \tTop 1 validation accuracy: 0.482082\n",
      "Top 2 training accuracy: 0.832171 \tTop 2 validation accuracy: 0.791751\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 38 \tTraining Loss: 2.833954 \tValidation Loss: 3.118302\n",
      "Top 1 training accuracy: 0.533134 \tTop 1 validation accuracy: 0.483435\n",
      "Top 2 training accuracy: 0.832171 \tTop 2 validation accuracy: 0.791413\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 39 \tTraining Loss: 2.832815 \tValidation Loss: 3.119220\n",
      "Top 1 training accuracy: 0.533715 \tTop 1 validation accuracy: 0.482759\n",
      "Top 2 training accuracy: 0.832420 \tTop 2 validation accuracy: 0.791075\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 40 \tTraining Loss: 2.831730 \tValidation Loss: 3.120146\n",
      "Top 1 training accuracy: 0.533964 \tTop 1 validation accuracy: 0.482421\n",
      "Top 2 training accuracy: 0.832420 \tTop 2 validation accuracy: 0.790399\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 41 \tTraining Loss: 2.830693 \tValidation Loss: 3.121078\n",
      "Top 1 training accuracy: 0.534131 \tTop 1 validation accuracy: 0.483097\n",
      "Top 2 training accuracy: 0.832669 \tTop 2 validation accuracy: 0.790399\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 42 \tTraining Loss: 2.829702 \tValidation Loss: 3.122012\n",
      "Top 1 training accuracy: 0.534380 \tTop 1 validation accuracy: 0.482421\n",
      "Top 2 training accuracy: 0.832669 \tTop 2 validation accuracy: 0.790737\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 43 \tTraining Loss: 2.828752 \tValidation Loss: 3.122946\n",
      "Top 1 training accuracy: 0.534131 \tTop 1 validation accuracy: 0.482082\n",
      "Top 2 training accuracy: 0.832669 \tTop 2 validation accuracy: 0.791075\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 44 \tTraining Loss: 2.827842 \tValidation Loss: 3.123879\n",
      "Top 1 training accuracy: 0.534214 \tTop 1 validation accuracy: 0.482082\n",
      "Top 2 training accuracy: 0.832586 \tTop 2 validation accuracy: 0.791075\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 45 \tTraining Loss: 2.826968 \tValidation Loss: 3.124808\n",
      "Top 1 training accuracy: 0.534463 \tTop 1 validation accuracy: 0.480730\n",
      "Top 2 training accuracy: 0.832918 \tTop 2 validation accuracy: 0.791413\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n",
      "Epoch: 46 \tTraining Loss: 2.826128 \tValidation Loss: 3.125733\n",
      "Top 1 training accuracy: 0.534629 \tTop 1 validation accuracy: 0.481068\n",
      "Top 2 training accuracy: 0.832752 \tTop 2 validation accuracy: 0.791413\n",
      "Top 3 training accuracy: 1.000000 \tTop 3 validation accuracy: 1.000000\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_27282/2299174750.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    544\u001b[0m             \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m#, y=label)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    545\u001b[0m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdomain_label\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 546\u001b[0;31m             \u001b[0mvalid_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    547\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    548\u001b[0m             \u001b[0;31m# Update count of number of correct predictions per domain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "## Train domain classifier on perturbed data with a ViT. You'll need torch 1.13 here and for the eval.\n",
    "##\n",
    "## Original plan was to do it with a class conditional ResNet18 (note the new operation on last layer),\n",
    "## see https://arxiv.org/abs/1802.05637 and https://arxiv.org/abs/2006.04621 (<- has a repo), but it actually\n",
    "## is much worse than the pretrained ViT.\n",
    "##\n",
    "## Make sure to select attacks and classes you're interested in by modifying the training flags.\n",
    "\n",
    "\n",
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "\n",
    "import torchvision\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# import advertorch.attacks as attacks\n",
    "# from attacks.deepfool import DeepfoolLinfAttack\n",
    "import torch.nn as nn\n",
    "# from autoattack import AutoAttack\n",
    "\n",
    "from advertorch.context import ctx_noparamgrad_and_eval\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "import argparse\n",
    "from distutils.util import strtobool\n",
    "from math import pi\n",
    "import json\n",
    "\n",
    "import pandas as pd\n",
    "import seaborn as sbn\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Optional. Initial learning rate value, default=0.1.\")\n",
    "# argument_parser.add_argument(\"--wd\", type=float, help=\"Optional. Weight decay value for the optimiser. Default: 5e-4 for CIFAR10, or past value from checkpoint if resuming.\")\n",
    "# argument_parser.add_argument(\"--output_suffix\", type=str, help=\"Optional. Suffix for path where files should be created.\")\n",
    "# argument_parser.add_argument(\"--only_classes\", type=lambda s: [float(x) for x in s.split(',')], help=\"Optional. Restrict analysis to specified CIFAR10 classes (by index in CIFAR10).\")\n",
    "# argument_parser.add_argument(\"--only_domains\", type=lambda s: [float(x) for x in s.split(',')], help=\"Optional. Restrict analysis to specified domains (by domain indices, see metadata.json).\")\n",
    "# argument_parser.add_argument(\"--only_perturbation\", type=lambda x: bool(strtobool(x)), help=\"Optional. Only consider added perturbation (subtract original image) in analysis.\")\n",
    "# argument_parser.add_argument(\"--standardise\", type=lambda x: bool(strtobool(x)), help=\"Optional. Recommanded with --only_perturbation. Standardise data during preprocessing.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "domains = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base',\n",
    "                'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack']\n",
    "\n",
    "class TrainingFlags():\n",
    "    def __init__(self):\n",
    "        self.wd = None\n",
    "        self.lr_init = None\n",
    "        # Only CIFAR10 classes to consider perturbating\n",
    "        self.only_classes = [0.]\n",
    "        # Only attacks/indices to consider in domains array\n",
    "        self.only_domains = [3, 7]\n",
    "        # Subtract unperturbed x from x_adv\n",
    "        self.only_perturbation = True\n",
    "        self.standardise = True\n",
    "        # Standardise per image instead of per dataset\n",
    "        self.instance_standardisation = True\n",
    "        # Output of training defaults to path based on chosen training flags. This allows adding a suffing to path.\n",
    "        self.output_suffix = \"\"\n",
    "        # Heavily recommend not modifying this -- old class conditional ResNet was not working well and is untested since adding ViT/using torch 1.13 reqs.\n",
    "        self.model_type = \"ViT\"\n",
    "        \n",
    "parsed_args = TrainingFlags()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 100\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Need to be in place to save memory\n",
    "def preprocess_data(data_path, parsed_args=parsed_args):\n",
    "    # Note: tensor_data in this function does not include original dataset label, which is in tensor_label\n",
    "    tensor_data, tensor_label = torch.load(data_path)\n",
    "\n",
    "    # Only consider the added perturbations as data, by subtracting the original sample from each adversarial example.\n",
    "    # Assumes 0th dim is original unperturbed sample #, 1st dim corresponds to perturbation used.\n",
    "    if parsed_args.only_perturbation is True:\n",
    "        num_domains = tensor_data.shape[1]\n",
    "        for i_domains in range(0, num_domains):\n",
    "            tensor_data[:, i_domains] -= tensor_data[:, 0]\n",
    "\n",
    "    # Only consider some perturbation domains if applicable\n",
    "    if parsed_args.only_domains is not None:\n",
    "        tensor_data = tensor_data[:, parsed_args.only_domains]\n",
    "\n",
    "    # Only consider some CIFAR10 classes if applicable\n",
    "    if parsed_args.only_classes is not None:\n",
    "        indices = []\n",
    "        for iter_x, x in enumerate(tensor_label):\n",
    "            if x in parsed_args.only_classes:\n",
    "                indices.append(iter_x)\n",
    "        print(tensor_label.shape)\n",
    "        tensor_label = tensor_label[indices]\n",
    "        tensor_data = tensor_data[indices]\n",
    "\n",
    "    # Standardise TODO ONLY USE TRAINING\n",
    "    if parsed_args.standardise is not (None or False):\n",
    "        # Get indices of tensor_data over which to standardise. If instance standardisation, we standardise per sample, not over dataset.\n",
    "        # Index 0 corresponds to which original CIFAR10 data sample, index 2 to RGB channels.\n",
    "        if parsed_args.instance_standardisation == True:\n",
    "            standardisation_dimensions = [x for x in range(0, len(tensor_data.shape)) if x != 0 and x != 2]\n",
    "        else:\n",
    "            standardisation_dimensions = [x for x in range(0, len(tensor_data.shape)) if x != 2]\n",
    "        means = tensor_data.mean(dim=standardisation_dimensions, keepdim=True)\n",
    "        stds = tensor_data.std(dim=standardisation_dimensions, keepdim=True)\n",
    "        tensor_data = (tensor_data - means) / stds\n",
    "\n",
    "    print(tensor_data.shape, tensor_label.shape)\n",
    "    return tensor_data, tensor_label\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "ViT transforms\n",
    "\"\"\"\n",
    "# Resize the input images\n",
    "resize_transform = transforms.Resize((224, 224))\n",
    "\n",
    "# Define a transform to reshape the input data\n",
    "class CustomTensorDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"TensorDataset with support of transforms.\n",
    "    \"\"\"\n",
    "    def __init__(self, tensors, transform=None):\n",
    "        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)\n",
    "        self.tensors = tensors\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        x = self.tensors[0][index]\n",
    "\n",
    "        if self.transform:\n",
    "            x = self.transform(x)\n",
    "\n",
    "        y = self.tensors[1][index]\n",
    "\n",
    "        return x, y\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.tensors[0].size(0)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_path = \"data/adversarial-cifar10/model_PGD_Linf_103.pt/\"\n",
    "\n",
    "tensor_train_data = preprocess_data(data_path+\"adversarial_data_train.pt\", parsed_args)\n",
    "train_data = CustomTensorDataset(tensor_train_data, transform=resize_transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "# del train_data, train_loader\n",
    "tensor_valid_data = preprocess_data(data_path+\"adversarial_data_valid.pt\", parsed_args)\n",
    "valid_data = CustomTensorDataset(tensor_valid_data, transform=resize_transform)\n",
    "valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "with open(data_path + \"metadata.json\", 'r') as json_file:\n",
    "    metadata = json.load(json_file)\n",
    "# del valid_data, valid_loader\n",
    "# test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transform_to_tensor)\n",
    "# test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(self.expansion * planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = F.relu(self.bn2(self.conv2(out)))\n",
    "        out = self.bn3(self.conv3(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class ClassConditionalResNet(nn.Module):\n",
    "    def __init__(self, block, num_blocks, num_output_classes=10, num_original_classes=10):\n",
    "        super(ClassConditionalResNet, self).__init__()\n",
    "        self.in_planes = 64\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
    "        self.class_conditional = torch.nn.utils.spectral_norm(nn.Embedding(num_original_classes, 512 * block.expansion))\n",
    "        self.linear = nn.Linear(512 * block.expansion, num_output_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, num_blocks, stride):\n",
    "        strides = [stride] + [1] * (num_blocks - 1)\n",
    "        layers = []\n",
    "        for stride in strides:\n",
    "            layers.append(block(self.in_planes, planes, stride))\n",
    "            self.in_planes = planes * block.expansion\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x, y):\n",
    "        features = F.relu(self.bn1(self.conv1(x)))\n",
    "        features = self.layer1(features)\n",
    "        features = self.layer2(features)\n",
    "        features = self.layer3(features)\n",
    "        features = self.layer4(features)\n",
    "        features = F.avg_pool2d(features, 4)\n",
    "        features = features.view(features.size(0), -1)\n",
    "        out = self.linear(features)\n",
    "        if y is not None:\n",
    "            out += torch.sum(self.class_conditional(y) * features, dim=1, keepdim=True)\n",
    "#         print(x.size(), out.size())\n",
    "        return out\n",
    "\n",
    "# Output classes are the classes the model can output (domains), original classes are the labels (e.g. CIFAR10 labels) of the unperturbed dataset\n",
    "def ClassConditionalResNet18(num_output_classes=10, num_original_classes=10):\n",
    "    return ClassConditionalResNet(BasicBlock, [2, 2, 2, 2], num_output_classes=num_output_classes, num_original_classes=num_original_classes)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def prepare_model(parsed_args, num_original_classes=10, num_output_classes=10, model_type=\"ViT\"):\n",
    "    if parsed_args.only_classes is None:\n",
    "        if parsed_args.only_domains is None:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18()\n",
    "        else:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18(num_output_classes=len(parsed_args.only_domains))\n",
    "    else:\n",
    "        if parsed_args.only_domains is None:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18(num_original_classes=len(parsed_args.only_classes))\n",
    "        else:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18(num_original_classes=len(parsed_args.only_classes), num_output_classes=len(parsed_args.only_domains))\n",
    "    model.to(device)\n",
    "\n",
    "    # if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "    #     print(\"Using DataParallel\")\n",
    "    #     model = torch.nn.DataParallel(model)\n",
    "\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad = False\n",
    "    # Set output layers to out_features=num_output_classes since model is pretrained on ImageNet\n",
    "    model.heads = torch.nn.Linear(in_features=768, out_features=len(parsed_args.only_domains), bias=True)\n",
    "    return model\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def prepare_data_for_attack_discriminator(data, label, num_domains, device=device):\n",
    "    data, label = data.to(device), label.to(device).long()\n",
    "    # Number of unperturbed samples in batch. Generally = batch_size except the last iter of epoch\n",
    "    num_unperturbed_samples_batch = data.shape[0]\n",
    "    # Reformats data so first axis is domain\n",
    "    data = data.swapaxes(0,1)\n",
    "    # Reformats data so data[:batch_size] is all the batch samples from the first domain, etc.\n",
    "    data_stacked_domains_shape = [num_domains * num_unperturbed_samples_batch] + list(data.shape)[2:]\n",
    "    data = data.reshape(data_stacked_domains_shape)\n",
    "    domain_label = torch.LongTensor(range(0, num_domains))\n",
    "    domain_label = (torch.transpose(domain_label.repeat(num_unperturbed_samples_batch,1), 0, 1)).flatten()\n",
    "    domain_label = domain_label.to(device)\n",
    "    label = label.repeat(num_domains)\n",
    "    return data, label, domain_label\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "save_interval = 1\n",
    "top_k = 3\n",
    "\n",
    "\n",
    "if parsed_args.only_domains is not None:\n",
    "    domains = [domains[int(i)] for i in parsed_args.only_domains]\n",
    "\n",
    "num_domains = len(domains)\n",
    "if num_domains < top_k:\n",
    "    print(\"There are less domains than k for top_k accuracy, top_k set to num_domains\")\n",
    "    top_k = num_domains\n",
    "\n",
    "model = prepare_model(parsed_args, model_type=parsed_args.model_type)\n",
    "model.to(device)\n",
    "# number of epochs to train the model\n",
    "n_epochs_AIT = 1001\n",
    "# initialize tracker for minimum validation loss\n",
    "valid_loss_min = np.Inf  # set initial \"min\" to infinity\n",
    "lr_init = 0.1\n",
    "best_epoch = 0\n",
    "starting_epoch = 0\n",
    "# # the following var decides when we start waterfalling in the REx term (includes the chosen epoch)\n",
    "# waterfall_epoch = 326\n",
    "momentum = 0.9\n",
    "if parsed_args.wd is not None:\n",
    "    weight_decay = parsed_args.wd\n",
    "else:\n",
    "    weight_decay = 5e-4\n",
    "# optimizer = torch.optim.SGD(model.parameters(), lr = lr_init, momentum=momentum, weight_decay=weight_decay)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "# schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)\n",
    "\n",
    "\n",
    "parsed_args_into_path = \"\"\n",
    "if parsed_args.only_classes is not None:\n",
    "    parsed_args_into_path += \"classes_\" + \"-\".join(str(int(s)) for s in parsed_args.only_classes)\n",
    "if parsed_args.only_domains is not None:\n",
    "    parsed_args_into_path += \"domains_\" + \"-\".join(str(int(s)) for s in parsed_args.only_domains)\n",
    "if parsed_args.only_perturbation is True:\n",
    "    parsed_args_into_path += \"only_pert\"\n",
    "if parsed_args.standardise is True:\n",
    "    if parsed_args.instance_standardisation is True:\n",
    "        parsed_args_into_path += \"instancestd\"\n",
    "    else:\n",
    "        parsed_args_into_path += \"std\"\n",
    "TRAINED_MODEL_PATH = \"experiments/discriminate_attacks/\" + data_path.split('/')[-2] + \"/\" + parsed_args.model_type + \"/\" + parsed_args_into_path + parsed_args.output_suffix + \"_ERM/\"\n",
    "\n",
    "path_of_checkpoint = \"\"\n",
    "writer = SummaryWriter(TRAINED_MODEL_PATH)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# MODIFY GIVEN PARSED ARGUMENTS\n",
    "try:\n",
    "    if parsed_args.lr_init is not None:\n",
    "        lr_init = parsed_args.lr_init\n",
    "        optimizer.param_groups[0]['lr'] = lr_init\n",
    "        print(\"lr_init set to %f by user\" % parsed_args.lr_init)\n",
    "        \n",
    "except:\n",
    "    print(\"No learning rate passed as argument; using default/past value of \", lr_init)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Parallelise model if possible\n",
    "if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "    print(\"Using DataParallel\")\n",
    "    model = torch.nn.DataParallel(model)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for epoch in range(starting_epoch, n_epochs_AIT):\n",
    "    # Keep track of total (REx or ERM) losses\n",
    "    train_loss = 0\n",
    "    valid_loss = 0\n",
    "\n",
    "    which_batch_train = 1\n",
    "    num_training_batches_in_epoch = len(train_loader)\n",
    "    which_batch_valid = 1\n",
    "    num_validation_batches = len(valid_loader)\n",
    "\n",
    "\n",
    "    topk_accuracies_train = torch.zeros(top_k)\n",
    "    topk_accuracies_valid = torch.zeros(top_k)\n",
    "\n",
    "    confusion_matrix_train = torch.zeros([num_domains, num_domains])\n",
    "    confusion_matrix_valid = torch.zeros([num_domains, num_domains])\n",
    "\n",
    "\n",
    "    ###################\n",
    "    # Train the model #\n",
    "    ###################\n",
    "    model.train()\n",
    "    for data, label in train_loader:\n",
    "        data, label, domain_label = prepare_data_for_attack_discriminator(data, label, num_domains, device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        preds = model(data)#, y=label)\n",
    "        loss = F.cross_entropy(preds, domain_label)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss += loss.item() * data.size(0)\n",
    "        \n",
    "        # Update count of number of correct predictions per domain\n",
    "        with torch.no_grad():\n",
    "            predicted_topk = torch.topk(preds, top_k, dim=1).indices\n",
    "            for pred, target in zip(predicted_topk, domain_label):\n",
    "                confusion_matrix_train[int(target), int(pred[0])] += 1\n",
    "                for iter_topk in range(0, top_k):\n",
    "                    if target in pred[:iter_topk+1]:\n",
    "                        topk_accuracies_train[iter_topk] += 1\n",
    "\n",
    "        which_batch_train += 1\n",
    "\n",
    "\n",
    "\n",
    "    ######################    \n",
    "    # Validate the model #\n",
    "    ######################\n",
    "    model.eval()\n",
    "    for _, (data, label) in enumerate(valid_loader):\n",
    "        data, label, domain_label = prepare_data_for_attack_discriminator(data, label, num_domains, device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            preds = model(data)#, y=label)\n",
    "            loss = F.cross_entropy(preds, domain_label)\n",
    "            valid_loss += loss.item() * data.size(0)\n",
    "            \n",
    "            # Update count of number of correct predictions per domain\n",
    "            predicted_topk = torch.topk(preds, top_k, dim=1).indices\n",
    "            for pred, target in zip(predicted_topk, domain_label):\n",
    "                confusion_matrix_valid[int(target), int(pred[0])] += 1\n",
    "                for iter_topk in range(0, top_k):\n",
    "                    if target in pred[:iter_topk+1]:\n",
    "                        topk_accuracies_valid[iter_topk] += 1\n",
    "\n",
    "        which_batch_valid += 1\n",
    "\n",
    "\n",
    "\n",
    "    train_loss = train_loss / len(train_loader.sampler)\n",
    "    valid_loss = valid_loss / len(valid_loader.sampler)\n",
    "\n",
    "    training_acc = torch.sum(confusion_matrix_train.diagonal()) / (len(train_loader.sampler) * num_domains)\n",
    "    valid_acc = torch.sum(confusion_matrix_valid.diagonal()) / (len(valid_loader.sampler) * num_domains)\n",
    "\n",
    "    topk_accuracies_train /= (len(train_loader.sampler) * num_domains)\n",
    "    topk_accuracies_valid /= (len(valid_loader.sampler) * num_domains)\n",
    "\n",
    "    if training_acc != topk_accuracies_train[0]:\n",
    "        print(training_acc)\n",
    "        print(topk_accuracies_train)\n",
    "        raise ValueError(\"Something's wrong with the topk acc.\")\n",
    "\n",
    "    print(\"Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}\".format(\n",
    "        epoch+1, \n",
    "        train_loss,\n",
    "        valid_loss\n",
    "        ))\n",
    "\n",
    "\n",
    "    for iter_topk in range(0, top_k):\n",
    "        print(\"Top {} training accuracy: {:.6f} \\tTop {} validation accuracy: {:.6f}\".format(\n",
    "            iter_topk+1,\n",
    "            topk_accuracies_train[iter_topk],\n",
    "            iter_topk+1,\n",
    "            topk_accuracies_valid[iter_topk]\n",
    "            ), flush=True)\n",
    "\n",
    "    if valid_loss <= valid_loss_min:\n",
    "        print('Validation loss decreased ({:.6f} --> {:.6f}). '.format(\n",
    "        valid_loss_min,\n",
    "        valid_loss))\n",
    "        best_epoch = epoch\n",
    "        valid_loss_min = valid_loss\n",
    "        \n",
    "    path_of_checkpoint = TRAINED_MODEL_PATH + 'model_ERM_' + str(epoch+1) + '.pt'\n",
    "    \n",
    "    # epoch_lr = schedule.get_last_lr()[0]\n",
    "    # schedule.step()\n",
    "    # # lr to start from in checkpoint\n",
    "    # lr_init = schedule.get_last_lr()[0]\n",
    "    \n",
    "    checkpoint = {'current_model': model.module.state_dict(),\n",
    "                  'optimiser': optimizer.state_dict(),\n",
    "                #   'schedule': schedule.state_dict(),\n",
    "                  'learning_rate': lr_init,\n",
    "                  'epoch': epoch + 1,\n",
    "                  'best_epoch': best_epoch,\n",
    "                  'seed': seed\n",
    "                 }\n",
    "\n",
    "    if (epoch+1) % save_interval == 0: #and epoch+1 >= 50:\n",
    "        torch.save(checkpoint, path_of_checkpoint)\n",
    "    \n",
    "    # writer.add_scalar('Learning_rate', epoch_lr, epoch+1)\n",
    "    writer.add_scalar('Momentum', momentum, epoch+1)\n",
    "    writer.add_scalar('Weight_decay', weight_decay, epoch+1)\n",
    "\n",
    "    writer.add_scalar('Training_loss', train_loss, epoch+1)\n",
    "    writer.add_scalar('Validation_loss', valid_loss, epoch+1)\n",
    "\n",
    "    writer.add_scalar('Training_accuracy', training_acc, epoch+1)\n",
    "    writer.add_scalar('Validation_accuracy', valid_acc, epoch+1)\n",
    "\n",
    "    # writer.\n",
    "\n",
    "\n",
    "    \n",
    "writer.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test domain classifier on perturbed data\n",
    "\n",
    "\n",
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "\n",
    "import torchvision\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# import advertorch.attacks as attacks\n",
    "# from attacks.deepfool import DeepfoolLinfAttack\n",
    "import torch.nn as nn\n",
    "# from autoattack import AutoAttack\n",
    "\n",
    "from advertorch.context import ctx_noparamgrad_and_eval\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "import argparse\n",
    "from distutils.util import strtobool\n",
    "from math import pi\n",
    "import json\n",
    "\n",
    "import pandas as pd\n",
    "import seaborn as sbn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Optional. Initial learning rate value, default=0.1.\")\n",
    "# argument_parser.add_argument(\"--wd\", type=float, help=\"Optional. Weight decay value for the optimiser. Default: 5e-4 for CIFAR10, or past value from checkpoint if resuming.\")\n",
    "# argument_parser.add_argument(\"--output_suffix\", type=str, help=\"Optional. Suffix for path where files should be created.\")\n",
    "# argument_parser.add_argument(\"--only_classes\", type=lambda s: [float(x) for x in s.split(',')], help=\"Optional. Restrict analysis to specified CIFAR10 classes (by index in CIFAR10).\")\n",
    "# argument_parser.add_argument(\"--only_domains\", type=lambda s: [float(x) for x in s.split(',')], help=\"Optional. Restrict analysis to specified domains (by domain indices, see metadata.json).\")\n",
    "# argument_parser.add_argument(\"--only_perturbation\", type=lambda x: bool(strtobool(x)), help=\"Optional. Only consider added perturbation (subtract original image) in analysis.\")\n",
    "# argument_parser.add_argument(\"--standardise\", type=lambda x: bool(strtobool(x)), help=\"Optional. Recommanded with --only_perturbation. Standardise data during preprocessing.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "class TrainingFlags():\n",
    "    def __init__(self):\n",
    "        # Only CIFAR10 classes to consider perturbating\n",
    "        self.only_classes = [0.]\n",
    "        # Only attacks/indices to consider in domains array\n",
    "        self.only_domains = [3, 7]\n",
    "        # Subtract unperturbed x from x_adv\n",
    "        self.only_perturbation = True\n",
    "        self.standardise = True\n",
    "        # Standardise per image instead of per dataset\n",
    "        self.instance_standardisation = True\n",
    "parsed_args = TrainingFlags()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 100\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Need to be in place to save memory\n",
    "def preprocess_data(data_path, parsed_args=parsed_args):\n",
    "    # Note: tensor_data in this function does not include original dataset label, which is in tensor_label\n",
    "    tensor_data, tensor_label = torch.load(data_path)\n",
    "\n",
    "    # Only consider the added perturbations as data, by subtracting the original sample from each adversarial example.\n",
    "    # Assumes 0th dim is original unperturbed sample #, 1st dim corresponds to perturbation used.\n",
    "    if parsed_args.only_perturbation is True:\n",
    "        num_domains = tensor_data.shape[1]\n",
    "        for i_domains in range(0, num_domains):\n",
    "            tensor_data[:, i_domains] -= tensor_data[:, 0]\n",
    "\n",
    "    # Only consider some perturbation domains if applicable\n",
    "    if parsed_args.only_domains is not None:\n",
    "        tensor_data = tensor_data[:, parsed_args.only_domains]\n",
    "\n",
    "    # Only consider some CIFAR10 classes if applicable\n",
    "    if parsed_args.only_classes is not None:\n",
    "        indices = []\n",
    "        for iter_x, x in enumerate(tensor_label):\n",
    "            if x in parsed_args.only_classes:\n",
    "                indices.append(iter_x)\n",
    "        print(tensor_label.shape)\n",
    "        tensor_label = tensor_label[indices]\n",
    "        tensor_data = tensor_data[indices]\n",
    "\n",
    "    # Standardise TODO ONLY USE TRAINING\n",
    "    if parsed_args.standardise is not (None or False):\n",
    "        # Get indices of tensor_data over which to standardise. If instance standardisation, we standardise per sample, not over dataset.\n",
    "        # Index 0 corresponds to which original CIFAR10 data sample, index 2 to RGB channels.\n",
    "        if parsed_args.instance_standardisation == True:\n",
    "            standardisation_dimensions = [x for x in range(0, len(tensor_data.shape)) if x != 0 and x != 2]\n",
    "        else:\n",
    "            standardisation_dimensions = [x for x in range(0, len(tensor_data.shape)) if x != 2]\n",
    "        means = tensor_data.mean(dim=standardisation_dimensions, keepdim=True)\n",
    "        stds = tensor_data.std(dim=standardisation_dimensions, keepdim=True)\n",
    "        tensor_data = (tensor_data - means) / stds\n",
    "\n",
    "    print(tensor_data.shape, tensor_label.shape)\n",
    "    return tensor_data, tensor_label\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "ViT transforms\n",
    "\"\"\"\n",
    "# Resize the input images\n",
    "resize_transform = transforms.Resize((224, 224))\n",
    "\n",
    "# Define a transform to reshape the input data\n",
    "class CustomTensorDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"TensorDataset with support of transforms.\n",
    "    \"\"\"\n",
    "    def __init__(self, tensors, transform=None):\n",
    "        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)\n",
    "        self.tensors = tensors\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        x = self.tensors[0][index]\n",
    "\n",
    "        if self.transform:\n",
    "            x = self.transform(x)\n",
    "\n",
    "        y = self.tensors[1][index]\n",
    "\n",
    "        return x, y\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.tensors[0].size(0)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_path = \"data/adversarial-cifar10/model_PGD_Linf_103.pt/\"\n",
    "\n",
    "\n",
    "tensor_test_data = preprocess_data(data_path+\"adversarial_data_test.pt\", parsed_args)\n",
    "test_data = CustomTensorDataset(tensor_test_data, transform=resize_transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "with open(data_path + \"metadata.json\", 'r') as json_file:\n",
    "    metadata = json.load(json_file)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(self.expansion * planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = F.relu(self.bn2(self.conv2(out)))\n",
    "        out = self.bn3(self.conv3(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class ClassConditionalResNet(nn.Module):\n",
    "    def __init__(self, block, num_blocks, num_output_classes=10, num_original_classes=10):\n",
    "        super(ClassConditionalResNet, self).__init__()\n",
    "        self.in_planes = 64\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
    "        self.class_conditional = torch.nn.utils.spectral_norm(nn.Embedding(num_original_classes, 512 * block.expansion))\n",
    "        self.linear = nn.Linear(512 * block.expansion, num_output_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, num_blocks, stride):\n",
    "        strides = [stride] + [1] * (num_blocks - 1)\n",
    "        layers = []\n",
    "        for stride in strides:\n",
    "            layers.append(block(self.in_planes, planes, stride))\n",
    "            self.in_planes = planes * block.expansion\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x, y):\n",
    "        features = F.relu(self.bn1(self.conv1(x)))\n",
    "        features = self.layer1(features)\n",
    "        features = self.layer2(features)\n",
    "        features = self.layer3(features)\n",
    "        features = self.layer4(features)\n",
    "        features = F.avg_pool2d(features, 4)\n",
    "        features = features.view(features.size(0), -1)\n",
    "        out = self.linear(features)\n",
    "        if y is not None:\n",
    "            out += torch.sum(self.class_conditional(y) * features, dim=1, keepdim=True)\n",
    "#         print(x.size(), out.size())\n",
    "        return out\n",
    "\n",
    "# Output classes are the classes the model can output (domains), original classes are the labels (e.g. CIFAR10 labels) of the unperturbed dataset\n",
    "def ClassConditionalResNet18(num_output_classes=10, num_original_classes=10):\n",
    "    return ClassConditionalResNet(BasicBlock, [2, 2, 2, 2], num_output_classes=num_output_classes, num_original_classes=num_original_classes)\n",
    "\n",
    "def prepare_model(parsed_args, num_original_classes=10, num_output_classes=10, model_type=\"ViT\"):\n",
    "    if parsed_args.only_classes is None:\n",
    "        if parsed_args.only_domains is None:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18()\n",
    "        else:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18(num_output_classes=len(parsed_args.only_domains))\n",
    "    else:\n",
    "        if parsed_args.only_domains is None:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18(num_original_classes=len(parsed_args.only_classes))\n",
    "        else:\n",
    "            model = torchvision.models.vit_b_16(weights=\"IMAGENET1K_SWAG_LINEAR_V1\")#ClassConditionalResNet18(num_original_classes=len(parsed_args.only_classes), num_output_classes=len(parsed_args.only_domains))\n",
    "    model.to(device)\n",
    "\n",
    "    # if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "    #     print(\"Using DataParallel\")\n",
    "    #     model = torch.nn.DataParallel(model)\n",
    "\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad = False\n",
    "    # Set output layers to out_features=num_output_classes since model is pretrained on ImageNet\n",
    "    model.heads = torch.nn.Linear(in_features=768, out_features=len(parsed_args.only_domains), bias=True)\n",
    "    return model\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# if parsed_args.only_classes is None:\n",
    "#     if parsed_args.only_domains is None:\n",
    "#         model = ClassConditionalResNet18()\n",
    "#     else:\n",
    "#         model = ClassConditionalResNet18(num_output_classes=len(parsed_args.only_domains))\n",
    "# else:\n",
    "#     if parsed_args.only_domains is None:\n",
    "#         model = ClassConditionalResNet18(num_original_classes=len(parsed_args.only_classes))\n",
    "#     else:\n",
    "#         model = ClassConditionalResNet18(num_original_classes=len(parsed_args.only_classes), num_output_classes=len(parsed_args.only_domains))\n",
    "# # model = ClassConditionalResNet18(num_original_classes=len(parsed_args.only_domains), num_output_classes=len(parsed_args.only_domains))\n",
    "# model.to(device)\n",
    "\n",
    "# if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "#     print(\"Using DataParallel\")\n",
    "#     model = torch.nn.DataParallel(model)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def prepare_data_for_attack_discriminator(data, label, num_domains, device=device):\n",
    "    data, label = data.to(device), label.to(device).long()\n",
    "    # Number of unperturbed samples in batch. Generally = batch_size except the last iter of epoch\n",
    "    num_unperturbed_samples_batch = data.shape[0]\n",
    "    # Reformats data so first axis is domain\n",
    "    data = data.swapaxes(0,1)\n",
    "    # Reformats data so data[:batch_size] is all the batch samples from the first domain, etc.\n",
    "    data_stacked_domains_shape = [num_domains * num_unperturbed_samples_batch] + list(data.shape)[2:]\n",
    "    data = data.reshape(data_stacked_domains_shape)\n",
    "    domain_label = torch.LongTensor(range(0, num_domains))\n",
    "    domain_label = (torch.transpose(domain_label.repeat(num_unperturbed_samples_batch,1), 0, 1)).flatten()\n",
    "    domain_label = domain_label.to(device)\n",
    "    label = label.repeat(num_domains)\n",
    "    return data, label, domain_label\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "top_k = 3\n",
    "if len(parsed_args.only_domains) < 3:\n",
    "    top_k = len(parsed_args.only_domains)\n",
    "\n",
    "\n",
    "domains = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base',\n",
    "                'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack']\n",
    "domains_latex = [r\"None\", r\"$P_1$\", r\"$P_2$\", r\"$P_\\infty$\", r\"$\\textit{DF}_\\infty$\", r\"$\\textit{CW}_2$\",\n",
    "                r\"$P_\\infty^\\bullet$\", r\"$\\textit{DF}_\\infty^\\bullet$\", r\"$\\textit{CW}_2^\\bullet$\", r\"$\\textit{AA}_\\infty$\"]\n",
    "\n",
    "if parsed_args.only_domains is not None:\n",
    "    domains = [domains[int(i)] for i in parsed_args.only_domains]\n",
    "    domains_latex = [domains_latex[int(i)] for i in parsed_args.only_domains]\n",
    "\n",
    "num_domains = len(domains)\n",
    "model = prepare_model(parsed_args)\n",
    "model.to(device)\n",
    "\n",
    "WORKING_DIR = \"results/adversarial-cifar10/model_PGD_Linf_103.pt/ViT/\"#classes_0domains_3-7only_pertinstancestd_ERM/\"\n",
    "if parsed_args.only_classes is not None:\n",
    "    WORKING_DIR += \"classes_\" + \"-\".join(str(int(s)) for s in parsed_args.only_classes)\n",
    "if parsed_args.only_domains is not None:\n",
    "    WORKING_DIR += \"domains_\" + \"-\".join(str(int(s)) for s in parsed_args.only_domains)\n",
    "if parsed_args.only_perturbation is True:\n",
    "    WORKING_DIR += \"only_pert\"\n",
    "if parsed_args.standardise is True:\n",
    "    if parsed_args.instance_standardisation is True:\n",
    "        WORKING_DIR += \"instancestd\"\n",
    "    else:\n",
    "        WORKING_DIR += \"std\"\n",
    "WORKING_DIR += \"_ERM/\"\n",
    "\n",
    "TRAINED_MODEL_PATH = WORKING_DIR + \"models/\"\n",
    "for root, dirs, files in os.walk(TRAINED_MODEL_PATH):\n",
    "    model_filenames = files\n",
    "    model_paths = [TRAINED_MODEL_PATH + file for file in files]\n",
    "\n",
    "\n",
    "\n",
    "# adversarial-invariance/experiments/discriminate_attacks/model_PGD_Linf_103.pt/classes_0domains_1-2-3-4-5-6-7-8-9only_pertstd_ERM/model_ERM_127.pt\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Parallelise model if possible\n",
    "if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "    print(\"Using DataParallel\")\n",
    "    model = torch.nn.DataParallel(model)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for model_num, model_path in enumerate(model_paths):\n",
    "    # Loading the PAT model is slightly different\n",
    "    checkpoint = torch.load(model_path)\n",
    "    # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons\n",
    "    # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even\n",
    "    # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model\n",
    "    try:\n",
    "        model.load_state_dict(checkpoint['current_model'])\n",
    "    except:\n",
    "        print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "        model.module.load_state_dict(checkpoint['current_model'])\n",
    "        print(\"Successfully loaded onto model.module.\")\n",
    "\n",
    "    # Keep track of total (REx or ERM) losses\n",
    "    test_loss = 0\n",
    "\n",
    "    which_batch_test = 1\n",
    "    num_test_batches = len(test_loader)\n",
    "\n",
    "    topk_accuracies_test = torch.zeros(top_k)\n",
    "\n",
    "    confusion_matrix_test = torch.zeros([num_domains, num_domains]).to(device)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    ######################    \n",
    "    # Test the model #\n",
    "    ######################\n",
    "    model.eval()\n",
    "    for _, (data, label) in enumerate(test_loader):\n",
    "        data, label, domain_label = prepare_data_for_attack_discriminator(data, label, num_domains, device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            preds = model(data)#, y=label)\n",
    "            loss = F.cross_entropy(preds, domain_label)\n",
    "            test_loss += loss.item() * data.size(0)\n",
    "            \n",
    "            # Update count of number of correct predictions per domain\n",
    "            # pred_probabilities = F.softmax(preds, dim=1)\n",
    "            predicted_topk = torch.topk(preds, top_k, dim=1).indices\n",
    "            for iter_samples, (pred, target) in enumerate(zip(predicted_topk, domain_label)):\n",
    "                # confusion_matrix_test[int(target)] += pred_probabilities[iter_samples]\n",
    "                confusion_matrix_test[int(target), int(pred[0])] += 1\n",
    "                for iter_topk in range(0, top_k):\n",
    "                    if target in pred[:iter_topk+1]:\n",
    "                        topk_accuracies_test[iter_topk] += 1\n",
    "\n",
    "        # which_batch_test += 1\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    test_loss = test_loss / len(test_loader.sampler)\n",
    "    test_acc = torch.sum(confusion_matrix_test.diagonal()) / (len(test_loader.sampler) * num_domains)\n",
    "    topk_accuracies_test /= (len(test_loader.sampler) * num_domains)\n",
    "    # shouldn't divide by num_domains as the confusion matrix is mostly read per column or row, in which\n",
    "    # case the max expected number of samples should (but not quite for predicted labels) be len(test_loader.sampler)\n",
    "    confusion_matrix_test /= len(test_loader.sampler)\n",
    "    # if test_acc != topk_accuracies_test[0]:\n",
    "    #     raise ValueError(\"The two ways of computing acc don't match !\")\n",
    "\n",
    "\n",
    "    print(\"Model: {} \\tTest Loss: {:.6f}\".format(\n",
    "        model_filenames[model_num], \n",
    "        test_loss\n",
    "        ))\n",
    "\n",
    "\n",
    "    for iter_topk in range(0, top_k):\n",
    "        print(\"Top {} test accuracy: {:.6f}\".format(\n",
    "            iter_topk+1,\n",
    "            topk_accuracies_test[iter_topk]\n",
    "            ), flush=True)\n",
    "\n",
    "    confusion_matrix_test = confusion_matrix_test.cpu().numpy()\n",
    "    results = {}\n",
    "    results[\"topk_accuracies\"] = topk_accuracies_test\n",
    "    results[\"confusion_matrix\"] = confusion_matrix_test\n",
    "\n",
    "\n",
    "    plt.rcParams['text.usetex'] = True\n",
    "    df_confmat = pd.DataFrame(confusion_matrix_test, index = domains_latex, columns=domains_latex)\n",
    "    fig = plt.figure(figsize=(15,10))#, dpi=1200)\n",
    "    # sbn.set(font_scale=3)\n",
    "    # Recommended font sizes: [20, 24] for 9 attacks, [30, 36] for 5, [25, 30] for 3, [60, 72] for 2\n",
    "    custom_fontsizes = [25, 30]\n",
    "    heatmap = sbn.heatmap(df_confmat, annot=True, fmt=\".2f\", annot_kws={\n",
    "                'fontsize': custom_fontsizes[0]})\n",
    "    x_ticks_pos, _ = plt.xticks()\n",
    "    heatmap.set_xticks(x_ticks_pos + 0.2)\n",
    "    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=0, ha='right', fontsize=custom_fontsizes[1])\n",
    "    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=custom_fontsizes[1])\n",
    "    cbar = heatmap.collections[0].colorbar\n",
    "    cbar.ax.tick_params(labelsize=custom_fontsizes[0])\n",
    "    plt.ylabel('True label', fontsize=custom_fontsizes[1])\n",
    "    plt.xlabel('Predicted label', fontsize=custom_fontsizes[1])\n",
    "\n",
    "    working_dir_of_save = WORKING_DIR + \"test_accs/\"\n",
    "    if not os.path.exists(working_dir_of_save):\n",
    "        os.mkdir(WORKING_DIR + \"test_accs/\")\n",
    "    np.save(working_dir_of_save + model_filenames[model_num], results)\n",
    "    fig.savefig(working_dir_of_save + model_filenames[model_num] + \"_confusion_matrix.pdf\", bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test CW2 sample distances to clean image\n",
    "\n",
    "\n",
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "\n",
    "import torchvision\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# import advertorch.attacks as attacks\n",
    "# from attacks.deepfool import DeepfoolLinfAttack\n",
    "import torch.nn as nn\n",
    "# from autoattack import AutoAttack\n",
    "\n",
    "from advertorch.context import ctx_noparamgrad_and_eval\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "import argparse\n",
    "from distutils.util import strtobool\n",
    "from math import pi\n",
    "import json\n",
    "\n",
    "import pandas as pd\n",
    "import seaborn as sbn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Optional. Initial learning rate value, default=0.1.\")\n",
    "# argument_parser.add_argument(\"--wd\", type=float, help=\"Optional. Weight decay value for the optimiser. Default: 5e-4 for CIFAR10, or past value from checkpoint if resuming.\")\n",
    "# argument_parser.add_argument(\"--output_suffix\", type=str, help=\"Optional. Suffix for path where files should be created.\")\n",
    "# argument_parser.add_argument(\"--only_classes\", type=lambda s: [float(x) for x in s.split(',')], help=\"Optional. Restrict analysis to specified CIFAR10 classes (by index in CIFAR10).\")\n",
    "# argument_parser.add_argument(\"--only_domains\", type=lambda s: [float(x) for x in s.split(',')], help=\"Optional. Restrict analysis to specified domains (by domain indices, see metadata.json).\")\n",
    "# argument_parser.add_argument(\"--only_perturbation\", type=lambda x: bool(strtobool(x)), help=\"Optional. Only consider added perturbation (subtract original image) in analysis.\")\n",
    "# argument_parser.add_argument(\"--standardise\", type=lambda x: bool(strtobool(x)), help=\"Optional. Recommanded with --only_perturbation. Standardise data during preprocessing.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "class TrainingFlags():\n",
    "    def __init__(self):\n",
    "        self.only_classes = [0.]\n",
    "        self.only_domains = [2.]#[3., 7.]#[1., 2., 3., 4., 5., 6., 7., 8., 9.]\n",
    "        self.only_perturbation = True\n",
    "        self.standardise = False\n",
    "        self.instance_standardisation = False\n",
    "parsed_args = TrainingFlags()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 100\n",
    "\n",
    "\n",
    "# Define a transform to reshape the input data\n",
    "class CustomTensorDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"TensorDataset with support of transforms.\n",
    "    \"\"\"\n",
    "    def __init__(self, tensors, transform=None):\n",
    "        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)\n",
    "        self.tensors = tensors\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        x = self.tensors[0][index]\n",
    "\n",
    "        if self.transform:\n",
    "            x = self.transform(x)\n",
    "\n",
    "        y = self.tensors[1][index]\n",
    "\n",
    "        return x, y\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.tensors[0].size(0)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Need to be in place to save memory\n",
    "def preprocess_data(data_path, parsed_args=parsed_args):\n",
    "    # Note: tensor_data in this function does not include original dataset label, which is in tensor_label\n",
    "    tensor_data, tensor_label = torch.load(data_path)\n",
    "\n",
    "    # Only consider the added perturbations as data, by subtracting the original sample from each adversarial example.\n",
    "    # Assumes 0th dim is original unperturbed sample #, 1st dim corresponds to perturbation used.\n",
    "    if parsed_args.only_perturbation is True:\n",
    "        num_domains = tensor_data.shape[1]\n",
    "        for i_domains in range(0, num_domains):\n",
    "            tensor_data[:, i_domains] -= tensor_data[:, 0]\n",
    "\n",
    "    # Only consider some perturbation domains if applicable\n",
    "    if parsed_args.only_domains is not None:\n",
    "        tensor_data = tensor_data[:, parsed_args.only_domains]\n",
    "\n",
    "    # Only consider some CIFAR10 classes if applicable\n",
    "    if parsed_args.only_classes is not None:\n",
    "        indices = []\n",
    "        for iter_x, x in enumerate(tensor_label):\n",
    "            if x in parsed_args.only_classes:\n",
    "                indices.append(iter_x)\n",
    "        print(tensor_label.shape)\n",
    "        tensor_label = tensor_label[indices]\n",
    "        tensor_data = tensor_data[indices]\n",
    "\n",
    "    # Standardise\n",
    "    if parsed_args.standardise is not (None or False):\n",
    "        # Get indices of tensor_data over which to standardise. If instance standardisation, we standardise per sample, not over dataset.\n",
    "        # Index 0 corresponds to which original CIFAR10 data sample, index 2 to RGB channels.\n",
    "        if parsed_args.instance_standardisation == True:\n",
    "            standardisation_dimensions = [x for x in range(0, len(tensor_data.shape)) if x != 0 and x != 2]\n",
    "        else:\n",
    "            standardisation_dimensions = [x for x in range(0, len(tensor_data.shape)) if x != 2]\n",
    "        means = tensor_data.mean(dim=standardisation_dimensions, keepdim=True)\n",
    "        stds = tensor_data.std(dim=standardisation_dimensions, keepdim=True)\n",
    "        tensor_data = (tensor_data - means) / stds\n",
    "\n",
    "    print(tensor_data.shape, tensor_label.shape)\n",
    "    return tensor_data, tensor_label\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_path = \"data/adversarial-cifar10/model_PGD_Linf_103.pt/\"\n",
    "\n",
    "\n",
    "tensor_test_data = preprocess_data(data_path+\"adversarial_data_test.pt\", parsed_args)\n",
    "test_data = CustomTensorDataset(tensor_test_data)#, transform=resize_transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "with open(data_path + \"metadata.json\", 'r') as json_file:\n",
    "    metadata = json.load(json_file)\n",
    "\n",
    "how_many_smaller_total = 0\n",
    "which_batch_test = 0\n",
    "for _, (data, label) in enumerate(test_loader):\n",
    "    with torch.no_grad():\n",
    "        data, label = data.to(device), label.to(device)\n",
    "        # epsilons = ((data[1] - data[0])**2).sum(1).sqrt()\n",
    "        epsilons = data[:,0].abs().pow(2).view(batch_size_test, -1).sum(dim=1).pow(1. / 2)\n",
    "        print(epsilons.shape)\n",
    "        # print(data_all_domains['CW_base'].view(batch_size_test, -1).size())\n",
    "        # print(epsilons)\n",
    "        # Reject adv examples not in the L2 ball of radius 0.5\n",
    "        mask = epsilons <= 0.5\n",
    "        how_many_smaller = mask.sum().item()\n",
    "        how_many_smaller_total += how_many_smaller\n",
    "        # if how_many_smaller > 0:\n",
    "        #     masked_data_domain = data_all_domains[domain][mask, :, :, :]\n",
    "        #     num_test_correct_preds_per_domain[domain] += (torch.argmax(model(masked_data_domain), dim=1) == label[mask]).sum().item()\n",
    "        #     how_many_smaller_total += how_many_smaller\n",
    "    if which_batch_test == 10:\n",
    "        break\n",
    "    which_batch_test += 1\n",
    "\n",
    "print(how_many_smaller_total)\n",
    "print(len(test_loader.sampler))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([128, 10, 3, 32, 32])\n",
      "torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "## Current and subsequent cells: visualisation stuff for the custom adversarially perturbed CIFAR10 dataset.\n",
    "\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import advertorch.attacks as attacks\n",
    "from attacks.deepfool import DeepfoolLinfAttack\n",
    "import torch.nn as nn\n",
    "from autoattack import AutoAttack\n",
    "\n",
    "from advertorch.context import ctx_noparamgrad_and_eval\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "# import argparse\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 100\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_path = \"data/adversarial-cifar10/model_ERM_clean_56.pt/\"\n",
    "tensor_training_data = torch.load(data_path+\"adversarial_data_train.pt\")\n",
    "train_data = torch.utils.data.TensorDataset(tensor_training_data[0], tensor_training_data[1])\n",
    "num_train_samples = len(train_data)\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)\n",
    "\n",
    "\n",
    "# del train_data, train_loader\n",
    "# tensor_valid_data = torch.load(data_path+\"adversarial_data_valid.pt\")\n",
    "# valid_data = torch.utils.data.TensorDataset(tensor_valid_data[0], tensor_valid_data[1])\n",
    "# num_valid_samples = len(valid_data)\n",
    "# valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)\n",
    "\n",
    "for _, (data, label) in enumerate(train_loader):\n",
    "    data, label = data.to(device), label.to(device)\n",
    "    print(data.shape)\n",
    "    print(label.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([160, 3, 32, 32]) torch.Size([160])\n",
      "torch.cuda.LongTensor\n",
      "0\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'numpy.ndarray' object has no attribute 'type'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_417390/1213707255.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_my_tensor_on_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfusion_matrix_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiagonal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'numpy.ndarray' object has no attribute 'type'"
     ]
    }
   ],
   "source": [
    "# tempotensor = data.swapaxes(0,1)\n",
    "# tempotensor = tempotensor.reshape([1280, 3, 32, 32])\n",
    "# tempotensor2 = data[:,2,:,:,:]\n",
    "# print(tempotensor.shape)\n",
    "# print((tempotensor2[0] - data[0,2]).sum())\n",
    "# print(label[0])\n",
    "# print((tempotensor[:128] - data[:,0]).sum())\n",
    "\n",
    "# test_shape = [5] + list(data.shape)[2:]\n",
    "# print(test_shape)\n",
    "\n",
    "\n",
    "# testochampo = torch.Tensor(range(0, 10))\n",
    "# # testochampo[:5] = 1\n",
    "# testochampo = (torch.transpose(testochampo.repeat(1,5), 0, 1)).flatten()\n",
    "# print(testochampo)\n",
    "\n",
    "# tempochampo = torch.ones(60,60)\n",
    "# print(torch.sum(tempochampo.diagonal()))\n",
    "# print(len(train_data))\n",
    "# print(len(train_loader.sampler))\n",
    "# print(len(valid_loader))\n",
    "print(data.shape, label.shape)\n",
    "print(label.type())\n",
    "\n",
    "def is_my_tensor_on_device():\n",
    "    testensor = torch.zeros(3).to(device)\n",
    "    return testensor\n",
    "\n",
    "print(is_my_tensor_on_device().get_device())\n",
    "print(confusion_matrix_train.diagonal().type())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([40000])\n",
      "torch.Size([4014, 5, 3, 32, 32]) torch.Size([4014])\n",
      "torch.Size([10000])\n",
      "torch.Size([986, 5, 3, 32, 32]) torch.Size([986])\n",
      "tensor([0.])\n",
      "torch.Size([986, 1, 3, 1, 1])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdIUlEQVR4nO2de6xd9XXnv+s879vX1za2ARvzajIkLYa5UNLQKqWThEAqkmkUJZVSWkV11DbVMOpohDJtQ6tqGqomUf6IUjkFlVZ50RAmpBNlQhgUJopKMME8gksC1IBd42uw7/ue9+ofZ9PeMOe77vW5955j+H0/kuVz9zp7/9b+7X3W2Wd/91rL3B1CiHTJ9dsBIUR/URAQInEUBIRIHAUBIRJHQUCIxFEQECJx+hIEzOxaM3vKzJ42s5v74cMyXw6b2eNmdtDMDvR47NvNbMrMnli2bMLM7jWzn2T/b+6jL7eY2dFsbg6a2XU98GOXmd1vZk+a2Y/M7L9ky3s+L4Ev/ZiXATP7gZk9mvnyJ9ny883sweyz9BUzK532xt29p/8A5AE8A+ACACUAjwK4pNd+LPPnMICtfRr7lwBcDuCJZcv+AsDN2eubAdzaR19uAfDfejwnOwFcnr0eBfBjAJf0Y14CX/oxLwZgJHtdBPAggKsA3AngA9nyvwLwO6e77X5cCVwJ4Gl3f9bdawC+DOCGPvjRd9z9AQAnX7X4BgB3ZK/vAPCePvrSc9z9mLv/MHs9B+AQgHPQh3kJfOk53mY++7OY/XMA1wD4ara8q3npRxA4B8ALy/4+gj5NbIYD+LaZPWxm+/roxytsd/dj2esXAWzvpzMAPmpmj2U/F3ry0+QVzGwPgMvQ/tbr67y8yhegD/NiZnkzOwhgCsC9aF9RT7t7I3tLV58l3RgErnb3ywG8C8Dvmdkv9duhV/D2NV4/n+v+HIALAewFcAzAJ3s1sJmNALgLwE3uPrvc1ut56eBLX+bF3ZvuvhfAuWhfUb9xPbbbjyBwFMCuZX+fmy3rC+5+NPt/CsDdaE9uPzluZjsBIPt/ql+OuPvx7MRrAfg8ejQ3ZlZE+0P3BXf/Wra4L/PSyZd+zcsruPs0gPsBvAXAuJkVMlNXn6V+BIGHAFyc3dUsAfgAgHv64AfMbNjMRl95DeAdAJ6I19pw7gFwY/b6RgBf75cjr3zoMt6LHsyNmRmA2wAccvdPLTP1fF6YL32al21mNp69HgTwdrTvUdwP4H3Z27qbl17e4Vx2p/M6tO+0PgPgf/TDh8yPC9BWJx4F8KNe+wLgS2hfTtbR/j33YQBbANwH4CcAvgNgoo++/B2AxwE8hvaHcGcP/Lga7Uv9xwAczP5d1495CXzpx7z8HIBHsjGfAPDHy87hHwB4GsDfAyif7rYt25AQIlF0Y1CIxFEQECJxFASESBwFASESR0FAiMTpWxA4Qx7RBSBfGPKlM683X/p5JXDGTCTkC0O+dOZ15Yt+DgiROGt6WMjMrgXwGbRrBPy1u38ien+xPOzl4QkAQL26gGJ5+N+NoRuB0YyvxU0/RaO6gMIyXzwIjdYKNhSMt9ptNioLKAwM8zcv32a0f1F4j/ZhGafjyyqn+rRYftT/P1+6HXAdno17tS+5RvDmyJXoGAX7Z8v2oV5ZQHG5L7XOB7dSmUatvtBxq4VOC1eDmeUBfBbtZ5iPAHjIzO5x9yfZOuXhCbz5nTd1tOVr/OhYEKgaZT6Tnqem8ANUH+bG4gL3xXN8vcYQH6+wyG2t4AhF+9cc4L7kq11+EoLgYRvw4Gn0IWmW+f5ZizuTqwUDRoEl2L/BU6uMqq+iERyjZpHbCsHxGz5a7bj8oR9+lq6zlp8DKg4ixOuAtQSBM604iBCiCzb8xqCZ7TOzA2Z2oF5d2OjhhBCnyVqCwKqKg7j7fnefdPfJn7oRKIQ4I+j6xiCWFQdB+8P/AQC/Hq3geX7TrVXk67UKwc2/6GZOYCtUuruTFd+QClYMhotugOWa3NYI5mxxBx+wfIrvw8BLfL3mAB/P6twWHdtcsF5pnvuyFNxUq49wW3mab7Ow1N1N3+ooP4CtoAB4vvM9PABAaYGfTLl6cIwGO98tjvzvOgi4e8PMPgrg/6AtEd7u7j/qdntCiP6wlisBuPs3AXxznXwRQvQBPTEoROIoCAiROAoCQiSOgoAQibOmG4OnjXNJKJIwIkmkVQjWa0aSDzXBAlkulLuC59Ij+bBZ7s6XQiXwpcHnZe5CnvFSPskTEpqlIFkrF+R+dDmfkfxbDOTD8qngfAmeyY/kym4zj1rBvhcX+TaL83zFaB+sQbYZuK8rASESR0FAiMRREBAicRQEhEgcBQEhEkdBQIjE6alEaC1eRiwqT1ULyn1FZcmiTMHGYCCzRCUNAxkwGi/KPmwGmWYRURmtgRN8vMp5XH6a28NPiU1P8/HqUZZ4IPVVJrjx5KXcz/wmrucNPjJIbaNHuFYbSaBRmbdQ6gvK0RUX+P6VTvEUw8o2ns7ZKpHv9eAY6EpAiMRREBAicRQEhEgcBQEhEkdBQIjEURAQInF6m0UYkGPZTwCsFWSv5YOuP0FhzIhI8qkFRSzndvOYOnAyyHqbibIkqSnMTIzms3CCa5JjkyeorXF0K7XVNvF5WbiY66oX7JmitmJQZXWmyg/u5uuPU9vUF8+jtoGgk1BtJCh2G3yVRhl/9ZGghRR4amkrkDJpxqYkQiEEQ0FAiMRREBAicRQEhEgcBQEhEkdBQIjE6XEWoaM011nDsKAoaJQtFxYaDQqUVsa5PBMVuIyKZtbGg/HO4rbtD/JtDpziA0bFWStb+KEtv8zHq9T5ert//TC1jZeWqG2syCuiPjVzFrU1A+3tP+86SG1vH36S2j64879S29BUkJU5zY9DISwKGkiLgbwdkQ+yVQvznbMro8/XmoKAmR0GMAegCaDh7pNr2Z4Qovesx5XAL7v7S+uwHSFEH9A9ASESZ61BwAF828weNrN9nd5gZvvM7ICZHajXFtY4nBBivVnrz4Gr3f2omZ0F4F4z+yd3f2D5G9x9P4D9ADA6fm53bVyEEBvGmq4E3P1o9v8UgLsBXLkeTgkhekfXVwJmNgwg5+5z2et3APjTcJ0WUFjsLKcU5nmmWXOYN6xzCyTCalCossK3WR3n05IPCo0WZ3lMrV/Kfwrl6rww5sAUl95aJS5zzl/NbcWL5qjNAn30f+65m9qerfMMw9v/5Wpq+9nN/0Jt1216lNp+eZDLjr/13Dupbeh4d3JzvhpIbEuBjFvlmYkejBfJ4vlF3ksyOicYa/k5sB3A3db+EBYAfNHdv7WG7Qkh+kDXQcDdnwVw6Tr6IoToA5IIhUgcBQEhEkdBQIjEURAQInF6mkXYyhuqE52lucoWLtlFvQFLM1wuyS9yPS+/wHu9FWd5kUcEmXvW4uu9eAmXbhZ28Ficr3D5cPoiPmeFC7kMeOMbeNriVG2U2u6euZza/tPoE9T2u+fcT2078rPU9qYSPz33z+yhtoP3XEJtY9NcsitUuC0f2HJVfg5G8rYFxWAbQ/x8aRUCKXq083phti21CCGSQEFAiMRREBAicRQEhEgcBQEhEkdBQIjE6Xmh0cJiZ6mlNsYlkQZvnQc3vgtuQ9SWr0eSD88KywXZjoMneExtLXE/N//aUWo7ucD34bLtfL16i8/nXJP38Svk+Lw8vzRBbT8p76C2dw49S207CyPU9u1FLq99+n+/m9rGT3DprVkO+louBlmEFX5ORDIgogK6ga00w8drBOO1WPHSKNuWWoQQSaAgIETiKAgIkTgKAkIkjoKAEImjICBE4vRWImw6StOds/eKczwe1caDQqNBVl9UADI/xwtVRlidb9M8kIPqfP9+/7z/S23XD81Q27eXhqmt7vzQTjXGqO3Wh3mRTrzEsySHr+ZZme8e/mdqO9VcpLaPfOcmajv3+1zKXNwW9P/jCX+ojfD1CotBAc8gyzUH7qfVgvNzgUvRuaXO/QbbvnSWlKOMRV0JCJE4CgJCJI6CgBCJoyAgROIoCAiROAoCQiTOihKhmd0O4N0Aptz9zdmyCQBfAbAHwGEA73f3Uytuq+XIVTprNLbIJab8PE8jbJX5LliQKRgRyoBLgZ+LPDvPy1xi+kl1O3cmkAinm1wifGDmZ6jt7DLf5tgo7304G0iE37p3ktouuYH3GzxWH6e2bf8Y9dXr7tjWR7ikjPlgtGJ335e5Be5n1DfQnJ/zuRkuqxbnOn8eoozF1ezZ3wC49lXLbgZwn7tfDOC+7G8hxGuQFYOAuz8A4OSrFt8A4I7s9R0A3rO+bgkhekW39wS2u/ux7PWLaHcoFkK8BlnzjUF3dwQPTprZPjM7YGYHag3+W0YI0R+6DQLHzWwnAGT/T7E3uvt+d59098lSgZfKEkL0h26DwD0Absxe3wjg6+vjjhCi16xGIvwSgLcB2GpmRwB8HMAnANxpZh8G8ByA969msGY5h4XzOheWHDgR9NU7xX9G5Gs8LcyDXnahwtQIikpu4b36Xng7t5XHeG/ApxfPorZjY49S28GF3dT23W/tpbbqLp6h9o43PUltrR0vUNsPvngptd36/XdR2w2XHaQ2D76i8lV+AIde4uvN7g56/AX1QnOR3BxkEVqTr+d5voNRMdFijcuHRj4PUYbrikHA3T9ITL+y0rpCiDMfPTEoROIoCAiROAoCQiSOgoAQiaMgIETi9LTQqOcN1U2dJRrP8Qy8/GYuiZRfDrL6ZnhGnFV5scbGWbwQ53Pv4r3zypfyRMqhApcdD89tobZnt3BfLhygz2ghaDeI8Yd4NuADIxdR21W7DlNb/Rdn+YAv8YfEthZ56l51M8/423yIy5zNQX5aFyf49159mI9XmeDbLJ8KqpdG8mGLG/Mk2xYAPMf3ITdH5PRW0HuTWoQQSaAgIETiKAgIkTgKAkIkjoKAEImjICBE4vRUIgR4ZpgHNSVro9zoxrWwoaBnW/VsLr0dvp5ncLUGuHTT/PE4te268jC1/dymo9T2fH2C2s4ucknyDcF4Ty+dT232OM+EfKDC5cOhxwapbevLXAprvoV/Dy3uiJr8cTmvMcTPl1zQxq+4EPhZCgqURv0wX+YSaGuMz5kX+LzkZ7ksbiwDNuyXKIRIGgUBIRJHQUCIxFEQECJxFASESBwFASESp+cSIZMqIgmmWeS20izXPhbO53LX0fdxrejis3lBzcPf48U9tx3kmVr/lDuP2t57/SPUdsMI9+W7Szz7cPcwlw+P7t1EbYuPb6a2/BEux255ks/n0GGeYfj3z1xGbTi7Qk2N4SBTcJ7LuI3B6FzimZ61sUCmDiRCzPACs7ki34fGOJcPo4xAHyQZooGPuhIQInEUBIRIHAUBIRJHQUCIxFEQECJxFASESJzV9CK8HcC7AUy5+5uzZbcA+G0AJ7K3fczdv7nitpqOgZnOMszSZi7BFJe4DMgKlwLA1BXcl9YS3/WBPJeYrr3+IWr7xtgktW3jq+HPNv0qtT1/5fep7e/+39XUNnCcz0vtP/Dejjsmj1PbQpUXfH1+mMuO597HbWd/khcMPX4Fl8laBS5J5he5hDZwkh/bXI2vV36Zy5XNAX4u+Q4u4+Zmgx6beS7peZkfh8Z4ZxnXnw+Kk1LLv/M3AK7tsPzT7r43+7diABBCnJmsGATc/QEAJ3vgixCiD6zlnsBHzewxM7vdzPhjZkKIM5pug8DnAFwIYC+AYwA+yd5oZvvM7ICZHWhUF7ocTgixUXQVBNz9uLs33b0F4PMArgzeu9/dJ919slAe7tZPIcQG0VUQMLOdy/58L4An1scdIUSvWY1E+CUAbwOw1cyOAPg4gLeZ2V60cwIPA/jIagYzB3L1znJfjidwobjApRtrcPlw7BleMLSwyCWYx2p7qO2PrvsGtV1z/ZPU9ocv/Ca17bmL7/x3/9cvUNsbn+K9CHFqhppau3dS20uXb6e2uUv4XJd38596x6/k/Rsv+quXqG3T1rOp7dTP8GM79hz/bisu8LluFYPinot833M1vs3mMJfz0OTbtAqXQLsqUGpBMVRqeWWj7h/ssPi2ldYTQrw20BODQiSOgoAQiaMgIETiKAgIkTgKAkIkTk8LjXreaF9BC+SSZpnLG2M/5mkNAyd4YcyTb+ZFSEd/zDPwfmPHb1FbZZHLQTue5zLSwItcXsudmKY2WlQSgI1yWc4OPUNtWx7h2XLbBvh8Vq9+E7VNTfLjN/cfuQzYKvD1inP8fKmNBFJfkCmYD87Byo4hastVg23Wg6Kg5aABZ9CLMJIWc2w8D9bhIwkhUkBBQIjEURAQInEUBIRIHAUBIRJHQUCIxOmtRAigRVSRPE+agget3hpbeI2CqEdcISheGo1X/sYYtW2e53JQY4BvdOkcvg+lIAutML1EbWhwW26cF/60Cpcdm7Pz1Fb8zsPUtvu5C6ht+vKzqK00x2XVwSDr9Nhb+Vxv/wGX5ViGKwCAH9ow+7BVDmyBDFhYCAqpnuSSsjU7T0wkwetKQIjEURAQInEUBIRIHAUBIRJHQUCIxFEQECJxeioR5pqOgenOEoYHhRCjLMLGEN+FxnDQ3zAoXjo0xfvVNYPMr1Yx8DOQCFlmZXu8QEYa5cU2B44FvedyQeyvB9LUZi4tos7nrPX8UWrbFPTVO/6LvKfN9FVVarvl5++htltPvp/aJg7x42BBFl4h6H3YKPG5bmzm526UQVlu8PFyS+T4KYtQCMFQEBAicRQEhEgcBQEhEkdBQIjEURAQInFW04twF4C/BbAd7UTA/e7+GTObAPAVAHvQ7kf4fnc/FW6rBRRIL7hIEgEi6YavVR3j6w1NcSks6m/YHON+RvtgQdZbfSjY92D/coEkubibF1Itv8xlufwo73OXmwlay1e4ZGe1GrX5U89S2/iOn6W2U1fw76/hHB9vxzVHqG3+hXOorbjEZbnCApdHczXu58LZ/Dhw8RdhodFuWM2VQAPAH7j7JQCuAvB7ZnYJgJsB3OfuFwO4L/tbCPEaY8Ug4O7H3P2H2es5AIcAnAPgBgB3ZG+7A8B7NshHIcQGclr3BMxsD4DLADwIYLu7H8tML6L9c0EI8Rpj1UHAzEYA3AXgJnefXW5zdwf59Wpm+8zsgJkdqNWC35RCiL6wqiBgZkW0A8AX3P1r2eLjZrYzs+8EMNVpXXff7+6T7j5ZKvEyWkKI/rBiEDAzA3AbgEPu/qllpnsA3Ji9vhHA19ffPSHERrOaLMK3AvgQgMfN7GC27GMAPgHgTjP7MIDnAPD0rIxWwVDZ0ln8iKQ+JisCQDPI0soHhSNzgQxYH+HTUh/i43nQWi7HVaQww7BQifrOBVJmsE0PCmNWt/OrtfwY70VYOsrVYavw/oa+xG2l7z5ObRe0uHz434u/Rm0fv+Ib1PbnF/JTePtDXCK0JrcVF7hcWdjMhcDmQHCeBZmszrIFgyzdFYOAu38PANvCr6y0vhDizEZPDAqROAoCQiSOgoAQiaMgIETiKAgIkTg9LTTaygOVzZ3jTmku0AiHAu0tSMCLZMBGsM2oYGhkqw9xX4rBw5LFxUjK5OtFMmCUfRgVS40k11qQlWmtcWorLCxSW67IM+l8ka9X+v6PqG3XIJcP79x1BbVtuepFams+uo3aop6Clg+kucVA+g4kwupm3i+yON/5GHngh64EhEgcBQEhEkdBQIjEURAQInEUBIRIHAUBIRKnpxKhtXhWXNRvMMowbAV7UAx6xNWHefxrRhJhMF6eJ4yFGYZRpmBxjstI9ZFgH0rBfPJpQX6Jj1cLeh9WJ7jUV9jEi576IF/Palwftek5ahv65xlqO3LX+dQ2+wa+7+cFcl6+ym2tAT5npRl+wjTq3WUYdoOuBIRIHAUBIRJHQUCIxFEQECJxFASESBwFASESp6cSYa7hGJrqLPvM7uKSSNTHL1/j8prngmKbUfZh0Ost6jdYXAoKfwaSXaHCNbtWsF60f7lgzqIMQwTbjKTM0jTv7RjJgI1xXrzU6nxefOsItUWFYicO8Z6JE0/xfS+/OE9tCCTXqChofj4oQtoKzqUynzOaLRgUGtWVgBCJoyAgROIoCAiROAoCQiSOgoAQiaMgIETirCgRmtkuAH8LYDva4tJ+d/+Mmd0C4LcBnMje+jF3/2a4LQdytc56SiSvRbJciyuLYRHSaJuR7Nhtb8CIsJ9iletPpdlABwzCe+kkl6Zq41zOK83yrL78EpcIm0NBMdFIugpkslaQddocDHY+MJVf5vJhK9iHXIXPS2E66MNYDHoKBsVLo+xDKhuzHoVY3XMCDQB/4O4/NLNRAA+b2b2Z7dPu/per2IYQ4gxlNQ1JjwE4lr2eM7NDAM7ZaMeEEL3htO4JmNkeAJcBeDBb9FEze8zMbjezzevtnBBi41l1EDCzEQB3AbjJ3WcBfA7AhQD2on2l8Emy3j4zO2BmB2q1oAOHEKIvrCoImFkR7QDwBXf/GgC4+3F3b7p7C8DnAVzZaV133+/uk+4+WSoNr5ffQoh1YsUgYGYG4DYAh9z9U8uW71z2tvcCeGL93RNCbDSrUQfeCuBDAB43s4PZso8B+KCZ7UVbNjwM4CMrbcgNaA6QXmlRW73AVh8MpD6uWsEC+SmSFvOBDBj2e+tSPiy/zOUgC7Id65u4pBXuXyBJhpJrKZC78vy7JhdlCgbyYZQJWZrhkl1jkPsZ9fjLE2kbiDP+co1g/4IMwyjrL5L7cjUiGwfztRp14HvofPjDZwKEEK8N9MSgEImjICBE4igICJE4CgJCJI6CgBCJ09NCo3AuCVnQrC/qnRf1+HOuFCEf2KJ+g41ykN01zx2N+inmg0KjVPJBnGkW9XZsbA+KewZyV3Qcov54UaHYwgI/EF7k22xFmZdBP8V8IOPWh6OsviBrscTnMzcWZB8GEm90HBoD/ARl51koX1OLECIJFASESBwFASESR0FAiMRREBAicRQEhEic3kqEAZGMVAwy2yxKMQyIVouksEKX2YCtQMrMBf3/GqNBkc5A9on2IZJAW4EsF9EMkhaLC1EPw0ACbXFfGpGcZ9yWa3BfIqk2+rpsDAVSZoH7UpoPpMzAl6gHZYMUWQ37VlKLECIJFASESBwFASESR0FAiMRREBAicRQEhEicM0YijAp4RgU1C0t8m/VAuslFvdkWA0kykDKbA1HR0+72rzEUZVcG2wykMAsy4gpBBl6UKdgsdifVtoJim5GsFREVS422GWV6NoM5awTHPcrmLFSD84W3MER9NDgn2LkUTKWuBIRIHAUBIRJHQUCIxFEQECJxFASESBwFASESZ0WJ0MwGADwAoJy9/6vu/nEzOx/AlwFsAfAwgA+5O2+cBwDGM9+iXn2RhBatVx8OioLOcSmssMht9RE+ZY1ADhpY4rJVJL2FhVSDfnVRhmExKIgayWvRcbBmUGg0SM5bPIunH+Zr3R33qAhpYZ43qGwM816EtdFAbg6K1kbFbkMJtLtk1fC4M1ZzJVAFcI27XwpgL4BrzewqALcC+LS7XwTgFIAPn/boQoi+s2IQ8Dbz2Z/F7J8DuAbAV7PldwB4z0Y4KITYWFZ1T8DM8llb8ikA9wJ4BsC0+79d7BwBcM6GeCiE2FBWFQTcvenuewGcC+BKAG9c7QBmts/MDpjZgXptoTsvhRAbxmmpA+4+DeB+AG8BMG5mr9wlOxfAUbLOfnefdPfJYml4Lb4KITaAFYOAmW0zs/Hs9SCAtwM4hHYweF/2thsBfH2DfBRCbCCrySLcCeAOM8ujHTTudPd/MLMnAXzZzP4MwCMAblvNgEwWydUCHSmQwiIZMCoqmQuksG7lmVAqCrPXomzAYJvBvCxOBBlqtSDrzfl8Fme4vBadSk1S/BKI56VV5KNF2ZzRORHKnMH5UqhwW1S0Nj4nuC0qpBruAzMF5/SKQcDdHwNwWYflz6J9f0AI8RpGTwwKkTgKAkIkjoKAEImjICBE4igICJE45oE8te6DmZ0A8Fz251YAL/Vs8Bj50hn50pnXoi/nufu2ToaeBoGfGtjsgLtP9mXwVyFfOiNfOvN680U/B4RIHAUBIRKnn0Fgfx/HfjXypTPypTOvK1/6dk9ACHFmoJ8DQiSOgoAQiaMgIETiKAgIkTgKAkIkzr8CCIDox8DakPUAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAECCAYAAAD+eGJTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAalklEQVR4nO2dfYxc5XXGnzOzM/vpZf3FYgzhmxLnA0M2hLRpSklCaRoJokYIKqX8geKoClIjJZUQlRoatVJSNUT5o0rlFBSnTUMoIQIlqA1FSChqQrIQ82m+Y7AX2+u1vfZ6Pzw7M6d/zHW6oXOeHd/dnTF+n59kefaeue89953ZZ+/cZ855zd0hhEiXQqcTEEJ0FomAEIkjERAicSQCQiSORECIxJEICJE4HREBM7vWzF40s1fM7LZO5LAgl51m9oyZbTez0TYf+24zGzezZxdsW2NmD5vZy9n/qzuYyx1mNpbNzXYz+3gb8jjbzB41s+fN7Dkz+8tse9vnheTSiXnpMbNfmNlTWS5/m20/z8wez36Xvm9m5RMe3N3b+g9AEcCrAM4HUAbwFIBN7c5jQT47Aazr0LE/DOByAM8u2PYPAG7LHt8G4KsdzOUOAF9s85xsAHB59ngVgJcAbOrEvJBcOjEvBmAge1wC8DiAKwHcC+DGbPs/A/iLEx27E1cCVwB4xd1fc/cKgHsAXNeBPDqOuz8G4OBbNl8HYFv2eBuA6zuYS9tx9z3u/mT2eArADgAb0YF5Ibm0HW9wNPuxlP1zAFcDuC/bnmteOiECGwHsWvDzbnRoYjMcwE/M7Akz29LBPI4z7O57ssd7AQx3MhkAt5rZ09nHhbZ8NDmOmZ0L4DI0/up1dF7ekgvQgXkxs6KZbQcwDuBhNK6oJ929mj0l1++SbgwCH3L3ywH8MYDPmdmHO53QcbxxjdfJ73V/E8AFADYD2APga+06sJkNAPgBgM+7+5GFsXbPS5NcOjIv7l5z980AzkLjivqS5Ri3EyIwBuDsBT+flW3rCO4+lv0/DuCHaExuJ9lnZhsAIPt/vFOJuPu+7I1XB/AttGluzKyExi/dd939/mxzR+alWS6dmpfjuPskgEcBfBDAkJl1ZaFcv0udEIFfArgou6tZBnAjgAc7kAfMrN/MVh1/DOAaAM/yvVacBwHcnD2+GcADnUrk+C9dxifRhrkxMwNwF4Ad7n7nglDb5yXKpUPzst7MhrLHvQA+hsY9ikcBfCp7Wr55aecdzgV3Oj+Oxp3WVwH8dSdyyPI4Hw134ikAz7U7FwDfQ+Nych6Nz3O3AFgL4BEALwP4bwBrOpjLvwJ4BsDTaPwSbmhDHh9C41L/aQDbs38f78S8kFw6MS/vBfCr7JjPAvibBe/hXwB4BcB/AOg+0bEtG0gIkSi6MShE4kgEhEgciYAQiSMRECJxJAJCJE7HROAk+YouAOUSoVyac6rl0skrgZNmIqFcIpRLc06pXPRxQIjEWdKXhczsWgDfQKNHwL+4+1fY88ulfu/pGQIAVOanUS71/19wJb6z1OK5zVdnUOrq+83P1u7vTy3Is1KdQXlBLm1nYS61WZSLvSt8vNaeVqnNoFxcMC8d/JJbpT6LcmHhvLQ5lwWH+/+5NGe2NoVKfdaaxbqabWwFMysC+Cc0vsO8G8AvzexBd38+2qenZwjvv/xzzcer1eODkTlm+9ExScjqLBeSDNmNYfPVxZ/UjELT1zQblMTovOQ8d4LV2ZzlnLRqLd9+eaGve85zyInnOPefHbwvjC3l44CagwhxCrAUETjZmoMIIXKQ++NAq2QWxhYA6O4+baUPJ4Q4QZZyJdBScxB33+ruI+4+8ls3AoUQJwVLuRL4TXMQNH75bwTwZ3QPM3ix+Q0rI/c6jNyUoXfyc96v8UKsjfSmYSHnXeJim51adjwWozdv43N38iKx15becGM3PvPulzMXZ8djN0XZjd285Bgztwi4e9XMbgXwX2hYhHe7+3N5xxNCdIYl3RNw94cAPLRMuQghOoC+MShE4kgEhEgciYAQiSMRECJxVvzLQr+FO6wWWCbESXFm6zD7qYtoHLF1whyxBPuQ4KUiCZLzY3myOcuJ5bXliE3mZWZJxr6xse/PM6uPvpfyzZmR94RXc9aF5LQWw9eInJquBIRIHImAEIkjERAicSQCQiSORECIxJEICJE47bUIc0IrzYhd4sRmYXaXE8eOVibmrFDLazv6ChShtR1W9VZnvharhMxZzcnGZLmwFndd5FeMWKDexr6FuhIQInEkAkIkjkRAiMSRCAiROBIBIRJHIiBE4rTfIlxu2WE2IKvEylmdx3Mhdl6R+I5sTpgPyKzMFVgpiY7J7DV6fvkq/pxZiwXyts5bXenMxp3PNyZ5Txh7vzBrMZpP9tLFISFECkgEhEgciYAQiSMRECJxJAJCJI5EQIjEab9FGNhvtFKQQNcizGuT5cyFwuSWNarMu64ewUCadBJotWNeWzVv48+8Vl/OSk9qN3flaxRLYVWnldiStGgtSXLaSxIBM9sJYApADUDV3UeWMp4Qov0sx5XAH7r7xDKMI4ToALonIETiLFUEHMBPzOwJM9vS7AlmtsXMRs1stDI/vcTDCSGWm6V+HPiQu4+Z2ekAHjazF9z9sYVPcPetALYCwOCqje3rmSSEaIklXQm4+1j2/ziAHwK4YjmSEkK0j9xXAmbWD6Dg7lPZ42sAfLmFHZtuZlVheasBKUT+mBXGrEXWFNSqpMKQ2ocklhfWNDPnfLK1D2mzVDbXpfjtafNkjT9qDedcL5K+P9mOOd+7zI7t6Y73C88vzn8pHweGAfww69rbBeDf3f0/lzCeEKID5BYBd38NwKXLmIsQogPIIhQicSQCQiSORECIxJEICJE47a0iNIMXA6uCNdRkrlVXTh1jDUrZ8Zh7yBpHEouQ2ofMzmOwaVmJCjzmneZtbMqIquUW3zGM2Fwl3u0YiREr00mMYTNzufYLKxpZX9Z8RxJCnCpIBIRIHImAEIkjERAicSQCQiSORECIxGl/o9HAqqgOlMJdJi+MY4V51hwyTqNrNt6vezLesTxF1p0jzTYLrBIysk0Bbq+x/Zidx9ayy7umoLGFEXNCzt2ZNcwqPeeOxbsN9ISxo5cOh7G9H4hz6b74SJwL8aK7Hj4jjA0/fjiMFSfi40XoSkCIxJEICJE4EgEhEkciIETiSASESByJgBCJ01aL0EEaNpKmi7VyPGad2WRE4iqr4v3m1sQ7FudiK6xrNj5ez2Rsy3XNxLHiTNxQs1CJY5WhuBllrSc+h94xsjYEsx0LrLlnzqpFUj5q8/GcVYf6wtj0pqEwtv+yOM9rPvpkGHvwjEfD2OpinAvjnk2rw9iXn/mTMPaOr5zWdLuPx6+5rgSESByJgBCJIxEQInEkAkIkjkRAiMSRCAiROItahGZ2N4BPABh393dn29YA+D6AcwHsBHCDux9a9GiGsNEoqwZc92zc5JFZhLWeWOPmhuJYtY+M2c1iYQiVQbKuXo3YN8QfnV0f53LmR3aFsTcPDIWxc+4kucyQZptsLUJiZXo5npfaaXFV38F39oaxA78bV3oWJsMQjNjUE5X+MPZvR94Zxmqkw2fJYptzT2UojM0ejM8dtZk4FtDKlcC3AVz7lm23AXjE3S8C8Ej2sxDibciiIuDujwE4+JbN1wHYlj3eBuD65U1LCNEu8t4TGHb3PdnjvWisUCyEeBuy5BuD7u4g3+80sy1mNmpmo/MV8pVUIURHyCsC+8xsAwBk/49HT3T3re4+4u4jpXJ8c0UI0RnyisCDAG7OHt8M4IHlSUcI0W5asQi/B+AqAOvMbDeALwH4CoB7zewWAK8DuKHVA0aWHm3ESWLFKolV4o6TxWOx/s0PxDZZZSC2fJz02rTYJUP5aJxnrRwfb/D3wwsw/P3594exL1bjl2vivXGDy76J2JoqHYlPcGY4bhQ7sTk+v/6LJsPYFy75URj76nN/FMZmi7GPu/GM2OUuF+Lze3gitginKvHxDk7HFYYzLw6FsQt/FK9TWDzc/CO31eL32KIi4O43BaGPLLavEOLkR98YFCJxJAJCJI5EQIjEkQgIkTgSASESp71rERrgXc0toTrRIyMWIav8YvtF1YwAUD4SV3eVD8djVvvzrcfH8iwR+3DP6Olh7KaJLWGsuCuuzuvrieel0h+/RkfOia2wK258Kox9bPWzYezBicvC2D17rghjlWOxJblu3VQYW9MTV+C9PLk+jO0bb97cEwC69sTzMvhyGMIFTx8NY8UD8TmEjXwJuhIQInEkAkIkjkRAiMSRCAiROBIBIRJHIiBE4rR3LcKChVVxzFxzJ00sa2RNuoH49A5eEsfKR2LLbnBn3MSy+1Acq5diva2sis/+2Po41reXzMubsQ1YIr1disdiS7LnECmFtNiW+/mb54SxJ8c3hrHDL6wNY+XD8bnX18bnMFGJ53P/m0NhrLQ/fr8MvRGGMPQaeb/sjW1AmzkWD7rM6EpAiMSRCAiROBIBIRJHIiBE4kgEhEgciYAQidP2KsJ66cSrnIz0IIXFls/MujhWJ2d++OL4gLXueG3Atc/Ftk5xJrbXukmz1G6ywuP84PKfe2UVsTIHYhvQyZj1XwyFsbhlJtBDXvfKEKkQXROvmWiHyOv3VPzeHHqFNPc8So5HLGy2fiPI+8zrsQWa51i6EhAicSQCQiSORECIxJEICJE4EgEhEkciIETitLIW4d0APgFg3N3fnW27A8BnAOzPnna7uz+02Fhu8dp6pFCQ50csx6652EYqVOP9ijNxbOYM0hR0KrZ1uqfIuogkzyPnxi/R4ffHttWZwxNhbGznujB22nPx8WrdpHIv7qeJLuIDVuNiR8ytJ2s0DsbWW3E8Tmb18/Hx1jwXV/UVKvHxWNNaL8RWrc2TMYmlZ4Xl/dvdymjfBnBtk+1fd/fN2b9FBUAIcXKyqAi4+2MADrYhFyFEB1jKdcWtZva0md1tZquXLSMhRFvJKwLfBHABgM0A9gD4WvREM9tiZqNmNlqdIy1thBAdIZcIuPs+d6+5ex3AtwCEy8G4+1Z3H3H3ka6e/rx5CiFWiFwiYGYbFvz4SQDxWlJCiJOaVizC7wG4CsA6M9sN4EsArjKzzQAcwE4An23lYG6xJVQgNgujwIq0yBp/1T62H6kmeykec+3P94Wx8auGw9j0NbE19acXbQ9jG8qTYey12XjtvFd7449lO1bHeZafGAhjhbiQDpV4qT7MD8Y2YGE+fh36dsQVjXPr49eoHru4FC+RVrhkPUwaY1YfKZ11+p4/8QrDRUXA3W9qsvmuEz6SEOKkRN8YFCJxJAJCJI5EQIjEkQgIkTgSASESp81rEQLV3ua2TxezRIhUVUm1FateY3S953AYq+0aDGP1X8eL0p1+JLYBX7r4gjA2dvZQGBs9+I4wViOT9gfrXw5jrxyIKwxrxF6bHY6tqTrpGFqcjvMcfpzZh/GYb/xOHJuZja3FeplU/NVYRSqx5cj700H2Iza1kf08sh3VaFQIESERECJxJAJCJI5EQIjEkQgIkTgSASESp71rERbi6r1aD2liGbs6MOKyzDObJV4aELPTsRc2feV8GOs98L4w1v/Q9jB2wV/9LIy9en3YqgFjH4413NfFZX27Dw6FscruuOeDrSaW3TFS8TcW59k3Ho954N2xZTd7Xnx+/S/Fr9/AbmJlduezCFEgFbCkktXoMoXMMmfVhyfeyFdXAkIkjkRAiMSRCAiROBIBIRJHIiBE4kgEhEictlcR1nqbWx9WIx4GCRHjBvUiiZIxy6/0hrG+vaQirhL7jrPXXBrG+n8aV/X1/Xh7GDt/YlMYG39f3Em12h8vANgXO6AokFjxWDwv/XtjL6xrNrbsxj8Qj9k9cCyMrX4xtvqKx4hFWIr/JhZI5V49599SI+thsspEalfGe8XHyjGaEOIUQiIgROJIBIRIHImAEIkjERAicSQCQiROK2sRng3gOwCG0XDktrr7N8xsDYDvAzgXjfUIb3D3Q3yweC04q7L13PiYYYj1fySWJKvuYk1PD59D1seL+3ei78xLwljpKFlXrys+B7ZfeSrOhVl9bK5rJVIFSt5l8/3xhK4bjWPFeVLt6KQRJ13zksw1sQ+NvNFIMSCMlPbVi6SikfyuhEOS920rVwJVAF9w900ArgTwOTPbBOA2AI+4+0UAHsl+FkK8zVhUBNx9j7s/mT2eArADwEYA1wHYlj1tG4DrVyhHIcQKckL3BMzsXACXAXgcwLC778lCe9H4uCCEeJvRsgiY2QCAHwD4vLsfWRhzd0fwgcrMtpjZqJmN1qanl5SsEGL5aUkEzKyEhgB8193vzzbvM7MNWXwDgPFm+7r7VncfcfeRYn98M0cI0RkWFQEzMwB3Adjh7ncuCD0I4Obs8c0AHlj+9IQQK00rVYS/B+DTAJ4xs+3ZttsBfAXAvWZ2C4DXAdyw2EBuQL0U2BssE2YDEquPVb0x27HeTdayO4McL+59CY8dHxw9K45ZnVS2kWap7NyZdVqokLUdZ0nTTJJLZRWx11guxApj+9XK7HhsTGYbk7UIi+R4zG6mdmWMdZ24Rehszc7FDujuP0X8a/iRxfYXQpzc6BuDQiSORECIxJEICJE4EgEhEkciIETitH0twnpPYG+QpousEssLrPIr3o9aRazCkFVwERuQwewudvJO1sCrxr1S6ZjMImTrRTJ7lNuAJMZsTraOH4kVaPVoTvuQVatSSzLej9qj5PxCH2+JVYRCiFMYiYAQiSMRECJxJAJCJI5EQIjEkQgIkTjttQjN4aXAFyF2F115Lc+ybItg86QqjNiOdEy67hzZkTWqJGstkh6WvAKPVFDWSTNRCxrILna8vM1g2ZzxMeNY7spLamWSc2CWJLMyqaXcfDOrWNSVgBCJIxEQInEkAkIkjkRAiMSRCAiROBIBIRKnzRYhYN3NPRqvET1iZYSkuovGWGViMfaRnHlvDHYK8zkbTuZcT5GOSecsX0UctbRYLqxSMOdry2IFVsma13akVZKsSW6+uY4g/Wp1JSBE6kgEhEgciYAQiSMRECJxJAJCJI5EQIjEWdQiNLOzAXwHwDAaBstWd/+Gmd0B4DMA9mdPvd3dH+KDOQpBFSFrGMrWImQ4s5HojiswJoM0BbWu2A8iy8uhViGVkNPkZSc2ILNceeNPVtKYr7ln7vJRVg2Y01allZ55KyFZ1SnLM8iFNcFt5XsCVQBfcPcnzWwVgCfM7OEs9nV3/8cWxhBCnKS0siDpHgB7ssdTZrYDwMaVTkwI0R5O6J6AmZ0L4DIAj2ebbjWzp83sbjNbvdzJCSFWnpZFwMwGAPwAwOfd/QiAbwK4AMBmNK4Uvhbst8XMRs1stDY1vfSMhRDLSksiYGYlNATgu+5+PwC4+z53r7l7HcC3AFzRbF933+ruI+4+UlzVv1x5CyGWiUVFwMwMwF0Adrj7nQu2b1jwtE8CeHb50xNCrDStuAO/B+DTAJ4xs+3ZttsB3GRmm9EwJXYC+OxiAxUKjp6e5l0Z5+djD4O5VrSjZhdpxEnKqurEmurui7tKlkqx5zN1ML4K2nTeWBgbLM2FsRcOnB7Gpme7w9g8sQi9h5XLkflkNiCD2Y7EJqO2HBmTrl2Zcz1MmguxAeukOSu3HUlsJSxCd/8pmjv1/DsBQoi3BfrGoBCJIxEQInEkAkIkjkRAiMSRCAiROG1tNFooOAZ6jjWNzZdiPWLNPYvE8jHi61RZY1PCWacdDmOv7F8XxnpWNT9vAPjzM/8njL1eiccc3XV2GKuP9cW5TMbn3j0Sn99Fa/eHsV+9EedSKMa2Y+VwbGXSnq7MrlyBSk9msbHKRGO/YXkbsOYooHQ1GhVCREgEhEgciYAQiSMRECJxJAJCJI5EQIjEaatFWLQ6hnpmm8amKrFVVCde0UC5EsY29MV2V16GSs3zB4Ay6Rz5nsE3w9hkLa4w3DW3JozVa7FvVZyL56xeij2mq896KYydXp4KY6OHLgxjdiT+W9M7Hec5PxjnWTs9ft2dNekM1sIEACN/EmvHiEdImrp6MWdDVIIx7zQ6nCxCIUSERECIxJEICJE4EgEhEkciIETiSASESJy2WoTuhrlq6YT329B/JIyt7Z4JY/1dceVeXyG2mAaK8X6rinHjzw+sei2MPT51fhj78e53hTFW7VgoxGVo1XPiPN919p4w9p7+3WFs+/Q7wphVSHPWQ3GsZ4JUepL9JtfmqxQ8/8yJMMYs5dGx+NznxgbCmHeTUsFgXU4A1Haka2WGlmS8k64EhEgciYAQiSMRECJxJAJCJI5EQIjEkQgIkTiLWoRm1gPgMQDd2fPvc/cvmdl5AO4BsBbAEwA+7e6x7wagWi9g4mjzirnhwbhC7fKhXWFsrh5bjgfn4+q8jd2HwtiZpckwViKLy23sisfcXogtpgMHY4tpcDCuWvzohS+GsXN6DoSxDaU4zxdmzwxj1XpcSVfvje0uL5B1H8k7sHsyHrP31/FCfrPnxW/DtT3TYWx9+WgYq1TiRHv3xuc3c25ctbjm9Nj6PjIVN4q1N3rDWNhQlDRfbeVK4BiAq939UgCbAVxrZlcC+CqAr7v7hQAOAbilhbGEECcZi4qANzgukaXsnwO4GsB92fZtAK5fiQSFECtLS/cEzKyYLUs+DuBhAK8CmHT349fGuwFsXJEMhRArSksi4O41d98M4CwAVwC4pNUDmNkWMxs1s9HqkfgrvkKIznBC7oC7TwJ4FMAHAQyZ/WZ9lbMAjAX7bHX3EXcf6RqMb3YIITrDoiJgZuvNbCh73AvgYwB2oCEGn8qedjOAB1YoRyHECtJKFeEGANvMrIiGaNzr7j8ys+cB3GNmfwfgVwDuWmyg7q4qLl433jTGqgFLFtssU94TxtaTxpisipDZgIOFuDpvqBBXH24ok6anOZfOY41NV3fFVlg/OXeWZ4EtgleO7TziLNJzZ4db/VJ8vL69sW381K74k+wvL4jt2OKu+H3GKMzEJ39wfDCMlfrmwxj5dUDv3uYTSl7yxUXA3Z8GcFmT7a+hcX9ACPE2Rt8YFCJxJAJCJI5EQIjEkQgIkTgSASESx9yXf6208GBm+wG8nv24DkDc9bG9KJfmKJfmvB1zOcfd1zcLtFUEfuvAZqPuPtKRg78F5dIc5dKcUy0XfRwQInEkAkIkTidFYGsHj/1WlEtzlEtzTqlcOnZPQAhxcqCPA0IkjkRAiMSRCAiROBIBIRJHIiBE4vwvCtAJfpKnRUsAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "tensor_training_data = preprocess_data(data_path+\"adversarial_data_train.pt\", parsed_args)\n",
    "\n",
    "\n",
    "# del train_data, train_loader\n",
    "tensor_valid_data = preprocess_data(data_path+\"adversarial_data_valid.pt\", parsed_args)\n",
    "\n",
    "# print(tensor_test_data[0].shape)\n",
    "plt.matshow(torch.mean(tensor_training_data[0][0,0], dim=0).cpu().numpy())\n",
    "plt.matshow(torch.mean(tensor_valid_data[0][0,0], dim=0).cpu().numpy())\n",
    "# plt.matshow(torch.mean(tensor_test_data[0][2,0], dim=0).cpu().numpy())\n",
    "\n",
    "print(torch.unique(tensor_training_data[1]))\n",
    "\n",
    "means = tensor_valid_data[0].mean(dim=[x for x in range(0, len(tensor_valid_data[0].shape)) if x != 2 and x != 0], keepdim=True)\n",
    "print(means.shape)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "65ece6dca1d30e560a7eedc0faf594e5b278fbb4bf81aedd4a08d0f646ac509d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
