{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from loader import Box\n",
    "from models.unet_model import UNet\n",
    "import cfg\n",
    "from torchvision.datasets import CIFAR10\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "opt = cfg.get_arguments().parse_args(['--dataset', 'cifar', '--tlabel', '0', '--model', 'resnet18', '--attack', 'ia', '--device', 'cuda:0'\n",
    "                                      ,'--size', '32', '--num_classes', '10', '--batch_size', '128', '--attack_type', 'all-to-one', '--root', ''])\n",
    "box = Box(opt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load Backdoor Informations (param1, param2) and Backdoored Model (cls_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "param1, param2, cls_model = box.get_state_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "test_tf = box.get_transform(train=\"test\")\n",
    "cln_testset = CIFAR10(\"./datasets\", train=False, transform=test_tf, download=True)\n",
    "cln_testloader = DataLoader(cln_testset, batch_size=opt.batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test Poisoned Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Test Poisoned Samples: 100%|██████████| 79/79 [00:06<00:00, 12.82it/s, BA=94.07, ASR=99.41]\n"
     ]
    }
   ],
   "source": [
    "cls_model.eval()\n",
    "total_ba = 0\n",
    "total_asr = 0\n",
    "correct_ba = 0\n",
    "correct_asr = 0\n",
    "pbar = tqdm(cln_testloader, desc=\"Test Poisoned Samples\")\n",
    "for cln_imgs, labels in pbar:\n",
    "    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)\n",
    "    poi_imgs = box.poisoned(cln_imgs, param1, param2)\n",
    "    cln_outputs = cls_model(cln_imgs)\n",
    "    poi_outputs = cls_model(poi_imgs)\n",
    "\n",
    "    _, cln_pred = cln_outputs.max(1)\n",
    "    _, poi_pred = poi_outputs.max(1)\n",
    "\n",
    "    for i in range(cln_imgs.shape[0]):\n",
    "        total_ba += 1\n",
    "        if cln_pred[i] == labels[i]:\n",
    "            correct_ba += 1\n",
    "        if labels[i] != box.tlabel:\n",
    "            total_asr += 1\n",
    "            if poi_pred[i] == box.tlabel:\n",
    "                correct_asr += 1\n",
    "\n",
    "    ba = 100. * correct_ba / total_ba\n",
    "    asr = 100. * correct_asr / total_asr\n",
    "    \n",
    "    pbar.set_postfix({\"BA\": \"{:.2f}\".format(ba), \"ASR\": \"{:.2f}\".format(asr)})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BTI-DBF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Test BTI-DBF: 100%|██████████| 79/79 [00:03<00:00, 22.76it/s, ASR=95.27]\n"
     ]
    }
   ],
   "source": [
    "cls_model.eval()\n",
    "total = 0\n",
    "correct_asr = 0\n",
    "inv_generator = UNet(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4)\n",
    "inv_generator.load_state_dict(torch.load(\"./inv_generator/inv_cifar10_ia_t0.pt\", map_location=\"cpu\"))\n",
    "inv_generator.to(box.device)\n",
    "inv_generator.eval()\n",
    "pbar = tqdm(cln_testloader, desc=\"Test BTI-DBF\")\n",
    "for cln_imgs, labels in pbar:\n",
    "    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)\n",
    "    inv_imgs = inv_generator(cln_imgs)\n",
    "    inv_outputs = cls_model(inv_imgs)\n",
    "\n",
    "    _, inv_pred = inv_outputs.max(1)\n",
    "\n",
    "    for i in range(cln_imgs.shape[0]):\n",
    "        total += 1\n",
    "        if inv_pred[i] == box.tlabel:\n",
    "            correct_asr += 1\n",
    "\n",
    "    asr = 100. * correct_asr / total\n",
    "    \n",
    "    pbar.set_postfix({\"ASR\": \"{:.2f}\".format(asr)})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BTI-DBF (U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Test BTI-DBF (U): 100%|██████████| 79/79 [00:05<00:00, 15.51it/s, BA=91.35, ASR=1.83]\n"
     ]
    }
   ],
   "source": [
    "from copy import deepcopy\n",
    "total_ba = 0\n",
    "total_asr = 0\n",
    "correct_ba = 0\n",
    "correct_asr = 0\n",
    "pbar = tqdm(cln_testloader, desc=\"Test BTI-DBF (U)\")\n",
    "unlearn_model = deepcopy(cls_model)\n",
    "unlearn_model.load_state_dict(torch.load(\"ul_model/unlearn_model.pt\", map_location=\"cpu\"))\n",
    "unlearn_model = unlearn_model.to(box.device)\n",
    "unlearn_model.eval()\n",
    "for cln_imgs, labels in pbar:\n",
    "    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)\n",
    "    poi_imgs = box.poisoned(cln_imgs, param1, param2)\n",
    "    cln_outputs = unlearn_model(cln_imgs)\n",
    "    poi_outputs = unlearn_model(poi_imgs)\n",
    "\n",
    "    _, cln_pred = cln_outputs.max(1)\n",
    "    _, poi_pred = poi_outputs.max(1)\n",
    "\n",
    "    for i in range(cln_imgs.shape[0]):\n",
    "        total_ba += 1\n",
    "        if cln_pred[i] == labels[i]:\n",
    "            correct_ba += 1\n",
    "        if labels[i] != box.tlabel:\n",
    "            total_asr += 1\n",
    "            if poi_pred[i] == box.tlabel:\n",
    "                correct_asr += 1\n",
    "\n",
    "    ba = 100. * correct_ba / total_ba\n",
    "    asr = 100. * correct_asr / total_asr\n",
    "    \n",
    "    pbar.set_postfix({\"BA\": \"{:.2f}\".format(ba), \"ASR\": \"{:.2f}\".format(asr)})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BTI-DBF (P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Test BTI-DBF (P): 100%|██████████| 79/79 [00:06<00:00, 12.59it/s, BA=90.63, ASR=5.28]\n"
     ]
    }
   ],
   "source": [
    "cls_model.eval()\n",
    "total_ba = 0\n",
    "total_asr = 0\n",
    "correct_ba = 0\n",
    "correct_asr = 0\n",
    "pur_generator = UNet(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4)\n",
    "pur_generator.load_state_dict(torch.load(\"pur_generator/pur_cifar10_ia_t0.pt\", map_location=\"cpu\"))\n",
    "pur_generator.to(box.device)\n",
    "pur_generator.eval()\n",
    "pbar = tqdm(cln_testloader, desc=\"Test BTI-DBF (P)\")\n",
    "for cln_imgs, labels in pbar:\n",
    "    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)\n",
    "    poi_imgs = box.poisoned(cln_imgs, param1, param2)\n",
    "    cln_pur_imgs = pur_generator(cln_imgs)\n",
    "    poi_pur_imgs = pur_generator(poi_imgs)\n",
    "    cln_pur_outputs = cls_model(cln_pur_imgs)\n",
    "    poi_pur_outputs = cls_model(poi_pur_imgs)\n",
    "\n",
    "    _, cln_pur_pred = cln_pur_outputs.max(1)\n",
    "    _, poi_pur_pred = poi_pur_outputs.max(1)\n",
    "\n",
    "    for i in range(cln_imgs.shape[0]):\n",
    "        total_ba += 1\n",
    "        if cln_pur_pred[i] == labels[i]:\n",
    "            correct_ba += 1\n",
    "        if labels[i] != box.tlabel:\n",
    "            total_asr += 1\n",
    "            if poi_pur_pred[i] == box.tlabel:\n",
    "                correct_asr += 1\n",
    "\n",
    "    ba = 100. * correct_ba / total_ba\n",
    "    asr = 100. * correct_asr / total_asr\n",
    "    \n",
    "    pbar.set_postfix({\"BA\": \"{:.2f}\".format(ba), \"ASR\": \"{:.2f}\".format(asr)})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Test Detection: 100%|██████████| 79/79 [00:07<00:00, 10.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precision: 93.19705462138697\n",
      "Recall: 94.22222222222223\n",
      "F1 score: 93.70683463174761\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "tp = 0\n",
    "fp = 0\n",
    "tn = 0\n",
    "fn = 0\n",
    "total_poi = 0\n",
    "pbar = tqdm(cln_testloader, desc=\"Test Detection\")\n",
    "cls_model.eval()\n",
    "pur_generator.eval()\n",
    "for cln_img, targets in pbar:\n",
    "    cln_img, targets = cln_img.to(box.device), targets.to(box.device)\n",
    "    poi_img = box.poisoned(cln_img, param1, param2)\n",
    "    poi_outputs = cls_model(poi_img)\n",
    "    cln_outputs = cls_model(cln_img)\n",
    "    pur_poi_img = pur_generator(poi_img)\n",
    "    pur_cln_img = pur_generator(cln_img)\n",
    "    pur_poi_outputs = cls_model(pur_poi_img)\n",
    "    pur_cln_outputs = cls_model(pur_cln_img)\n",
    "\n",
    "    _, poi_pred = poi_outputs.max(1)\n",
    "    _, cln_pred = cln_outputs.max(1)\n",
    "    _, pur_poi_pred = pur_poi_outputs.max(1)\n",
    "    _, pur_cln_pred = pur_cln_outputs.max(1)\n",
    "\n",
    "    for i in range(cln_img.shape[0]):\n",
    "        if targets[i] == box.tlabel:\n",
    "            continue\n",
    "        total_poi += 1\n",
    "        if poi_pred[i]!=pur_poi_pred[i]:\n",
    "            tp += 1\n",
    "        else:\n",
    "            fn += 1\n",
    "        if cln_pred[i] != pur_cln_pred[i]:\n",
    "            fp += 1\n",
    "        else:\n",
    "            tn += 1\n",
    "\n",
    "precision = 100. * tp / (tp + fp)\n",
    "recall = 100. * tp / (tp + fn)\n",
    "f1_score = 2 * (precision * recall) / (precision + recall)\n",
    "print(f\"Precision: {precision}\")\n",
    "print(f\"Recall: {recall}\")\n",
    "print(f\"F1 score: {f1_score}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "backdoor",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
