{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4cf95ee3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Optimizer\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torchvision\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import os\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import hypergrad as hg\n",
    "from itertools import repeat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c605af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from poi_util import poison_dataset,patching_test\n",
    "import poi_util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "76241913",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "def get_results(model, criterion, data_loader, device):\n",
    "    model.eval()\n",
    "    val_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in enumerate(data_loader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, targets.long())\n",
    "\n",
    "            val_loss += loss.item()\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "        return correct / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ca879abd",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = {'small_VGG16': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'],}\n",
    "drop_rate = [0.3,0.4,0.4]\n",
    "\n",
    "class VGG(nn.Module):\n",
    "    def __init__(self, vgg_name):\n",
    "        super(VGG, self).__init__()\n",
    "        self.features = self._make_layers(cfg[vgg_name])\n",
    "        self.classifier = nn.Linear(2048, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.features(x)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.classifier(out)\n",
    "        return out\n",
    "\n",
    "    def _make_layers(self, cfg):\n",
    "        layers = []\n",
    "        in_channels = 3\n",
    "        key = 0\n",
    "        for x in cfg:\n",
    "            if x == 'M':\n",
    "                layers += [nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "                           nn.Dropout(drop_rate[key])]\n",
    "                key += 1\n",
    "            else:\n",
    "                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),\n",
    "                           nn.BatchNorm2d(x),\n",
    "                           nn.ELU(inplace=True)]\n",
    "                in_channels = x\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "model = VGG('small_VGG16').to(device)\n",
    "outer_opt = torch.optim.Adam(params=model.parameters())\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e696263a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==> Preparing data..\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "print('==> Preparing data..')\n",
    "from torchvision.datasets import CIFAR10\n",
    "root = './datasets'\n",
    "testset = CIFAR10(root, train=False, transform=None, download=True)\n",
    "x_test, y_test = testset.data, testset.targets\n",
    "x_test = x_test.astype('float32')/255\n",
    "y_test = np.asarray(y_test)\n",
    "\n",
    "attack_name = 'badnets'\n",
    "target_lab = '8'\n",
    "x_poi_test,y_poi_test= patching_test(x_test, y_test, attack_name, target_lab=target_lab)\n",
    "\n",
    "y_test = torch.Tensor(y_test.reshape((-1,)).astype(np.int))\n",
    "y_poi_test = torch.Tensor(y_poi_test.reshape((-1,)).astype(np.int))\n",
    "\n",
    "x_test = torch.Tensor(np.transpose(x_test,(0,3,1,2)))\n",
    "x_poi_test = torch.Tensor(np.transpose(x_poi_test,(0,3,1,2)))\n",
    "\n",
    "test_set = TensorDataset(x_test[5000:],y_test[5000:])\n",
    "unl_set = TensorDataset(x_test[:5000],y_test[:5000])\n",
    "att_val_set = TensorDataset(x_poi_test[:5000],y_poi_test[:5000])\n",
    "\n",
    "#data loader for verifying the clean test accuracy\n",
    "clnloader = torch.utils.data.DataLoader(\n",
    "    test_set, batch_size=200, shuffle=False, num_workers=2)\n",
    "\n",
    "#data loader for verifying the attack success rate\n",
    "poiloader_cln = torch.utils.data.DataLoader(\n",
    "    unl_set, batch_size=200, shuffle=False, num_workers=2)\n",
    "\n",
    "poiloader = torch.utils.data.DataLoader(\n",
    "    att_val_set, batch_size=200, shuffle=False, num_workers=2)\n",
    "\n",
    "#data loader for the unlearning step\n",
    "unlloader = torch.utils.data.DataLoader(\n",
    "    unl_set, batch_size=100, shuffle=False, num_workers=2)\n",
    "\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat', 'deer',\n",
    "           'dog', 'frog', 'horse', 'ship', 'truck')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c9cd26d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#define the inner loss L2\n",
    "def loss_inner(perturb, model_params):\n",
    "    images = images_list[0].cuda()\n",
    "    labels = labels_list[0].long().cuda()\n",
    "    per_img = torch.clamp(images+perturb[0],min=0,max=1)\n",
    "    per_logits = model.forward(per_img)\n",
    "    loss = F.cross_entropy(per_logits, labels, reduction='none')\n",
    "    loss_regu = torch.mean(-loss) +0.001*torch.pow(torch.norm(perturb[0]),2)\n",
    "    return loss_regu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "86863669",
   "metadata": {},
   "outputs": [],
   "source": [
    "#define the outer loss L1\n",
    "def loss_outer(perturb, model_params):\n",
    "    portion = 0.01\n",
    "    images, labels = images_list[batchnum].cuda(), labels_list[batchnum].long().cuda()\n",
    "    patching = torch.zeros_like(images, device='cuda')\n",
    "    number = images.shape[0]\n",
    "    rand_idx = random.sample(list(np.arange(number)),int(number*portion))\n",
    "    patching[rand_idx] = perturb[0]\n",
    "    unlearn_imgs = torch.clamp(images+patching,min=0,max=1)\n",
    "    logits = model(unlearn_imgs)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    loss = criterion(logits, labels)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ea973174",
   "metadata": {},
   "outputs": [],
   "source": [
    "images_list, labels_list = [], []\n",
    "for index, (images, labels) in enumerate(unlloader):\n",
    "    images_list.append(images)\n",
    "    labels_list.append(labels)\n",
    "inner_opt = hg.GradientDescent(loss_inner, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e58ab57c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b106445b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = './checkpoint/'+attack_name+'_'+str(target_lab)+'_02_'+'ckpt.pth'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "53b0d206",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# initialize theta\n",
    "model = VGG('small_VGG16').to(device)\n",
    "outer_opt = torch.optim.Adam(params=model.parameters())\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "model.load_state_dict(torch.load(model_path)['net'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4cf11b11",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original ACC: 0.8494\n",
      "Original ASR: 0.9764\n"
     ]
    }
   ],
   "source": [
    "ACC = get_results(model, criterion, clnloader, device)\n",
    "ASR = get_results(model, criterion, poiloader, device)\n",
    "\n",
    "print('Original ACC:', ACC)\n",
    "print('Original ASR:', ASR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c4f96e2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3988c6cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Conducting Defence\n"
     ]
    }
   ],
   "source": [
    "#inner loop and optimization by batch computing\n",
    "import tqdm\n",
    "print(\"Conducting Defence\")\n",
    "\n",
    "model = VGG('small_VGG16').to(device)\n",
    "outer_opt = torch.optim.Adam(params=model.parameters())\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "model.load_state_dict(torch.load(model_path)['net'])\n",
    "model.eval()\n",
    "ASR_list = [get_results(model, criterion, poiloader, device)]\n",
    "ACC_list = [get_results(model, criterion, clnloader, device)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "46e055b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Round: 0\n",
      "ASR: 0.1232\n",
      "ACC: 0.8108\n"
     ]
    }
   ],
   "source": [
    "for round in range(1):\n",
    "    batch_pert = torch.zeros_like(x_test[:1], requires_grad=True, device='cuda')\n",
    "    batch_opt = torch.optim.SGD(params=[batch_pert],lr=10)\n",
    "#     batch_opt = torch.optim.Adam(params=[batch_pert],lr=10) #you can fine-tune this\n",
    "    \n",
    "    for images, labels in poiloader:\n",
    "        images = images.to(device)\n",
    "        ori_lab = torch.argmax(model.forward(images),axis = 1).long()\n",
    "        per_logits = model.forward(torch.clamp(images+batch_pert,min=0,max=1))\n",
    "        loss = F.cross_entropy(per_logits, ori_lab, reduction='mean')\n",
    "        loss_regu = torch.mean(-loss) +0.001*torch.pow(torch.norm(batch_pert),2)\n",
    "        batch_opt.zero_grad()\n",
    "        loss_regu.backward(retain_graph = True)\n",
    "        batch_opt.step()\n",
    "\n",
    "    #l2-ball\n",
    "    pert = batch_pert * min(1, 10 / torch.norm(batch_pert))\n",
    "\n",
    "    #unlearn step         \n",
    "    for batchnum in range(len(images_list)):\n",
    "        outer_opt.zero_grad()\n",
    "        hg.fixed_point(pert, list(model.parameters()), 5, inner_opt, loss_outer)\n",
    "        outer_opt.step()\n",
    "\n",
    "    ASR_list.append(get_results(model,criterion,poiloader,device))\n",
    "    ACC_list.append(get_results(model,criterion,clnloader,device))\n",
    "    print('Round:',round)\n",
    "    print('ASR:',get_results(model,criterion,poiloader,device))\n",
    "    print('ACC:',get_results(model,criterion,clnloader,device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "616de939",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
