{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W2EucuE8UmUv"
      },
      "outputs": [],
      "source": [
        "#from dataset import Dataset\n",
        "import time\n",
        "from keras.datasets import cifar10\n",
        "import numpy as np\n",
        "from sklearn.datasets import make_classification\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import random\n",
        "import sys\n",
        "\n",
        "import torch\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torch.optim.lr_scheduler import _LRScheduler\n",
        "import torch.utils.data as data\n",
        "from torch.nn.modules.loss import CrossEntropyLoss\n",
        "\n",
        "import torchvision.datasets as datasets\n",
        "\n",
        "import copy\n",
        "from copy import deepcopy\n",
        "import random\n",
        "import time\n",
        "\n",
        "import json\n",
        "import os\n",
        "from PIL import Image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r4inPj9EZm0O",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "91b8bbfd-70ee-4959-b87a-2364c2a8921c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1fNPemfCN-wv"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8XhIZ3QFOAHH"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hn9PUMGl_Wb0"
      },
      "outputs": [],
      "source": [
        "DATASET = 'CIFAR10'      # Choose between 'CIFAR2', 'CIFAR10'\n",
        "MODEL = 'RESNET18'       # Choose between 'RESNET18', 'VGG11'\n",
        "AUGMENTS = False          # Use Data Augmentation\n",
        "SAVEMODEL = True         # Save Clean Model \n",
        "LOADMODEL = False        # Load Clean Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Qs3vTkc4s1S"
      },
      "outputs": [],
      "source": [
        "# Set seed here for target / poisons / camous selection\n",
        "\n",
        "seed = 555555\n",
        "\n",
        "#seed = 10540012\n",
        "\n",
        "torch.manual_seed(seed + 1)\n",
        "torch.cuda.manual_seed(seed + 2)\n",
        "torch.cuda.manual_seed_all(seed + 3)\n",
        "np.random.seed(seed + 4)\n",
        "torch.cuda.manual_seed_all(seed + 5)\n",
        "random.seed(seed + 6)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GLLqPlIWNrFR"
      },
      "source": [
        "Prepare Datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4rw8ZyaWUyGh"
      },
      "outputs": [],
      "source": [
        "# Class Dictionary for CIFAR10\n",
        "classDict = {'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4,\n",
        "             'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}\n",
        "\n",
        "binaryClasses = {0:'Machine', 1:'Animal'} # Machine , Animal"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7tBrfubjN4t7"
      },
      "outputs": [],
      "source": [
        "if DATASET == 'CIFAR2' or DATASET == 'CIFAR10':\n",
        "  # Mean and std of CIFAR10:\n",
        "  data_mean = (0.4914, 0.4822, 0.4465)\n",
        "  data_std = (0.2470, 0.2435, 0.2616)\n",
        "else:\n",
        "  # Mean and std of Imagenet:\n",
        "  data_mean = (0.485, 0.456, 0.406)\n",
        "  data_std = (0.229, 0.224, 0.225)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HmRD5ua40aiu"
      },
      "outputs": [],
      "source": [
        "from torch.nn.modules.transformer import TransformerDecoderLayer\n",
        "# Overwrite getitem method to obtain the index of the images when iterating through the images\n",
        "\n",
        "\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "\n",
        "\n",
        "class CIFAR10(Dataset):\n",
        "    def __init__(self, train, transform):\n",
        "        self.cifar10 = torchvision.datasets.CIFAR10(\n",
        "                        root='./data', train=train, download=True, transform=transform)\n",
        "        self.targets = self.cifar10.targets\n",
        "        self.classes = self.cifar10.classes\n",
        "        self.data = self.cifar10.data\n",
        "        \n",
        "  \n",
        "    # Overloaded the getitem method to return index as well\n",
        "    def __getitem__(self, index):\n",
        "        data, target = self.cifar10[index]\n",
        "        return data, target, index\n",
        "    \n",
        "    # Method to get all images' indices from a certain class without iterating through the loader\n",
        "    def get_index(self, target_label):\n",
        "      index_list = []\n",
        "      for index, label in enumerate(self.targets):\n",
        "        if label == target_label:\n",
        "          index_list.append(index)\n",
        "      return index_list\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.cifar10)\n",
        "\n",
        "    def remove(self, remove_list):\n",
        "      mask = np.ones(len(self.cifar10), dtype=bool)\n",
        "      mask[remove_list] = False\n",
        "      data = self.data[mask]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vry-6elmU0zj",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "05e5cf38-29f1-4d32-dd5c-53338eb8012f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n"
          ]
        }
      ],
      "source": [
        "# Data Prep.\n",
        "\n",
        "inv_normalize = transforms.Normalize(\n",
        "   mean= [-m/s for m, s in zip(data_mean, data_std)],\n",
        "   std= [1/s for s in data_std]\n",
        ")\n",
        "\n",
        "transform_train = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(data_mean, data_std),\n",
        "])\n",
        "\n",
        "transform_test = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(data_mean, data_std),\n",
        "])\n",
        "\n",
        "if DATASET == 'CIFAR2' or DATASET == 'CIFAR10':\n",
        "  trainset = CIFAR10(train=True, transform=transform_train)\n",
        "  testset = CIFAR10(train=False, transform=transform_test)\n",
        "\n",
        "# Binary Classfication\n",
        "\n",
        "NUM_CLASS = 10\n",
        "\n",
        "if DATASET == 'CIFAR2':\n",
        "\n",
        "  # Converts to Binary Cifar 10\n",
        "  NUM_CLASS = 1\n",
        "  for i in range(len(trainset)):\n",
        "    if trainset.targets[i] in [2,3,4,5,6,7]:\n",
        "      trainset.targets[i] = 0\n",
        "    else:\n",
        "      trainset.targets[i] = 1\n",
        "\n",
        "  for i in range(len(testset)):\n",
        "    if testset.targets[i] in [2,3,4,5,6,7]:\n",
        "      testset.targets[i] = 0\n",
        "    else:\n",
        "      testset.targets[i] = 1\n",
        "\n",
        "  trainset.classes = ['animal', 'machine' ]\n",
        "  testset.classes = ['animal', 'machine' ]\n",
        "\n",
        "elif DATASET == 'Imagenet':\n",
        "  # Tiny Imagenet\n",
        "  NUM_CLASS = 200\n",
        "elif DATASET == 'FOOD101':\n",
        "  NUM_CLASS = 101\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "24mUpaOtSSP7"
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "#train_indices = np.random.choice(len(trainset), 5000)\n",
        "#tets_indices = np.random.choice(len(testset), 1000)\n",
        "\n",
        "#trainset = torch.utils.data.Subset(trainset, train_indices)\n",
        "#testset = torch.utils.data.Subset(testset, tets_indices)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hGfXn0cMS5aC"
      },
      "outputs": [],
      "source": [
        "\n",
        "trainloader = torch.utils.data.DataLoader(\n",
        "    trainset, batch_size=100, shuffle=True)\n",
        "testloader = torch.utils.data.DataLoader(\n",
        "    testset, batch_size=100, shuffle=False)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "otfUIW83CIkY"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7SCfQTmpNxCB"
      },
      "source": [
        "Initialize Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7bIs-1AYXToO"
      },
      "outputs": [],
      "source": [
        "class VGG(nn.Module):\n",
        "    def __init__(self, features, output_dim):\n",
        "        super().__init__()\n",
        "\n",
        "        self.features = features\n",
        "\n",
        "        self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
        "\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(512 * 7 * 7, 4096),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Dropout(0.5),\n",
        "            nn.Linear(4096, 4096),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Dropout(0.5),\n",
        "            nn.Linear(4096, output_dim),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        x = self.avgpool(x)\n",
        "        h = x.view(x.shape[0], -1)\n",
        "        x = self.classifier(h)\n",
        "        #return 2 * x - 1\n",
        "        return x\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lN9UKPVUXZ7e"
      },
      "outputs": [],
      "source": [
        "vgg11_config = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']\n",
        "vgg16_config = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],\n",
        "\n",
        "def get_vgg_layers(config, batch_norm):\n",
        "\n",
        "    layers = []\n",
        "    in_channels = 3\n",
        "\n",
        "    for c in config:\n",
        "        assert c == 'M' or isinstance(c, int)\n",
        "        if c == 'M':\n",
        "            layers += [nn.MaxPool2d(kernel_size=2)]\n",
        "        else:\n",
        "            conv2d = nn.Conv2d(in_channels, c, kernel_size=3, padding=1)\n",
        "            if batch_norm:\n",
        "                layers += [conv2d, nn.BatchNorm2d(c), nn.ReLU(inplace=True)]\n",
        "            else:\n",
        "                layers += [conv2d, nn.ReLU(inplace=True)]\n",
        "            in_channels = c\n",
        "\n",
        "    return nn.Sequential(*layers)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3zAsLDrq8zrP"
      },
      "outputs": [],
      "source": [
        "# Models: AlexNet - Not used\n",
        "\n",
        "class AlexNet(nn.Module):\n",
        "    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:\n",
        "        super().__init__()\n",
        "        #_log_api_usage_once(self)\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
        "            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
        "            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
        "        )\n",
        "        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Dropout(p=dropout),\n",
        "            nn.Linear(256 * 6 * 6, 4096),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Dropout(p=dropout),\n",
        "            nn.Linear(4096, 4096),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Linear(4096, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        x = self.features(x)\n",
        "        x = self.avgpool(x)\n",
        "        x = torch.flatten(x, 1)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qyL7UfDREeSw"
      },
      "outputs": [],
      "source": [
        "class ResNet(torchvision.models.ResNet):\n",
        "    \"\"\"ResNet generalization for CIFAR-like thingies.\n",
        "\n",
        "    This is a minor modification of\n",
        "    https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py,\n",
        "    adding additional options.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, block, layers, num_classes=2, zero_init_residual=False,\n",
        "                 groups=1, base_width=64, replace_stride_with_dilation=[False, False, False, False],\n",
        "                 norm_layer=torch.nn.BatchNorm2d, strides=[1, 2, 2, 2], initial_conv=[3, 1, 1]):\n",
        "        \"\"\"Initialize as usual. Layers and strides are scriptable.\"\"\"\n",
        "        super(torchvision.models.ResNet, self).__init__()  # torch.nn.Module\n",
        "        self._norm_layer = norm_layer\n",
        "\n",
        "        self.dilation = 1\n",
        "        if len(replace_stride_with_dilation) != 4:\n",
        "            raise ValueError(\"replace_stride_with_dilation should be None \"\n",
        "                             \"or a 4-element tuple, got {}\".format(replace_stride_with_dilation))\n",
        "        self.groups = groups\n",
        "\n",
        "        self.inplanes = base_width\n",
        "        self.base_width = 64  # Do this to circumvent BasicBlock errors. The value is not actually used.\n",
        "        self.conv1 = torch.nn.Conv2d(3, self.inplanes, kernel_size=initial_conv[0],\n",
        "                                     stride=initial_conv[1], padding=initial_conv[2], bias=False)\n",
        "        self.bn1 = norm_layer(self.inplanes)\n",
        "        self.relu = torch.nn.ReLU(inplace=True)\n",
        "\n",
        "        layer_list = []\n",
        "        width = self.inplanes\n",
        "        for idx, layer in enumerate(layers):\n",
        "            layer_list.append(self._make_layer(block, width, layer, stride=strides[idx], dilate=replace_stride_with_dilation[idx]))\n",
        "            width *= 2\n",
        "        self.layers = torch.nn.Sequential(*layer_list)\n",
        "\n",
        "        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))\n",
        "        self.fc = torch.nn.Linear(width // 2 * block.expansion, num_classes)\n",
        "        #self.predict = nn.Sigmoid()\n",
        "\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, torch.nn.Conv2d):\n",
        "                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
        "            elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)):\n",
        "                torch.nn.init.constant_(m.weight, 1)\n",
        "                torch.nn.init.constant_(m.bias, 0)\n",
        "\n",
        "        # Zero-initialize the last BN in each residual branch,\n",
        "        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n",
        "        # This improves the arch by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n",
        "\n",
        "\n",
        "\n",
        "    def _forward_impl(self, x):\n",
        "        # See note [TorchScript super()]\n",
        "        x = self.conv1(x)\n",
        "        x = self.bn1(x)\n",
        "        x = self.relu(x)\n",
        "\n",
        "        x = self.layers(x)\n",
        "\n",
        "        x = self.avgpool(x)\n",
        "        x = torch.flatten(x, 1)\n",
        "        x = self.fc(x) # Sigmoid\n",
        "        #x = self.predict(x)\n",
        "        return x\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "azxkoJcpXlTX"
      },
      "outputs": [],
      "source": [
        "# Select model\n",
        "vgg11_layers = get_vgg_layers(vgg11_config, batch_norm=True)\n",
        "initial_conv = [3, 1, 1]\n",
        "\n",
        "if MODEL == 'VGG11':\n",
        "  model = VGG(vgg11_layers, output_dim=NUM_CLASS)\n",
        "elif MODEL == 'RESNET18':\n",
        "  model = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASS, base_width=64, initial_conv=initial_conv)\n",
        "elif MODEL == \"ALEXNET\":\n",
        "  model = AlexNet(num_classes = NUM_CLASS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xuaUG8-wN2WI"
      },
      "source": [
        "Losses and parameters for Fitting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lX4gOjRE5Q5u"
      },
      "outputs": [],
      "source": [
        "# Randomly select poison and target class:\n",
        "# Assume Camouflage chosen from the same class as target.\n",
        "\n",
        "avail_classes = np.arange(NUM_CLASS)\n",
        "[target_class, poison_class] = np.random.choice(avail_classes, replace=False, size=2)\n",
        "camou_class = target_class\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-bWbxMqnP-hy",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5238448d-fd95-4dbb-834a-712aaab777de"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Running on a GPU\n"
          ]
        }
      ],
      "source": [
        "# Setting up training params\n",
        "\n",
        "epochs = 41\n",
        "eta = 0.01\n",
        "optimizer = torch.optim.SGD(params = model.parameters(), lr = eta, weight_decay = 5e-4, momentum=0.9)\n",
        "device = torch.device('cpu')\n",
        "if DATASET == 'CIFAR2':\n",
        "  loss_fun = nn.BCEWithLogitsLoss()\n",
        "else:\n",
        "  loss_fun = nn.CrossEntropyLoss()\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    device = torch.device('cuda')\n",
        "    print(\"Running on a GPU\")\n",
        "else:\n",
        "    print(\"Running on a CPU...Uhh, are you sure you want to do this?\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A-Yov0q4zvrf",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4f52068f-9b91-4bbe-f880-a7d52cf0d5c1"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ResNet(\n",
              "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "  (relu): ReLU(inplace=True)\n",
              "  (layers): Sequential(\n",
              "    (0): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (downsample): Sequential(\n",
              "          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        )\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (2): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (downsample): Sequential(\n",
              "          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        )\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (3): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (downsample): Sequential(\n",
              "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        )\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "  )\n",
              "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
              "  (fc): Linear(in_features=512, out_features=10, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 186
        }
      ],
      "source": [
        "model.to(device)\n",
        "model.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tudqnx_ETBZi",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 320
        },
        "outputId": "6c2182fa-3921-4ff6-8e4d-f3ae34c1a4e2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Target image is chosen with ID [4098]\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f2a4db56b90>"
            ]
          },
          "metadata": {},
          "execution_count": 187
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAVb0lEQVR4nO3df5DV1XnH8fcDCy66IOgS3ACKP0gaJhqxWyJG08QMjk2c+qOJjdOxmNqQNDrVjJlGzSQxnaSxncTEtk4yJNKQxIomarSJrbGOjflBkQUU0DWiiALh1xoIbITAwtM/7mW6mO9zdvf+XDif1wzD3fPs+X4PX/bZe+/3ueccc3dE5Mg3otkDEJHGULKLZELJLpIJJbtIJpTsIplQsotkoqWazmZ2IXA7MBL4prvfmvr+9vZ2nzZtWjWnlDro3rwzjL22cVMYa50yNYzNmHR0YbslxrFs2bJEVAbL3Qsvc8XJbmYjgTuAOcAGYKmZPeTuz0Z9pk2bRldXV6WnlDrpvPWRMLbs018IY6fecFsYW3x9Z2H76MQ4zFK/CqRa1byMnwW84O5r3X0vsAi4uDbDEpFaqybZJwPr+329odwmIsNQ3W/Qmdk8M+sys65t27bV+3QiEqgm2TcC/e/QTCm3HcLd57t7p7t3Tpw4sYrTiUg1qkn2pcB0MzvZzEYDHwQeqs2wRKTWKr4b7+59ZnYt8Ail0tsCd3+mZiOThpk9a3YYW9b3mzB26QUzw1h01/3O9a8lRnJWIrY8EZPBqKrO7u4PAw/XaCwiUkf6BJ1IJpTsIplQsotkQskukgklu0gmqrobL8NPVNj65JNxnyvOHxfG5s+7OYxdOWPkIEf1//76xGOG3KfkpETs5QqPmRc9s4tkQskukgklu0gmlOwimVCyi2TisLgbf/ZXf1TY3joq7vM/17yvTqMZ3opXfoMPzYr7tCaON/Oc88LY5bfeF8ZWLf5u4qiV0B33aumZXSQTSnaRTCjZRTKhZBfJhJJdJBNKdpFMDJvS2zHXx7uLvHb7DUM+nt15SRjz5Q8M+XiHu9TqbnclYm2tcWFuyU3vr3g8DTPlquL2Dd9q5Chq77y/L25f8fWwi57ZRTKhZBfJhJJdJBNKdpFMKNlFMqFkF8lEVaU3M1sH7AL2A33u3pn6/t8AxfPXKiuvJa34QRhKbSSUKlEdqS5NxD7VfZhvu3RYl9jibbnYHGzLtW9/2KUWdfZ3u3tPDY4jInWkl/Eimag22R34sZktM7N5tRiQiNRHtS/jz3X3jWb2BuBRM3vO3Z/o/w3lXwLzACaeeGKVpxORSlX1zO7uG8t/bwUeAH5v8SN3n+/une7eeezEidWcTkSqUHGym9kxZjb24GPgAmB1rQYmIrVVzcv4ScADZnbwOP/u7v+V6vDCmg1c9L5PVHHK2vjDN30gjC15/nthLLFm42EtWqQSoHfbKw0bxxFt5ifD0Jceu7Ww/TsLu8M+T398xpCHUHGyu/ta4G2V9heRxlLpTSQTSnaRTCjZRTKhZBfJhJJdJBPm7o07mVnjTpZ0fBz603j2nT94U01H8Yt9ceycxD52jWQz4jIl3d9v3EAOA0d/LL4eu+74szBWyTNutEjopzs7WdvVZbU6j4gchpTsIplQsotkQskukgklu0gmMr0bX5mv/up3he0f6hgd9lmcON7mtXFs7imDHFQN/DoRO94Kb+zm693/EIY++tW4WnPaW+JD3lBB5WVl0H5FZyfP6G68SN6U7CKZULKLZELJLpIJJbtIJpTsIplQ6a0WpsebRk2+7OowdtM1HwtjH54an+6dn/jnMLbk/mDbq9198QF7i0uKAMycGYb++LJ446j2juJlwzcvXhr2+fk9i+Jx9Lwax/p+GccItkmq0Nu/Ho//vDnx7mediVJq3+7i9i/e8UjY55mbPhMcbDXuv1XpTSRnSnaRTCjZRTKhZBfJhJJdJBNKdpFMDFh6M7MFwEXAVnd/a7ntOOAeYBqwDrjc3bcPeLIjtfSWdFIYOf6v/i6MtSQ269ly73fj040P1teb0Br3aYtD77ns3fGpJsb/tpa24nF87Jy4PLWq+1dhrHd3XB5cuur5MPbcmhcL259ZnJiP2B0fj+lxDe3t110X99vXG4aWfHxO3K8C7l5x6e1bwIWva7sReMzdpwOPlb8WkWFswGQv77f++inPFwMLy48XApfUeFwiUmOVvmef5O6byo83U9rRVUSGsWq2bAbA3T31XtzM5gHzqj2PiFSn0mf2LWbWAVD+e2v0je4+39073T2+MyMidVdpsj8EzC0/ngs8WJvhiEi9DKb0djfwLqAd2AJ8FvgBcC9wIvAypdJbat3Cg8fKsPSWMP39YWjc6W8KYx/4y6vC2Oxzphe2nz4xHkaiKMeORGxKIvZvy4t/HD531nFhn3p86GNv0P4ficU+/+UL8azCnyyIF5wspUfzRaW3Ad+zu/sVQeg9VY1IRBpKn6ATyYSSXSQTSnaRTCjZRTKhZBfJRNWfoJPKvf3yeErBtz//F2FsWuKY8a5ztXcgEWufWlxiey7R54RELC7YpUXX47zEApAbrvvbMNY6/Yww9shN8QzB4UDP7CKZULKLZELJLpIJJbtIJpTsIplQsotkQqW3Jppzflyqiee8Vea1RKwnEYuXSUz/8Cz9xf8Wtl//uTviTuOPDUN/PveqMDbnsniphHWbits/f+0X43H89Kdx7Pzz4tgwp2d2kUwo2UUyoWQXyYSSXSQTSnaRTOhufBP17vpNGHuFN4axzYlj7gna+xJ99iViKaljrl5RvO0SKxJbVyXc83h8F/+eqyo6ZGUe/s8Gnqy29Mwukgklu0gmlOwimVCyi2RCyS6SCSW7SCYGLL2Z2QLgImCru7+13HYL8GFgW/nbbnb3h+s1yMZqS8RS00KGbuniR8PY4xe8JYzt2BFtapSQ+J9uZWTcrTURS5wudRWlOQbzzP4t4MKC9q+4+5nlP0dIooscuQZMdnd/Ahhw00YRGd6qec9+rZmtNLMFZjahZiMSkbqoNNm/BpwKnAlsAr4cfaOZzTOzLjPrqvBcIlIDFSW7u29x9/3ufgD4BjAr8b3z3b3T3ePlRESk7ipKdjPr6PflpcDq2gxHROplMKW3u4F3Ae1mtgH4LPAuMzsTcGAd8JE6jrHBKiivtU5OxI4PQy174llvqf+ZlsR8s3DW256h9wHo2x73Gz82LrCte+n5IJIqysUz/dJz7LYmYrUtlx7OBkx2d7+ioPnOOoxFROpIn6ATyYSSXSQTSnaRTCjZRTKhZBfJRIMXnBxJXHpJlKGGucnnzQxjGxevDWMzZ4efReIDo+Lz9XQcHcai/9BU4Sql0jmAN/f+LojEG1uNuPySMHagN1EgfHJpHOsJFr4kGh+kfxYP31KentlFMqFkF8mEkl0kE0p2kUwo2UUyoWQXyYS5e8NONmLcZB919kcLY3v74uLQ5KmTCttbW1vDPq1j4qpiW9sxYWxJai+vFcXzf57+o7ic9Lal0ewveHOisHV2+1Fh7O6e+N82dc5Fhe3jZ5wR9mlvHx/GEpPlmD1zRhhbsWp5Yfsj9y4K+4ybHl/HnWsSpbLxx8axnleL27ftSPRJzaJ7MhEbHtzditr1zC6SCSW7SCaU7CKZULKLZELJLpKJht6NHzVhqh9//seLgy3xHeZtvcUTE1qSexol7lifEK919uLSFfExlxavmP3DuAfF98ZL4ukskFqIf2MiNlz86zcXFra3Ed8Ff6D72TD2XPcr8cla4spFz67iu/ivbkhcxV0VTnbZHE96aiTdjRfJnJJdJBNKdpFMKNlFMqFkF8mEkl0kEwOW3sxsKvBtYBKl7Z7mu/vtZnYccA8wjdIWUJe7+/bksY5+g3Pa5YWxcVPjbZJGBWW0ttZ4AkRfoiq3acdvw9iBNYmSTPcdhc3j4h7sTMSKp/eUTEnEliViw8U7gvbUmnaP1GMgGaqm9NYH3ODuM4CzgWvMbAZwI/CYu08HHit/LSLD1IDJ7u6b3H15+fEuoBuYDFwMHPzkxEIgXhpURJpuSO/ZzWwaMBNYAkxy903l0GbSr0pFpMkGnexm1gbcB1zv7oe8FfXSG//CN/9mNs/Musysi77dVQ1WRCo3qGQ3s1GUEv0ud7+/3LzFzDrK8Q6CTbLdfb67d7p7Jy1jajFmEanAgMluZkZpP/Zud7+tX+ghYG758VzgwdoPT0RqZTDbP70DuBJYZWZPldtuBm4F7jWzq4GXgeKaWn/7+mBb8ZpgOxMjGTE+KNiMjRdIG59Yl6ylL57VtHfP0LehSpXXUrZUGDsc/LzZA5DfM2Cyu/vPgMK6HfCe2g5HROpFn6ATyYSSXSQTSnaRTCjZRTKhZBfJxGBKb43RuycMHQjaX+2NS2i7dsUz2xJrUbI3NS0r/ETwcCqUnRS0v9zQUWSpJd4Oi754Mc1G0TO7SCaU7CKZULKLZELJLpIJJbtIJpTsIpkYPqW3lD3B7LZUCS0xs41oFh1w9Anxwpevtc4pDryUKGv1rIxjDH2G3cAqKLGdfkMcW1W8v50UOP2MONY2O471BmXi9amfq2hPwuK97UDP7CLZULKLZELJLpIJJbtIJpTsIpkYPnfj+/bHsdTMlbhTGNnbG69dl5wI03pMcfvJb4r7TIzv7vNSYgJNci28Gk+qOOGUOLaqtqc63L05EfvlikWJaOLngMnFzS3xzzBE/2drwx56ZhfJhJJdJBNKdpFMKNlFMqFkF8mEkl0kE1bagDXxDWZTgW9TWoDNgfnufruZ3QJ8GNhW/tab3f3h5LFGHeeMDzaRGZsoM0QlL+ISGq2pskVCqltLcQlwRKI0eKAvXlsvnAABsL4njm3+VRxjeSI2HKTKqIn/Txk0dy/cwWkwBew+4AZ3X25mY4FlZvZoOfYVd/9SrQYpIvUzmL3eNgGbyo93mVk34acARGS4GtJ7djObBswElpSbrjWzlWa2wMwm1HhsIlJDg052M2sD7gOud/edwNeAU4EzKT3zF65yYGbzzKzLzLo4EE+sF5H6GlSym9koSol+l7vfD+DuW9x9v7sfAL4BzCrq6+7z3b3T3TsZcVStxi0iQzRgspuZAXcC3e5+W7/2jn7fdimwuvbDE5FaGczd+HcAVwKrzOypctvNwBVmdialctw64CNVnXFfqlPw8j81+j2JkldQQgOgb+gluwOJitGIMfHxWhKxvW3HxgftaI9jLwWzq3aktqhKrZNXayqvNctg7sb/DCiq2yVr6iIyvOgTdCKZULKLZELJLpIJJbtIJpTsIplo7IKTfiAuiaUqXqNSq0AGUgtYkoglKna0BuNIVJMO7I5jfWPiyz96bPwBpL2pGYJRya5nUtxn24lxrGdrHCMxMy+x8KE0h57ZRTKhZBfJhJJdJBNKdpFMKNlFMqFkF8lEY0tv+w+kF1mMhLPUKhx+X4Uzr6JufYlxpBajTMy+25uamZfYA2zEhOKSXepctCTWGRibKHtOOy8ex9Q3Fo9jcWLzuDWp2XevVhjrTcTyomd2kUwo2UUyoWQXyYSSXSQTSnaRTCjZRTIx4F5vNT3Z0ZOc064oDrbHCyyOmFi8iGLb+LgsNGZMtD8ctLXF/Vpa4zJUW1BGa0uUp1oTxxuf2I+upS2O9bWMjGPBbL89iQU4e3ri0tWG9fFCla2JsuJp02cUtqcqgD+6/4E4+MTSOLbh5TjG80F7qlx3eIv2etMzu0gmlOwimVCyi2RCyS6SCSW7SCYGvBtvZq3AE8BRlGaefN/dP2tmJwOLgOOBZcCV7r43dazxp830c7/8k8JY+4RxYb8JwW5HLYlNolN3fRM3wUnc4Ket8B5nevm81FSdClbWA9LL5EWx1NSfHanY9jjWlxhIdL72E+I+6zbEsfXPxmvhvbDulbjfmhcL23e+FPdh/cY4FhwPgB3PxrEGrslXzd343wHnu/vbKG3PfKGZnQ38I/AVdz8N2A5cXavBikjtDZjsXnJwnuCo8h8Hzge+X25fCFxSlxGKSE0Mdn/2keUdXLcCjwIvAjvc/eCrtQ3A5PoMUURqYVDJ7u773f1MYAowC/iDwZ7AzOaZWZeZde3deeR+aklkuBvS3Xh33wE8DswGxpvZwftPU4DCuxruPt/dO929c/S4YO9wEam7AZPdzCaa2fjy4zHAHKCbUtK/v/xtc4EH6zVIEaneYBZx6wAWmtlISr8c7nX3H5rZs8AiM/s8sAK4c6AD2ciRtI4NJqik6leRxNZKjK3geMCORDlpTzDG1qAkB9CSqGz2pvrFoaRo+Hv2Db0PxDteQXrpvd6g9rZhc9xnd2KMHdPfEMdmxLHduzqLx/FSXMpbtyEuk+3YHq9pt70nnjR0YE2i9LY8KNlteDzuQ3yuyIA/U+6+EphZ0L6W0vt3ETkM6BN0IplQsotkQskukgklu0gmlOwimWjsGnRm24CDC4a1Az0NO3lM4ziUxnGow20cJ7n7xKJAQ5P9kBObdbl7cRFU49A4NI6aj0Mv40UyoWQXyUQzk31+E8/dn8ZxKI3jUEfMOJr2nl1EGksv40Uy0ZRkN7MLzeyXZvaCmd3YjDGUx7HOzFaZ2VNm1tXA8y4ws61mtrpf23Fm9qiZrSn/nVhOs67juMXMNpavyVNm9t4GjGOqmT1uZs+a2TNmdl25vaHXJDGOhl4TM2s1syfN7OnyOD5Xbj/ZzJaU8+YeMxs9pAO7e0P/ACMpLWt1CjAaeBqY0ehxlMeyDmhvwnnfCZwFrO7X9k/AjeXHNwL/2KRx3AJ8osHXowM4q/x4LKUN2mY0+pokxtHQawIY0FZ+PApYApwN3At8sNz+deBvhnLcZjyzzwJecPe1Xlp6ehFwcRPG0TTu/gTw69c1X0xp4U5o0AKewTgazt03ufvy8uNdlBZHmUyDr0liHA3lJTVf5LUZyT4ZWN/v62YuVunAj81smZnNa9IYDprk7pvKjzcDk5o4lmvNbGX5ZX7d3070Z2bTKK2fsIQmXpPXjQMafE3qschr7jfoznX3s4A/Aa4xs3c2e0BQ+s1O6RdRM3wNOJXSHgGbgC836sRm1gbcB1zv7jv7xxp5TQrG0fBr4lUs8hppRrJvBKb2+zpcrLLe3H1j+e+twAM0d+WdLWbWAVD+O143qY7cfUv5B+0A8A0adE3MbBSlBLvL3e8vNzf8mhSNo1nXpHzuIS/yGmlGsi8FppfvLI4GPgg81OhBmNkxZjb24GPgAmB1ulddPURp4U5o4gKeB5Or7FIacE3MzCitYdjt7rf1CzX0mkTjaPQ1qdsir426w/i6u43vpXSn80XgU00awymUKgFPA880chzA3ZReDu6j9N7rakp75j0GrAH+GziuSeP4DrAKWEkp2ToaMI5zKb1EXwk8Vf7z3kZfk8Q4GnpNgDMoLeK6ktIvls/0+5l9EngB+B5w1FCOq0/QiWQi9xt0ItlQsotkQskukgklu0gmlOwimVCyi2RCyS6SCSW7SCb+D/XIegz2vKtMAAAAAElFTkSuQmCC\n"
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ],
      "source": [
        "# Get loss of each in the loader and choose Target\n",
        " \n",
        "poison_index = trainset.get_index(poison_class)\n",
        "target_index = testset.get_index(target_class)\n",
        "camou_index = trainset.get_index(camou_class)\n",
        "\n",
        "# Choose Target\n",
        "target_index = [np.random.choice(target_index)]\n",
        "\n",
        "\n",
        "targetset = data.Subset(testset, target_index)\n",
        "targetloader = torch.utils.data.DataLoader(targetset)\n",
        "\n",
        "print(\"Target image is chosen with ID {}\".format(target_index))\n",
        "plt.imshow(testset[target_index[0]][0].permute(1, 2, 0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MgPkXoVwyUAj"
      },
      "outputs": [],
      "source": [
        "# Optional Augmentation defenses\n",
        "\n",
        "aug_transform1 = transforms.Compose([\n",
        "transforms.RandomHorizontalFlip(p=0.5),\n",
        "transforms.RandomRotation(degrees=10)\n",
        "])\n",
        "\n",
        "aug_transform = transforms.Compose([\n",
        "transforms.RandomHorizontalFlip(p=0.5)\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sodf_epUN-Vt"
      },
      "source": [
        "Fit\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m5QeHu_cC-ue",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4f366535-c93a-4f81-ed4d-f6ae45e4c5be"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training Epoch 0: Loss: 1.4363257912397385, Accuracy: 0.47886\n",
            "Target Original Loss: 2.5079946517944336\n",
            "Validation Epoch 0: Valid loss: 1.181295331120491, Accuracy: 0.5733\n",
            "Training Epoch 1: Loss: 0.8971330616474151, Accuracy: 0.68112\n",
            "Training Epoch 2: Loss: 0.6310325750112533, Accuracy: 0.77724\n",
            "Training Epoch 3: Loss: 0.44488299039006235, Accuracy: 0.8439\n",
            "Training Epoch 4: Loss: 0.29879983243346214, Accuracy: 0.89514\n"
          ]
        }
      ],
      "source": [
        "if not LOADMODEL:\n",
        "  for epoch in range(epochs):\n",
        "\n",
        "    # Decrease lr\n",
        "    # Todo: Use a scheduler instead\n",
        "    if epoch == 11:\n",
        "      optimizer = torch.optim.SGD(params = model.parameters(), lr = eta/10, weight_decay = 5e-4, momentum=0.9)\n",
        "    if epoch == 21:\n",
        "      optimizer = torch.optim.SGD(params = model.parameters(), lr = eta/100, weight_decay = 5e-4, momentum=0.9)\n",
        "\n",
        "    train_loss = []\n",
        "\n",
        "    correct_preds = 0\n",
        "    total_preds = 0\n",
        "    model.train()\n",
        "    for inputs, labels, index in trainloader:\n",
        "\n",
        "      inputs, labels = inputs.to(device), labels.to(device)\n",
        "      if AUGMENTS:\n",
        "        inputs = aug_transform(inputs)\n",
        "\n",
        "      optimizer.zero_grad()            # reset the gradients to zero\n",
        "      output = model(inputs)            # Generate model outputs\n",
        "\n",
        "      if DATASET == 'CIFAR2':\n",
        "        labels = labels.to(torch.float32)\n",
        "        output = output.flatten()\n",
        "      loss = loss_fun(output, labels)   # Calculate loss\n",
        "\n",
        "      loss.backward()            # Compute gradients\n",
        "      optimizer.step()            # update parameters,\n",
        "\n",
        "      \n",
        "      #For BCELoss:\n",
        "      if DATASET == 'CIFAR2':\n",
        "        predictions = torch.where(output < 0, 0, 1)\n",
        "      else:\n",
        "        predictions = torch.argmax(output.data, dim=1)\n",
        "\n",
        "      total_preds += labels.size(0)\n",
        "      correct_preds += (predictions == labels).sum().item()\n",
        "\n",
        "      train_loss.append(loss.item())\n",
        "\n",
        "    print(\"Training Epoch {}: Loss: {}, Accuracy: {}\".format(epoch, np.mean(train_loss), correct_preds / total_preds))\n",
        "    # validation phase - once every 10 epochs\n",
        "      \n",
        "    if epoch % 10 == 0:\n",
        "      valid_losses = []\n",
        "      correct = 0\n",
        "      total = 0\n",
        "      model.eval()\n",
        "\n",
        "      # Evaluate Model\n",
        "      for inputs, labels, index in testloader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "\n",
        "\n",
        "        with torch.no_grad():\n",
        "          output = model(inputs)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            labels = labels.to(torch.float32)\n",
        "            output = output.flatten()\n",
        "          \n",
        "          # negative labels: when using hinge embedding loss only\n",
        "          flipped_labels = labels # * -1\n",
        "          loss = loss_fun(output, flipped_labels)   # Calculate loss\n",
        "          \n",
        "          valid_loss = loss_fun(output, labels)\n",
        "          valid_losses.append(valid_loss.item())\n",
        "\n",
        "          #predictions = torch.argmax(output, dim=1)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            predictions = torch.where(output < 0, 0, 1)\n",
        "          else:\n",
        "            predictions = torch.argmax(output.data, dim=1)\n",
        "          total += labels.size(0)\n",
        "          correct += (predictions == labels).sum().item()\n",
        "      \n",
        "      # Evaluate Model on Target\n",
        "      for inputs, labels, index in targetloader:\n",
        "        #\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "        with torch.no_grad():\n",
        "          output = model(inputs)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            labels = labels.to(torch.float32)\n",
        "            output = output.flatten()\n",
        "          #predictions = torch.argmax(output.data, dim=1)\n",
        "          target_loss = loss_fun(output, labels)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            predictions = torch.where(output < 0, 0, 1)\n",
        "          else:\n",
        "            predictions = torch.argmax(output.data, dim=1)\n",
        "          print(\"Target Original Loss: {}\".format(target_loss))\n",
        "\n",
        "      print(\"Validation Epoch {}: Valid loss: {}, Accuracy: {}\".format(epoch, np.mean(valid_losses), correct / total))\n",
        "  #   torch.autograd(model, target)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r0Z7SV8zCout"
      },
      "outputs": [],
      "source": [
        "# Saving Clean Model\n",
        "\n",
        "import os\n",
        "PATH = \"./drive/MyDrive/Poisoning_Machine_Unlearning/model\"\n",
        "os.makedirs(PATH, exist_ok = True) \n",
        "PATH += \"/resnet_cifar.ptr\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iTsAKd9tTAZN"
      },
      "outputs": [],
      "source": [
        "if LOADMODEL:\n",
        "  model.load_state_dict(torch.load(PATH))\n",
        "  model.to(device)\n",
        "  \n",
        "if SAVEMODEL:\n",
        "  torch.save(model.state_dict(), PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wQCH6jOwTFWX"
      },
      "outputs": [],
      "source": [
        "# Creating a subset of all machine images:\n",
        "\n",
        "#machine_index = []\n",
        "\n",
        "#for index in range(len(testset)):\n",
        "#  if testset.targets[index] == -1:\n",
        "#    machine_index.append(index)\n",
        "\n",
        "#print(machine_index)\n",
        "\n",
        "# torch.util.data.Subset()\n",
        "#machine_loader = \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UCvuj8ebXsGt"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R9S-Xo-zy5F9"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ad5k6yysaUcX"
      },
      "outputs": [],
      "source": [
        "# Inverse normalization\n",
        "\n",
        "inv_normalize = transforms.Normalize(\n",
        "   mean= [-m/s for m, s in zip(data_mean, data_std)],\n",
        "   std= [1/s for s in data_std]\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xLCHNbZcZfdO"
      },
      "outputs": [],
      "source": [
        "# Show Target Image:\n",
        "print(testset[target_index[0]][0].shape)\n",
        "plt.imshow(inv_normalize(testset[target_index[0]][0]).permute(1, 2, 0))\n",
        "target_image = data.Subset(testset, indices=target_index)\n",
        "target_label = torch.Tensor([5]).to(device).long()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iLU8bxC7bdg4"
      },
      "outputs": [],
      "source": [
        "# Poison parameters:\n",
        "# theta = model.coef_\n",
        "budget = 500 # number of poisoned images\n",
        "R = 1 # restarts\n",
        "epsilon = 16 # perturbation bound\n",
        "attackiter = 251 # optimization steps\n",
        "loss_opt = sys.maxsize # optimal loss\n",
        "delta_opt = 0 # optimal delta\n",
        "poison_opt = [] # optimal poison images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jd6Z4UIRalno"
      },
      "outputs": [],
      "source": [
        "# Choose Poison Images:\n",
        "\n",
        "poison_index = np.random.choice(poison_index, budget, replace=False)\n",
        "poison_dict = {}\n",
        "\n",
        "for index, val in enumerate(poison_index):\n",
        "  poison_dict[val] = index\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kQSyPAS9ds49"
      },
      "outputs": [],
      "source": [
        "poisonset = data.Subset(trainset, poison_index)\n",
        "poisonloader = torch.utils.data.DataLoader(poisonset, batch_size=128, drop_last=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8yI1T2EsOoyd"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P9ZmUHSbwSYq"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "groYJxo-dxsp"
      },
      "outputs": [],
      "source": [
        "# Initialize poison_delta\n",
        "\n",
        "std_tensor = torch.tensor(data_std)[None, :, None, None]\n",
        "mean_tensor = torch.tensor(data_mean)[None, :, None, None]\n",
        "\n",
        "print(std_tensor)\n",
        "print(mean_tensor)\n",
        "\n",
        "model.eval()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uG9txCdASWfh"
      },
      "outputs": [],
      "source": [
        "# Function to calculate gradient:\n",
        "\n",
        "def gradient(model, images, labels, criterion=None):\n",
        "    \"\"\"Compute the gradient of criterion(model) w.r.t to given data.\"\"\"\n",
        "\n",
        "#    labels_uns = labels.unsqueeze(1)\n",
        "#    labels_uns = labels_uns\n",
        "    if DATASET == 'CIFAR2':\n",
        "      loss = loss_fun(model(images).flatten(), labels.float())\n",
        "    else:\n",
        "      loss = loss_fun(model(images), labels)\n",
        "    gradients = torch.autograd.grad(loss, model.parameters(), only_inputs=True)\n",
        "    grad_norm = 0\n",
        "    for grad in gradients:\n",
        "        grad_norm += grad.detach().pow(2).sum()\n",
        "    grad_norm = grad_norm.sqrt()\n",
        "    return gradients, grad_norm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vetgOKJD_xst"
      },
      "outputs": [],
      "source": [
        "# Calculate target gradient:\n",
        "\n",
        "targets = torch.stack([data[0] for data in target_image], dim=0).to(device)\n",
        "intended_classes = torch.tensor([poison_class]).to(device=device, dtype=torch.long)\n",
        "true_classes = torch.tensor([data[1] for data in target_image]).to(device=device, dtype=torch.long)\n",
        "\n",
        "\n",
        "print(true_classes)\n",
        "\n",
        "target_grad, target_grad_norm = gradient(model, targets, intended_classes)\n",
        "\n",
        "print(target_grad_norm)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2CPlkTgeYLZr"
      },
      "outputs": [],
      "source": [
        "# Similarity Loss calculation for poison\n",
        "\n",
        "norm_type = 2\n",
        "support_data = {}\n",
        "\n",
        "def compute_loss(inputs, labels, support_data):\n",
        "  target_losses = 0 \n",
        "  poison_norm = 0\n",
        "\n",
        "  outputs = model(inputs)#.flatten()\n",
        "  flipped_labels = labels#* -1\n",
        "  \n",
        "  if DATASET == 'CIFAR2':\n",
        "    labels = labels.to(torch.float32)\n",
        "    outputs = outputs.flatten()\n",
        "    poison_prediction = torch.where(outputs < 0, 0, 1)\n",
        "  else:\n",
        "    poison_prediction = torch.argmax(outputs.data, dim=1)\n",
        "\n",
        "  poison_correct = (poison_prediction == labels).sum().item()\n",
        "\n",
        "  poison_loss = loss_fun(outputs, flipped_labels)\n",
        "  poison_grad = torch.autograd.grad(poison_loss, model.parameters(), retain_graph=True, create_graph=True)\n",
        "\n",
        "  indices = torch.arange(len(poison_grad))\n",
        "  #print(indices)\n",
        "  for i in indices:\n",
        "    target_losses -= (poison_grad[i] * target_grad[i]).sum()\n",
        "    poison_norm += poison_grad[i].pow(2).sum()\n",
        "\n",
        "  poison_norm = poison_norm.sqrt()\n",
        "\n",
        "  # poison_grad_norm = torch.norm(torch.stack([torch.norm(grad, norm_type).to(device) for grad in poison_grad]), norm_type)\n",
        "  target_losses /= target_grad_norm \n",
        "\n",
        "  target_losses = 1 + target_losses / poison_norm\n",
        "  target_losses.backward()\n",
        "\n",
        "  return target_losses.detach().cpu(), poison_correct\n",
        "\n",
        "\n",
        "#compute_loss(support_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uFHTe9bdXw7i"
      },
      "source": [
        "# Witches Brew"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1GKM6Dbj_nOZ"
      },
      "outputs": [],
      "source": [
        "# Brew Poisons \n",
        "\n",
        "poison_deltas = []\n",
        "minimum_loss = 1\n",
        "minimum_loss_trial = 0\n",
        "\n",
        "for trial in range(R):\n",
        "  init_lr = 0.1\n",
        "  print(\"Trial #{}:\".format(trial))\n",
        "\n",
        "  poison_delta = torch.randn(len(poison_index), *trainset[0][0].shape)\n",
        "  poison_delta *= epsilon / std_tensor / 255\n",
        "  poison_delta.data = torch.max(torch.min(poison_delta, epsilon / (std_tensor * 255)), -epsilon / (std_tensor * 255))\n",
        "\n",
        "  att_optimizer = torch.optim.Adam([poison_delta], lr=init_lr)\n",
        "\n",
        "  poison_delta.grad = torch.zeros_like(poison_delta)\n",
        "  poison_delta.requires_grad_()\n",
        "\n",
        "  poison_bounds = torch.zeros_like(poison_delta)\n",
        "  for iter in range(attackiter):\n",
        "\n",
        "      target_loss = 0\n",
        "      poison_correct = 0\n",
        "      for batch, example in enumerate(poisonloader):\n",
        "\n",
        "        inputs, labels, ids = example\n",
        "\n",
        "        inputs = inputs.to(device)\n",
        "        labels = labels.to(device)\n",
        "        if iter % 50 == 0 and batch == 0:\n",
        "          plt.imshow(inv_normalize(inputs[0]).permute(1, 2, 0).cpu().detach().numpy())\n",
        "          plt.show()\n",
        "\n",
        "        ### Add delta to the correct images\n",
        "\n",
        "        poison_slices, batch_positions = [], []\n",
        "        for batch_id, image_id in enumerate(ids.tolist()):\n",
        "            lookup = poison_dict.get(image_id)\n",
        "            if lookup is not None:\n",
        "                poison_slices.append(lookup)\n",
        "                batch_positions.append(batch_id)\n",
        "\n",
        "        if len(batch_positions) > 0:\n",
        "            delta_slice = poison_delta[poison_slices].detach().to(device)\n",
        "            delta_slice.requires_grad_()\n",
        "            poison_images = inputs[batch_positions]\n",
        "            inputs[batch_positions] += delta_slice\n",
        "        if iter % 50 == 0 and batch == 0:\n",
        "          plt.imshow(inv_normalize(inputs[0]).permute(1, 2, 0).cpu().detach().numpy())\n",
        "          plt.show()\n",
        "\n",
        "###################################################################################\n",
        "        loss, p_correct = compute_loss(inputs, labels, support_data)\n",
        "###################################################################################\n",
        "\n",
        "        # Update Step:\n",
        "        poison_delta.grad[poison_slices] = delta_slice.grad.detach().to(device=torch.device('cpu'))\n",
        "        poison_bounds[poison_slices] = poison_images.detach().to(device=torch.device('cpu'))\n",
        "        #for i in range(budget):\n",
        "      \n",
        "        #  poison_bounds[i] = poison_delta[i]\n",
        "        #print(poison_delta.grad)\n",
        "\n",
        "        target_loss += loss\n",
        "        poison_correct += p_correct\n",
        "\n",
        "      if iter % 50 == 0:\n",
        "        print(\"For iterations {} Target-Poison Loss is {}\".format(iter, target_loss/(batch + 1)))\n",
        "        print(\"For iterations {} Poison accuracy is {}\".format(iter, poison_correct / budget))\n",
        "\n",
        "      #poison_delta.grad.sign_()\n",
        "      att_optimizer.step()\n",
        "      att_optimizer.zero_grad()\n",
        "  \n",
        "      with torch.no_grad():\n",
        "        #Projection Step \n",
        "\n",
        "        poison_delta.data = torch.max(torch.min(poison_delta, epsilon / std_tensor / 255), -epsilon / std_tensor / 255)\n",
        "        poison_delta.data = torch.max(torch.min(poison_delta, (1 - mean_tensor) / std_tensor - poison_bounds), -mean_tensor / std_tensor - poison_bounds)\n",
        "\n",
        "      if iter == attackiter - 1:\n",
        "        poison_deltas.append(poison_delta)\n",
        "        if target_loss < minimum_loss: \n",
        "          minimum_loss = target_loss/(batch + 1)\n",
        "          minimum_loss_trial = trial\n",
        "\n",
        "poison_delta = poison_deltas[minimum_loss_trial]\n",
        "\n",
        "print(\"Trial #{} selected with target loss {}\".format(minimum_loss_trial, minimum_loss))       \n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E6MQ1IFa7rn5"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "PATH = \"./drive/MyDrive/Poisoning_Machine_Unlearning/model\"\n",
        "os.makedirs(PATH, exist_ok = True) \n",
        "PATH += \"/poison_cifar_l2dcharts.npy\"\n",
        "\n",
        "with open(PATH, 'wb') as f:\n",
        "  np.save(f, poison_delta.detach().numpy())\n",
        "\n",
        "#with open(PATH, 'rb') as f:\n",
        "#  poison_delta = np.load(f)\n",
        "#poison_delta = torch.from_numpy(poison_delta)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mlrhMbWEA7DW"
      },
      "outputs": [],
      "source": [
        "#Examning Poison images\n",
        "\n",
        "'''\n",
        "for batch, example in enumerate(poisonloader):\n",
        "        inputs, labels, ids = example\n",
        "\n",
        "        inputs = inputs.to(device)\n",
        "        labels = labels.to(torch.float32).to(device)\n",
        "\n",
        "        poison_order = []\n",
        "        batch_ids = []\n",
        "\n",
        "        # Use poison_dict to match poison_delta[i] to the correct poison image:\n",
        "        for batch_id, image_id in enumerate(ids.tolist()):\n",
        "            batch_ids.append(batch_id)\n",
        "            poison_order.append(poison_dict[image_id])\n",
        "    \n",
        "        delta_slice = poison_delta[poison_order].detach().to(device)\n",
        "        delta_slice.requires_grad_()\n",
        "        #poison_images = inputs[batch_ids]\n",
        "\n",
        "        for input in inputs:\n",
        "          print(\"Original images:\")\n",
        "          plt.imshow(inv_normalize(input).permute(1, 2, 0).cpu().detach().numpy())\n",
        "          plt.show()\n",
        "        inputs[batch_ids] += delta_slice.to(device)\n",
        "        for input in inputs:\n",
        "          print(\"Poisoned Images:\")\n",
        "          plt.imshow(inv_normalize(input).permute(1, 2, 0).cpu().detach().numpy())\n",
        "          plt.show()\n",
        "\n",
        "        if batch == 10:\n",
        "          break\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ltCzcqtyinjP"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-qLYeR6j0MYW"
      },
      "outputs": [],
      "source": [
        "# Retraining 1:\n",
        "if MODEL == 'VGG11':\n",
        "  vgg11_layers = get_vgg_layers(vgg11_config, batch_norm=True)\n",
        "  model2 = VGG(vgg11_layers, output_dim=10)\n",
        "elif MODEL == 'RESNET18':\n",
        "  model2 = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASS, base_width=64, initial_conv=initial_conv)\n",
        "\n",
        "\n",
        "#model2 = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASS, base_width=64, initial_conv=initial_conv)\n",
        "model2 = model2.to(device)\n",
        "optimizer = torch.optim.SGD(params = model2.parameters(), lr = eta, weight_decay = 5e-4, momentum=0.9)\n",
        "optimizer.zero_grad()\n",
        "\n",
        "for epoch in range(epochs):\n",
        "  \n",
        "  if epoch == 11:\n",
        "    optimizer = torch.optim.SGD(params = model2.parameters(), lr = eta/10, weight_decay = 5e-4, momentum=0.9)\n",
        "  if epoch == 20:\n",
        "    optimizer = torch.optim.SGD(params = model2.parameters(), lr = eta/100, weight_decay = 5e-4, momentum=0.9)\n",
        "    \n",
        "  train_loss = []\n",
        "\n",
        "  correct_preds = 0\n",
        "  total_preds = 0\n",
        "  for inputs, labels, index in trainloader:\n",
        "    model2.train()\n",
        "    inputs, labels = inputs.to(device), labels.to(device)\n",
        "    optimizer.zero_grad()            # reset the gradients to zero\n",
        "    picture_id = []\n",
        "    poison_order = []\n",
        "\n",
        "    # Use poison_dict to match poison_delta[i] to the correct poison image:\n",
        "    for order, id in enumerate(index.tolist()):\n",
        "      if poison_dict.get(id) is not None:\n",
        "        picture_id.append(order)\n",
        "        poison_order.append(poison_dict[id])\n",
        "        #print(\"{} ({}) : {}\".format(id, order, poison_dict[id])) # Check matching\n",
        "\n",
        "    if len(poison_order) > 0:\n",
        "      inputs[picture_id] += poison_delta[poison_order].to(device)\n",
        "    \n",
        "    if AUGMENTS:\n",
        "      inputs = aug_transform(inputs)\n",
        "\n",
        "    \n",
        "    output = model2(inputs)            # Generate model outputs\n",
        "    #labels = labels.to(torch.float32).unsqueeze(1)\n",
        "    \n",
        "    if DATASET == 'CIFAR2':\n",
        "      labels = labels.to(torch.float32)\n",
        "      output = output.flatten()\n",
        "\n",
        "    loss = loss_fun(output, labels)   # Calculate loss\n",
        "\n",
        "    loss.backward()            # Compute gradients\n",
        "    optimizer.step()            # update parameters,\n",
        "\n",
        "    if DATASET == 'CIFAR2':\n",
        "      predictions = torch.where(output < 0, 0, 1)\n",
        "    else:\n",
        "      predictions = torch.argmax(output.data, dim=1)\n",
        "\n",
        "    total_preds += labels.size(0)\n",
        "    correct_preds += (predictions == labels).sum().item()\n",
        "\n",
        "    train_loss.append(loss.item())\n",
        "\n",
        "  print(\"Training Epoch {}: Loss: {}, Accuracy: {}\".format(epoch, np.mean(train_loss), correct_preds / total_preds))\n",
        "  # validation phase - once every 10 epochs\n",
        "      \n",
        "  if epoch % 10 == 0:\n",
        "    model2.eval()\n",
        "    valid_losses = []\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    \n",
        "    for inputs, labels, index in targetloader:\n",
        "      inputs, labels = inputs.to(device), labels.to(device)\n",
        "      with torch.no_grad():\n",
        "        output = model2(inputs)\n",
        "        if DATASET == 'CIFAR2':\n",
        "          labels = labels.to(torch.float32)\n",
        "          output = output.flatten()\n",
        "        if DATASET == 'CIFAR2':\n",
        "          predictions = torch.where(output < 0, 0, 1)\n",
        "        else:\n",
        "          predictions = torch.argmax(output.data, dim=1)\n",
        "        print(output)\n",
        "      \n",
        "      if predictions[0] != intended_classes[0]:\n",
        "        print(\"Target is not fooled.\")\n",
        "      else:\n",
        "        print(\"Target is fooled.\")\n",
        "\n",
        "    for inputs, labels, index in testloader:\n",
        "      #\n",
        "      inputs, labels = inputs.to(device), labels.to(device)\n",
        "      with torch.no_grad():\n",
        "        output = model2(inputs)\n",
        "        if DATASET == 'CIFAR2':\n",
        "          labels = labels.to(torch.float32)\n",
        "          output = output.flatten()\n",
        "\n",
        "        valid_loss = loss_fun(output, labels) # Calculate loss\n",
        "        valid_losses.append(valid_loss.item())\n",
        "\n",
        "        if DATASET == 'CIFAR2':\n",
        "          predictions = torch.where(output < 0, 0, 1)\n",
        "        else:\n",
        "          predictions = torch.argmax(output.data, dim=1)\n",
        "        total += labels.size(0)\n",
        "        correct += (predictions == labels).sum().item()\n",
        "\n",
        "    print(\"Validation Epoch {}: Valid loss: {}, Accuracy: {}\".format(epoch, np.mean(valid_losses), correct / total))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2WH5qDc8lEf2"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cRCwI-Wd-jp_"
      },
      "outputs": [],
      "source": [
        "# Camou set:\n",
        "\n",
        "camou_index = np.random.choice(camou_index, budget, replace=False)\n",
        "camou_dict = {}\n",
        "\n",
        "for index, val in enumerate(camou_index):\n",
        "  camou_dict[val] = index\n",
        "\n",
        "camouset = data.Subset(trainset, camou_index)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b9z_CkxvnG8h"
      },
      "outputs": [],
      "source": [
        "print(len(trainset))\n",
        "camouloader = DataLoader(camouset, batch_size=128, drop_last=False)\n",
        "combinedloader = DataLoader(trainset,shuffle=True,batch_size=100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XvYiGcRyuJiU"
      },
      "outputs": [],
      "source": [
        "target_grad, target_grad_norm = gradient(model, targets, true_classes)\n",
        "\n",
        "print(target_grad_norm)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "boTSfnLghhH9"
      },
      "outputs": [],
      "source": [
        "# Brew Camou Images:\n",
        "\n",
        "camou_deltas = []\n",
        "minimum_loss = 1\n",
        "minimum_loss_trial = 0\n",
        "\n",
        "for trial in range(R):\n",
        "  init_lr = 0.1\n",
        "  if trial > -(R // -2):\n",
        "    init_lr = 0.01\n",
        "\n",
        "  print(\"Trial #{}:\".format(trial))\n",
        "\n",
        "  camou_delta = torch.randn(len(camou_index), *trainset[0][0].shape)\n",
        "  camou_delta *= epsilon / std_tensor / 255\n",
        "  camou_delta.data = torch.max(torch.min(camou_delta, epsilon / (std_tensor * 255)), -epsilon / (std_tensor * 255))\n",
        "\n",
        "  att_optimizer = torch.optim.Adam([camou_delta], lr=init_lr)\n",
        "\n",
        "  camou_delta.grad = torch.zeros_like(camou_delta)\n",
        "  camou_delta.requires_grad_()\n",
        "\n",
        "  camou_bounds = torch.zeros_like(camou_delta)\n",
        "  for iter in range(attackiter):\n",
        "\n",
        "      target_loss = 0\n",
        "      camou_correct = 0\n",
        "      for batch, example in enumerate(camouloader):\n",
        "\n",
        "        inputs, labels, ids = example\n",
        "\n",
        "        inputs = inputs.to(device)\n",
        "        labels = labels.to(device)\n",
        "\n",
        "        if DATASET == 'CIFAR2':\n",
        "          labels = labels.to(torch.float32)\n",
        "\n",
        "        if iter % 50 == 0 and batch == 0:\n",
        "          plt.imshow(inv_normalize(inputs[0]).permute(1, 2, 0).cpu().detach().numpy())\n",
        "          plt.show()\n",
        "\n",
        "        ### Add delta to the correct images\n",
        "\n",
        "        camou_slices, batch_positions = [], []\n",
        "        for batch_id, image_id in enumerate(ids.tolist()):\n",
        "            lookup = camou_dict.get(image_id)\n",
        "            if lookup is not None:\n",
        "                camou_slices.append(lookup)\n",
        "                batch_positions.append(batch_id)\n",
        "\n",
        "        if len(batch_positions) > 0:\n",
        "            delta_slice = camou_delta[camou_slices].detach().to(device)\n",
        "            delta_slice.requires_grad_()\n",
        "            camou_images = inputs[batch_positions]\n",
        "            inputs[batch_positions] += delta_slice\n",
        "        if iter % 50 == 0 and batch == 0:\n",
        "          plt.imshow(inv_normalize(inputs[0]).permute(1, 2, 0).cpu().detach().numpy())\n",
        "          plt.show()\n",
        "\n",
        "###################################################################################\n",
        "        loss, p_correct = compute_loss(inputs, labels, support_data)\n",
        "###################################################################################\n",
        "\n",
        "        # Update Step:\n",
        "        camou_delta.grad[camou_slices] = delta_slice.grad.detach().to(device=torch.device('cpu'))\n",
        "        camou_bounds[camou_slices] = camou_images.detach().to(device=torch.device('cpu'))\n",
        "\n",
        "\n",
        "        target_loss += loss\n",
        "        camou_correct += p_correct\n",
        "\n",
        "      if iter % 50 == 0:\n",
        "        print(\"For iterations {} Target-Camou Loss is {}\".format(iter, target_loss/(batch + 1)))\n",
        "        print(\"For iterations {} Camou accuracy is {}\".format(iter, camou_correct / budget))\n",
        "\n",
        "      att_optimizer.step()\n",
        "      att_optimizer.zero_grad()\n",
        "  \n",
        "      with torch.no_grad():\n",
        "        camou_delta.data = torch.max(torch.min(camou_delta, epsilon / std_tensor / 255), -epsilon / std_tensor / 255)\n",
        "        camou_delta.data = torch.max(torch.min(camou_delta, (1 - mean_tensor) / std_tensor - camou_bounds), -mean_tensor / std_tensor - camou_bounds)\n",
        "\n",
        "  camou_deltas.append(camou_delta)\n",
        "  if target_loss < minimum_loss: \n",
        "    minimum_loss = target_loss\n",
        "    minimum_loss_trial = trial\n",
        "\n",
        "camou_delta = camou_deltas[minimum_loss_trial]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2F0B2UJq7zvp"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "PATH = \"./drive/MyDrive/Poisoning_Machine_Unlearning/model\"\n",
        "os.makedirs(PATH, exist_ok = True) \n",
        "PATH += \"/camou_cifar_l2dcharts.npy\"\n",
        "\n",
        "with open(PATH, 'wb') as f:\n",
        "  np.save(f, camou_delta.detach().numpy())\n",
        "\n",
        "#with open(PATH, 'rb') as f:\n",
        "#  camou_delta = np.load(f)\n",
        "#camou_delta = torch.from_numpy(camou_delta)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(len(poisonset))"
      ],
      "metadata": {
        "id": "Cx1u0z0D0oAC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "era7kwYNheI9"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "for batch, example in enumerate(camouloader):\n",
        "        inputs, labels, ids = example\n",
        "\n",
        "        inputs = inputs.to(device)\n",
        "        labels = labels.to(device)\n",
        "\n",
        "        if DATASET == 'CIFAR2':\n",
        "          labels = labels.to(torch.float32)\n",
        "          \n",
        "        camou_order = []\n",
        "        batch_ids = []\n",
        "\n",
        "        # Use poison_dict to match poison_delta[i] to the correct poison image:\n",
        "        for batch_id, image_id in enumerate(ids.tolist()):\n",
        "            batch_ids.append(batch_id)\n",
        "            camou_order.append(camou_dict[image_id])\n",
        "    \n",
        "        delta_slice = camou_delta[camou_order].detach().to(device)\n",
        "        delta_slice.requires_grad_()\n",
        "        #poison_images = inputs[batch_ids]\n",
        "\n",
        "        inputs[batch_ids] += delta_slice.to(device)\n",
        "        for input in inputs:\n",
        "          print(\"Poisoned Images:\")\n",
        "          plt.imshow(inv_normalize(input).permute(1, 2, 0).cpu().detach().numpy())\n",
        "          plt.show()\n",
        "\n",
        "        if batch == 10:\n",
        "          break\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nH1yCNRg5_v6"
      },
      "outputs": [],
      "source": [
        "# Retraining # 2 with camouflage images:\n",
        "vgg11_layers = get_vgg_layers(vgg11_config, batch_norm=True)\n",
        "\n",
        "if MODEL == 'VGG11':\n",
        "  model2 = VGG(vgg11_layers, output_dim=10)\n",
        "elif MODEL == 'RESNET18':\n",
        "  model2 = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASS, base_width=64, initial_conv=initial_conv)\n",
        "\n",
        "#model2 = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASS, base_width=64, initial_conv=initial_conv)\n",
        "model2 = model2.to(device)\n",
        "optimizer = torch.optim.SGD(params = model2.parameters(), lr = eta, weight_decay = 5e-4, momentum=0.9)\n",
        "optimizer.zero_grad()\n",
        "\n",
        "for epoch in range(epochs):\n",
        "  \n",
        "  if epoch == 11:\n",
        "    optimizer = torch.optim.SGD(params = model2.parameters(), lr = eta/10, weight_decay = 5e-4, momentum=0.9)\n",
        "  if epoch == 20:\n",
        "    optimizer = torch.optim.SGD(params = model2.parameters(), lr = eta/100, weight_decay = 5e-4, momentum=0.9)\n",
        "    \n",
        "  train_loss = []\n",
        "\n",
        "  correct_preds = 0\n",
        "  total_preds = 0\n",
        "  for inputs, labels, index in combinedloader:\n",
        "    model2.train()\n",
        "    #\n",
        "    inputs, labels = inputs.to(device), labels.to(device)\n",
        "    optimizer.zero_grad()            # reset the gradients to zero\n",
        "\n",
        "    picture_id = []\n",
        "    poison_order = []\n",
        "\n",
        "    picture_cid = []\n",
        "    camou_order = []\n",
        "\n",
        "    for order, id in enumerate(index.tolist()):\n",
        "      if poison_dict.get(id) is not None:\n",
        "        picture_id.append(order)\n",
        "        poison_order.append(poison_dict[id])\n",
        "        \n",
        "    for order, id in enumerate(index.tolist()):\n",
        "      if camou_dict.get(id) is not None:\n",
        "        picture_cid.append(order)\n",
        "        camou_order.append(camou_dict[id])\n",
        "\n",
        "    if len(camou_order) > 0:\n",
        "      inputs[picture_cid] += camou_delta[camou_order].to(device)\n",
        "\n",
        "    if len(poison_order) > 0:\n",
        "      inputs[picture_id] += poison_delta[poison_order].to(device)\n",
        "    \n",
        "    if AUGMENTS:\n",
        "      inputs = aug_transform(inputs)\n",
        "    \n",
        "    output = model2(inputs)            # Generate model outputs\n",
        "    #labels = labels.to(torch.float32).unsqueeze(1)\n",
        "    if DATASET == 'CIFAR2':\n",
        "      labels = labels.to(torch.float32)\n",
        "      output = output.flatten()\n",
        "\n",
        "    # negative labels: when using hinge embedding loss only\n",
        "    flipped_labels = labels# * -1\n",
        "    loss = loss_fun(output, flipped_labels)   # Calculate loss\n",
        "\n",
        "    loss.backward()            # Compute gradients\n",
        "    optimizer.step()            # update parameters,\n",
        "\n",
        "    #predictions = torch.argmax(output, dim=1)\n",
        "    #trainset.classes = ['machine', 'animal']\n",
        "    \n",
        "    #predictions = torch.argmax(output.data, dim=1)\n",
        "    if DATASET == 'CIFAR2':\n",
        "      predictions = torch.where(output < 0, 0, 1)\n",
        "    else:\n",
        "      predictions = torch.argmax(output.data, dim=1)\n",
        "\n",
        "    total_preds += labels.size(0)\n",
        "    correct_preds += (predictions == labels).sum().item()\n",
        "\n",
        "    train_loss.append(loss.item())\n",
        "\n",
        "  print(\"Training Epoch {}: Loss: {}, Accuracy: {}\".format(epoch, np.mean(train_loss), correct_preds / total_preds))\n",
        "  # validation phase - once every 10 epochs\n",
        "      \n",
        "  if epoch % 10 == 0:\n",
        "    valid_losses = []\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    model2.eval()\n",
        "    for inputs, labels, index in targetloader:\n",
        "      #\n",
        "      inputs, labels = inputs.to(device), labels.to(device)\n",
        "      with torch.no_grad():\n",
        "        output = model2(inputs)\n",
        "        if DATASET == 'CIFAR2':\n",
        "          labels = labels.to(torch.float32)\n",
        "          output = output.flatten()\n",
        "        #predictions = torch.argmax(output.data, dim=1)\n",
        "        if DATASET == 'CIFAR2':\n",
        "          predictions = torch.where(output < 0, 0, 1)\n",
        "        else:\n",
        "          predictions = torch.argmax(output.data, dim=1)\n",
        "        print(output)\n",
        "      \n",
        "      if predictions[0] != intended_classes[0]:\n",
        "        print(\"Target is not fooled.\")\n",
        "      else:\n",
        "        print(\"Target is fooled.\")\n",
        "\n",
        "    for inputs, labels, index in testloader:\n",
        "      #\n",
        "      inputs, labels = inputs.to(device), labels.to(device)\n",
        "      with torch.no_grad():\n",
        "        output = model2(inputs)\n",
        "        if DATASET == 'CIFAR2':\n",
        "          labels = labels.to(torch.float32)\n",
        "          output = output.flatten()\n",
        "        \n",
        "        valid_loss = loss_fun(output, labels) # Calculate loss\n",
        "        valid_losses.append(valid_loss.item())\n",
        "\n",
        "        #predictions = torch.argmax(output, dim=1)\n",
        "        if DATASET == 'CIFAR2':\n",
        "          predictions = torch.where(output < 0, 0, 1)\n",
        "        else:\n",
        "          predictions = torch.argmax(output.data, dim=1)\n",
        "        total += labels.size(0)\n",
        "        correct += (predictions == labels).sum().item()\n",
        "\n",
        "    print(\"Validation Epoch {}: Valid loss: {}, Accuracy: {}\".format(epoch, np.mean(valid_losses), correct / total))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xFFi37RIavF_"
      },
      "outputs": [],
      "source": [
        "# Detach deltas to use in plots\n",
        "\n",
        "poison_detach = poison_delta.detach()\n",
        "camou_detach = camou_delta.detach()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mzh7kaoqPHFo"
      },
      "outputs": [],
      "source": [
        "# Plots and figures for papers\n",
        "'''\n",
        "fig = plt.figure(figsize=(5,3))\n",
        "plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)  \n",
        "# setting values to rows and column variables\n",
        "rows = 3\n",
        "columns = 5\n",
        "\n",
        "for i in range(3):\n",
        "  for j in range(1,6):\n",
        "    index = (i*5)+j-1\n",
        "    fig.add_subplot(rows, columns, i*5+j)\n",
        "    plt.imshow(inv_normalize(trainset[poison_index[index]][0]).permute(1, 2, 0)) \n",
        "    plt.axis('off')\n",
        "plt.savefig(f\"/content/drive/MyDrive/dummy_cifar_photogrid3.png\", bbox_inches='tight')\n",
        "\n",
        "\n",
        "fig = plt.figure(figsize=(5,3))\n",
        "  \n",
        "# setting values to rows and column variables\n",
        "rows = 3\n",
        "columns = 5\n",
        "\n",
        "plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)\n",
        "from google.colab import files\n",
        "#plt.tight_layout()\n",
        "\n",
        "\n",
        "for i in range(3):\n",
        "  for j in range(1,6):\n",
        "    index = (i*5)+j-1\n",
        "    poison_order = poison_index[index]\n",
        "    fig.add_subplot(rows, columns, i*5+j)\n",
        "    plt.imshow(inv_normalize(trainset[poison_order][0] + poison_detach[poison_dict[poison_order]]).permute(1, 2, 0))\n",
        "    plt.axis('off')\n",
        "plt.savefig(f\"/content/drive/MyDrive/dummy_cifar_photogrid4.png\", bbox_inches='tight')\n",
        "\n",
        "plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VVNf-DP8bC68"
      },
      "outputs": [],
      "source": [
        "# Plots and figures for papers\n",
        "'''\n",
        "fig = plt.figure(figsize=(5,3))\n",
        "plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)  \n",
        "# setting values to rows and column variables\n",
        "rows = 3\n",
        "columns = 5\n",
        "\n",
        "for i in range(3):\n",
        "  for j in range(1,6):\n",
        "    index = (i*5)+j-1\n",
        "    fig.add_subplot(rows, columns, i*5+j)\n",
        "    plt.imshow(inv_normalize(trainset[camou_index[index]][0]).permute(1, 2, 0)) \n",
        "    plt.axis('off')\n",
        "plt.savefig(f\"/content/drive/MyDrive/dummy_cifar_photogrid3.png\", bbox_inches='tight')\n",
        "\n",
        "\n",
        "fig = plt.figure(figsize=(5,3))\n",
        "  \n",
        "# setting values to rows and column variables\n",
        "rows = 3\n",
        "columns = 5\n",
        "\n",
        "plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)\n",
        "from google.colab import files\n",
        "#plt.tight_layout()\n",
        "\n",
        "\n",
        "for i in range(3):\n",
        "  for j in range(1,6):\n",
        "    index = (i*5)+j-1\n",
        "    camou_order = camou_index[index]\n",
        "    fig.add_subplot(rows, columns, i*5+j)\n",
        "    plt.imshow(inv_normalize(trainset[camou_order][0] + camou_detach[camou_dict[camou_order]]).permute(1, 2, 0))\n",
        "    plt.axis('off')\n",
        "plt.savefig(f\"/content/drive/MyDrive/dummy_cifar_photogrid4.png\", bbox_inches='tight')\n",
        "\n",
        "plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kPSUpa2VL4sC"
      },
      "outputs": [],
      "source": [
        "plt.imshow(inv_normalize(targetset[0][0]).permute(1, 2, 0))\n",
        "plt.axis('off')\n",
        "plt.savefig(f\"/content/drive/MyDrive/target_deer.png\", bbox_inches='tight')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P1mraa5GikFa"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BpVgevpXJLmw"
      },
      "source": [
        "## Compute L2 Distance graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4xD0HcZ_JNWt"
      },
      "outputs": [],
      "source": [
        "import scipy.spatial.distance as dist\n",
        "\n",
        "target_image = targetset[0][0].flatten()\n",
        "\n",
        "poison_images = torch.stack([(data[0] + poison_detach[poison_dict[data[2]]]).flatten() for data in poisonset], dim=0)\n",
        "clean_poison_images = torch.stack([data[0].flatten() for data in poisonset], dim=0)\n",
        "poison_class_set = data.Subset(trainset, trainset.get_index(poison_class))\n",
        "poison_class_images = torch.stack([data[0].flatten()  for data in poison_class_set], dim=0)\n",
        "\n",
        "\n",
        "camou_images = torch.stack([(data[0] + camou_detach[camou_dict[data[2]]]).flatten() for data in camouset], dim=0)\n",
        "clean_camou_images = torch.stack([data[0].flatten() for data in camouset], dim=0)\n",
        "camou_class_set = data.Subset(trainset, trainset.get_index(camou_class))\n",
        "camou_class_images = torch.stack([data[0].flatten()  for data in camou_class_set], dim=0)\n",
        "\n",
        "\n",
        "l2_ds_target_poison = []\n",
        "l2_ds_target_poison_class = []\n",
        "l2_ds_poisoncentroid_poison = []\n",
        "l2_ds_poisoncentroid_poison_poison_class = []\n",
        "\n",
        "poison_centroid = torch.mean(poison_class_images) \n",
        "\n",
        "for image in poison_images:\n",
        "  l2d_t = dist.euclidean(image, target_image)\n",
        "  l2_ds_target_poison_class.append(l2d_t)\n",
        "\n",
        "  l2d_pc = dist.euclidean(image, poison_centroid)\n",
        "  l2_ds_poisoncentroid_poison.append(l2d_pc)\n",
        "\n",
        "for image in poison_class_images:\n",
        "  l2d_t = dist.euclidean(image, target_image)\n",
        "  l2_ds_target_poison.append(l2d_t)\n",
        "\n",
        "  l2d_pc = dist.euclidean(image, poison_centroid)\n",
        "  l2_ds_poisoncentroid_poison_poison_class.append(l2d_pc)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "centroids_bin = np.arange(10, 110, 5)\n",
        "target_bin =  np.arange(30, 140, 5)"
      ],
      "metadata": {
        "id": "j2Lhb3qQ1idY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1NwrfgNiVyDC"
      },
      "outputs": [],
      "source": [
        "plt.hist(l2_ds_target_poison_class, bins=target_bin, alpha = 0.5, density=True, label=\"total poison class\")\n",
        "plt.hist(l2_ds_target_poison, bins=target_bin, alpha = 0.4, density=True, label=\"poisons\")\n",
        "plt.title(\"L2 Distance to Target Image Features\", fontsize=14)\n",
        "plt.ylabel(\"Count\", fontsize=14)\n",
        "#plt.xlabel(\"L2 Distance to Target Image Features\")\n",
        "plt.legend(fontsize=12)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fSLy6TsPw1wC"
      },
      "outputs": [],
      "source": [
        "plt.hist(l2_ds_poisoncentroid_poison_poison_class, bins = centroids_bin, alpha = 0.5, density=True, label=\"total poison class\")\n",
        "plt.hist(l2_ds_poisoncentroid_poison, bins = centroids_bin, alpha = 0.4, density=True, label=\"poisons\")\n",
        "plt.title(\"L2 Distance to Poison Class Mean\", fontsize=14)\n",
        "plt.ylabel(\"Count\", fontsize=14)\n",
        "plt.legend(fontsize=12)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IMDHtwHaxWqf"
      },
      "outputs": [],
      "source": [
        "l2_ds_target_camou = []\n",
        "l2_ds_target_camou_class = []\n",
        "l2_ds_camoucentroid_camou = []\n",
        "l2_ds_camoucentroid_camou_class = []\n",
        "\n",
        "camou_centroid = torch.mean(camou_class_images) \n",
        "\n",
        "for image in camou_images:\n",
        "  l2d_t = dist.euclidean(image, target_image)\n",
        "  l2_ds_target_camou_class.append(l2d_t)\n",
        "\n",
        "  l2d_pc = dist.euclidean(image, camou_centroid)\n",
        "  l2_ds_camoucentroid_camou.append(l2d_pc)\n",
        "\n",
        "for image in camou_class_images:\n",
        "  l2d_t = dist.euclidean(image, target_image)\n",
        "  l2_ds_target_camou.append(l2d_t)\n",
        "\n",
        "  l2d_pc = dist.euclidean(image, camou_centroid)\n",
        "  l2_ds_camoucentroid_camou_class.append(l2d_pc)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zz3fSi7Bx7gA"
      },
      "outputs": [],
      "source": [
        "plt.hist(l2_ds_camoucentroid_camou_class, bins = centroids_bin, alpha = 0.5, density=True, label=\"total camouflage class\")\n",
        "plt.hist(l2_ds_camoucentroid_camou, bins = centroids_bin, alpha = 0.4, density=True, label=\"camouflages\")\n",
        "plt.title(\"L2 Distance to Camouflage Class Mean\", fontsize=14)\n",
        "plt.ylabel(\"Count\", fontsize=14)\n",
        "plt.legend(fontsize=12)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PplVV7r3yca4"
      },
      "outputs": [],
      "source": [
        "plt.hist(l2_ds_target_camou_class, bins = target_bin,alpha = 0.5, density=True, label=\"total\\ncamouflage class\")\n",
        "plt.hist(l2_ds_target_camou, bins = target_bin , alpha = 0.4, density=True, label=\"camouflages\")\n",
        "plt.title(\"L2 Distance to Target Image Features\", fontsize=14)\n",
        "plt.ylabel(\"Count\", fontsize=14)\n",
        "plt.legend(fontsize=12)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "background_execution": "on",
      "collapsed_sections": [],
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}