{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "852e9274",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Optimizer\n",
    "import torch.backends.cudnn as cudnn\n",
    "import tqdm\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import TensorDataset, DataLoader,Subset\n",
    "import torchvision.models as models\n",
    "import torch.nn.functional as F\n",
    "from models import *\n",
    "\n",
    "import os\n",
    "import copy\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import cv2 as cv\n",
    "from util import *\n",
    "\n",
    "random_seed = 0\n",
    "np.random.seed(random_seed)\n",
    "random.seed(random_seed)\n",
    "torch.manual_seed(random_seed)\n",
    "\n",
    "torch.cuda.set_device(0)\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "923b33b2",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "'''\n",
    "The path for target dataset and public out-of-distribution (POOD) dataset. The setting used \n",
    "here is CIFAR-10 as the target dataset and Tiny-ImageNet as the POOD dataset. Their directory\n",
    "structure is as follows:\n",
    "\n",
    "dataset_path--cifar-10-batches-py\n",
    "            |\n",
    "            |-tiny-imagenet-200\n",
    "'''\n",
    "dataset_path = '/home/minzhou/data/'\n",
    "\n",
    "#The target class label\n",
    "lab = 2\n",
    "\n",
    "#Noise size, default is full image size\n",
    "noise_size = 32\n",
    "\n",
    "#Radius of the L-inf ball\n",
    "l_inf_r = 16/255\n",
    "\n",
    "#Model for generating surrogate model and trigger\n",
    "surrogate_model = ResNet18_201().cuda()\n",
    "generating_model = ResNet18_201().cuda()\n",
    "\n",
    "#Surrogate model training epochs\n",
    "surrogate_epochs = 200\n",
    "\n",
    "#Learning rate for poison-warm-up\n",
    "generating_lr_warmup = 0.1\n",
    "warmup_round = 5\n",
    "\n",
    "#Learning rate for trigger generating\n",
    "generating_lr_tri = 0.01      \n",
    "gen_round = 1000\n",
    "\n",
    "#Training batch size\n",
    "train_batch_size = 350\n",
    "\n",
    "#The model for adding the noise\n",
    "patch_mode = 'add'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4adeef42",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d533122a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#The argumention use for surrogate model training stage\n",
    "transform_surrogate_train = transforms.Compose([\n",
    "    transforms.Resize(32),\n",
    "    transforms.RandomCrop(32, padding=4),  \n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])\n",
    "\n",
    "#The argumention use for all training set\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),  \n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])\n",
    "\n",
    "#The argumention use for all testing set\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a3e2df",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "ori_train = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=False, transform=transform_train)\n",
    "ori_test = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=False, transform=transform_test)\n",
    "outter_trainset = torchvision.datasets.ImageFolder(root=dataset_path + 'tiny-imagenet-200/train/', transform=transform_surrogate_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48a8f73b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Outter train dataset\n",
    "train_label = [get_labels(ori_train)[x] for x in range(len(get_labels(ori_train)))]\n",
    "test_label = [get_labels(ori_test)[x] for x in range(len(get_labels(ori_test)))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8760de53",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Inner train dataset\n",
    "train_target_list = list(np.where(np.array(train_label)==lab)[0])\n",
    "train_target = Subset(ori_train,train_target_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9201691b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "concoct_train_dataset = concoct_dataset(train_target,outter_trainset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d5d50c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "surrogate_loader = torch.utils.data.DataLoader(concoct_train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=16)\n",
    "\n",
    "poi_warm_up_loader = torch.utils.data.DataLoader(train_target, batch_size=train_batch_size, shuffle=True, num_workers=16)\n",
    "\n",
    "trigger_gen_loaders = torch.utils.data.DataLoader(train_target, batch_size=train_batch_size, shuffle=True, num_workers=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb29f99b",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#  Training surrogate modle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b9ce8c1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Batch_grad\n",
    "condition = True\n",
    "noise = torch.zeros((1, 3, noise_size, noise_size), device=device)\n",
    "\n",
    "\n",
    "surrogate_model = surrogate_model\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "# outer_opt = torch.optim.RAdam(params=base_model.parameters(), lr=generating_lr_outer)\n",
    "surrogate_opt = torch.optim.SGD(params=surrogate_model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n",
    "surrogate_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(surrogate_opt, T_max=surrogate_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe7d6ffe",
   "metadata": {
    "scrolled": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Training the surrogate model\n",
    "print('Training the surrogate model')\n",
    "for epoch in range(0, surrogate_epochs):\n",
    "    surrogate_model.train()\n",
    "    loss_list = []\n",
    "    for images, labels in surrogate_loader:\n",
    "        images, labels = images.cuda(), labels.cuda()\n",
    "        surrogate_opt.zero_grad()\n",
    "        outputs = surrogate_model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        loss_list.append(float(loss.data))\n",
    "        surrogate_opt.step()\n",
    "    surrogate_scheduler.step()\n",
    "    ave_loss = np.average(np.array(loss_list))\n",
    "    print('Epoch:%d, Loss: %.03f' % (epoch, ave_loss))\n",
    "#Save the surrogate model\n",
    "save_path = './checkpoint/surrogate_pretrain_' + str(surrogate_epochs) +'.pth'\n",
    "torch.save(surrogate_model.state_dict(),save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cff016a4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Poison warm up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2d6be24",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Prepare models and optimizers for poi_warm_up training\n",
    "poi_warm_up_model = generating_model\n",
    "poi_warm_up_model.load_state_dict(surrogate_model.state_dict())\n",
    "\n",
    "poi_warm_up_opt = torch.optim.RAdam(params=poi_warm_up_model.parameters(), lr=generating_lr_warmup)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5174da4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Poi_warm_up stage\n",
    "poi_warm_up_model.train()\n",
    "for param in poi_warm_up_model.parameters():\n",
    "    param.requires_grad = True\n",
    "\n",
    "#Training the surrogate model\n",
    "for epoch in range(0, warmup_round):\n",
    "    poi_warm_up_model.train()\n",
    "    loss_list = []\n",
    "    for images, labels in poi_warm_up_loader:\n",
    "        images, labels = images.cuda(), labels.cuda()\n",
    "        poi_warm_up_model.zero_grad()\n",
    "        poi_warm_up_opt.zero_grad()\n",
    "        outputs = poi_warm_up_model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward(retain_graph = True)\n",
    "        loss_list.append(float(loss.data))\n",
    "        poi_warm_up_opt.step()\n",
    "    ave_loss = np.average(np.array(loss_list))\n",
    "    print('Epoch:%d, Loss: %e' % (epoch, ave_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a27fab21",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Trigger generating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b156e248",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Trigger generating stage\n",
    "for param in poi_warm_up_model.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "batch_pert = torch.autograd.Variable(noise.cuda(), requires_grad=True)\n",
    "batch_opt = torch.optim.RAdam(params=[batch_pert],lr=generating_lr_tri)\n",
    "for minmin in tqdm.notebook.tqdm(range(gen_round)):\n",
    "    loss_list = []\n",
    "    for images, labels in trigger_gen_loaders:\n",
    "        images, labels = images.cuda(), labels.cuda()\n",
    "        new_images = torch.clone(images)\n",
    "        clamp_batch_pert = torch.clamp(batch_pert,-l_inf_r*2,l_inf_r*2)\n",
    "        new_images = torch.clamp(apply_noise_patch(clamp_batch_pert,new_images.clone(),mode=patch_mode),-1,1)\n",
    "        per_logits = poi_warm_up_model.forward(new_images)\n",
    "        loss = criterion(per_logits, labels)\n",
    "        loss_regu = torch.mean(loss)\n",
    "        batch_opt.zero_grad()\n",
    "        loss_list.append(float(loss_regu.data))\n",
    "        loss_regu.backward(retain_graph = True)\n",
    "        batch_opt.step()\n",
    "    ave_loss = np.average(np.array(loss_list))\n",
    "    ave_grad = np.sum(abs(batch_pert.grad).detach().cpu().numpy())\n",
    "    print('Gradient:',ave_grad,'Loss:', ave_loss)\n",
    "    if ave_grad == 0:\n",
    "        break\n",
    "\n",
    "noise = torch.clamp(batch_pert,-l_inf_r*2,l_inf_r*2)\n",
    "best_noise = noise.clone().detach().cpu()\n",
    "plt.imshow(np.transpose(noise[0].detach().cpu(),(1,2,0)))\n",
    "plt.show()\n",
    "print('Noise max val:',noise.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4065f896",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Save the trigger\n",
    "import time\n",
    "save_name = './checkpoint/best_noise'+'_'+ time.strftime(\"%m-%d-%H_%M_%S\",time.localtime(time.time())) \n",
    "np.save(save_name, best_noise)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b12b8499",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Testing  attack effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9902b993",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Using this block if you only want to test the attack result.\n",
    "# import imageio\n",
    "# import cv2 as cv\n",
    "# best_noise = torch.zeros((1, 3, noise_size, noise_size), device=device)\n",
    "# noise_npy = np.load('./checkpoint/resnet18_trigger.npy')\n",
    "# best_noise = torch.from_numpy(noise_npy).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ab66953",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Poisoning amount use for the target class\n",
    "poison_amount = 25\n",
    "\n",
    "#Model uses for testing\n",
    "noise_testing_model = ResNet18().cuda()    \n",
    "\n",
    "#Training parameters\n",
    "training_epochs = 200\n",
    "training_lr = 0.1\n",
    "test_batch_size = 150\n",
    "\n",
    "#The multiple of noise amplification during testing\n",
    "multi_test = 3\n",
    "\n",
    "#random seed for testing stage\n",
    "random_seed = 65"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e2fd55",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "np.random.seed(random_seed)\n",
    "random.seed(random_seed)\n",
    "torch.manual_seed(random_seed)\n",
    "model = noise_testing_model\n",
    "\n",
    "optimizer = torch.optim.SGD(params=model.parameters(), lr=training_lr, momentum=0.9, weight_decay=5e-4)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4075516",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "transform_tensor = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])\n",
    "poi_ori_train = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=False, transform=transform_tensor)\n",
    "poi_ori_test = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=False, transform=transform_tensor)\n",
    "transform_after_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),  \n",
    "    transforms.RandomHorizontalFlip(),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "205f61b5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Poison traing\n",
    "random_poison_idx = random.sample(train_target_list, poison_amount)\n",
    "poison_train_target = poison_image(poi_ori_train,random_poison_idx,best_noise.cpu(),transform_after_train)\n",
    "print('Traing dataset size is:',len(poison_train_target),\" Poison numbers is:\",len(random_poison_idx))\n",
    "clean_train_loader = DataLoader(poison_train_target, batch_size=test_batch_size, shuffle=True, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b27b061",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Attack success rate testing\n",
    "test_non_target = list(np.where(np.array(test_label)!=lab)[0])\n",
    "test_non_target_change_image_label = poison_image_label(poi_ori_test,test_non_target,best_noise.cpu()*multi_test,lab,None)\n",
    "asr_loaders = torch.utils.data.DataLoader(test_non_target_change_image_label, batch_size=test_batch_size, shuffle=True, num_workers=2)\n",
    "print('Poison test dataset size is:',len(test_non_target_change_image_label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f60a531",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Clean acc test dataset\n",
    "clean_test_loader = torch.utils.data.DataLoader(ori_test, batch_size=test_batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e49ab88e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Target clean test dataset\n",
    "test_target = list(np.where(np.array(test_label)==lab)[0])\n",
    "target_test_set = Subset(ori_test,test_target)\n",
    "target_test_loader = torch.utils.data.DataLoader(target_test_set, batch_size=test_batch_size, shuffle=True, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6760be1a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from util import AverageMeter\n",
    "train_ACC = []\n",
    "test_ACC = []\n",
    "clean_ACC = []\n",
    "target_ACC = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "722d3cd3",
   "metadata": {
    "scrolled": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "for epoch in tqdm.notebook.tqdm(range(training_epochs)):\n",
    "    # Train\n",
    "    model.train()\n",
    "    acc_meter = AverageMeter()\n",
    "    loss_meter = AverageMeter()\n",
    "    pbar = tqdm.notebook.tqdm(clean_train_loader, total=len(clean_train_loader))\n",
    "    for images, labels in pbar:\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "        model.zero_grad()\n",
    "        optimizer.zero_grad()\n",
    "        logits = model(images)\n",
    "        loss = criterion(logits, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        _, predicted = torch.max(logits.data, 1)\n",
    "        acc = (predicted == labels).sum().item()/labels.size(0)\n",
    "        acc_meter.update(acc)\n",
    "        loss_meter.update(loss.item())\n",
    "        pbar.set_description(\"Acc %.2f Loss: %.2f\" % (acc_meter.avg*100, loss_meter.avg))\n",
    "    train_ACC.append(acc_meter.avg)\n",
    "    print('Train_loss:',loss)\n",
    "    scheduler.step()\n",
    "    \n",
    "    # Testing attack effect\n",
    "    model.eval()\n",
    "    correct, total = 0, 0\n",
    "    for i, (images, labels) in enumerate(asr_loaders):\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "        with torch.no_grad():\n",
    "            logits = model(images)\n",
    "            out_loss = criterion(logits,labels)\n",
    "            _, predicted = torch.max(logits.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    acc = correct / total\n",
    "    test_ACC.append(acc)\n",
    "    print('\\nAttack success rate %.2f' % (acc*100))\n",
    "    print('Test_loss:',out_loss)\n",
    "    \n",
    "    correct_clean, total_clean = 0, 0\n",
    "    for i, (images, labels) in enumerate(clean_test_loader):\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "        with torch.no_grad():\n",
    "            logits = model(images)\n",
    "            out_loss = criterion(logits,labels)\n",
    "            _, predicted = torch.max(logits.data, 1)\n",
    "            total_clean += labels.size(0)\n",
    "            correct_clean += (predicted == labels).sum().item()\n",
    "    acc_clean = correct_clean / total_clean\n",
    "    clean_ACC.append(acc_clean)\n",
    "    print('\\nTest clean Accuracy %.2f' % (acc_clean*100))\n",
    "    print('Test_loss:',out_loss)\n",
    "    \n",
    "    correct_tar, total_tar = 0, 0\n",
    "    for i, (images, labels) in enumerate(target_test_loader):\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "        with torch.no_grad():\n",
    "            logits = model(images)\n",
    "            out_loss = criterion(logits,labels)\n",
    "            _, predicted = torch.max(logits.data, 1)\n",
    "            total_tar += labels.size(0)\n",
    "            correct_tar += (predicted == labels).sum().item()\n",
    "    acc_tar = correct_tar / total_tar\n",
    "    target_ACC.append(acc_tar)\n",
    "    print('\\nTarget test clean Accuracy %.2f' % (acc_tar*100))\n",
    "    print('Test_loss:',out_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa785df0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#ours -- higher_configureations\n",
    "from matplotlib import pyplot as plt\n",
    "half = np.arange(0,training_epochs)\n",
    "plt.figure(figsize=(12.5,8))\n",
    "plt.plot(half, np.asarray(train_ACC)[half], label='Training ACC', linestyle=\"-.\", marker=\"o\", linewidth=3.0, markersize = 8)\n",
    "plt.plot(half, np.asarray(test_ACC)[half], label='Attack success rate', linestyle=\"-.\", marker=\"o\", linewidth=3.0, markersize = 8)\n",
    "plt.plot(half, np.asarray(clean_ACC)[half], label='Clean test ACC', linestyle=\"-.\", marker=\"o\", linewidth=3.0, markersize = 8)\n",
    "plt.plot(half, np.asarray(target_ACC)[half], label='Target class clean test ACC', linestyle=\"-\", marker=\"o\", linewidth=3.0, markersize = 8)\n",
    "# plt.plot(half, np.asarray(test_unl_ACC)[half], label='protected test ACC', linestyle=\"-.\", marker=\"o\", linewidth=3.0, markersize = 8)\n",
    "plt.ylabel('ACC', fontsize=24)\n",
    "plt.xticks(fontsize=20)\n",
    "plt.xlabel('Epoches', fontsize=24)\n",
    "plt.yticks(np.arange(0,1.1, 0.1),fontsize=20)\n",
    "plt.legend(fontsize=20,bbox_to_anchor=(1.016, 1.2),ncol=2)\n",
    "plt.grid(color=\"gray\", linestyle=\"-\")\n",
    "plt.show()\n",
    "\n",
    "dis_idx = clean_ACC.index(max(clean_ACC))\n",
    "print(train_ACC[dis_idx])\n",
    "print('attack',test_ACC[dis_idx])\n",
    "print(clean_ACC.index(max(clean_ACC)))\n",
    "print('all class clean', clean_ACC[dis_idx])\n",
    "print('target clean',target_ACC[dis_idx])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}