{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "398suIH8O0ia"
      },
      "outputs": [],
      "source": [
        "import argparse\n",
        "import os\n",
        "import random\n",
        "import shutil\n",
        "import time\n",
        "import warnings\n",
        "from enum import Enum\n",
        "import sys \n",
        "\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.parallel\n",
        "import torch.backends.cudnn as cudnn\n",
        "import torch.distributed as dist\n",
        "import torch.optim\n",
        "from torch.nn.modules.loss import CrossEntropyLoss\n",
        "from torch.optim.lr_scheduler import StepLR\n",
        "import torch.multiprocessing as mp\n",
        "import torch.utils.data\n",
        "import torch.utils.data.distributed\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "import torchvision.datasets as datasets\n",
        "import torchvision.models as models\n",
        "from torch.utils.data import Subset\n",
        "from torchvision.models import resnet18, ResNet18_Weights\n",
        "import numpy as np      \n",
        "import matplotlib.pyplot as plt\n",
        "import time\n",
        "from scipy.special import softmax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RTUDfMtZ_rQS"
      },
      "outputs": [],
      "source": [
        "# Set seed here for target / poisons / camous selection\n",
        "\n",
        "seed = 211111111\n",
        "\n",
        "torch.manual_seed(seed + 1)\n",
        "torch.cuda.manual_seed(seed + 2)\n",
        "torch.cuda.manual_seed_all(seed + 3)\n",
        "np.random.seed(seed + 4)\n",
        "torch.cuda.manual_seed_all(seed + 5)\n",
        "random.seed(seed + 6)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ae0L53oPO5aJ"
      },
      "outputs": [],
      "source": [
        "from torch.nn.modules.transformer import TransformerDecoderLayer\n",
        "# Overwrite getitem method to obtain the index of the images when iterating through the images\n",
        "\n",
        "\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset, DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YHswvZYyO8ZP",
        "outputId": "84b3ea34-2281-4425-ed8a-80de8609f689"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sX1tRX6ePCnT"
      },
      "outputs": [],
      "source": [
        "DATASET = 'CIFAR10'      # Choose between 'CIFAR2', 'CIFAR10'\n",
        "MODEL = 'RESNET18'       # Choose between 'RESNET18', 'VGG11'\n",
        "AUGMENTS = 1        # Use Data Augmentation\n",
        "SAVEMODEL = 0       # Save Clean Model \n",
        "LOADMODEL = 1         # Load Clean Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wHeAkBsDPGNA"
      },
      "outputs": [],
      "source": [
        "# Class Dictionary for CIFAR10\n",
        "classDict = {'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4,\n",
        "             'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}\n",
        "\n",
        "binaryClasses = {0:'Machine', 1:'Animal'} # Machine , Animal"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CyRqHwpRPWQn"
      },
      "outputs": [],
      "source": [
        "if DATASET == 'CIFAR2' or DATASET == 'CIFAR10':\n",
        "  # Mean and std of CIFAR10:\n",
        "  data_mean = (0.4914, 0.4822, 0.4465)\n",
        "  data_std = (0.2470, 0.2435, 0.2616)\n",
        "else:\n",
        "  # Mean and std of Imagenet:\n",
        "  data_mean = (0.485, 0.456, 0.406)\n",
        "  data_std = (0.229, 0.224, 0.225)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KW-qpjfjPYaQ"
      },
      "outputs": [],
      "source": [
        "from torch.nn.modules.transformer import TransformerDecoderLayer\n",
        "# Overwrite getitem method to obtain the index of the images when iterating through the images\n",
        "\n",
        "\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "\n",
        "\n",
        "class CIFAR10(Dataset):\n",
        "    def __init__(self, train, transform):\n",
        "        self.cifar10 = torchvision.datasets.CIFAR10(\n",
        "                        root='./data', train=train, download=True, transform=transform)\n",
        "        self.targets = self.cifar10.targets\n",
        "        self.classes = self.cifar10.classes\n",
        "        self.data = self.cifar10.data\n",
        "        \n",
        "  \n",
        "    # Overloaded the getitem method to return index as well\n",
        "    def __getitem__(self, index):\n",
        "        data, target = self.cifar10[index]\n",
        "        return data, target, index\n",
        "    \n",
        "    # Method to get all images' indices from a certain class without iterating through the loader\n",
        "    def get_index(self, target_label):\n",
        "      index_list = []\n",
        "      for index, label in enumerate(self.targets):\n",
        "        if label == target_label:\n",
        "          index_list.append(index)\n",
        "      return index_list\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.cifar10)\n",
        "\n",
        "    def remove(self, remove_list):\n",
        "      mask = np.ones(len(self.cifar10), dtype=bool)\n",
        "      mask[remove_list] = False\n",
        "      data = self.data[mask]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AAZelpqLPlIW",
        "outputId": "1942ad36-2340-4ff8-a896-a085e7432ea2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 170498071/170498071 [00:12<00:00, 13256714.08it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Extracting ./data/cifar-10-python.tar.gz to ./data\n",
            "Files already downloaded and verified\n"
          ]
        }
      ],
      "source": [
        "# Data Prep.\n",
        "\n",
        "inv_normalize = transforms.Normalize(\n",
        "   mean= [-m/s for m, s in zip(data_mean, data_std)],\n",
        "   std= [1/s for s in data_std]\n",
        ")\n",
        "\n",
        "transform_train = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(data_mean, data_std),\n",
        "])\n",
        "\n",
        "transform_test = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(data_mean, data_std),\n",
        "])\n",
        "\n",
        "if DATASET == 'CIFAR2' or DATASET == 'CIFAR10':\n",
        "  trainset = CIFAR10(train=True, transform=transform_train)\n",
        "  testset = CIFAR10(train=False, transform=transform_test)\n",
        "\n",
        "# Binary Classfication\n",
        "\n",
        "NUM_CLASS = 10\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SVe-TmQwPbuX"
      },
      "outputs": [],
      "source": [
        "\n",
        "trainloader = torch.utils.data.DataLoader(\n",
        "    trainset, batch_size=100, shuffle=True)\n",
        "testloader = torch.utils.data.DataLoader(\n",
        "    testset, batch_size=100, shuffle=False)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ov2y5U-mPeQn"
      },
      "outputs": [],
      "source": [
        "'''ResNet in PyTorch.\n",
        "\n",
        "For Pre-activation ResNet, see 'preact_resnet.py'.\n",
        "\n",
        "Reference:\n",
        "[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun\n",
        "    Deep Residual Learning for Image Recognition. arXiv:1512.03385\n",
        "'''\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import pdb\n",
        "import numpy as np\n",
        "\n",
        "class BasicBlock(nn.Module):\n",
        "    expansion = 1\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1, train_dp=0, test_dp=0, droplayer=0, bdp=0):\n",
        "        # if test_dp > 0: will always keep dp there\n",
        "        super(BasicBlock, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "        self.train_dp = train_dp\n",
        "        self.test_dp = test_dp\n",
        "\n",
        "        self.droplayer = droplayer\n",
        "\n",
        "    def forward(self, x):\n",
        "        action = np.random.binomial(1, self.droplayer)\n",
        "        if action == 1:\n",
        "            out = self.shortcut(x)\n",
        "        else:\n",
        "            out = F.relu(self.bn1(self.conv1(x)))\n",
        "            if self.test_dp > 0 or (self.training and self.train_dp>0):\n",
        "                dp = max(self.test_dp, self.train_dp)\n",
        "                out = F.dropout(out, dp, training=True)\n",
        "            out = self.bn2(self.conv2(out))\n",
        "            out += self.shortcut(x)\n",
        "\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class Bottleneck(nn.Module):\n",
        "    expansion = 4\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1, train_dp=0, test_dp=0, droplayer=0, bdp=0):\n",
        "        super(Bottleneck, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)\n",
        "        self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "\n",
        "        self.train_dp = train_dp\n",
        "        self.test_dp = test_dp\n",
        "        self.bdp = bdp\n",
        "\n",
        "        self.droplayer = droplayer\n",
        "\n",
        "    def forward(self, x):\n",
        "        action = np.random.binomial(1, self.droplayer)\n",
        "        if action == 1:\n",
        "            out = self.shortcut(x)\n",
        "        else:\n",
        "            out = F.relu(self.bn1(self.conv1(x)))\n",
        "            out = F.relu(self.bn2(self.conv2(out)))\n",
        "\n",
        "            if self.test_dp > 0 or (self.training and self.train_dp>0):\n",
        "                dp = max(self.test_dp, self.train_dp)\n",
        "                out = F.dropout(out, dp, training=True)\n",
        "            if self.bdp > 0:\n",
        "                # each sample will be applied the same mask\n",
        "                bdp_mask = torch.bernoulli(\n",
        "                    self.bdp * torch.ones(1, out.size(1), out.size(2), out.size(3)).to(out.device)) / self.bdp\n",
        "                out = bdp_mask * out\n",
        "\n",
        "            out = self.bn3(self.conv3(out))\n",
        "            out += self.shortcut(x)\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "    def __init__(self, block, num_blocks, num_classes=10, train_dp=0, test_dp=0, droplayer=0, bdp=0,\n",
        "                                            middle_feat_num=1):\n",
        "        super(ResNet, self).__init__()\n",
        "        self.in_planes = 64\n",
        "\n",
        "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(64)\n",
        "\n",
        "        nblks = sum(num_blocks)\n",
        "        dl_step = droplayer / nblks\n",
        "\n",
        "        dl_start = 0\n",
        "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, train_dp=train_dp, test_dp=test_dp,\n",
        "                                       dl_start=dl_start, dl_step=dl_step, bdp=bdp)\n",
        "\n",
        "        dl_start += dl_step * num_blocks[0]\n",
        "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, train_dp=train_dp, test_dp=test_dp,\n",
        "                                       dl_start=dl_start, dl_step=dl_step, bdp=bdp)\n",
        "\n",
        "        dl_start += dl_step * num_blocks[1]\n",
        "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, train_dp=train_dp, test_dp=test_dp,\n",
        "                                       dl_start=dl_start, dl_step=dl_step, bdp=bdp)\n",
        "\n",
        "        dl_start += dl_step * num_blocks[2]\n",
        "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, train_dp=train_dp, test_dp=test_dp,\n",
        "                                       dl_start=dl_start, dl_step=dl_step, bdp=bdp)\n",
        "        self.linear = nn.Linear(512*block.expansion, num_classes)\n",
        "\n",
        "        self.test_dp = test_dp\n",
        "        self.middle_feat_num = middle_feat_num\n",
        "\n",
        "    def get_block_feats(self, x):\n",
        "        feat_list = []\n",
        "\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.layer1(out)\n",
        "        # feat_list.append(out)\n",
        "\n",
        "        out = self.layer2(out)\n",
        "        # feat_list.append(out)\n",
        "\n",
        "        out = self.layer3(out)\n",
        "        # feat_list.append(out)\n",
        "\n",
        "        # out = self.layer4(out)\n",
        "        for nl, layer in enumerate(self.layer4):\n",
        "            out = layer(out)\n",
        "            if len(self.layer4) - nl - 1 <= self.middle_feat_num and len(self.layer4) - nl - 1 > 0:\n",
        "                feat_list.append(out)\n",
        "\n",
        "        out = F.avg_pool2d(out, 4)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        feat_list.append(out)\n",
        "\n",
        "        return feat_list\n",
        "\n",
        "    def set_testdp(self, dp):\n",
        "        for layer in self.layer1:\n",
        "            layer.test_dp = dp\n",
        "        for layer in self.layer2:\n",
        "            layer.test_dp = dp\n",
        "        for layer in self.layer3:\n",
        "            layer.test_dp = dp\n",
        "        for layer in self.layer4:\n",
        "            layer.test_dp = dp\n",
        "\n",
        "    def _make_layer(self, block, planes, num_blocks, stride, train_dp=0, test_dp=0, dl_start=9, dl_step=0, bdp=0):\n",
        "        strides = [stride] + [1]*(num_blocks-1)\n",
        "        layers = []\n",
        "        for ns, stride in enumerate(strides):\n",
        "            layers.append(block(self.in_planes, planes, stride, train_dp=train_dp, test_dp=test_dp,\n",
        "                                droplayer=dl_start+dl_step*ns, bdp=bdp))\n",
        "            self.in_planes = planes * block.expansion\n",
        "        return nn.Sequential(*layers)\n",
        "\n",
        "    def penultimate(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.layer1(out)\n",
        "        out = self.layer2(out)\n",
        "        out = self.layer3(out)\n",
        "        out = self.layer4(out)\n",
        "        out = F.avg_pool2d(out, 4)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        return out\n",
        "\n",
        "    def reset_last_layer(self):\n",
        "        self.linear.weight.data.normal_(0, 0.1)\n",
        "        self.linear.bias.data.zero_()\n",
        "\n",
        "    def forward(self, x, penu=False, block=False):\n",
        "        if block:\n",
        "            return self.get_block_feats(x)\n",
        "\n",
        "        out = self.penultimate(x)\n",
        "        if penu:\n",
        "            return out\n",
        "        out = self.linear(out)\n",
        "        return out\n",
        "\n",
        "    def get_penultimate_params_list(self):\n",
        "        return [param for name, param in self.named_parameters() if 'linear' in name]\n",
        "\n",
        "def ResNet18(train_dp=0, test_dp=0, droplayer=0, bdp=0):\n",
        "    return ResNet(BasicBlock, [2,2,2,2], train_dp=train_dp, test_dp=test_dp, droplayer=droplayer, bdp=bdp)\n",
        "\n",
        "def ResNet34(train_dp=0, test_dp=0, droplayer=0):\n",
        "    return ResNet(BasicBlock, [3,4,6,3], train_dp=train_dp, test_dp=test_dp, droplayer=droplayer)\n",
        "\n",
        "def ResNet50(train_dp=0, test_dp=0, droplayer=0, bdp=0):\n",
        "    return ResNet(Bottleneck, [3,4,6,3], train_dp=train_dp, test_dp=test_dp, droplayer=droplayer, bdp=bdp)\n",
        "\n",
        "def ResNet101(train_dp=0, test_dp=0, droplayer=0):\n",
        "    return ResNet(Bottleneck, [3,4,23,3], train_dp=train_dp, test_dp=test_dp, droplayer=droplayer)\n",
        "\n",
        "def ResNet152(train_dp=0, test_dp=0, droplayer=0):\n",
        "    return ResNet(Bottleneck, [3,8,36,3], train_dp=train_dp, test_dp=test_dp, droplayer=droplayer)\n",
        "\n",
        "\n",
        "def test():\n",
        "    net = ResNet18()\n",
        "    y = net(torch.randn(1,3,32,32))\n",
        "    print(y.size())\n",
        "\n",
        "# test()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHQ5OtY6QuOT"
      },
      "outputs": [],
      "source": [
        "initial_conv = [3, 1, 1]\n",
        "model = ResNet18()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nQLHbuMmQznw",
        "outputId": "53b351b8-a251-4f9f-fe61-c1de6ab6e850"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "8\n",
            "5\n"
          ]
        }
      ],
      "source": [
        "# Randomly select poison and target class:\n",
        "# Assume Camouflage chosen from the same class as target.\n",
        "\n",
        "avail_classes = np.arange(NUM_CLASS)\n",
        "[target_class, poison_class] = np.random.choice(avail_classes, replace=False, size=2)\n",
        "camou_class = target_class\n",
        "\n",
        "print(target_class)\n",
        "print(poison_class)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IU3RsdLDQ0yo",
        "outputId": "7feb8578-d620-47cc-9487-1cf20526a01c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Running on a GPU\n"
          ]
        }
      ],
      "source": [
        "# Setting up training params\n",
        "\n",
        "epochs = 41\n",
        "eta = 0.01\n",
        "optimizer = torch.optim.SGD(params = model.parameters(), lr = eta, weight_decay = 5e-4, momentum=0.9)\n",
        "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)\n",
        "device = torch.device('cpu')\n",
        "if DATASET == 'CIFAR2':\n",
        "  loss_fun = nn.BCEWithLogitsLoss()\n",
        "else:\n",
        "  loss_fun = nn.CrossEntropyLoss()\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    device = torch.device('cuda')\n",
        "    print(\"Running on a GPU\")\n",
        "else:\n",
        "    print(\"Running on a CPU...Uhh, are you sure you want to do this?\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5n8eguYYQ2CC",
        "outputId": "7a29c8f5-be3a-403d-991e-895af60ca7a5"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ResNet(\n",
              "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "  (layer1): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (layer2): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential(\n",
              "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (layer3): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential(\n",
              "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (layer4): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential(\n",
              "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (linear): Linear(in_features=512, out_features=10, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ],
      "source": [
        "model.to(device)\n",
        "model.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 484
        },
        "id": "Wlw64UR6Q6xY",
        "outputId": "329be4f1-ec12-4a0f-c90a-c28cdf5fcf78"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Target image is chosen with ID [3950]\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f949dbb7880>"
            ]
          },
          "metadata": {},
          "execution_count": 16
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAh/UlEQVR4nO3dfVCU5/3v8c8qsEiARURYqGjwIdrEh7Y2EiaJNZGq9DcZn/4wD52a1GNGg5mqTZPQX57bOaRmTpqHY3XOZBonZ2JM7USd5ExMEwzYtGgj1TEmKaMeWvEnYGMPu4CCINf5I822RIj3BbteLL5fM/eM7n659nvvDfvhZm+++IwxRgAAXGbDXDcAALgyEUAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnEhw3cCXdXd369SpU0pLS5PP53PdDgDAkjFGLS0tysvL07BhfZ/nDLoAOnXqlPLz8123AQAYoPr6eo0ZM6bP+2MWQBs3btQzzzyjxsZGzZgxQy+++KJmzZp1yY9LS0uT9Hnj6enpsWoPuKQjJ73XHv+b3dqhc5b1F7zX/t8zdmv/+bj3xUPNdmu3h4Z7rm1uDVutnTXK+8vXqKDfau2x47z3LUmT8yzWzrBaWjlB77Xj+n6t75VF20q2qG0JhzU+Pz/yet6XmATQ66+/rvXr12vz5s0qLCzUc889p/nz56u2tlbZ2dlf+bFf/NgtPT2dAIJTqV/9tdNDylV2a5+3fPe1o8t7bdJZu7WH+70H0LAku7WHJXp/IfclWq6d5P3la7jfLoASR9gFkN/i+I9ItVpaKRafh6mWL5k25TYB9IVLvY0Sk4sQnn32Wa1cuVL33HOPrr32Wm3evFkpKSn69a9/HYuHAwDEoagH0Pnz51VTU6Pi4uJ/PciwYSouLlZ1dfVF9R0dHQqHwz02AMDQF/UA+uyzz3ThwgXl5OT0uD0nJ0eNjY0X1ZeXlysQCEQ2LkAAgCuD898DKisrUygUimz19fWuWwIAXAZRvwghKytLw4cPV1NTU4/bm5qaFAxefDmH3++X3/INQgBA/Iv6GVBSUpJmzpypioqKyG3d3d2qqKhQUVFRtB8OABCnYnIZ9vr167V8+XJ9+9vf1qxZs/Tcc8+pra1N99xzTyweDgAQh2ISQMuWLdPf//53PfbYY2psbNQ3vvEN7d69+6ILEwAAVy6fMca4buLfhcNhBQIBhUIhfhEVUXWixa7+0zrvtWmWv1yYPNKuvt2i1qZvSfqLxXU/7Ra/ECtJLc0Wa1sen2SLX9C0/eXPDMv6q0d7rw1arp2b6712jOXnVYZFrc3ZSjgc1mgPr+POr4IDAFyZCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMxmQUHDEZdlp/tk6Z5rw367Na2/cL7zKK2y3JcTkaGXb2NhETvte3n7Na2GQuUYPmEZ1mM1pGkfIvxOsl2S1s9h5aHXscsBrElWHyOt3qs4wwIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4wSw4XDGCI+zq22PThiT7L7ygTa3lHLN2i3qvM776w/Y5ieXxsXm+pdh+Jx+2qLU9Pq0WH9BiMWiuzWPTnAEBAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAATjCKB1eMZMt6my8O2y+kwfSdX0qMahEd6Ra1tp/j7Wnea7s6vdd2d3urG0xfBwCAKwgBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADgxJGbBeRw7JInEvZLZHvukmHQBxI7t5+x4i9qERO+1LR5reT0GADgR9QB64okn5PP5emxTpkyJ9sMAAOJcTH4Ed9111+m9997714MkDImf9AEAoigmyZCQkKBgMBiLpQEAQ0RM3gM6evSo8vLyNH78eN111106ceJEn7UdHR0Kh8M9NgDA0Bf1ACosLNSWLVu0e/dubdq0SXV1dbr55pvV0tLSa315ebkCgUBky8/Pj3ZLAIBByGeMMbF8gObmZo0bN07PPvusVqxYcdH9HR0d6ujoiPw/HA4rPz9foVBI6ene/hgtl2EDQGz1/XOsi7WEw5oaCFzydTzmVwdkZGTommuu0bFjx3q93+/3y+/3x7oNAMAgE/MTgtbWVh0/fly5ubmxfigAQByJegA98MADqqqq0l//+lf98Y9/1OLFizV8+HDdcccd0X4oAEAci/qP4E6ePKk77rhDZ86c0ejRo3XTTTdp3759Gj16tNU6f+mUUju91SZ0eV83I9mqDaX6vNem2C0NAHHD4mXWc23UA2jbtm3RXhIAMARxURgAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgRMz/HEN/nTkptad5q031WCdJlqPglDXC8gMAYAgKWtR6nYvJGRAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgxKAdxXOuTRrm81abkeh93WTLWTwkNAB4H68jSV0e63h9BQA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAATgzaWXCJXWeV2OWxvU7vU4pSU/vZEAAgqjgDAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAATgzaWXDJgRQlp3mb8ZaQ5n3ddss+vE+ZAwDY4AwIAOCEdQDt3btXt912m/Ly8uTz+bRz584e9xtj9Nhjjyk3N1cjRoxQcXGxjh49Gq1+AQBDhHUAtbW1acaMGdq4cWOv92/YsEEvvPCCNm/erP379+uqq67S/Pnz1d5u+8MvAMBQZv0eUElJiUpKSnq9zxij5557To888ogWLlwoSXrllVeUk5OjnTt36vbbbx9YtwCAISOq7wHV1dWpsbFRxcXFkdsCgYAKCwtVXV3d68d0dHQoHA732AAAQ19UA6ixsVGSlJOT0+P2nJycyH1fVl5erkAgENny8/Oj2RIAYJByfhVcWVmZQqFQZKuvr3fdEgDgMohqAAWDQUlSU1NTj9ubmpoi932Z3+9Xenp6jw0AMPRFNYAKCgoUDAZVUVERuS0cDmv//v0qKiqK5kMBAOKc9VVwra2tOnbsWOT/dXV1OnTokDIzMzV27FitXbtWP//5zzVp0iQVFBTo0UcfVV5enhYtWhTNvgEAcc46gA4cOKBbbrkl8v/169dLkpYvX64tW7bowQcfVFtbm+699141Nzfrpptu0u7du5WcnGz1OCfbLijFd8FTbUNzh+d16/9u1YZyR3ofxjMx127tky3ea6ddO9tq7f/1P/+759rlC2+yWhsAosE6gObMmSNjTJ/3+3w+PfXUU3rqqacG1BgAYGhzfhUcAODKRAABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJywHsVzuRysrZc/Jc1bcXur53VT7UbSKTg6w3Nts+z+mN7Bw0c9154/edBq7Tff7/0v0PaGWXAAXOAMCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHBi0I7i+ex4nZKSr/JW3NXmed0uy1E8qV2jPNfWdbZbrf1R9e8tGsm2Wvv6aR6fOwBwhDMgAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgxKCdBXfduGwlp6R6qu1qD3leNznBbl5bY91hz7UJrd7nxknSN3O9P/1v/fY/rdb+j/nft6q/Enxyzq6+1WJuYKrPbu2r7cqVYlkPxAPOgAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnBu0onvZkSR5HoSQke9+N5uY2qz7qm5s8137jm1+zWnvZf9xiUZ1ttbZ01KK2y3Ltr1vW2/iHVXX4nPcxTJ/VBazWThid6bm21fIr6aS3KVMRGYnea7Pslua7UDjD5x4AwAkCCADghHUA7d27V7fddpvy8vLk8/m0c+fOHvfffffd8vl8PbYFCxZEq18AwBBhHUBtbW2aMWOGNm7c2GfNggUL1NDQENlee+21ATUJABh6rC9CKCkpUUlJyVfW+P1+BYPBfjcFABj6YvIeUGVlpbKzszV58mStXr1aZ86c6bO2o6ND4XC4xwYAGPqiHkALFizQK6+8ooqKCv3iF79QVVWVSkpKdOHChV7ry8vLFQgEIlt+fn60WwIADEJR/z2g22+/PfLvadOmafr06ZowYYIqKys1d+7ci+rLysq0fv36yP/D4TAhBABXgJhfhj1+/HhlZWXp2LFjvd7v9/uVnp7eYwMADH0xD6CTJ0/qzJkzys3NjfVDAQDiiPWP4FpbW3uczdTV1enQoUPKzMxUZmamnnzySS1dulTBYFDHjx/Xgw8+qIkTJ2r+/PlRbRwAEN98xhhj8wGVlZW65ZaLZ5gtX75cmzZt0qJFi3Tw4EE1NzcrLy9P8+bN089+9jPl5OR4Wj8cDisQCOh7j/9aickpnj6m+bMTnvuv3vuu51pJOv+hXb2NCdO8z44r+8/bL130bxrrD3uuvTrf27H5Qm6GXf3ESd/0XPu/f/N7q7UbW0d5rv2orvcLYfqSNanIc23BpG/ZrZ1l9xyOyU3yXPt1y7dQMyzm0mVZzKSTJJsfqJ+2W1oZFrXen73Pnbest1m/23Ltdsv6WLGZGBkOh5UfCCgUCn3l2yrWZ0Bz5szRV2XWO++8Y7skAOAKxCw4AIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwImo/z2gaAl1hZXQ1emp9vd/+qPndbtjONvN1vGP/stz7X+7/X/EsBNc2Wzm0nmfXyhJKrjKe+25Dqul06dN81w7qeBaq7Xb2+3mBi6+0/usxobPmq3Wbu3yPoVt0tenW62dmjHcc23GSO/rnmvxVscZEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAODEoB3F03bmHxruT/ZU290einE3wFDWFKNaSXV25TbCjX/yXFsTuzYkSR+/8mCMH2EwSLWoNZ6qOAMCADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABODNpZcB/X1MiXkOituLUtts0AwBWvNeorcgYEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAODFoR/F0fvh/XLcAAIghzoAAAE5YBVB5ebmuv/56paWlKTs7W4sWLVJtbW2Pmvb2dpWWlmrUqFFKTU3V0qVL1dTUFNWmAQDxzyqAqqqqVFpaqn379undd99VZ2en5s2bp7a2f02jXrdund58801t375dVVVVOnXqlJYsWRL1xgEAcc4MwOnTp40kU1VVZYwxprm52SQmJprt27dHaj799FMjyVRXV3taMxQKGUlsbGxsbHG+hUKhr3y9H9B7QKFQSJKUmZkpSaqpqVFnZ6eKi4sjNVOmTNHYsWNVXV3d6xodHR0Kh8M9NgDA0NfvAOru7tbatWt14403aurUqZKkxsZGJSUlKSMjo0dtTk6OGhsbe12nvLxcgUAgsuXn5/e3JQBAHOl3AJWWlurIkSPatm3bgBooKytTKBSKbPX19QNaDwAQH/r1e0Br1qzRW2+9pb1792rMmDGR24PBoM6fP6/m5uYeZ0FNTU0KBoO9ruX3++X3+/vTBgAgjlmdARljtGbNGu3YsUN79uxRQUFBj/tnzpypxMREVVRURG6rra3ViRMnVFRUFJ2OAQBDgtUZUGlpqbZu3apdu3YpLS0t8r5OIBDQiBEjFAgEtGLFCq1fv16ZmZlKT0/X/fffr6KiIt1www0x2QEAQJyyuexafVxq9/LLL0dqzp07Z+677z4zcuRIk5KSYhYvXmwaGho8PwaXYbOxsbENje1Sl2H7/hksg0Y4HFYgEHDdBgBggEKhkNLT0/u8n1lwAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBP9+nMMuLKlZNnVp6VZFFt+Rp7rSvVc29453Grt80r2XJtkUStJI9Psxk1lZYyyqreRkGDxpFseny6LD2htbbda26bv9na7tbu6uuzqLXo5095htbbaLXo5Z9e3Eq7yXmuzdHeXdPrDS5ZxBgQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJxgFhwk7+PUJEkjbGa7SUqwWN9yBJeV8//P8gO6vH95dI22m+12LtmuvjXZ+8yurq4LVmvbaG+1O0DnLGrPNtutPcxi/F633Sg46Zzlczja4pO8pc1ubZvn3GZunCQp5L00K8d7bbe3PjgDAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJxgFA+stdjMV5HU0um9NnmE3drhFosZK5ajXpTgfdZLtyzmwkgKW45MCTd7H9+Skup9bI8kq1cBi+lE/6y36aXVam2P014+l2o3+khdZ+zqGyw+DxNG2a1tMyor2fIAjfB7r7X5uu/29kXPGRAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCWXCwHcGl85b1VmvbfkZ2Wc53s1o75L224b/s1rbdzwzvM9XONtvNpVOCzTA4y+e702LtRMsn5Zz3+Xiqtzw+tmyelgSb4W6SLMbMWX9iZWV7r+264L2WWXAAgMHMKoDKy8t1/fXXKy0tTdnZ2Vq0aJFqa2t71MyZM0c+n6/HtmrVqqg2DQCIf1YBVFVVpdLSUu3bt0/vvvuuOjs7NW/ePLW19TwVXrlypRoaGiLbhg0boto0ACD+Wf3AcPfu3T3+v2XLFmVnZ6umpkazZ8+O3J6SkqJgMBidDgEAQ9KA3gMKhT5/kzYzM7PH7a+++qqysrI0depUlZWV6ezZs32u0dHRoXA43GMDAAx9/b4Krru7W2vXrtWNN96oqVOnRm6/8847NW7cOOXl5enw4cN66KGHVFtbqzfeeKPXdcrLy/Xkk0/2tw0AQJzyGWNMfz5w9erVevvtt/XBBx9ozJgxfdbt2bNHc+fO1bFjxzRhwoSL7u/o6FBHR0fk/+FwWPn5+f1pCUOB9WXYMeninywuZ062/DPLMbwMW4lXyGXYnRa9NMfwdwckLsP+su5O6R/vKRQKKT09vc+yfp0BrVmzRm+99Zb27t37leEjSYWFhZLUZwD5/X75/RZ/lxwAMCRYBZAxRvfff7927NihyspKFRQUXPJjDh06JEnKzc3tV4MAgKHJKoBKS0u1detW7dq1S2lpaWpsbJQkBQIBjRgxQsePH9fWrVv1ve99T6NGjdLhw4e1bt06zZ49W9OnT4/JDgAA4pPVe0A+n6/X219++WXdfffdqq+v1/e//30dOXJEbW1tys/P1+LFi/XII4985c8B/104HFYgEPDaEoYa3gPqHe8B9bI27wH1srjd2o7fA+r3RQixQgABvbF4YUm2DCAN917abjEfL+ZsXshjHEBWvdh+12STQLbf2dh842TxTZC6Jf31kgHELDgAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADAiX7/QToAl5PF+Jb2WI+dGSwG034Oll5sx/w0xaQLrzgDAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnrAJo06ZNmj59utLT05Wenq6ioiK9/fbbkfvb29tVWlqqUaNGKTU1VUuXLlVTU1PUmwYAxD+rABozZoyefvpp1dTU6MCBA7r11lu1cOFCffzxx5KkdevW6c0339T27dtVVVWlU6dOacmSJTFpHAAQ58wAjRw50rz00kumubnZJCYmmu3bt0fu+/TTT40kU11d7Xm9UChkJLGxsbGxxfkWCoW+8vW+3+8BXbhwQdu2bVNbW5uKiopUU1Ojzs5OFRcXR2qmTJmisWPHqrq6us91Ojo6FA6He2wAgKHPOoA++ugjpaamyu/3a9WqVdqxY4euvfZaNTY2KikpSRkZGT3qc3Jy1NjY2Od65eXlCgQCkS0/P996JwAA8cc6gCZPnqxDhw5p//79Wr16tZYvX65PPvmk3w2UlZUpFApFtvr6+n6vBQCIHwm2H5CUlKSJEydKkmbOnKkPP/xQzz//vJYtW6bz58+rubm5x1lQU1OTgsFgn+v5/X75/X77zgEAcW3AvwfU3d2tjo4OzZw5U4mJiaqoqIjcV1tbqxMnTqioqGigDwMAGGKszoDKyspUUlKisWPHqqWlRVu3blVlZaXeeecdBQIBrVixQuvXr1dmZqbS09N1//33q6ioSDfccEOs+gcAxCmrADp9+rR+8IMfqKGhQYFAQNOnT9c777yj7373u5KkX/7ylxo2bJiWLl2qjo4OzZ8/X7/61a9i0jgAIL75jDHGdRP/LhwOKxAIuG4DADBAoVBI6enpfd7PLDgAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBODLoAG2WAGAEA/Xer1fNAFUEtLi+sWAABRcKnX80E3C667u1unTp1SWlqafD5f5PZwOKz8/HzV19d/5WyheMd+Dh1Xwj5K7OdQE439NMaopaVFeXl5Gjas7/Mc6z9IF2vDhg3TmDFj+rw/PT19SB/8L7CfQ8eVsI8S+znUDHQ/vQyVHnQ/ggMAXBkIIACAE3ETQH6/X48//rj8fr/rVmKK/Rw6roR9lNjPoeZy7ueguwgBAHBliJszIADA0EIAAQCcIIAAAE4QQAAAJ+ImgDZu3Kirr75aycnJKiws1J/+9CfXLUXVE088IZ/P12ObMmWK67YGZO/evbrtttuUl5cnn8+nnTt39rjfGKPHHntMubm5GjFihIqLi3X06FE3zQ7Apfbz7rvvvujYLliwwE2z/VReXq7rr79eaWlpys7O1qJFi1RbW9ujpr29XaWlpRo1apRSU1O1dOlSNTU1Oeq4f7zs55w5cy46nqtWrXLUcf9s2rRJ06dPj/yyaVFRkd5+++3I/ZfrWMZFAL3++utav369Hn/8cf35z3/WjBkzNH/+fJ0+fdp1a1F13XXXqaGhIbJ98MEHrlsakLa2Ns2YMUMbN27s9f4NGzbohRde0ObNm7V//35dddVVmj9/vtrb2y9zpwNzqf2UpAULFvQ4tq+99tpl7HDgqqqqVFpaqn379undd99VZ2en5s2bp7a2tkjNunXr9Oabb2r79u2qqqrSqVOntGTJEodd2/Oyn5K0cuXKHsdzw4YNjjrunzFjxujpp59WTU2NDhw4oFtvvVULFy7Uxx9/LOkyHksTB2bNmmVKS0sj/79w4YLJy8sz5eXlDruKrscff9zMmDHDdRsxI8ns2LEj8v/u7m4TDAbNM888E7mtubnZ+P1+89prrznoMDq+vJ/GGLN8+XKzcOFCJ/3EyunTp40kU1VVZYz5/NglJiaa7du3R2o+/fRTI8lUV1e7anPAvryfxhjzne98x/zoRz9y11SMjBw50rz00kuX9VgO+jOg8+fPq6amRsXFxZHbhg0bpuLiYlVXVzvsLPqOHj2qvLw8jR8/XnfddZdOnDjhuqWYqaurU2NjY4/jGggEVFhYOOSOqyRVVlYqOztbkydP1urVq3XmzBnXLQ1IKBSSJGVmZkqSampq1NnZ2eN4TpkyRWPHjo3r4/nl/fzCq6++qqysLE2dOlVlZWU6e/asi/ai4sKFC9q2bZva2tpUVFR0WY/loBtG+mWfffaZLly4oJycnB635+Tk6C9/+YujrqKvsLBQW7Zs0eTJk9XQ0KAnn3xSN998s44cOaK0tDTX7UVdY2OjJPV6XL+4b6hYsGCBlixZooKCAh0/flw//elPVVJSourqag0fPtx1e9a6u7u1du1a3XjjjZo6daqkz49nUlKSMjIyetTG8/HsbT8l6c4779S4ceOUl5enw4cP66GHHlJtba3eeOMNh93a++ijj1RUVKT29nalpqZqx44duvbaa3Xo0KHLdiwHfQBdKUpKSiL/nj59ugoLCzVu3Dj95je/0YoVKxx2hoG6/fbbI/+eNm2apk+frgkTJqiyslJz58512Fn/lJaW6siRI3H/HuWl9LWf9957b+Tf06ZNU25urubOnavjx49rwoQJl7vNfps8ebIOHTqkUCik3/72t1q+fLmqqqouaw+D/kdwWVlZGj58+EVXYDQ1NSkYDDrqKvYyMjJ0zTXX6NixY65biYkvjt2Vdlwlafz48crKyorLY7tmzRq99dZbev/993v82ZRgMKjz58+rubm5R328Hs++9rM3hYWFkhR3xzMpKUkTJ07UzJkzVV5erhkzZuj555+/rMdy0AdQUlKSZs6cqYqKisht3d3dqqioUFFRkcPOYqu1tVXHjx9Xbm6u61ZioqCgQMFgsMdxDYfD2r9//5A+rpJ08uRJnTlzJq6OrTFGa9as0Y4dO7Rnzx4VFBT0uH/mzJlKTEzscTxra2t14sSJuDqel9rP3hw6dEiS4up49qa7u1sdHR2X91hG9ZKGGNm2bZvx+/1my5Yt5pNPPjH33nuvycjIMI2Nja5bi5of//jHprKy0tTV1Zk//OEPpri42GRlZZnTp0+7bq3fWlpazMGDB83BgweNJPPss8+agwcPmr/97W/GGGOefvppk5GRYXbt2mUOHz5sFi5caAoKCsy5c+ccd27nq/azpaXFPPDAA6a6utrU1dWZ9957z3zrW98ykyZNMu3t7a5b92z16tUmEAiYyspK09DQENnOnj0bqVm1apUZO3as2bNnjzlw4IApKioyRUVFDru2d6n9PHbsmHnqqafMgQMHTF1dndm1a5cZP368mT17tuPO7Tz88MOmqqrK1NXVmcOHD5uHH37Y+Hw+87vf/c4Yc/mOZVwEkDHGvPjii2bs2LEmKSnJzJo1y+zbt891S1G1bNkyk5uba5KSkszXvvY1s2zZMnPs2DHXbQ3I+++/byRdtC1fvtwY8/ml2I8++qjJyckxfr/fzJ0719TW1rptuh++aj/Pnj1r5s2bZ0aPHm0SExPNuHHjzMqVK+Pum6fe9k+SefnllyM1586dM/fdd58ZOXKkSUlJMYsXLzYNDQ3umu6HS+3niRMnzOzZs01mZqbx+/1m4sSJ5ic/+YkJhUJuG7f0wx/+0IwbN84kJSWZ0aNHm7lz50bCx5jLdyz5cwwAACcG/XtAAIChiQACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABO/H9OgniMcPOVpwAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Get loss of each in the loader and choose Target\n",
        " \n",
        "poison_index = trainset.get_index(poison_class)\n",
        "target_index = testset.get_index(target_class)\n",
        "camou_index = trainset.get_index(camou_class)\n",
        "\n",
        "# Choose Target\n",
        "target_index = [np.random.choice(target_index)]\n",
        "\n",
        "\n",
        "targetset = torch.utils.data.Subset(testset, target_index)\n",
        "targetloader = torch.utils.data.DataLoader(targetset)\n",
        "\n",
        "print(\"Target image is chosen with ID {}\".format(target_index))\n",
        "plt.imshow(testset[target_index[0]][0].permute(1, 2, 0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3jznSc0pQ7_F"
      },
      "outputs": [],
      "source": [
        "# Optional Augmentation defenses\n",
        "\n",
        "aug_transform1 = transforms.Compose([\n",
        "transforms.RandomHorizontalFlip(p=0.5),\n",
        "transforms.RandomRotation(degrees=5)\n",
        "])\n",
        "\n",
        "aug_transform = transforms.Compose([\n",
        "transforms.RandomRotation(degrees=5),\n",
        "transforms.RandomHorizontalFlip(p=0.5)\n",
        "])\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "C_vxzWAdQ9Ix",
        "outputId": "aaa8bd8f-cff4-4e8b-fd03-aefd004a3cf5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--- 0.0006906986236572266 seconds ---\n"
          ]
        }
      ],
      "source": [
        "start_time = time.time()\n",
        "\n",
        "\n",
        "if not LOADMODEL:\n",
        "  for epoch in range(epochs):\n",
        "    train_loss = []\n",
        "\n",
        "    correct_preds = 0\n",
        "    total_preds = 0\n",
        "    model.train()\n",
        "    for inputs, labels, index in trainloader:\n",
        "\n",
        "      inputs, labels = inputs.to(device), labels.to(device)\n",
        "      if AUGMENTS:\n",
        "        \n",
        "        inputs = aug_transform(inputs)\n",
        "\n",
        "      optimizer.zero_grad()            # reset the gradients to zero\n",
        "      output = model(inputs)            # Generate model outputs\n",
        "\n",
        "      if DATASET == 'CIFAR2':\n",
        "        labels = labels.to(torch.float32)\n",
        "        output = output.flatten()\n",
        "      loss = loss_fun(output, labels)   # Calculate loss\n",
        "\n",
        "      loss.backward()            # Compute gradients\n",
        "      optimizer.step()            # update parameters,\n",
        "\n",
        "      \n",
        "      #For BCELoss:\n",
        "      if DATASET == 'CIFAR2':\n",
        "        predictions = torch.where(output < 0, 0, 1)\n",
        "      else:\n",
        "        predictions = torch.argmax(output.data, dim=1)\n",
        "\n",
        "      total_preds += labels.size(0)\n",
        "      correct_preds += (predictions == labels).sum().item()\n",
        "\n",
        "      train_loss.append(loss.item())\n",
        "\n",
        "    print(\"Training Epoch {}: Loss: {}, Accuracy: {}\".format(epoch, np.mean(train_loss), correct_preds / total_preds))\n",
        "    # validation phase - once every 10 epochs\n",
        "      \n",
        "    if epoch % 10 == 0:\n",
        "      valid_losses = []\n",
        "      correct = 0\n",
        "      total = 0\n",
        "      model.eval()\n",
        "\n",
        "      # Evaluate Model\n",
        "      for inputs, labels, index in testloader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "\n",
        "\n",
        "        with torch.no_grad():\n",
        "          output = model(inputs)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            labels = labels.to(torch.float32)\n",
        "            output = output.flatten()\n",
        "          \n",
        "          # negative labels: when using hinge embedding loss only\n",
        "          flipped_labels = labels # * -1\n",
        "          loss = loss_fun(output, flipped_labels)   # Calculate loss\n",
        "          \n",
        "          valid_loss = loss_fun(output, labels)\n",
        "          valid_losses.append(valid_loss.item())\n",
        "\n",
        "          #predictions = torch.argmax(output, dim=1)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            predictions = torch.where(output < 0, 0, 1)\n",
        "          else:\n",
        "            predictions = torch.argmax(output.data, dim=1)\n",
        "          total += labels.size(0)\n",
        "          correct += (predictions == labels).sum().item()\n",
        "      \n",
        "      # Evaluate Model on Target\n",
        "      for inputs, labels, index in targetloader:\n",
        "        #\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "        with torch.no_grad():\n",
        "          output = model(inputs)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            labels = labels.to(torch.float32)\n",
        "            output = output.flatten()\n",
        "          #predictions = torch.argmax(output.data, dim=1)\n",
        "          target_loss = loss_fun(output, labels)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            predictions = torch.where(output < 0, 0, 1)\n",
        "          else:\n",
        "            predictions = torch.argmax(output.data, dim=1)\n",
        "          print(\"Target Original Loss: {}\".format(target_loss))\n",
        "\n",
        "      print(\"Validation Epoch {}: Valid loss: {}, Accuracy: {}\".format(epoch, np.mean(valid_losses), correct / total))\n",
        "    scheduler.step()\n",
        "    \n",
        "  #   torch.autograd(model, target)\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qtqil-H-e9qe"
      },
      "outputs": [],
      "source": [
        "# Saving Clean Model\n",
        "\n",
        "import os\n",
        "PATH = \"./drive/MyDrive/Poisoning_Machine_Unlearning/model\"\n",
        "os.makedirs(PATH, exist_ok = True) \n",
        "PATH += \"/resnet_cifar.ptr\"\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pdexS3yNe_o7"
      },
      "outputs": [],
      "source": [
        "\n",
        "if LOADMODEL:\n",
        "  model.load_state_dict(torch.load(PATH))\n",
        "  model.to(device)\n",
        "  \n",
        "if SAVEMODEL:\n",
        "  torch.save(model.state_dict(), PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bwrSVA6tfJVR"
      },
      "outputs": [],
      "source": [
        "# Inverse normalization\n",
        "\n",
        "inv_normalize = transforms.Normalize(\n",
        "   mean= [-m/s for m, s in zip(data_mean, data_std)],\n",
        "   std= [1/s for s in data_std]\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 448
        },
        "id": "_nOxB2ZqfKs6",
        "outputId": "cda0798c-19c7-4d86-873b-7e9831c9ffbb"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([3, 32, 32])\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAtYUlEQVR4nO3df3CV5Z338c99n18JkAQRScgSKGgrtQqdZZVmbF0qWX7sjqOVeUbbzix2HR3d4Kyy3bbstFrd3YlrZ6xth+If28J2pkjXnaKPzqhVLGHcAl1YGWq7mxEeWvCBxMpTkpCQ8+O+r+cPS3YjoNc3JFxJeL9mzgwkV65c949zPufknPM5kXPOCQCACywOvQAAwMWJAAIABEEAAQCCIIAAAEEQQACAIAggAEAQBBAAIAgCCAAQRDb0At4rTVMdPXpUNTU1iqIo9HIAAEbOOfX29qqxsVFxfO7HOWMugI4ePaqmpqbQywAAnKcjR45o1qxZ5/z+qAXQ+vXr9Y1vfEOdnZ1auHChvvOd7+i66677wJ+rqamRJO1of01Tpkzx+l2R4S+J+Yxtk3PZnP/cOdtfNAtZ//GZjP86JEnO/9FjJUlNU5fLFdP4gYEB77GnBkq2ucv+TVInTvqvQ5J+19PnPfaUYRslqZLa9mFiOJ6nyrbjedKwzyu2ZStN/M/xSlI2zZ3L+O+TXMF43ayy3U5Myuf95zbeBuUL/uOrCrbbiVzef+5Mxn8f9vf16XOrWgZvz89lVALoRz/6kdauXasnn3xSixcv1hNPPKHly5ero6NDM2bMeN+fPf1ntylTpmjKlPdf/GmxJYCyoxdABXMAZbzHZgzrkGQLoIo1gIw3FDn/tWeztgDKGgKo7GzHvmi4sY1i/2Mp2QOoYjieKtmOZyXy3y/mAKr4XyfKxgDKZ0cvgKqqbNe36oJ/AFWNZgBV+a9DkvKjFECnfdDTKKPyIoTHH39cd911l77whS/oqquu0pNPPqlJkybp+9///mj8OgDAODTiAVQqlbR37161tLT89y+JY7W0tGjnzp1njC8Wi+rp6RlyAQBMfCMeQO+8846SJFF9ff2Qr9fX16uzs/OM8W1tbaqrqxu88AIEALg4BH8f0Lp169Td3T14OXLkSOglAQAugBF/EcL06dOVyWTU1dU15OtdXV1qaGg4Y3yhUFChUBjpZQAAxrgRfwSUz+e1aNEibdu2bfBraZpq27Ztam5uHulfBwAYp0blZdhr167V6tWr9Ud/9Ee67rrr9MQTT6ivr09f+MIXRuPXAQDGoVEJoNtuu02//e1v9eCDD6qzs1Mf//jH9eKLL57xwgQAwMVr1JoQ1qxZozVr1gx/gijj/eY+S2dc9D69RGeTyfi/wTA2vsk1k/OfO5L1jY7+b0YslxPT3KWKbXzZ8CbKsvGvwidP9XuP7ek7ZZq7ZHiDbt74PGZVtso0PjXsl6yxTSKT8z9v09R2fBLDqZIktne5ZjL+b0I2vk9YWUPLgiQVDLcTOUMDiiTlDHNnDWMlKY4saxn5scFfBQcAuDgRQACAIAggAEAQBBAAIAgCCAAQBAEEAAiCAAIABEEAAQCCIIAAAEEQQACAIEatiud8OZcq9a2Tif1rM0qWbhBJivxrTVLZ5q5U/NedJP61I5JUMdTrFEvlUZtbkk4Z1t5/ylYj02+onXGGyiZJqp48yXts3livEhnPlbLz34fO2ebOZi37xVb1EkX+41PjdTOVpbrHdv3J5203jYWM//hYtvPQUh/mnH99lCSdMlx/osh/GwdOFb3G8QgIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABAEAQQACIIAAgAEMWa74BKlSuTXa5QaKqQyka0T6lTFvyet37e77vdc4j8+MXY8GarDlFRsc5fLtvFFw/EpVoxdfYZqsnw+Z5o6texEz3P1tCg2dpMZeuzyNf4ddpKUpv5rSYznuAy9Z5Hx5igdxZuvfM52ruRi/xMxMZ7jJcPthO2clRJD/15imLvoebvJIyAAQBAEEAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgiDFbxdN3akBRxm95ruxfJ+FfDHJ6cv/6CWeoNJGk1FCDYS1AsW2pba84ZxufGO7nxMa7RGniv89ja12Oaazt2Fu3M5Pxr3rJZg39RJKcYR8aWmEkSVHsvxezOdvNUWyoJ7JW1MSR7QBlDOd4krHNnTXs9KKxVssyulgxXNdiv2PJIyAAQBAEEAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABDEmO2C6+3tVZr6NRU5Q1dSxtAfJUmxoREsshZ8GUTGdWfinP/YjP/Yd8cbTxtDj5m1g6tYLHqPrZTKprmzhv2Sydv2iaUjzTreehZaVmI+xSNDf5jx2MuwTxJjiZ2zdt4Z+hGzxp0YWzoms7aFx0nee2xUrniPTRO/6xqPgAAAQYx4AH39619XFEVDLvPnzx/pXwMAGOdG5U9wH/vYx/TKK6/89y/Jjtm/9AEAAhmVZMhms2poaBiNqQEAE8SoPAf05ptvqrGxUfPmzdPnP/95HT58+Jxji8Wienp6hlwAABPfiAfQ4sWLtWnTJr344ovasGGDDh06pE996lPq7e096/i2tjbV1dUNXpqamkZ6SQCAMShyzvhZtUYnTpzQnDlz9Pjjj+vOO+884/vFYnHIS2l7enrU1NSk//3CTzR58mSv38HLsM/Ey7DPxMuwzzG3YSwvwz47y8uwrXOXDS/DHjBuZznxHztgeBl2f99J/a8/+4S6u7tVW1t7znGj/uqAqVOn6iMf+YgOHDhw1u8XCgUVCoXRXgYAYIwZ9fcBnTx5UgcPHtTMmTNH+1cBAMaREQ+gL37xi2pvb9evf/1r/exnP9NnPvMZZTIZffaznx3pXwUAGMdG/E9wb731lj772c/q+PHjuuyyy/TJT35Su3bt0mWXXWaap7+/JEV+f4OPDE9jWWswLM8ZZQ3PdUhStuD/HEO+4F+ZIUmFrP/4TNa27pxxOy1PHLjU9tyI5XmaJG/czpz/3FnDWEmKM8bngAzP1DhLdYukyHA/1LZqSfJ/ksH6ZLRlMyuetV6D42Pb+MSwmNR4t98ZboMyxscUFef/vI5LDMfSc+yIB9CWLVtGekoAwAREFxwAIAgCCAAQBAEEAAiCAAIABEEAAQCCIIAAAEEQQACAIAggAEAQBBAAIAgCCAAQxKh/HMNwlUuJylm/PiFLNZk1cXM5/12UNYyVpFzWf7y1fy2O/bupchlbC5ex9kwZS++Z8TNhJkd+nxn17jqsp7vhs6CMJWmx9fOADL/AGT6bRpIqhg+FSSqGD5CRZPm0sSQxtsEZNjOTs51XieGzbyQpNXTeJcYPBLLsl4pxHzpDR14+538bVPYcyyMgAEAQBBAAIAgCCAAQBAEEAAiCAAIABEEAAQCCIIAAAEEQQACAIAggAEAQBBAAIIgxW8WTJiWlid/ysob6lji2bXJsyOg4slVsRIb6DstQScoVCt5j88aakoJxfC7v392TNdblRJF/PUgU2+qMMpH/WowNQqNaxZMa61iKcdl77ICxRqaSGtZtrTMybKYzXn8iS8/Puz/gzXjoZdnjptorSZGh4iuTGKp4yn7XeR4BAQCCIIAAAEEQQACAIAggAEAQBBAAIAgCCAAQBAEEAAiCAAIABEEAAQCCIIAAAEEQQACAIMZsF1zkEkW+BU7Ov//I2pWUzfpndD5n6xorVOe9x1bl/PvUJKmQ9z+0BePcecPckpQx7MNcxnafKE0tvWe2QjDLSjLG+3IZY3mcZXhq7DGLDN1+kfEmo1jxPz5RYjs+pYr/+Ei2fjzr7UTB0GHojH16maz/2hNrh51hfGpYdpLQBQcAGMMIIABAEAQQACAIAggAEAQBBAAIggACAARBAAEAgiCAAABBEEAAgCAIIABAEAQQACCIMdsFF2ezirN+y4sMvU22FiYpNXRIpc7WN1WqVLzHdp/4nWnuaZdM9R5bNe0S09zWfejKJe+xSaVsmjsy3Yey3d+KY0O3X852VYpl6w2MIv/xUWzrA8tk/efOGs/xiqH3LLF2pBm2M5M3HnvjTWOa+vfSVQzX+3d/wNB5Z7xyJobbt9iwv2PP3c0jIABAEOYA2rFjh2666SY1NjYqiiI988wzQ77vnNODDz6omTNnqrq6Wi0tLXrzzTdHar0AgAnCHEB9fX1auHCh1q9ff9bvP/bYY/r2t7+tJ598Urt379bkyZO1fPlyDQwMnPdiAQATh/k5oJUrV2rlypVn/Z5zTk888YS++tWv6uabb5Yk/eAHP1B9fb2eeeYZ3X777ee3WgDAhDGizwEdOnRInZ2damlpGfxaXV2dFi9erJ07d571Z4rFonp6eoZcAAAT34gGUGdnpySpvr5+yNfr6+sHv/debW1tqqurG7w0NTWN5JIAAGNU8FfBrVu3Tt3d3YOXI0eOhF4SAOACGNEAamhokCR1dXUN+XpXV9fg996rUCiotrZ2yAUAMPGNaADNnTtXDQ0N2rZt2+DXenp6tHv3bjU3N4/krwIAjHPmV8GdPHlSBw4cGPz/oUOHtG/fPk2bNk2zZ8/W/fffr7//+7/Xhz/8Yc2dO1df+9rX1NjYqFtuuWUk1w0AGOfMAbRnzx59+tOfHvz/2rVrJUmrV6/Wpk2b9KUvfUl9fX26++67deLECX3yk5/Uiy++qKqqKtPvKSapMolfr0Sp4t8/USzbql6KRf9dVJpsqykp/s7/vVF79vzcNPdHrrjce+zHF15jmtvY9KKMoe4jk7FV1OTzBf+545xpbkNDjbmeqGLYJ5KUk6FuyvN6M7gWQ9VLqWirkSkaqpWS1LZPUsM+MdUqSYaZT/Pf55HxbIkM/TqZjLVyyDDet19HUj7nN9YcQEuWLJF7nz6oKIr0yCOP6JFHHrFODQC4iAR/FRwA4OJEAAEAgiCAAABBEEAAgCAIIABAEAQQACAIAggAEAQBBAAIggACAARBAAEAgjBX8VwoJ/sHlPjmY+LfZWWsGlM+57+LbC1Z0sn+Pu+xv+vpNc19/ET3qIyVpKp83jjef6e7Usk0dz6x9Mz5nyeSFMf+xz4u2drDssbOrkzGfy3v05R1Vqlht1h75pLU/1rhzA1s/uMrZdu1MzF29aWG26DIOLepr814XkWGfjdF/tfjnOcNLY+AAABBEEAAgCAIIABAEAQQACAIAggAEAQBBAAIggACAARBAAEAgiCAAABBEEAAgCDGbBVPeWBA2cizasP512A4Y+RmXM577ICz1ZT09ZzwHjtlkq3+pmayoXPIDZjmnjypYBpfyPpXjwwUbZUpkaFbqbu/bJp7oOg/Put/mrw73rBPJKlg+AW5rO1qnc/7zx1ljRVCMsxtvD9cSQ370Fgh5IzjLWs3NiXJshJnqKayjnfyv50tlf3G8ggIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABAEAQQACIIAAgAEMWa74CZX5VRd5dd/5lL//rDY1KwklQZOeo+NElsh2JS8Z9edpI9f9SHT3HPmNHiPrSr4dzxJUlXG1h1XkP92Fow9c0e6Or3HvtNjO/ZJPMl7bCZjuy9XZbzmFXL+P1BdbduHVRX/45839sxZJMb+Ncv9Z2ctYDOcs++ONnSqWTrsJKWWuWXogPz9T/iqGG47KwldcACAMYwAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABAEAQQACIIAAgAEMWareFJJqWcbRpTxr83wrYg4rVgueY+dMsVWgdJw6SXeY2fU15rm7jv5/7zHlk6dMs3tcraakpNl/6qk3r6yae6ek/5zlwdsp3tk2M6kYpu7WLFVpiQl/xqUcsl2jvdHhqokQyWQJGXz/uMHikXT3HHkvw8jY0VNZKxWiuR/HkaxrXLIRYa1xMYqHsM+tBQIVcp+dV08AgIABEEAAQCCMAfQjh07dNNNN6mxsVFRFOmZZ54Z8v077rhDURQNuaxYsWKk1gsAmCDMAdTX16eFCxdq/fr15xyzYsUKHTt2bPDy1FNPndciAQATj/lFCCtXrtTKlSvfd0yhUFBDg//n0QAALj6j8hzQ9u3bNWPGDF155ZW69957dfz48XOOLRaL6unpGXIBAEx8Ix5AK1as0A9+8ANt27ZN//iP/6j29natXLlSyTle/tzW1qa6urrBS1NT00gvCQAwBo34+4Buv/32wX9fc801WrBggS6//HJt375dS5cuPWP8unXrtHbt2sH/9/T0EEIAcBEY9Zdhz5s3T9OnT9eBAwfO+v1CoaDa2tohFwDAxDfqAfTWW2/p+PHjmjlz5mj/KgDAOGL+E9zJkyeHPJo5dOiQ9u3bp2nTpmnatGl6+OGHtWrVKjU0NOjgwYP60pe+pCuuuELLly8f0YUDAMY3cwDt2bNHn/70pwf/f/r5m9WrV2vDhg3av3+//vmf/1knTpxQY2Ojli1bpr/7u79ToWDrSTtZ7Fcl9msfqpT8eockqafHvyNNkrp/d+5X8L3Xod/YOrgurZvsPXb27HrT3KXiSe+xhULeNHchaxtfXV3jPbbrtydMc5eSnPfYvgFLm5WUq67zHltl2EZJyuVs+zCf97/+TDZe17KG3rOBiu0cz6X+3X6lkq0LLpvx34fZrP95IkmVAdt2ZrL+fXppZJs7tZy2hm63d/kfe0uDXd+A322yOYCWLFki5869R1566SXrlACAixBdcACAIAggAEAQBBAAIAgCCAAQBAEEAAiCAAIABEEAAQCCIIAAAEEQQACAIAggAEAQI/55QCOl4iqquIrX2O6T3d7z/uxnr5nW8X8Onf1jJM7GpZa2JMm/PUqaVG3reKqu8j+0GUMXmCRlI1uv1iRDT1oq23aWK/57sWir4FIu59/VN8X4MSKZnHWf+++XqVOnmuaeVKj2HmvtdCxM8l93arz+TJ40xXtsdbX/sXx3LbbewOkz/LsaS2W/27XTKoa1TDLsE0nKGDrysjn/saf6+7zG8QgIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABAEAQQACGLMVvEk5bKSrN/y0tS/2uLUKb+KiNP6+nq8x2aztpqSxFCx0XdqwDR3LEOtia11RNkobxqfyb7tPTY23ifK5fz3uaV25PRqfJ04bqzWydrGW/bL/41sc2di/5sB6/FxWf/zME1sVTz5fJX3WOuxd8brRE3tNO+xibOUcEmK/fd5VWGSaeqqav/xhUn+dUalYtFrHI+AAABBEEAAgCAIIABAEAQQACAIAggAEAQBBAAIggACAARBAAEAgiCAAABBEEAAgCAIIABAEGO2C66/t1dJuew3uJJ4z1tbU2Nax6SCf9+UEmOBVGzpvrL1R2Vi//EZ42mQNfZq5TP+Y3PGjrR8zn87s6b9LeUM644N+1uSMsZaunzG/xhFiW0tkeF+qPEMV2rYL86ywyUp9bx9kJSW/MdK9qty5Z1+/7UYr8vOML7XOHev5dhH/senUvHr5+QREAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABDEmK3i+e1vjylfyHuNzRrqPiZVGap1JF16yTTvsVHqVz9xWl/PCe+xaWq7r5CJ/cdnItvcOcPckpQ31M5YKoQkW11OLmObO5v1n7yQM9YTZf3O7dNy2YL32GxkvVqP3v1Q22lrW0dqaFZyzlbD5KydQ4b5LeuWJMtw/1Ky00anhsm3RY1HQACAIEwB1NbWpmuvvVY1NTWaMWOGbrnlFnV0dAwZMzAwoNbWVl166aWaMmWKVq1apa6urhFdNABg/DMFUHt7u1pbW7Vr1y69/PLLKpfLWrZsmfr6+gbHPPDAA3ruuef09NNPq729XUePHtWtt9464gsHAIxvpj8Wv/jii0P+v2nTJs2YMUN79+7VDTfcoO7ubn3ve9/T5s2bdeONN0qSNm7cqI9+9KPatWuXPvGJT4zcygEA49p5PQfU3d0tSZo27d0n6vfu3atyuayWlpbBMfPnz9fs2bO1c+fOs85RLBbV09Mz5AIAmPiGHUBpmur+++/X9ddfr6uvvlqS1NnZqXw+r6lTpw4ZW19fr87OzrPO09bWprq6usFLU1PTcJcEABhHhh1Ara2teuONN7Rly5bzWsC6devU3d09eDly5Mh5zQcAGB+G9T6gNWvW6Pnnn9eOHTs0a9aswa83NDSoVCrpxIkTQx4FdXV1qaGh4axzFQoFFQr+73EAAEwMpkdAzjmtWbNGW7du1auvvqq5c+cO+f6iRYuUy+W0bdu2wa91dHTo8OHDam5uHpkVAwAmBNMjoNbWVm3evFnPPvusampqBp/XqaurU3V1terq6nTnnXdq7dq1mjZtmmpra3XfffepubmZV8ABAIYwBdCGDRskSUuWLBny9Y0bN+qOO+6QJH3zm99UHMdatWqVisWili9fru9+97sjslgAwMQROWduPRpVPT09qqur02e/cJvyeb++rFzWv4dr4NQp03pKxaL3WGPVmAZO9X3woEG2Aqlc7N9jljU+E5jL2HrMbFVzxp1o6bwzdtjFGcM+jG1dcLnYttOzGf/xkbFPLzb8Jd5YY2Y6Pi4dvZuiNLW1pKXWm0VLL51tZlN3nPX4RIbbiTT1P69KpZK+v3Gjuru7VVtbe85xdMEBAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQQzr4xguhKRcUiXyK63IGuokqidNMq1jsmG8rXJGukSXeI+11+X4V8NU5Y21MDlrjYyhGiay1cikzlAl4oxzG+6fWepsJClrrOKx1E2ZWap7jE1Jlh9IEluRTGSYO3W2ua0NZc5w3paN25kY1lKx1hlF/tcfyy4pDgx4jeMREAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABAEAQQACGLMdsGd6h9QpZx4jc3Iv88oV8ib1hFnRm8XpfLvhHLW+wqGfq8oY5s7YxwfZQzbmRrLxgzDK2Xb1M4weWTsx0uN51WSsXR2GfvADNLENrffNfj3Yyu2uWPDaZja6teUGjvV4pz/8Smmlr0iVQz7xdIbJ0mJKt5jMzn/286iZzEmj4AAAEEQQACAIAggAEAQBBAAIAgCCAAQBAEEAAiCAAIABEEAAQCCIIAAAEEQQACAIMZsFU+pf0Au51cTMWCoTLFWbLhq/w6PjHV3GuI/NlaJxIYqHidjfYex1yRx/tUjlnoVSaok/ts5ULat21I5FBnvyyXGSpuyodQmY6jtkWSqM7Jysf9aKqbiHtt1wlqplThbb1Na8l9MLsqZ5rYsPXW2g+kMVzjL1T51fovmERAAIAgCCAAQBAEEAAiCAAIABEEAAQCCIIAAAEEQQACAIAggAEAQBBAAIAgCCAAQBAEEAAhizHbBlYsDconf8iJDmVXF2Hvm4irvsYVqW55Hhm6lsmw9ZkVDlVWkvGnuJLF1djlDiVSUsc3t2zklSeWK8djLv8csKRVNc1v717JZQ6daxXi/MjIsxtn2oaWbLLKsQ1LFsJaK8fhY6/Fiy26JbF19lg62xLjyXM7/uh8Z9nfG+S2aR0AAgCBMAdTW1qZrr71WNTU1mjFjhm655RZ1dHQMGbNkyRJFUTTkcs8994zoogEA458pgNrb29Xa2qpdu3bp5ZdfVrlc1rJly9TX1zdk3F133aVjx44NXh577LERXTQAYPwzPQf04osvDvn/pk2bNGPGDO3du1c33HDD4NcnTZqkhoaGkVkhAGBCOq/ngLq7uyVJ06ZNG/L1H/7wh5o+fbquvvpqrVu3Tv39/eeco1gsqqenZ8gFADDxDftVcGma6v7779f111+vq6++evDrn/vc5zRnzhw1NjZq//79+vKXv6yOjg79+Mc/Pus8bW1tevjhh4e7DADAODXsAGptbdUbb7yh1157bcjX77777sF/X3PNNZo5c6aWLl2qgwcP6vLLLz9jnnXr1mnt2rWD/+/p6VFTU9NwlwUAGCeGFUBr1qzR888/rx07dmjWrFnvO3bx4sWSpAMHDpw1gAqFggqFwnCWAQAYx0wB5JzTfffdp61bt2r79u2aO3fuB/7Mvn37JEkzZ84c1gIBABOTKYBaW1u1efNmPfvss6qpqVFnZ6ckqa6uTtXV1Tp48KA2b96sP/3TP9Wll16q/fv364EHHtANN9ygBQsWjMoGAADGJ1MAbdiwQdK7bzb9nzZu3Kg77rhD+Xxer7zyip544gn19fWpqalJq1at0le/+tURWzAAYGIw/wnu/TQ1Nam9vf28FnRauVTy7hCLY0O3Usb4yvOsf7eSra1NyuT8152PbevuTyreY5OK/1hJymRsXVa52L9vyrNC6n/+xCiNlQZK5377wHtFZds+yeRsT7+mif/xj43niqX5rGw8V8pl/26/2HjdzMT++9B6jkexrVPNchvkjJ2EFUMZnKV7T5LiTM5/rKGPsOTZvUcXHAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABDEsD8PaLRVyhXpA6p/TosN1TDOWLHhDA0rtoINKS//GgxrhVDGUMcS2XaJssa6HJf1r0GJnGGfSErSkvfYNLUdIctmxrFt7tjWDKOyZbDn9ea0KPI/V1LjWV4umVY+aiqJbR2VivEkzxqqkgz1N5KUpv5zJ8ZjnxiuE9mc/7rLRb/rJY+AAABBEEAAgCAIIABAEAQQACAIAggAEAQBBAAIggACAARBAAEAgiCAAABBEEAAgCAIIABAEGO2C66UJEp9O8rKhp4nY6daVBq9TrUo8u9hivK2/qhM1tKPZ9sn5TQxja/417Upjm1zWxq7nLELTs7QkWboDJSkcsW2naPJdNoaz5WSoYPN2L6mjOEKVyrZyvf6T50yjS8bOtjiTN40d2zpdTTepGcM/W6S/zlb8Ty/eQQEAAiCAAIABEEAAQCCIIAAAEEQQACAIAggAEAQBBAAIAgCCAAQBAEEAAiCAAIABDFmq3iStCIlfvUWUWLI0ZKtLyfO+u+iKGfL8zj1rwcpV2zrdvKvBilZqoyGITZUpmRiW6dN6t3XJBnaUn7PUsNk6BuSFGdt1UqxpTAnNnZCGaRlW6VNxVB/FOdsxz4xHNDEWPSTMVZfWaaPYuvNrv95WCzazkNLjVliqI+qeM7LIyAAQBAEEAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABDEmO2CS5PEu/0qSfw7imJj11il4t99FZVtHVzljP94Q52aJClJ/Due0tTWk1U2dsdlM/69Wrlc3jS3nP+OSZ2xqy8ynCvWAxTbOrsysf/acznb1Tpn6KWznSm2DrakbC3r8x9fSWwrT1P/2xRJKlf858/lbedhVaHKe2zG2KeXGvZ5MeN/zrrIb3/wCAgAEIQpgDZs2KAFCxaotrZWtbW1am5u1gsvvDD4/YGBAbW2turSSy/VlClTtGrVKnV1dY34ogEA458pgGbNmqVHH31Ue/fu1Z49e3TjjTfq5ptv1i9/+UtJ0gMPPKDnnntOTz/9tNrb23X06FHdeuuto7JwAMD4Zvpj8U033TTk///wD/+gDRs2aNeuXZo1a5a+973vafPmzbrxxhslSRs3btRHP/pR7dq1S5/4xCdGbtUAgHFv2M8BJUmiLVu2qK+vT83Nzdq7d6/K5bJaWloGx8yfP1+zZ8/Wzp07zzlPsVhUT0/PkAsAYOIzB9AvfvELTZkyRYVCQffcc4+2bt2qq666Sp2dncrn85o6deqQ8fX19ers7DznfG1tbaqrqxu8NDU1mTcCADD+mAPoyiuv1L59+7R7927de++9Wr16tX71q18NewHr1q1Td3f34OXIkSPDngsAMH6Y3weUz+d1xRVXSJIWLVqkf//3f9e3vvUt3XbbbSqVSjpx4sSQR0FdXV1qaGg453yFQkGFQsG+cgDAuHbe7wNK01TFYlGLFi1SLpfTtm3bBr/X0dGhw4cPq7m5+Xx/DQBggjE9Alq3bp1Wrlyp2bNnq7e3V5s3b9b27dv10ksvqa6uTnfeeafWrl2radOmqba2Vvfdd5+am5t5BRwA4AymAHr77bf153/+5zp27Jjq6uq0YMECvfTSS/qTP/kTSdI3v/lNxXGsVatWqVgsavny5frud787rIVVKomcb0tEZKjLMVSaSFJcMtRPWB9PGlozLFUsVomxisda3TNQ8a/uGSjaan4iwylcTvzPE0lKDfUqMh6fbGb0WrAi46mSN1S9ZLO2dQ+U/I+ntYgnjv3rjwaKRdPcReN4Zzj+1dWTTHOXCv5ryRqrrHI5/6c/sqn/Nvruj8g575v5C6Knp0d1dXW66qp5ymT8bqGzOUPXWMF2gPJ5//HZKuPc1f7rHs3nyawBZOmZk6TUcOJaEUBnIoDOso5xHEDVhWrvsaMZQJWy//WnXCrphc1Pqbu7W7W1teccRxccACAIAggAEAQBBAAIggACAARBAAEAgiCAAABBEEAAgCAIIABAEAQQACCI0Xs79jCdLmZIEv93oVvqdeJKYlpPJTaMr9jeaR+X/d/JHceG3h4jcxWPtVFgVJsQ/FXGUBOCbLvcxNqEEMWGtgJjcUqlPDaaECzv4h/OeEsTQiVraxIpx/430+ZeGzc6+7D8++P+QefLmAug3t5eSVJHx6/DLgQAcF56e3tVV1d3zu+PuS64NE119OhR1dTUKIr+O517enrU1NSkI0eOvG+30HjHdk4cF8M2SmznRDMS2+mcU29vrxobGxW/z6PDMfcIKI5jzZo165zfr62tndAH/zS2c+K4GLZRYjsnmvPdzvd75HMaL0IAAARBAAEAghg3AVQoFPTQQw+N6ufijAVs58RxMWyjxHZONBdyO8fcixAAABeHcfMICAAwsRBAAIAgCCAAQBAEEAAgiHETQOvXr9eHPvQhVVVVafHixfr5z38eekkj6utf/7qiKBpymT9/fuhlnZcdO3bopptuUmNjo6Io0jPPPDPk+845Pfjgg5o5c6aqq6vV0tKiN998M8xiz8MHbecdd9xxxrFdsWJFmMUOU1tbm6699lrV1NRoxowZuuWWW9TR0TFkzMDAgFpbW3XppZdqypQpWrVqlbq6ugKteHh8tnPJkiVnHM977rkn0IqHZ8OGDVqwYMHgm02bm5v1wgsvDH7/Qh3LcRFAP/rRj7R27Vo99NBD+o//+A8tXLhQy5cv19tvvx16aSPqYx/7mI4dOzZ4ee2110Iv6bz09fVp4cKFWr9+/Vm//9hjj+nb3/62nnzySe3evVuTJ0/W8uXLNTAwcIFXen4+aDslacWKFUOO7VNPPXUBV3j+2tvb1draql27dunll19WuVzWsmXL1NfXNzjmgQce0HPPPaenn35a7e3tOnr0qG699daAq7bz2U5Juuuuu4Ycz8ceeyzQiodn1qxZevTRR7V3717t2bNHN954o26++Wb98pe/lHQBj6UbB6677jrX2to6+P8kSVxjY6Nra2sLuKqR9dBDD7mFCxeGXsaokeS2bt06+P80TV1DQ4P7xje+Mfi1EydOuEKh4J566qkAKxwZ791O55xbvXq1u/nmm4OsZ7S8/fbbTpJrb293zr177HK5nHv66acHx/znf/6nk+R27twZapnn7b3b6Zxzf/zHf+z+6q/+KtyiRskll1zi/umf/umCHssx/wioVCpp7969amlpGfxaHMdqaWnRzp07A65s5L355ptqbGzUvHnz9PnPf16HDx8OvaRRc+jQIXV2dg45rnV1dVq8ePGEO66StH37ds2YMUNXXnml7r33Xh0/fjz0ks5Ld3e3JGnatGmSpL1796pcLg85nvPnz9fs2bPH9fF873ae9sMf/lDTp0/X1VdfrXXr1qm/vz/E8kZEkiTasmWL+vr61NzcfEGP5ZgrI32vd955R0mSqL6+fsjX6+vr9V//9V+BVjXyFi9erE2bNunKK6/UsWPH9PDDD+tTn/qU3njjDdXU1IRe3ojr7OyUpLMe19PfmyhWrFihW2+9VXPnztXBgwf1t3/7t1q5cqV27typTGb0PudptKRpqvvvv1/XX3+9rr76aknvHs98Pq+pU6cOGTuej+fZtlOSPve5z2nOnDlqbGzU/v379eUvf1kdHR368Y9/HHC1dr/4xS/U3NysgYEBTZkyRVu3btVVV12lffv2XbBjOeYD6GKxcuXKwX8vWLBAixcv1pw5c/Qv//IvuvPOOwOuDOfr9ttvH/z3NddcowULFujyyy/X9u3btXTp0oArG57W1la98cYb4/45yg9yru28++67B/99zTXXaObMmVq6dKkOHjyoyy+//EIvc9iuvPJK7du3T93d3frXf/1XrV69Wu3t7Rd0DWP+T3DTp09XJpM54xUYXV1damhoCLSq0Td16lR95CMf0YEDB0IvZVScPnYX23GVpHnz5mn69Onj8tiuWbNGzz//vH76058O+diUhoYGlUolnThxYsj48Xo8z7WdZ7N48WJJGnfHM5/P64orrtCiRYvU1tamhQsX6lvf+tYFPZZjPoDy+bwWLVqkbdu2DX4tTVNt27ZNzc3NAVc2uk6ePKmDBw9q5syZoZcyKubOnauGhoYhx7Wnp0e7d++e0MdVkt566y0dP358XB1b55zWrFmjrVu36tVXX9XcuXOHfH/RokXK5XJDjmdHR4cOHz48ro7nB23n2ezbt0+SxtXxPJs0TVUsFi/ssRzRlzSMki1btrhCoeA2bdrkfvWrX7m7777bTZ061XV2doZe2oj567/+a7d9+3Z36NAh92//9m+upaXFTZ8+3b399tuhlzZsvb297vXXX3evv/66k+Qef/xx9/rrr7vf/OY3zjnnHn30UTd16lT37LPPuv3797ubb77ZzZ071506dSrwym3ebzt7e3vdF7/4Rbdz50536NAh98orr7g//MM/dB/+8IfdwMBA6KV7u/fee11dXZ3bvn27O3bs2OClv79/cMw999zjZs+e7V599VW3Z88e19zc7JqbmwOu2u6DtvPAgQPukUcecXv27HGHDh1yzz77rJs3b5674YYbAq/c5itf+Yprb293hw4dcvv373df+cpXXBRF7ic/+Ylz7sIdy3ERQM45953vfMfNnj3b5fN5d91117ldu3aFXtKIuu2229zMmTNdPp93f/AHf+Buu+02d+DAgdDLOi8//elPnaQzLqtXr3bOvftS7K997Wuuvr7eFQoFt3TpUtfR0RF20cPwftvZ39/vli1b5i677DKXy+XcnDlz3F133TXu7jydbfskuY0bNw6OOXXqlPvLv/xLd8kll7hJkya5z3zmM+7YsWPhFj0MH7Sdhw8fdjfccIObNm2aKxQK7oorrnB/8zd/47q7u8Mu3Ogv/uIv3Jw5c1w+n3eXXXaZW7p06WD4OHfhjiUfxwAACGLMPwcEAJiYCCAAQBAEEAAgCAIIABAEAQQACIIAAgAEQQABAIIggAAAQRBAAIAgCCAAQBAEEAAgCAIIABDE/wcfJFzkP1q2lgAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Show Target Image:\n",
        "print(testset[target_index[0]][0].shape)\n",
        "plt.imshow(inv_normalize(testset[target_index[0]][0]).permute(1, 2, 0))\n",
        "target_image = torch.utils.data.Subset(testset, indices=target_index)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1U5Ima2i801g"
      },
      "outputs": [],
      "source": [
        "# Retraining 1:\n",
        "\n",
        "model2 = ResNet18()\n",
        "\n",
        "\n",
        "#model2 = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASS, base_width=64, initial_conv=initial_conv)\n",
        "model2 = model2.to(device)\n",
        "LOADMODEL = 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nKnWHOv08w4E",
        "outputId": "3a869976-53ec-410e-d4ff-976ca4690569"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--- 0.0007696151733398438 seconds ---\n"
          ]
        }
      ],
      "source": [
        "start_time = time.time()\n",
        "optimizer = torch.optim.SGD(params = model2.parameters(), lr = eta, weight_decay = 5e-4, momentum=0.9)\n",
        "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)\n",
        "device = torch.device('cuda')\n",
        "\n",
        "if not LOADMODEL:\n",
        "  for epoch in range(epochs):\n",
        "    train_loss = []\n",
        "\n",
        "    correct_preds = 0\n",
        "    total_preds = 0\n",
        "    model2.train()\n",
        "    for inputs, labels, index in trainloader:\n",
        "\n",
        "      inputs, labels = inputs.to(device), labels.to(device)\n",
        "      if AUGMENTS:\n",
        "        inputs = aug_transform(inputs)\n",
        "\n",
        "      optimizer.zero_grad()            # reset the gradients to zero\n",
        "      output = model2(inputs)            # Generate model outputs\n",
        "\n",
        "      if DATASET == 'CIFAR2':\n",
        "        labels = labels.to(torch.float32)\n",
        "        output = output.flatten()\n",
        "      loss = loss_fun(output, labels)   # Calculate loss\n",
        "\n",
        "      loss.backward()            # Compute gradients\n",
        "      optimizer.step()            # update parameters,\n",
        "\n",
        "      predictions = torch.argmax(output.data, dim=1)\n",
        "\n",
        "      total_preds += labels.size(0)\n",
        "      correct_preds += (predictions == labels).sum().item()\n",
        "\n",
        "      train_loss.append(loss.item())\n",
        "\n",
        "    print(\"Training Epoch {}: Loss: {}, Accuracy: {}\".format(epoch, np.mean(train_loss), correct_preds / total_preds))\n",
        "    # validation phase - once every 10 epochs\n",
        "      \n",
        "    if epoch % 10 == 0:\n",
        "      valid_losses = []\n",
        "      correct = 0\n",
        "      total = 0\n",
        "      model2.eval()\n",
        "\n",
        "      # Evaluate Model\n",
        "      for inputs, labels, index in testloader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "\n",
        "\n",
        "        with torch.no_grad():\n",
        "          output = model2(inputs)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            labels = labels.to(torch.float32)\n",
        "            output = output.flatten()\n",
        "          \n",
        "          # negative labels: when using hinge embedding loss only\n",
        "          flipped_labels = labels # * -1\n",
        "          loss = loss_fun(output, flipped_labels)   # Calculate loss\n",
        "          \n",
        "          valid_loss = loss_fun(output, labels)\n",
        "          valid_losses.append(valid_loss.item())\n",
        "\n",
        "          #predictions = torch.argmax(output, dim=1)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            predictions = torch.where(output < 0, 0, 1)\n",
        "          else:\n",
        "            predictions = torch.argmax(output.data, dim=1)\n",
        "          total += labels.size(0)\n",
        "          correct += (predictions == labels).sum().item()\n",
        "      \n",
        "      # Evaluate Model on Target\n",
        "      for inputs, labels, index in targetloader:\n",
        "        #\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "        with torch.no_grad():\n",
        "          output = model2(inputs)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            labels = labels.to(torch.float32)\n",
        "            output = output.flatten()\n",
        "          #predictions = torch.argmax(output.data, dim=1)\n",
        "          target_loss = loss_fun(output, labels)\n",
        "          if DATASET == 'CIFAR2':\n",
        "            predictions = torch.where(output < 0, 0, 1)\n",
        "          else:\n",
        "            predictions = torch.argmax(output.data, dim=1)\n",
        "          print(\"Target Original Loss: {}\".format(target_loss))\n",
        "\n",
        "      print(\"Validation Epoch {}: Valid loss: {}, Accuracy: {}\".format(epoch, np.mean(valid_losses), correct / total))\n",
        "    scheduler.step()\n",
        "    \n",
        "  #   torch.autograd(model, target)\n",
        "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
        "\n",
        "# Saving Clean Model\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "idSgQviX85nc"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "PATH = \"./drive/MyDrive/Poisoning_Machine_Unlearning/model\"\n",
        "os.makedirs(PATH, exist_ok = True) \n",
        "PATH += \"/resnet_cifar.ptr\"\n",
        "\n",
        "SAVEMODEL2 = 0\n",
        "LOADMODEL2 = 1\n",
        "if LOADMODEL2:\n",
        "  model2.load_state_dict(torch.load(PATH))\n",
        "  model2.to(device)\n",
        "  \n",
        "if SAVEMODEL2:\n",
        "  torch.save(model2.state_dict(), PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xhn4eZCMfOmc"
      },
      "outputs": [],
      "source": [
        "# Bulleye Polytope parameters:\n",
        "budget = 5 # number of poisoned images\n",
        "cifar_mean = (0.4914, 0.4822, 0.4465)\n",
        "cifar_std = (0.2023, 0.1994, 0.2010)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ERRrLtAajMLn"
      },
      "outputs": [],
      "source": [
        "# gradient matching parameters:\n",
        "# theta = model.coef_\n",
        "\n",
        "camou_budget = 5\n",
        "R = 2 # restarts\n",
        "epsilon = 16 # perturbation bound\n",
        "attackiter = 251 # optimization steps\n",
        "loss_opt = sys.maxsize # optimal loss\n",
        "delta_opt = 0 # optimal delta\n",
        "poison_opt = 'adam' # optimal poison images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sIQp7Rwrxi1U"
      },
      "outputs": [],
      "source": [
        "def proj_onto_simplex(coeffs, psum=1.0):\n",
        "    \"\"\"\n",
        "    Code stolen from https://github.com/hsnamkoong/robustopt/blob/master/src/simple_projections.py\n",
        "    Project onto probability simplex by default.\n",
        "    \"\"\"\n",
        "    v_np = coeffs.view(-1).detach().cpu().numpy()\n",
        "    n_features = v_np.shape[0]\n",
        "    v_sorted = np.sort(v_np)[::-1]\n",
        "    cssv = np.cumsum(v_sorted) - psum\n",
        "    ind = np.arange(n_features) + 1\n",
        "    cond = v_sorted - cssv / ind > 0\n",
        "    rho = ind[cond][-1]\n",
        "    theta = cssv[cond][-1] / float(rho)\n",
        "    w_ = np.maximum(v_np - theta, 0)\n",
        "    return torch.Tensor(w_.reshape(coeffs.size())).to(coeffs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dlPuErHLxveh"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "waPfbKeixdKN"
      },
      "outputs": [],
      "source": [
        "def least_squares_simplex(A, b, x_init, tol=1e-6, verbose=False, device='cuda'):\n",
        "    \"\"\"\n",
        "    The inner loop of Algorithm 1\n",
        "    \"\"\"\n",
        "    m, n = A.size()\n",
        "    assert b.size()[0] == A.size()[0], 'Matrix and vector do not have compatible dimensions'\n",
        "\n",
        "    # Initialize the optimization variables\n",
        "    if x_init is None:\n",
        "        x = torch.zeros(n, 1).to(device)\n",
        "    else:\n",
        "        x = x_init\n",
        "\n",
        "    # Define the objective function and its gradient\n",
        "    f = lambda x: torch.norm(A.mm(x) - b).item()\n",
        "    # change into a faster version when A is a tall matrix\n",
        "    AtA = A.t().mm(A)\n",
        "    Atb = A.t().mm(b)\n",
        "    grad_f = lambda x: AtA.mm(x) - Atb\n",
        "    # grad_f = lambda x: A.t().mm(A.mm(x)-b)\n",
        "\n",
        "    # Estimate the spectral radius of the Matrix A'A\n",
        "    y = torch.normal(0, torch.ones(n, 1)).to(device)\n",
        "    lipschitz = torch.norm(A.t().mm(A.mm(y))) / torch.norm(y)\n",
        "\n",
        "    # The stepsize for the problem should be 2/lipschits.  Our estimator might not be correct, it could be too small.  In\n",
        "    # this case our learning rate will be too big, and so we need to have a backtracking line search to make sure things converge.\n",
        "    t = 2 / lipschitz\n",
        "\n",
        "    # Main iteration\n",
        "    for iter in range(10000):\n",
        "        x_hat = x - t * grad_f(x)  # Forward step:  Gradient decent on the objective term\n",
        "        if f(x_hat) > f(x):  # Check whether the learning rate is small enough to decrease objective\n",
        "            t = t / 2\n",
        "        else:\n",
        "            x_new = proj_onto_simplex(x_hat)  # Backward step: Project onto prob simplex\n",
        "            stopping_condition = torch.norm(x - x_new) / max(torch.norm(x), 1e-8)\n",
        "            if verbose: print('iter %d: error = %0.4e' % (iter, stopping_condition))\n",
        "            if stopping_condition < tol:  # check stopping conditions\n",
        "                break\n",
        "            x = x_new\n",
        "\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eaAy1o48x6jS"
      },
      "outputs": [],
      "source": [
        "class PoisonBatch(torch.nn.Module):\n",
        "    \"\"\"\n",
        "    Implementing this to work with PyTorch optimizers.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, base_list):\n",
        "        super(PoisonBatch, self).__init__()\n",
        "        base_batch = torch.stack(base_list, 0)\n",
        "        self.poison = torch.nn.Parameter(base_batch.clone())\n",
        "\n",
        "    def forward(self):\n",
        "        return self.poison\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IlsBQzeFllHi"
      },
      "outputs": [],
      "source": [
        "def get_poison_tuples(poison_batch, poison_label):\n",
        "    \"\"\"\n",
        "    Includes the labels\n",
        "    \"\"\"\n",
        "    poison_tuple = [(poison_batch.poison.data[num_p].detach().cpu(), poison_label) for num_p in\n",
        "                    range(poison_batch.poison.size(0))]\n",
        "\n",
        "    return poison_tuple\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SfimBZMDx5m3"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wFzgqYomxWgt"
      },
      "outputs": [],
      "source": [
        "def get_CP_loss(net_list, target_feature_list, poison_batch, s_coeff_list, net_repeat, tol=1e-6):\n",
        "    \"\"\"\n",
        "    Corresponding to one step of the outer loop (except for updating and clipping) of Algorithm 1\n",
        "    \"\"\"\n",
        "    # assert len(net_list) == 1 or net_repeat == 3\n",
        "    poison_feat_mat_list = []\n",
        "    for net in net_list:\n",
        "        if net_repeat > 1:\n",
        "            poisons = [net(x=poison_batch(), penu=True) for _ in range(net_repeat)]\n",
        "            poisons = sum(poisons) / len(poisons)\n",
        "        elif net_repeat == 1:\n",
        "            poisons = net(x=poison_batch(), penu=True)\n",
        "        else:\n",
        "            assert False\n",
        "        poison_feat_mat_list.append(poisons)\n",
        "\n",
        "    t = time.time()\n",
        "    for nn, (pfeat_mat, target_feat) in enumerate(zip(poison_feat_mat_list, target_feature_list)):\n",
        "        s_coeff_list[nn] = least_squares_simplex(A=pfeat_mat.t().detach(), b=target_feat.t().detach(),\n",
        "                                                 x_init=s_coeff_list[nn], tol=tol)\n",
        "    coeffs_time = int(time.time() - t)\n",
        "\n",
        "    total_loss = 0\n",
        "    for net, s_coeff, target_feat, poison_feat_mat in zip(net_list, s_coeff_list, target_feature_list,\n",
        "                                                          poison_feat_mat_list):\n",
        "        residual = target_feat - torch.sum(s_coeff * poison_feat_mat, 0, keepdim=True)\n",
        "        target_norm_square = torch.sum(target_feat ** 2)\n",
        "        recon_loss = 0.5 * torch.sum(residual ** 2) / target_norm_square\n",
        "\n",
        "        total_loss += recon_loss\n",
        "\n",
        "    total_loss = total_loss / len(net_list)\n",
        "\n",
        "    return total_loss, s_coeff_list, coeffs_time\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V_rH8ygxkmtc"
      },
      "outputs": [],
      "source": [
        "def make_convex_polytope_poisons(subs_net_list, target_net, base_tensor_list, target, device, opt_method='adam',\n",
        "                                 lr=0.1, momentum=0.9, iterations=4000, epsilon=0.1,\n",
        "                                 decay_ites=[10000, 15000], decay_ratio=0.1,\n",
        "                                 mean=torch.Tensor((0.4914, 0.4822, 0.4465)).reshape(1, 3, 1, 1),\n",
        "                                 std=torch.Tensor((0.2023, 0.1994, 0.2010)).reshape(1, 3, 1, 1),\n",
        "                                 chk_path='', poison_idxes=[], poison_label=-1,\n",
        "                                 tol=1e-6, start_ite=0, poison_init=None, end2end=False, mode='convex',\n",
        "                                 net_repeat=1):\n",
        "    target_net[0].eval()\n",
        "    poison_batch = PoisonBatch(poison_init).to(device)\n",
        "\n",
        "    opt_method = opt_method.lower()\n",
        "    if opt_method == 'sgd':\n",
        "        optimizer = torch.optim.SGD(poison_batch.parameters(), lr=lr, momentum=momentum)\n",
        "        print(\"Using Signed Adam\")\n",
        "    elif opt_method == 'adam':\n",
        "        optimizer = torch.optim.Adam(poison_batch.parameters(), lr=lr, betas=(momentum, 0.999))\n",
        "    target = target.to(device)\n",
        "    std, mean = std.to(device), mean.to(device)\n",
        "    base_tensor_batch = torch.stack(base_tensor_list, 0)\n",
        "    base_range01_batch = base_tensor_batch * std + mean\n",
        "\n",
        "    # Because we have turned on DP for the substitute networks,\n",
        "    # the target image's feature becomes random.\n",
        "    # We can try enforcing the convex polytope in one of the multiple realizations of the feature,\n",
        "    # but empirically one realization is enough.\n",
        "    target_feat_list = []\n",
        "    # Coefficients for the convex combination.\n",
        "    # Initializing from the coefficients of last step gives faster convergence.\n",
        "    s_init_coeff_list = []\n",
        "    n_poisons = len(base_tensor_list)\n",
        "    for n, net in enumerate(subs_net_list):\n",
        "        net.eval()\n",
        "        if end2end:\n",
        "            block_feats = [feat.detach() for feat in net(x=target, block=True)]\n",
        "            target_feat_list.append(block_feats)\n",
        "            s_coeff = [torch.ones(n_poisons, 1).to(device) / n_poisons for _ in range(len(block_feats))]\n",
        "        else:\n",
        "            target_feat_list.append(net(x=target, penu=True).detach())\n",
        "            s_coeff = torch.ones(n_poisons, 1).to(device) / n_poisons\n",
        "\n",
        "        s_init_coeff_list.append(s_coeff)\n",
        "\n",
        "    # Keep this for evaluation.\n",
        "    if end2end:\n",
        "        target_feat_in_target = [feat.detach() for feat in target_net(x=target, block=True)]\n",
        "        target_init_coeff = [[torch.ones(len(base_tensor_list), 1).to(device) / n_poisons\n",
        "                              for _ in range(len(target_feat_in_target))]]\n",
        "    else:\n",
        "        target_feat_in_target = target_net[0](x=target, penu=True).detach()\n",
        "        target_init_coeff = [torch.ones(len(base_tensor_list), 1).to(device) / n_poisons]\n",
        "      \n",
        "    print(target_feat_in_target)\n",
        "    cp_loss_func = get_CP_loss\n",
        "\n",
        "    coeffs_time = 0\n",
        "    poisons_time = 0\n",
        "    for ite in range(start_ite, iterations):\n",
        "        if ite in decay_ites:\n",
        "            for param_group in optimizer.param_groups:\n",
        "                param_group['lr'] *= decay_ratio\n",
        "            print(\"%s Iteration %d, Adjusted lr to %.2e\" % (time.strftime(\"%Y-%m-%d %H:%M:%S\"), ite, lr))\n",
        "\n",
        "        poison_batch.zero_grad()\n",
        "        t = time.time()\n",
        "        if mode == 'convex':\n",
        "            total_loss, s_init_coeff_list, coeffs_time_tmp = cp_loss_func(subs_net_list, target_feat_list, poison_batch,\n",
        "                                                                          s_init_coeff_list,\n",
        "                                                                          net_repeat=net_repeat,\n",
        "                                                                          tol=tol)\n",
        "\n",
        "\n",
        "        coeffs_time += coeffs_time_tmp\n",
        "\n",
        "        total_loss.backward()\n",
        "        optimizer.step()\n",
        "        poisons_time += int(time.time() - t)\n",
        "\n",
        "        # clip the perturbations into the range\n",
        "        perturb_range01 = torch.clamp((poison_batch.poison.data - base_tensor_batch) * std, -epsilon, epsilon)\n",
        "        perturbed_range01 = torch.clamp(base_range01_batch.data + perturb_range01.data, 0, 1)\n",
        "        poison_batch.poison.data = (perturbed_range01 - mean) / std\n",
        "\n",
        "        if ite % 50 == 0 or ite == iterations - 1:\n",
        "            # whether we are doing convex or mean mode, we want to see the convex loss function for the target victim.\n",
        "            # Note this unification has done after running the attack for convex method and mean method (0-74), i.e.,\n",
        "            # for convex 0-99 and mean 0-74 the \"loss in target network\" is showing different losses for convex vs. mean\n",
        "            target_loss, target_init_coeff, _ = cp_loss_func([target_net[0]],\n",
        "                                                             [target_feat_in_target],\n",
        "                                                             poison_batch,\n",
        "                                                             target_init_coeff,\n",
        "                                                             net_repeat=1,\n",
        "                                                             tol=tol)\n",
        "\n",
        "            # compute the difference in target\n",
        "            print(\" %s Iteration %d \\t Training Loss: %.3e \\t Loss in Target Net: %.3e\\t  \" % (\n",
        "                time.strftime(\"%Y-%m-%d %H:%M:%S\"), ite, total_loss.item(), target_loss.item()))\n",
        "            sys.stdout.flush()\n",
        "\n",
        "            # save the checkpoints\n",
        "            poison_tuple_list = get_poison_tuples(poison_batch, poison_label)\n",
        "            torch.save({'poison': poison_tuple_list, 'idx': poison_idxes, 'coeffs_time': coeffs_time,\n",
        "                        'poisons_time': poisons_time, 'target_loss': target_loss, 'total_loss': total_loss,\n",
        "                        'coeff_list': s_init_coeff_list, 'coeff_list_in_victim': target_init_coeff},\n",
        "                       os.path.join(chk_path, \"poison_%05d.pth\" % ite))\n",
        "\n",
        "    return get_poison_tuples(poison_batch, poison_label)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_uLKUbDplvh9"
      },
      "outputs": [],
      "source": [
        "class AverageMeter(object):\n",
        "    \"\"\"Computes and stores the average and current value\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self):\n",
        "        self.val = 0\n",
        "        self.avg = 0\n",
        "        self.sum = 0\n",
        "        self.count = 0\n",
        "\n",
        "    def update(self, val, n=1):\n",
        "        self.val = val\n",
        "        self.sum += val * n\n",
        "        self.count += n\n",
        "        self.avg = self.sum / self.count\n",
        "\n",
        "\n",
        "def accuracy(output, target, topk=(1,)):\n",
        "    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n",
        "    with torch.no_grad():\n",
        "        maxk = max(topk)\n",
        "        batch_size = target.size(0)\n",
        "\n",
        "        _, pred = output.topk(maxk, 1, True, True)\n",
        "        pred = pred.t()\n",
        "        correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
        "\n",
        "        res = []\n",
        "        for k in topk:\n",
        "            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n",
        "            res.append(correct_k.mul_(100.0 / batch_size))\n",
        "        return res\n",
        "\n",
        "\n",
        "def train_network_with_poison(net, target_img, poison_tuple_list, poisoned_dset,\n",
        "                              base_idx_list, testset, poison_dict, savemodel=None):\n",
        "    # requires implementing a get_penultimate_params_list() method to get the parameter identifier of the net's last\n",
        "    # layer\n",
        "    params = net.get_penultimate_params_list()\n",
        "    #if retrain_opt == 'adam':\n",
        "    print(\"Using Adam for retraining\")\n",
        "    optimizer = torch.optim.Adam(params, lr=0.1, weight_decay=0)\n",
        "    #else:\n",
        "    #print(\"Using SGD for retraining\")\n",
        "    #optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9,weight_decay=0)\n",
        "    \n",
        "    criterion = nn.CrossEntropyLoss() # .to('cuda')\n",
        "\n",
        "    poisoned_loader = torch.utils.data.DataLoader(poisoned_dset, batch_size=64, shuffle=True)\n",
        "    # The test set of clean CIFAR10\n",
        "    test_loader = torch.utils.data.DataLoader(testset, batch_size=500)\n",
        "    \n",
        "    net.train()\n",
        "    for epoch in range(60):\n",
        "        loss_meter = AverageMeter()\n",
        "        acc_meter = AverageMeter()\n",
        "        time_meter = AverageMeter()\n",
        "\n",
        "        if epoch in [30, 45]:\n",
        "            for param_group in optimizer.param_groups:\n",
        "                param_group['lr'] *= 0.1\n",
        "\n",
        "        end_time = time.time()\n",
        "        for ite, (input, label, indices) in enumerate(poisoned_loader):\n",
        "            for i, index in enumerate(indices):\n",
        "                if int(index) in poison_dict:\n",
        "                  input[i] = poison_tuple_list[poison_dict[int(index)]][0]\n",
        "            input, label = input.to(device), label.to(device)\n",
        "\n",
        "            with torch.no_grad():\n",
        "              feat = net.penultimate(input).detach()\n",
        "            output = net.linear(feat)\n",
        "            loss = criterion(output, label)\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            prec1 = accuracy(output, label)[0]\n",
        "\n",
        "            time_meter.update(time.time() - end_time)\n",
        "            end_time = time.time()\n",
        "            loss_meter.update(loss.item(), input.size(0))\n",
        "            acc_meter.update(prec1.item(), input.size(0))\n",
        "\n",
        "            if (epoch % 60 == 0 or epoch == 60 - 1)  and (ite == len(poisoned_loader) - 1):\n",
        "                print(\"{2}, Epoch {0}, Iteration {1}, loss {loss.val:.3f} ({loss.avg:.3f}), \"\n",
        "                      \"acc {acc.val:.3f} ({acc.avg:.3f})\".\n",
        "                      format(epoch, ite, time.strftime(\"%Y-%m-%d %H:%M:%S\"),\n",
        "                             loss=loss_meter, acc=acc_meter))\n",
        "            sys.stdout.flush()\n",
        "\n",
        "        if epoch == 60 - 1:\n",
        "            net.eval()\n",
        "            # print the scores for target and base\n",
        "            if device == 'cuda':\n",
        "                target_pred = net(target_img.to(device))\n",
        "            else:\n",
        "                target_pred = net(target_img)\n",
        "            target_scores = [float(n) for n in list(softmax(target_pred.view(-1).cpu().detach().numpy()))]\n",
        "            score, target_pred = target_pred.topk(1, 1, True, True)\n",
        "            poison_pred_list = []\n",
        "            for poison_img, _ in poison_tuple_list:\n",
        "                base_scores = net(poison_img[None, :, :, :].to(device))\n",
        "                base_score, base_pred = base_scores.topk(1, 1, True, True)\n",
        "                poison_pred_list.append(base_pred.item())\n",
        "            print(\n",
        "                \"Target Label: {}, Poison label: {}, Prediction:{}, Target's Score:{}, Poisons' Predictions:{}\".format(\n",
        "                    target_class, poison_class, target_pred[0][0].item(), target_scores,\n",
        "                    poison_pred_list))\n",
        "\n",
        "    # Evaluate the results on the clean test set\n",
        "    val_acc_meter = AverageMeter()\n",
        "    with torch.no_grad():\n",
        "        net.eval()\n",
        "        for ite, (input, label, index) in enumerate(test_loader):\n",
        "            input, label = input.to(device), label.to(device)\n",
        "\n",
        "            output = net(input)\n",
        "\n",
        "            prec1 = accuracy(output, label)[0]\n",
        "            val_acc_meter.update(prec1.item(), input.size(0))\n",
        "\n",
        "            if False or ite % 100 == 0 or ite == len(test_loader) - 1:\n",
        "                print(\"{2} Epoch {0}, Val iteration {1}, \"\n",
        "                      \"acc {acc.val:.3f} ({acc.avg:.3f})\".\n",
        "                      format(epoch, ite, time.strftime(\"%Y-%m-%d %H:%M:%S\"), acc=val_acc_meter))\n",
        "\n",
        "    print(\"* Prec: {}\".format(val_acc_meter.avg))\n",
        "\n",
        "    # if savemodel is not None:\n",
        "    #     torch.save(net.state_dict(), savemodel)\n",
        "\n",
        "    return {'clean acc': val_acc_meter.avg, 'prediction': target_pred[0][0].item(),\n",
        "            'poisons predictions': poison_pred_list,\n",
        "            'scores': target_scores, 'malicious score': target_scores[poison_class], 'camera': {}}\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xDWUzYxqzud1"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ebKipVxYLAlW"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r9oeGWuDwIHf"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ixi3ZVSMksCX"
      },
      "outputs": [],
      "source": [
        "# Bulleye Polytope parameters:\n",
        "# theta = model.coef_\n",
        "budget = 5 # number of poisoned images\n",
        "cifar_mean = (0.4914, 0.4822, 0.4465)\n",
        "cifar_std = (0.2023, 0.1994, 0.2010)\n",
        "poison_opt = 'adam'\n",
        "poison_lr = 0.1\n",
        "poison_momentum = 0.9\n",
        "poison_ites = 1000\n",
        "poison_epsilon = 0.5\n",
        "poison_decay_ites = []\n",
        "poison_decay_ratio = 0.1\n",
        "chk_path = \"./drive/MyDrive/Poisoning_Machine_Unlearning/\"\n",
        "\n",
        "poison_index = np.random.choice(poison_index, budget, replace=False)\n",
        "poison_dict = {}\n",
        "\n",
        "for index, val in enumerate(poison_index):\n",
        "  poison_dict[val] = index\n",
        "\n",
        "targets = torch.stack([data[0] for data in target_image], dim=0).to(device)\n",
        "intended_classes = torch.tensor([poison_class]).to(device=device, dtype=torch.long)\n",
        "true_classes = torch.tensor([data[1] for data in target_image]).to(device=device, dtype=torch.long)\n",
        "\n",
        "tol = 1e-6\n",
        "net_repeat = True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CiiVvSZqzPJ9"
      },
      "outputs": [],
      "source": [
        "def fetch_target(target_label, target_index, start_idx, path, subset, transforms):\n",
        "    \"\"\"\n",
        "    Fetch the \"target_index\"-th target, counting starts from start_idx\n",
        "    \"\"\"\n",
        "    print(\"fetch target\")\n",
        "    return  torch.stack([data[0] for data in target_image], dim=0).to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pOBXrP-_Iju9"
      },
      "outputs": [],
      "source": [
        "def fetch_poison_bases(poison_label, num_poison, subset, path, transforms_img):\n",
        "    \"\"\"\n",
        "    Only going to fetch the first num_poison image as the base class from the poison_label class\n",
        "    \"\"\"\n",
        "    print(\"fetch base\")\n",
        "    trainset = CIFAR10(train=True, transform=transforms_img)\n",
        "    testset = CIFAR10(train=False, transform=transforms_img)\n",
        "    base_tensor_list, base_idx_list = [], []\n",
        "    for idx, (img, label, index) in enumerate(trainset):\n",
        "        if index in subset:\n",
        "          base_tensor_list.append(img)\n",
        "          base_idx_list.append(idx)\n",
        "    base_tensor_list = torch.stack(base_tensor_list)\n",
        "    return base_tensor_list, base_idx_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CHJ7FhwHRbp7"
      },
      "outputs": [],
      "source": [
        "def loss_from_center(subs_net_list, target_feat_list, poison_batch, net_repeat, end2end):\n",
        "\n",
        "    loss = 0\n",
        "    for net, center in zip(subs_net_list, target_feat_list):\n",
        "        poisons = [net(x=poison_batch(), penu=True) for _ in range(net_repeat)]\n",
        "        poisons = sum(poisons) / len(poisons)\n",
        "        diff = torch.mean(poisons, dim=0) - center\n",
        "        diff_norm = torch.norm(diff, dim=1) / torch.norm(center, dim=1)\n",
        "        loss += torch.mean(diff_norm)\n",
        "\n",
        "    loss = loss / len(subs_net_list)\n",
        "\n",
        "    return loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kHKX4PofYcjt"
      },
      "outputs": [],
      "source": [
        "# load the pre-trained models\n",
        "def run_code():\n",
        "  if True:\n",
        "    chk_path = \"./drive/MyDrive/Poisoning_Machine_Unlearning/\"\n",
        "    sub_net_list = []\n",
        "    sub_net_list.append(model)\n",
        "\n",
        "    print(\"subs nets, effective num: {}\".format(len(sub_net_list)))\n",
        "\n",
        "    print(\"Loading the victims networks\")\n",
        "    targets_net = []\n",
        "    targets_net.append(model2)\n",
        "\n",
        "    cifar_mean = (0.4914, 0.4822, 0.4465)\n",
        "    cifar_std = (0.2023, 0.1994, 0.2010)\n",
        "    transform_test = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(cifar_mean, cifar_std),\n",
        "    ])\n",
        "\n",
        "    # Get the target image\n",
        "    target = fetch_target(target_class, target_index, 50, subset=target_index,\n",
        "                          path=\"./data\", transforms=transform_test)\n",
        "\n",
        "    chk_path = os.path.join(chk_path, 'mean')\n",
        "\n",
        "    chk_path = os.path.join(chk_path, str(poison_ites))\n",
        "    chk_path = os.path.join(chk_path, str(target_index))\n",
        "    if not os.path.exists(chk_path):\n",
        "        os.makedirs(chk_path)\n",
        "    import sys\n",
        "\n",
        "    print(\"Path: {}\".format(chk_path))\n",
        "\n",
        "    # just fetch the first poison_num samples\n",
        "    base_tensor_list, base_idx_list = fetch_poison_bases(poison_class, budget, subset=poison_index,\n",
        "                                                    path='./data', transforms_img=transform_test)\n",
        "    base_tensor_list = [bt.to('cuda') for bt in base_tensor_list]\n",
        "    print(\"Selected base image indices: {}\".format(base_idx_list))\n",
        "\n",
        "    poison_init = base_tensor_list\n",
        "\n",
        "    import time\n",
        "\n",
        "    t = time.time()\n",
        "    poison_tuple_list = make_convex_polytope_poisons(sub_net_list, targets_net, base_tensor_list,\n",
        "                                                         target, device='cuda', opt_method=poison_opt,\n",
        "                                                         lr=poison_lr, momentum=poison_momentum,\n",
        "                                                         iterations=poison_ites, epsilon=poison_epsilon,\n",
        "                                                         decay_ites=poison_decay_ites,\n",
        "                                                         decay_ratio=poison_decay_ratio,\n",
        "                                                         mean=torch.Tensor(cifar_mean).reshape(1, 3, 1, 1),\n",
        "                                                         std=torch.Tensor(cifar_std).reshape(1, 3, 1, 1),\n",
        "                                                         chk_path=chk_path, poison_idxes=base_idx_list,\n",
        "                                                         poison_label=poison_class,\n",
        "                                                         tol=tol,\n",
        "                                                         end2end=False,\n",
        "                                                         start_ite=0,\n",
        "                                                         poison_init=poison_init,\n",
        "                                                         mode='convex')\n",
        "    tt = time.time()\n",
        "    res = []\n",
        "    print(\"Evaluating against victims networks\")\n",
        "    poison_dict = {}\n",
        "    for i, index in enumerate(base_idx_list):\n",
        "      poison_dict[index] = i\n",
        "\n",
        "    for tnet, tnet_name in zip(targets_net, targets_net):\n",
        "        print(tnet_name)\n",
        "        pred = train_network_with_poison(tnet, target, poison_tuple_list, trainset, base_idx_list, testset, poison_dict)\n",
        "        res.append(pred)\n",
        "        print(\"--------\")\n",
        "\n",
        "    print(\"------SUMMARY------\")\n",
        "    print(\"TIME ELAPSED (mins): {}\".format(int((tt - t) / 60)))\n",
        "    print(\"TARGET INDEX: {}\".format(target_index))\n",
        "    for tnet_name, r in zip(targets_net, res):\n",
        "        print(tnet_name, int(r == poison_class))\n",
        "\n",
        "    return poison_tuple_list, base_idx_list, poison_dict"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hpSIWhiQN1eo",
        "outputId": "1c65c9dc-f9c4-4717-d504-aa3369ee05e5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "subs nets, effective num: 1\n",
            "Loading the victims networks\n",
            "fetch target\n",
            "Path: ./drive/MyDrive/Poisoning_Machine_Unlearning/mean/1000/[3950]\n",
            "fetch base\n",
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n",
            "Selected base image indices: [9355, 19338, 21737, 25157, 38079]\n",
            "tensor([[0.3844, 0.1127, 0.5763, 1.3533, 0.0036, 0.1609, 0.5820, 0.1210, 0.4921,\n",
            "         0.3321, 1.0170, 0.0647, 0.2350, 0.8950, 0.3634, 0.2452, 0.2774, 0.2852,\n",
            "         1.4937, 0.1139, 0.2353, 0.5114, 0.1976, 0.8876, 0.5219, 0.2104, 1.4093,\n",
            "         1.4835, 0.4308, 0.6945, 0.4287, 0.1161, 0.1083, 0.1979, 0.1586, 0.6016,\n",
            "         0.2713, 0.5853, 0.0220, 1.5277, 1.1277, 0.1875, 1.1149, 1.9977, 1.0559,\n",
            "         0.8771, 0.1981, 0.2169, 0.3167, 0.3267, 0.8558, 1.6952, 0.1839, 0.1262,\n",
            "         0.3966, 0.1520, 0.6357, 1.2591, 0.0353, 0.1083, 0.2240, 0.0059, 0.0916,\n",
            "         0.4562, 0.2320, 0.1372, 0.2587, 0.0965, 1.3439, 0.0255, 0.6766, 0.3003,\n",
            "         0.0587, 1.6923, 0.3020, 0.3641, 0.1072, 1.2275, 0.0688, 1.1079, 0.3375,\n",
            "         0.3137, 0.4425, 0.7269, 0.1517, 1.3284, 0.2470, 0.0557, 1.4235, 0.7662,\n",
            "         1.3476, 0.1525, 0.7496, 0.3293, 0.3210, 0.3246, 0.0938, 0.7569, 1.5743,\n",
            "         0.0872, 0.2159, 1.7689, 0.7822, 0.4410, 0.2527, 0.1236, 0.7834, 0.1170,\n",
            "         0.2729, 1.2302, 0.4088, 0.0419, 0.8573, 0.2758, 0.9109, 0.3199, 0.2043,\n",
            "         1.5084, 1.7806, 0.0636, 0.1450, 0.6858, 0.0940, 0.7381, 0.1296, 0.5399,\n",
            "         0.4302, 0.0184, 1.8013, 0.1026, 0.2996, 0.6068, 1.4616, 0.3113, 1.4782,\n",
            "         0.7001, 1.4765, 0.4852, 0.3341, 0.0881, 0.1507, 1.9523, 0.0337, 1.6032,\n",
            "         0.1217, 1.0276, 1.4730, 0.3229, 0.0386, 1.0912, 0.8368, 0.1011, 0.6848,\n",
            "         0.0736, 0.6756, 0.1452, 0.4664, 1.4205, 0.4906, 0.0931, 0.2128, 0.2588,\n",
            "         0.2777, 0.2198, 0.9161, 0.1378, 0.2759, 0.2680, 1.1098, 0.0744, 0.0967,\n",
            "         1.7433, 1.4113, 0.2885, 0.9029, 0.1750, 1.3663, 0.6250, 1.2069, 0.3388,\n",
            "         0.1214, 0.3127, 0.9003, 0.4768, 0.0596, 1.3117, 0.6678, 0.3273, 0.2983,\n",
            "         0.0446, 0.2931, 0.3743, 1.5208, 0.1237, 0.5532, 0.4221, 0.2359, 0.0728,\n",
            "         1.0576, 0.5392, 0.6460, 1.4236, 0.2515, 0.4822, 0.2666, 0.5525, 0.3088,\n",
            "         0.7179, 0.2591, 0.8913, 0.4484, 0.8781, 1.5848, 1.0184, 0.3969, 1.6154,\n",
            "         1.1198, 1.0757, 0.5099, 0.0915, 0.0115, 1.6928, 0.0253, 0.4320, 1.1229,\n",
            "         1.7738, 0.9744, 0.1421, 0.0369, 0.2925, 1.3270, 0.2571, 0.7265, 0.3258,\n",
            "         0.3360, 0.2128, 0.7876, 0.0000, 0.0635, 1.6180, 0.5239, 1.7046, 0.2300,\n",
            "         0.1855, 1.5250, 0.3636, 0.8122, 0.9562, 0.2267, 0.1697, 2.3337, 0.3712,\n",
            "         0.2699, 0.3345, 1.0885, 0.8689, 0.0768, 0.5751, 1.5274, 0.0687, 0.9432,\n",
            "         0.0045, 0.3379, 0.3602, 0.8054, 0.2238, 0.3054, 0.2552, 0.5814, 1.4262,\n",
            "         0.5904, 0.0053, 0.6948, 0.0140, 0.1420, 0.9909, 0.8311, 1.0012, 0.4388,\n",
            "         0.1764, 1.2875, 0.4531, 1.4724, 0.0676, 0.2518, 0.3533, 1.2801, 1.5236,\n",
            "         0.3248, 0.8515, 0.2757, 0.2388, 0.4219, 0.5242, 0.5373, 0.2659, 0.9848,\n",
            "         0.5720, 0.2823, 0.2980, 0.4439, 0.3783, 0.6486, 0.5868, 0.3444, 1.1831,\n",
            "         0.1377, 0.7241, 0.2469, 0.0774, 0.5377, 0.2932, 0.6675, 0.1927, 0.2524,\n",
            "         0.3661, 0.7988, 0.3930, 0.0942, 0.4064, 1.3398, 0.1719, 0.0824, 1.0881,\n",
            "         0.1875, 0.0773, 0.7188, 0.7633, 0.3649, 0.2241, 1.2658, 1.2805, 0.3134,\n",
            "         0.3252, 0.1160, 0.6568, 1.7791, 0.5187, 0.0526, 0.3724, 0.2316, 2.2749,\n",
            "         0.7802, 1.6159, 0.3474, 0.2068, 0.1791, 0.1827, 0.0728, 1.1373, 0.0910,\n",
            "         0.8828, 0.8075, 1.6378, 1.3705, 0.0895, 0.1654, 0.0466, 1.0647, 0.2636,\n",
            "         0.0565, 0.2196, 0.5440, 1.5431, 0.7056, 0.9628, 1.3227, 1.3508, 0.1878,\n",
            "         1.1227, 0.8891, 0.1055, 1.6481, 0.4395, 1.0931, 0.5949, 0.2429, 0.3685,\n",
            "         1.7689, 0.7593, 1.5252, 1.1900, 1.7352, 1.6100, 0.0930, 0.3473, 0.2850,\n",
            "         0.2441, 0.4510, 0.3691, 0.0696, 0.0634, 0.7810, 0.6795, 0.3007, 0.1041,\n",
            "         0.4640, 0.7600, 1.7742, 0.2145, 0.4267, 0.1182, 0.4055, 0.4505, 0.1665,\n",
            "         1.0697, 0.7978, 0.1105, 0.6780, 0.1727, 1.0777, 0.9981, 0.2871, 0.6724,\n",
            "         0.1861, 0.3469, 0.1708, 1.3020, 1.6102, 0.2128, 0.3826, 0.2603, 1.2062,\n",
            "         0.4469, 0.1529, 0.9250, 0.9986, 0.1247, 1.2632, 1.4332, 1.9175, 0.9194,\n",
            "         1.3487, 0.7084, 1.2917, 0.8800, 0.0236, 2.2883, 0.7652, 1.0059, 0.3301,\n",
            "         1.5774, 1.5489, 0.5523, 0.5189, 0.6832, 0.4246, 0.4658, 0.2525, 0.1789,\n",
            "         0.1173, 0.8279, 1.1099, 0.6521, 0.0067, 0.0576, 1.3366, 0.6255, 0.1838,\n",
            "         0.0454, 0.2585, 0.0469, 0.4256, 0.9530, 0.8285, 0.7758, 0.4816, 0.3334,\n",
            "         0.5346, 1.4029, 1.4570, 0.8638, 0.9864, 1.6038, 1.7536, 0.3076, 0.9472,\n",
            "         0.0439, 0.0964, 0.6399, 0.7604, 0.1005, 0.6967, 1.2758, 1.3471, 0.2968,\n",
            "         0.1594, 0.2556, 0.1023, 0.2119, 1.0572, 0.0165, 1.4627, 0.3920, 0.7915,\n",
            "         0.4041, 1.2797, 0.2991, 0.2764, 0.2755, 1.0206, 0.1695, 0.0769, 1.3972,\n",
            "         2.1347, 0.9236, 1.4910, 0.7001, 1.8239, 0.8797, 0.9915, 1.6839]],\n",
            "       device='cuda:0')\n",
            " 2023-05-22 22:10:25 Iteration 0 \t Training Loss: 3.349e-01 \t Loss in Target Net: 1.170e-01\t  \n",
            " 2023-05-22 22:10:30 Iteration 50 \t Training Loss: 2.126e-03 \t Loss in Target Net: 2.127e-03\t  \n",
            " 2023-05-22 22:10:35 Iteration 100 \t Training Loss: 8.154e-04 \t Loss in Target Net: 7.947e-04\t  \n",
            " 2023-05-22 22:10:40 Iteration 150 \t Training Loss: 4.752e-04 \t Loss in Target Net: 4.651e-04\t  \n",
            " 2023-05-22 22:10:45 Iteration 200 \t Training Loss: 5.081e-04 \t Loss in Target Net: 3.659e-04\t  \n",
            " 2023-05-22 22:10:51 Iteration 250 \t Training Loss: 2.689e-04 \t Loss in Target Net: 2.606e-04\t  \n",
            " 2023-05-22 22:10:57 Iteration 300 \t Training Loss: 2.618e-04 \t Loss in Target Net: 2.433e-04\t  \n",
            " 2023-05-22 22:11:03 Iteration 350 \t Training Loss: 1.685e-04 \t Loss in Target Net: 1.960e-04\t  \n",
            " 2023-05-22 22:11:09 Iteration 400 \t Training Loss: 3.410e-04 \t Loss in Target Net: 2.679e-04\t  \n",
            " 2023-05-22 22:11:16 Iteration 450 \t Training Loss: 1.782e-04 \t Loss in Target Net: 2.437e-04\t  \n",
            " 2023-05-22 22:11:22 Iteration 500 \t Training Loss: 1.217e-04 \t Loss in Target Net: 1.197e-04\t  \n",
            " 2023-05-22 22:11:28 Iteration 550 \t Training Loss: 7.863e-05 \t Loss in Target Net: 9.451e-05\t  \n",
            " 2023-05-22 22:11:35 Iteration 600 \t Training Loss: 6.384e-05 \t Loss in Target Net: 8.462e-05\t  \n",
            " 2023-05-22 22:11:43 Iteration 650 \t Training Loss: 1.332e-03 \t Loss in Target Net: 1.165e-03\t  \n",
            " 2023-05-22 22:11:49 Iteration 700 \t Training Loss: 1.200e-04 \t Loss in Target Net: 1.102e-04\t  \n",
            " 2023-05-22 22:11:55 Iteration 750 \t Training Loss: 4.548e-05 \t Loss in Target Net: 4.681e-05\t  \n",
            " 2023-05-22 22:12:01 Iteration 800 \t Training Loss: 4.211e-05 \t Loss in Target Net: 3.036e-05\t  \n",
            " 2023-05-22 22:12:07 Iteration 850 \t Training Loss: 4.795e-05 \t Loss in Target Net: 8.277e-05\t  \n",
            " 2023-05-22 22:12:13 Iteration 900 \t Training Loss: 3.205e-05 \t Loss in Target Net: 2.757e-05\t  \n",
            " 2023-05-22 22:12:19 Iteration 950 \t Training Loss: 4.657e-05 \t Loss in Target Net: 6.353e-05\t  \n",
            " 2023-05-22 22:12:26 Iteration 999 \t Training Loss: 1.625e-05 \t Loss in Target Net: 1.653e-05\t  \n",
            "Evaluating against victims networks\n",
            "ResNet(\n",
            "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "  (layer1): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer2): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer3): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer4): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (linear): Linear(in_features=512, out_features=10, bias=True)\n",
            ")\n",
            "Using Adam for retraining\n",
            "2023-05-22 22:12:51, Epoch 0, Iteration 781, loss 0.000 (0.315), acc 100.000 (98.368)\n",
            "2023-05-22 22:37:19, Epoch 59, Iteration 781, loss 0.000 (0.051), acc 100.000 (99.942)\n",
            "Target Label: 8, Poison label: 5, Prediction:5, Target's Score:[0.0, 0.0, 0.0, 0.0, 0.0, 0.9993101358413696, 0.0, 0.0, 0.0006898678839206696, 0.0], Poisons' Predictions:[2, 8, 5, 5, 8]\n",
            "2023-05-22 22:37:20 Epoch 59, Val iteration 0, acc 88.600 (88.600)\n",
            "2023-05-22 22:37:24 Epoch 59, Val iteration 19, acc 87.200 (87.150)\n",
            "* Prec: 87.1500015258789\n",
            "--------\n",
            "------SUMMARY------\n",
            "TIME ELAPSED (mins): 2\n",
            "TARGET INDEX: [3950]\n",
            "ResNet(\n",
            "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "  (layer1): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer2): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer3): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer4): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (linear): Linear(in_features=512, out_features=10, bias=True)\n",
            ") 0\n"
          ]
        }
      ],
      "source": [
        "poison_tuple_list, base_idx_list, poison_dict = run_code() "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "H40vNehNxvFi",
        "outputId": "6ed2c566-7fa6-44b6-8e23-33fbbafd9370"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "tensor([[[[0.2470]],\n",
            "\n",
            "         [[0.2435]],\n",
            "\n",
            "         [[0.2616]]]])\n",
            "tensor([[[[0.4914]],\n",
            "\n",
            "         [[0.4822]],\n",
            "\n",
            "         [[0.4465]]]])\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ResNet(\n",
              "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "  (layer1): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (layer2): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential(\n",
              "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (layer3): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential(\n",
              "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (layer4): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential(\n",
              "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (shortcut): Sequential()\n",
              "    )\n",
              "  )\n",
              "  (linear): Linear(in_features=512, out_features=10, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 41
        }
      ],
      "source": [
        "# Initialize poison_delta\n",
        "\n",
        "std_tensor = torch.tensor(data_std)[None, :, None, None]\n",
        "mean_tensor = torch.tensor(data_mean)[None, :, None, None]\n",
        "\n",
        "print(std_tensor)\n",
        "print(mean_tensor)\n",
        "\n",
        "model.eval()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rLMw2g1jeJCC"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JwTpwB0o7J5T",
        "outputId": "b8fdbb58-abcd-471d-f224-a531549b78b4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{9355: 0, 19338: 1, 21737: 2, 25157: 3, 38079: 4}\n",
            "[9355, 19338, 21737, 25157, 38079]\n"
          ]
        }
      ],
      "source": [
        "poison_dict = {}\n",
        "for i, index in enumerate(base_idx_list):\n",
        "  poison_dict[index] = i\n",
        "print(poison_dict)\n",
        "print(base_idx_list)\n",
        "# poison_delta = "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kSkqJWa6NRHg"
      },
      "outputs": [],
      "source": [
        "# preparing camou set:\n",
        "\n",
        "camou_index = np.random.choice(camou_index, camou_budget, replace=False)\n",
        "camou_dict = {}\n",
        "\n",
        "for index, val in enumerate(camou_index):\n",
        "  camou_dict[val] = index\n",
        "\n",
        "camouset = Subset(trainset, camou_index)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "efEMv4Eb4bzl",
        "outputId": "0fd34b52-0627-4a8e-a533-dd9db16319d4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[ 6762 39385 47871 40547 48457]\n"
          ]
        }
      ],
      "source": [
        "print(camou_index)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oTMR8cVd3s6p"
      },
      "outputs": [],
      "source": [
        "model2 = ResNet18()\n",
        "\n",
        "\n",
        "#model2 = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASS, base_width=64, initial_conv=initial_conv)\n",
        "model2 = model2.to(device)\n",
        "\n",
        "import os\n",
        "PATH = \"./drive/MyDrive/Poisoning_Machine_Unlearning/model\"\n",
        "os.makedirs(PATH, exist_ok = True) \n",
        "PATH += \"/resnet_cifar.ptr\"\n",
        "\n",
        "LOADMODEL2 = 1\n",
        "if LOADMODEL2:\n",
        "  model2.load_state_dict(torch.load(PATH))\n",
        "  model2.to(device)\n",
        "  \n",
        "if SAVEMODEL2:\n",
        "  torch.save(model2.state_dict(), PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3RsY9L2Y42V8"
      },
      "outputs": [],
      "source": [
        "class AverageMeter(object):\n",
        "    \"\"\"Computes and stores the average and current value\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self):\n",
        "        self.val = 0\n",
        "        self.avg = 0\n",
        "        self.sum = 0\n",
        "        self.count = 0\n",
        "\n",
        "    def update(self, val, n=1):\n",
        "        self.val = val\n",
        "        self.sum += val * n\n",
        "        self.count += n\n",
        "        self.avg = self.sum / self.count\n",
        "\n",
        "\n",
        "def accuracy(output, target, topk=(1,)):\n",
        "    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n",
        "    with torch.no_grad():\n",
        "        maxk = max(topk)\n",
        "        batch_size = target.size(0)\n",
        "\n",
        "        _, pred = output.topk(maxk, 1, True, True)\n",
        "        pred = pred.t()\n",
        "        correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
        "\n",
        "        res = []\n",
        "        for k in topk:\n",
        "            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n",
        "            res.append(correct_k.mul_(100.0 / batch_size))\n",
        "        return res\n",
        "\n",
        "\n",
        "def train_network_with_poison_camou(net, target_img, poison_tuple_list, camou_tuple_list, poisoned_dset, base_idx_list, camou_idx_list\n",
        "                                    ,testset, poison_dict, camou_dict, savemodel=None):\n",
        "    # requires implementing a get_penultimate_params_list() method to get the parameter identifier of the net's last\n",
        "    # layer\n",
        "    params = net.get_penultimate_params_list()\n",
        "    #if retrain_opt == 'adam':\n",
        "    print(\"Using Adam for retraining\")\n",
        "    optimizer = torch.optim.Adam(params, lr=0.1, weight_decay=0)\n",
        "    #else:\n",
        "    #print(\"Using SGD for retraining\")\n",
        "    #optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9,weight_decay=0)\n",
        "    \n",
        "    criterion = nn.CrossEntropyLoss() # .to('cuda')\n",
        "\n",
        "    poisoned_loader = torch.utils.data.DataLoader(poisoned_dset, batch_size=64, shuffle=True)\n",
        "    # The test set of clean CIFAR10\n",
        "    test_loader = torch.utils.data.DataLoader(testset, batch_size=500)\n",
        "    \n",
        "    net.train()\n",
        "    for epoch in range(60):\n",
        "        loss_meter = AverageMeter()\n",
        "        acc_meter = AverageMeter()\n",
        "        time_meter = AverageMeter()\n",
        "\n",
        "        if epoch in [30, 45]:\n",
        "            for param_group in optimizer.param_groups:\n",
        "                param_group['lr'] *= 0.1\n",
        "\n",
        "        end_time = time.time()\n",
        "        for ite, (input, label, indices) in enumerate(poisoned_loader):\n",
        "            for i, index in enumerate(indices):\n",
        "                if int(index) in poison_dict:\n",
        "                  input[i] = poison_tuple_list[poison_dict[int(index)]][0]\n",
        "\n",
        "            for i, index in enumerate(indices):\n",
        "                if int(index) in camou_dict:\n",
        "                  input[i] = camou_tuple_list[camou_dict[int(index)]][0]\n",
        "\n",
        "            input, label = input.to(device), label.to(device)\n",
        "\n",
        "            with torch.no_grad():\n",
        "              feat = net.penultimate(input).detach()\n",
        "            output = net.linear(feat)\n",
        "            loss = criterion(output, label)\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            prec1 = accuracy(output, label)[0]\n",
        "\n",
        "            time_meter.update(time.time() - end_time)\n",
        "            end_time = time.time()\n",
        "            loss_meter.update(loss.item(), input.size(0))\n",
        "            acc_meter.update(prec1.item(), input.size(0))\n",
        "\n",
        "            if (epoch % 60 == 0 or epoch == 60 - 1)  and (ite == len(poisoned_loader) - 1):\n",
        "                print(\"{2}, Epoch {0}, Iteration {1}, loss {loss.val:.3f} ({loss.avg:.3f}), \"\n",
        "                      \"acc {acc.val:.3f} ({acc.avg:.3f})\".\n",
        "                      format(epoch, ite, time.strftime(\"%Y-%m-%d %H:%M:%S\"),\n",
        "                             loss=loss_meter, acc=acc_meter))\n",
        "            sys.stdout.flush()\n",
        "\n",
        "        if epoch == 60 - 1:\n",
        "            net.eval()\n",
        "            # print the scores for target and base\n",
        "            if device == 'cuda':\n",
        "                target_pred = net(target_img.to(device))\n",
        "            else:\n",
        "                target_pred = net(target_img)\n",
        "            target_scores = [float(n) for n in list(softmax(target_pred.view(-1).cpu().detach().numpy()))]\n",
        "            score, target_pred = target_pred.topk(1, 1, True, True)\n",
        "            poison_pred_list = []\n",
        "            for poison_img, _ in poison_tuple_list:\n",
        "                base_scores = net(poison_img[None, :, :, :].to(device))\n",
        "                base_score, base_pred = base_scores.topk(1, 1, True, True)\n",
        "                poison_pred_list.append(base_pred.item())\n",
        "            print(\n",
        "                \"Target Label: {}, Poison label: {}, Prediction:{}, Target's Score:{}, Poisons' Predictions:{}\".format(\n",
        "                    target_class, poison_class, target_pred[0][0].item(), target_scores,\n",
        "                    poison_pred_list))\n",
        "\n",
        "    # Evaluate the results on the clean test set\n",
        "    val_acc_meter = AverageMeter()\n",
        "    with torch.no_grad():\n",
        "        net.eval()\n",
        "        for ite, (input, label, index) in enumerate(test_loader):\n",
        "            input, label = input.to(device), label.to(device)\n",
        "\n",
        "            output = net(input)\n",
        "\n",
        "            prec1 = accuracy(output, label)[0]\n",
        "            val_acc_meter.update(prec1.item(), input.size(0))\n",
        "\n",
        "            if False or ite % 100 == 0 or ite == len(test_loader) - 1:\n",
        "                print(\"{2} Epoch {0}, Val iteration {1}, \"\n",
        "                      \"acc {acc.val:.3f} ({acc.avg:.3f})\".\n",
        "                      format(epoch, ite, time.strftime(\"%Y-%m-%d %H:%M:%S\"), acc=val_acc_meter))\n",
        "\n",
        "    print(\"* Prec: {}\".format(val_acc_meter.avg))\n",
        "\n",
        "    # if savemodel is not None:\n",
        "    #     torch.save(net.state_dict(), savemodel)\n",
        "\n",
        "    return {'clean acc': val_acc_meter.avg, 'prediction': target_pred[0][0].item(),\n",
        "            'poisons predictions': poison_pred_list,\n",
        "            'scores': target_scores, 'malicious score': target_scores[poison_class], 'camera': {}}\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v4WKdu3ypTbV"
      },
      "outputs": [],
      "source": [
        "# load the pre-trained models\n",
        "def run_code():\n",
        "  if True:\n",
        "    chk_path = \"./drive/MyDrive/Poisoning_Machine_Unlearning/\"\n",
        "    sub_net_list = []\n",
        "    sub_net_list.append(model)\n",
        "\n",
        "    print(\"subs nets, effective num: {}\".format(len(sub_net_list)))\n",
        "\n",
        "    print(\"Loading the victims networks\")\n",
        "    targets_net = []\n",
        "    targets_net.append(model2)\n",
        "\n",
        "    cifar_mean = (0.4914, 0.4822, 0.4465)\n",
        "    cifar_std = (0.2023, 0.1994, 0.2010)\n",
        "    transform_test = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(cifar_mean, cifar_std),\n",
        "    ])\n",
        "\n",
        "    # Get the target image\n",
        "    target = fetch_target(target_class, target_index, 50, subset=target_index,\n",
        "                          path=\"./data\", transforms=transform_test)\n",
        "\n",
        "    chk_path = os.path.join(chk_path, 'mean')\n",
        "\n",
        "    chk_path = os.path.join(chk_path, str(poison_ites))\n",
        "    chk_path = os.path.join(chk_path, str(target_index))\n",
        "    if not os.path.exists(chk_path):\n",
        "        os.makedirs(chk_path)\n",
        "    import sys\n",
        "\n",
        "    print(\"Path: {}\".format(chk_path))\n",
        "\n",
        "    # just fetch the first poison_num samples\n",
        "    base_tensor_list, camou_idx_list = fetch_poison_bases(camou_class, budget, subset=camou_index,\n",
        "                                                    path='./data', transforms_img=transform_test)\n",
        "    base_tensor_list = [bt.to('cuda') for bt in base_tensor_list]\n",
        "    print(\"Selected base image indices: {}\".format(camou_idx_list))\n",
        "\n",
        "    poison_init = base_tensor_list\n",
        "\n",
        "    import time\n",
        "\n",
        "    t = time.time()\n",
        "    camou_tuple_list = make_convex_polytope_poisons(sub_net_list, targets_net, base_tensor_list,\n",
        "                                                         target, device='cuda', opt_method=poison_opt,\n",
        "                                                         lr=poison_lr, momentum=poison_momentum,\n",
        "                                                         iterations=poison_ites, epsilon=poison_epsilon,\n",
        "                                                         decay_ites=poison_decay_ites,\n",
        "                                                         decay_ratio=poison_decay_ratio,\n",
        "                                                         mean=torch.Tensor(cifar_mean).reshape(1, 3, 1, 1),\n",
        "                                                         std=torch.Tensor(cifar_std).reshape(1, 3, 1, 1),\n",
        "                                                         chk_path=chk_path, poison_idxes=camou_idx_list,\n",
        "                                                         poison_label=camou_class,\n",
        "                                                         tol=tol,\n",
        "                                                         end2end=False,\n",
        "                                                         start_ite=0,\n",
        "                                                         poison_init=poison_init,\n",
        "                                                         mode='convex')\n",
        "    tt = time.time()\n",
        "    res = []\n",
        "    print(\"Evaluating against victims networks\")\n",
        "    camou_dict = {}\n",
        "    for i, index in enumerate(camou_idx_list):\n",
        "      camou_dict[index] = i\n",
        "\n",
        "    for tnet, tnet_name in zip(targets_net, targets_net):\n",
        "        print(tnet_name)\n",
        "        pred = train_network_with_poison_camou(tnet, target, poison_tuple_list, camou_tuple_list, trainset, camou_idx_list, base_idx_list, testset, poison_dict, camou_dict)\n",
        "        res.append(pred)\n",
        "        print(\"--------\")\n",
        "\n",
        "    print(\"------SUMMARY------\")\n",
        "    print(\"TIME ELAPSED (mins): {}\".format(int((tt - t) / 60)))\n",
        "    print(\"TARGET INDEX: {}\".format(target_index))\n",
        "    for tnet_name, r in zip(targets_net, res):\n",
        "        print(tnet_name, int(r == poison_class))\n",
        "\n",
        "    return camou_tuple_list, camou_idx_list, camou_dict"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_Q4N7d2HiaaI",
        "outputId": "fc868630-81ac-4691-f962-1b449538d3d3"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[3950]"
            ]
          },
          "metadata": {},
          "execution_count": 48
        }
      ],
      "source": [
        "target_index"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "g3L9H4YlXT4X",
        "outputId": "6895fd78-989b-48dc-acfe-85a618611336"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "subs nets, effective num: 1\n",
            "Loading the victims networks\n",
            "fetch target\n",
            "Path: ./drive/MyDrive/Poisoning_Machine_Unlearning/mean/1000/[3950]\n",
            "fetch base\n",
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n",
            "Selected base image indices: [6762, 39385, 40547, 47871, 48457]\n",
            "tensor([[0.3844, 0.1127, 0.5763, 1.3533, 0.0036, 0.1609, 0.5820, 0.1210, 0.4921,\n",
            "         0.3321, 1.0170, 0.0647, 0.2350, 0.8950, 0.3634, 0.2452, 0.2774, 0.2852,\n",
            "         1.4937, 0.1139, 0.2353, 0.5114, 0.1976, 0.8876, 0.5219, 0.2104, 1.4093,\n",
            "         1.4835, 0.4308, 0.6945, 0.4287, 0.1161, 0.1083, 0.1979, 0.1586, 0.6016,\n",
            "         0.2713, 0.5853, 0.0220, 1.5277, 1.1277, 0.1875, 1.1149, 1.9977, 1.0559,\n",
            "         0.8771, 0.1981, 0.2169, 0.3167, 0.3267, 0.8558, 1.6952, 0.1839, 0.1262,\n",
            "         0.3966, 0.1520, 0.6357, 1.2591, 0.0353, 0.1083, 0.2240, 0.0059, 0.0916,\n",
            "         0.4562, 0.2320, 0.1372, 0.2587, 0.0965, 1.3439, 0.0255, 0.6766, 0.3003,\n",
            "         0.0587, 1.6923, 0.3020, 0.3641, 0.1072, 1.2275, 0.0688, 1.1079, 0.3375,\n",
            "         0.3137, 0.4425, 0.7269, 0.1517, 1.3284, 0.2470, 0.0557, 1.4235, 0.7662,\n",
            "         1.3476, 0.1525, 0.7496, 0.3293, 0.3210, 0.3246, 0.0938, 0.7569, 1.5743,\n",
            "         0.0872, 0.2159, 1.7689, 0.7822, 0.4410, 0.2527, 0.1236, 0.7834, 0.1170,\n",
            "         0.2729, 1.2302, 0.4088, 0.0419, 0.8573, 0.2758, 0.9109, 0.3199, 0.2043,\n",
            "         1.5084, 1.7806, 0.0636, 0.1450, 0.6858, 0.0940, 0.7381, 0.1296, 0.5399,\n",
            "         0.4302, 0.0184, 1.8013, 0.1026, 0.2996, 0.6068, 1.4616, 0.3113, 1.4782,\n",
            "         0.7001, 1.4765, 0.4852, 0.3341, 0.0881, 0.1507, 1.9523, 0.0337, 1.6032,\n",
            "         0.1217, 1.0276, 1.4730, 0.3229, 0.0386, 1.0912, 0.8368, 0.1011, 0.6848,\n",
            "         0.0736, 0.6756, 0.1452, 0.4664, 1.4205, 0.4906, 0.0931, 0.2128, 0.2588,\n",
            "         0.2777, 0.2198, 0.9161, 0.1378, 0.2759, 0.2680, 1.1098, 0.0744, 0.0967,\n",
            "         1.7433, 1.4113, 0.2885, 0.9029, 0.1750, 1.3663, 0.6250, 1.2069, 0.3388,\n",
            "         0.1214, 0.3127, 0.9003, 0.4768, 0.0596, 1.3117, 0.6678, 0.3273, 0.2983,\n",
            "         0.0446, 0.2931, 0.3743, 1.5208, 0.1237, 0.5532, 0.4221, 0.2359, 0.0728,\n",
            "         1.0576, 0.5392, 0.6460, 1.4236, 0.2515, 0.4822, 0.2666, 0.5525, 0.3088,\n",
            "         0.7179, 0.2591, 0.8913, 0.4484, 0.8781, 1.5848, 1.0184, 0.3969, 1.6154,\n",
            "         1.1198, 1.0757, 0.5099, 0.0915, 0.0115, 1.6928, 0.0253, 0.4320, 1.1229,\n",
            "         1.7738, 0.9744, 0.1421, 0.0369, 0.2925, 1.3270, 0.2571, 0.7265, 0.3258,\n",
            "         0.3360, 0.2128, 0.7876, 0.0000, 0.0635, 1.6180, 0.5239, 1.7046, 0.2300,\n",
            "         0.1855, 1.5250, 0.3636, 0.8122, 0.9562, 0.2267, 0.1697, 2.3337, 0.3712,\n",
            "         0.2699, 0.3345, 1.0885, 0.8689, 0.0768, 0.5751, 1.5274, 0.0687, 0.9432,\n",
            "         0.0045, 0.3379, 0.3602, 0.8054, 0.2238, 0.3054, 0.2552, 0.5814, 1.4262,\n",
            "         0.5904, 0.0053, 0.6948, 0.0140, 0.1420, 0.9909, 0.8311, 1.0012, 0.4388,\n",
            "         0.1764, 1.2875, 0.4531, 1.4724, 0.0676, 0.2518, 0.3533, 1.2801, 1.5236,\n",
            "         0.3248, 0.8515, 0.2757, 0.2388, 0.4219, 0.5242, 0.5373, 0.2659, 0.9848,\n",
            "         0.5720, 0.2823, 0.2980, 0.4439, 0.3783, 0.6486, 0.5868, 0.3444, 1.1831,\n",
            "         0.1377, 0.7241, 0.2469, 0.0774, 0.5377, 0.2932, 0.6675, 0.1927, 0.2524,\n",
            "         0.3661, 0.7988, 0.3930, 0.0942, 0.4064, 1.3398, 0.1719, 0.0824, 1.0881,\n",
            "         0.1875, 0.0773, 0.7188, 0.7633, 0.3649, 0.2241, 1.2658, 1.2805, 0.3134,\n",
            "         0.3252, 0.1160, 0.6568, 1.7791, 0.5187, 0.0526, 0.3724, 0.2316, 2.2749,\n",
            "         0.7802, 1.6159, 0.3474, 0.2068, 0.1791, 0.1827, 0.0728, 1.1373, 0.0910,\n",
            "         0.8828, 0.8075, 1.6378, 1.3705, 0.0895, 0.1654, 0.0466, 1.0647, 0.2636,\n",
            "         0.0565, 0.2196, 0.5440, 1.5431, 0.7056, 0.9628, 1.3227, 1.3508, 0.1878,\n",
            "         1.1227, 0.8891, 0.1055, 1.6481, 0.4395, 1.0931, 0.5949, 0.2429, 0.3685,\n",
            "         1.7689, 0.7593, 1.5252, 1.1900, 1.7352, 1.6100, 0.0930, 0.3473, 0.2850,\n",
            "         0.2441, 0.4510, 0.3691, 0.0696, 0.0634, 0.7810, 0.6795, 0.3007, 0.1041,\n",
            "         0.4640, 0.7600, 1.7742, 0.2145, 0.4267, 0.1182, 0.4055, 0.4505, 0.1665,\n",
            "         1.0697, 0.7978, 0.1105, 0.6780, 0.1727, 1.0777, 0.9981, 0.2871, 0.6724,\n",
            "         0.1861, 0.3469, 0.1708, 1.3020, 1.6102, 0.2128, 0.3826, 0.2603, 1.2062,\n",
            "         0.4469, 0.1529, 0.9250, 0.9986, 0.1247, 1.2632, 1.4332, 1.9175, 0.9194,\n",
            "         1.3487, 0.7084, 1.2917, 0.8800, 0.0236, 2.2883, 0.7652, 1.0059, 0.3301,\n",
            "         1.5774, 1.5489, 0.5523, 0.5189, 0.6832, 0.4246, 0.4658, 0.2525, 0.1789,\n",
            "         0.1173, 0.8279, 1.1099, 0.6521, 0.0067, 0.0576, 1.3366, 0.6255, 0.1838,\n",
            "         0.0454, 0.2585, 0.0469, 0.4256, 0.9530, 0.8285, 0.7758, 0.4816, 0.3334,\n",
            "         0.5346, 1.4029, 1.4570, 0.8638, 0.9864, 1.6038, 1.7536, 0.3076, 0.9472,\n",
            "         0.0439, 0.0964, 0.6399, 0.7604, 0.1005, 0.6967, 1.2758, 1.3471, 0.2968,\n",
            "         0.1594, 0.2556, 0.1023, 0.2119, 1.0572, 0.0165, 1.4627, 0.3920, 0.7915,\n",
            "         0.4041, 1.2797, 0.2991, 0.2764, 0.2755, 1.0206, 0.1695, 0.0769, 1.3972,\n",
            "         2.1347, 0.9236, 1.4910, 0.7001, 1.8239, 0.8797, 0.9915, 1.6839]],\n",
            "       device='cuda:0')\n",
            " 2023-05-22 22:37:40 Iteration 0 \t Training Loss: 7.362e-02 \t Loss in Target Net: 2.975e-02\t  \n",
            " 2023-05-22 22:37:46 Iteration 50 \t Training Loss: 1.550e-03 \t Loss in Target Net: 1.502e-03\t  \n",
            " 2023-05-22 22:37:51 Iteration 100 \t Training Loss: 5.224e-04 \t Loss in Target Net: 5.138e-04\t  \n",
            " 2023-05-22 22:37:58 Iteration 150 \t Training Loss: 3.584e-04 \t Loss in Target Net: 3.954e-04\t  \n",
            " 2023-05-22 22:38:06 Iteration 200 \t Training Loss: 2.809e-04 \t Loss in Target Net: 3.058e-04\t  \n",
            " 2023-05-22 22:38:14 Iteration 250 \t Training Loss: 1.936e-04 \t Loss in Target Net: 1.821e-04\t  \n",
            " 2023-05-22 22:38:21 Iteration 300 \t Training Loss: 1.069e-04 \t Loss in Target Net: 1.037e-04\t  \n",
            " 2023-05-22 22:38:29 Iteration 350 \t Training Loss: 1.214e-04 \t Loss in Target Net: 1.187e-04\t  \n",
            " 2023-05-22 22:38:37 Iteration 400 \t Training Loss: 1.104e-04 \t Loss in Target Net: 1.048e-04\t  \n",
            " 2023-05-22 22:38:47 Iteration 450 \t Training Loss: 7.166e-05 \t Loss in Target Net: 8.904e-05\t  \n",
            " 2023-05-22 22:38:56 Iteration 500 \t Training Loss: 4.722e-05 \t Loss in Target Net: 4.672e-05\t  \n",
            " 2023-05-22 22:39:05 Iteration 550 \t Training Loss: 7.997e-05 \t Loss in Target Net: 7.238e-05\t  \n",
            " 2023-05-22 22:39:16 Iteration 600 \t Training Loss: 3.800e-05 \t Loss in Target Net: 3.308e-05\t  \n",
            " 2023-05-22 22:39:25 Iteration 650 \t Training Loss: 2.267e-05 \t Loss in Target Net: 2.637e-05\t  \n",
            " 2023-05-22 22:39:35 Iteration 700 \t Training Loss: 2.475e-05 \t Loss in Target Net: 5.392e-05\t  \n",
            " 2023-05-22 22:39:47 Iteration 750 \t Training Loss: 1.142e-04 \t Loss in Target Net: 1.219e-04\t  \n",
            " 2023-05-22 22:39:58 Iteration 800 \t Training Loss: 5.194e-05 \t Loss in Target Net: 4.266e-05\t  \n",
            " 2023-05-22 22:40:10 Iteration 850 \t Training Loss: 2.442e-05 \t Loss in Target Net: 2.991e-05\t  \n",
            " 2023-05-22 22:40:21 Iteration 900 \t Training Loss: 2.271e-05 \t Loss in Target Net: 3.001e-05\t  \n",
            " 2023-05-22 22:40:32 Iteration 950 \t Training Loss: 3.335e-05 \t Loss in Target Net: 2.967e-05\t  \n",
            " 2023-05-22 22:40:46 Iteration 999 \t Training Loss: 1.082e-03 \t Loss in Target Net: 1.545e-03\t  \n",
            "Evaluating against victims networks\n",
            "ResNet(\n",
            "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "  (layer1): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer2): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer3): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer4): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (linear): Linear(in_features=512, out_features=10, bias=True)\n",
            ")\n",
            "Using Adam for retraining\n",
            "2023-05-22 22:41:12, Epoch 0, Iteration 781, loss 1.509 (0.252), acc 93.750 (98.520)\n",
            "2023-05-22 23:05:48, Epoch 59, Iteration 781, loss 0.000 (0.070), acc 100.000 (99.936)\n",
            "Target Label: 8, Poison label: 5, Prediction:8, Target's Score:[0.0, 0.0, 0.0, 0.0, 0.0, 4.1245933971367776e-05, 0.0, 0.0, 0.9999587535858154, 0.0], Poisons' Predictions:[3, 8, 5, 5, 8]\n",
            "2023-05-22 23:05:48 Epoch 59, Val iteration 0, acc 88.800 (88.800)\n",
            "2023-05-22 23:05:53 Epoch 59, Val iteration 19, acc 86.400 (87.120)\n",
            "* Prec: 87.12000198364258\n",
            "--------\n",
            "------SUMMARY------\n",
            "TIME ELAPSED (mins): 3\n",
            "TARGET INDEX: [3950]\n",
            "ResNet(\n",
            "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "  (layer1): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer2): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer3): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (layer4): Sequential(\n",
            "    (0): BasicBlock(\n",
            "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential(\n",
            "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
            "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      )\n",
            "    )\n",
            "    (1): BasicBlock(\n",
            "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
            "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
            "      (shortcut): Sequential()\n",
            "    )\n",
            "  )\n",
            "  (linear): Linear(in_features=512, out_features=10, bias=True)\n",
            ") 0\n"
          ]
        }
      ],
      "source": [
        "camou_tuple_list, camou_idx_list, camou_dict = run_code()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Results and Seeds"
      ],
      "metadata": {
        "id": "WbXcKvtteaza"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zWfIaCLN5PDW"
      },
      "outputs": [],
      "source": [
        "# (x, y)\n",
        "# x = 1 if poison worked, 0 otherwise\n",
        "# y = 1 if camouflage worked, 0 otherwise\n",
        "# Clean model accuracy, poisoned accuracy, camouflage+poison accuracy"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# seed = 200000000, 5 poisons, 5 camou, target class = 0,  poison_class = 8 (1, 1)\n",
        "# 0.8768, 88.400, 83.83"
      ],
      "metadata": {
        "id": "X3gD1bMpfIpd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4wDLIZV6daoH"
      },
      "outputs": [],
      "source": [
        "# seed = 200000001, 5 poisons, 5 camou, target class = 0,  poison_class = 6 (1, 1)\n",
        "#0.8768, 90.800, 91.400"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HzMUPKARy8lM"
      },
      "outputs": [],
      "source": [
        "# seed = 200000011, 5 poisons, 5 camou, target class = 2,  poison_class = 5 (1, 1)\n",
        "#0.8758, 0.8759, 0.8732"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NWLXVyYYdlb_"
      },
      "outputs": [],
      "source": [
        "# seed = 200000111, 5 poisons, 5 camou, target class = 0,  poison_class = 9 (1, 0)\n",
        "#0.8758, 0.8759, 0.8732"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vJuCG9bIdjGc"
      },
      "outputs": [],
      "source": [
        "# seed = 200001111, 5 poisons, 5 camou, target class = 1,  poison_class = 2 (1, 0)\n",
        "# 0.8758, 0.8759, 0.8800"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nMSXVq-2udvj"
      },
      "outputs": [],
      "source": [
        "# seed = 200011111 (1, 0)d\n",
        "#0.8695"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lHy0gWzz3pEW"
      },
      "outputs": [],
      "source": [
        "# seed = 200111111 (1, 1), 5 poisons, 5 camou, target class = 8,  poison_class = 5 (1, 1)\n",
        "#0.8728, 0.868, 0.8707"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xEKouPTjw76l"
      },
      "outputs": [],
      "source": [
        "# seed = 201111111 (1, 0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W5zacyGCNG3I"
      },
      "outputs": [],
      "source": [
        "# seed = 211111111 (1, 1), 5 poisons, 5 camou, target class = 8,\n",
        "# 0.870, 0.8715, 0.8712"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ItS7l0AmJBpZ"
      },
      "outputs": [],
      "source": [
        "# seed  = 221111111, 5 poisons, 5 camou, target class = 8,  poison_class = 5 (1, 1)\n",
        "# 0.8746, 0.8743, 0.8703"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EW1FXI29Jw3y"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}