{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ebd66700",
   "metadata": {},
   "source": [
    "## Demo_Quality\n",
    "This is a demo for visualizing the Image Quality\n",
    "\n",
    "To run this demo from scratch, you need first generate a BadNet attack result by using the following cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b950f4fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python ../../attack/badnet.py --save_folder_name badnet_demo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f81f973",
   "metadata": {},
   "source": [
    "or run the following command in your terminal\n",
    "\n",
    "```python attack/badnet.py --save_folder_name badnet_demo```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87bd9f5a",
   "metadata": {},
   "source": [
    "### Step 1: Import modules and set arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "71b7087b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "import yaml\n",
    "import torch\n",
    "import shap\n",
    "import numpy as np\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "sys.path.append(\"../../\")\n",
    "sys.path.append(os.getcwd())\n",
    "from visual_utils import *\n",
    "from utils.aggregate_block.dataset_and_transform_generate import (\n",
    "    get_transform,\n",
    "    get_dataset_denormalization,\n",
    ")\n",
    "from utils.aggregate_block.fix_random import fix_random\n",
    "from utils.aggregate_block.model_trainer_generate import generate_cls_model\n",
    "from utils.save_load_attack import load_attack_result\n",
    "from utils.defense_utils.dbd.model.utils import (\n",
    "    get_network_dbd,\n",
    "    load_state,\n",
    "    get_criterion,\n",
    "    get_optimizer,\n",
    "    get_scheduler,\n",
    ")\n",
    "from utils.defense_utils.dbd.model.model import SelfModel, LinearModel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fb719c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Basic setting: args\n",
    "args = get_args(True)\n",
    "\n",
    "########## For Demo Only ##########\n",
    "args.yaml_path = \"../../\"+args.yaml_path\n",
    "args.result_file_attack = \"badnet_demo\"\n",
    "######## End For Demo Only ##########\n",
    "\n",
    "with open(args.yaml_path, \"r\") as stream:\n",
    "    config = yaml.safe_load(stream)\n",
    "config.update({k: v for k, v in args.__dict__.items() if v is not None})\n",
    "args.__dict__ = config\n",
    "args = preprocess_args(args)\n",
    "fix_random(int(args.random_seed))\n",
    "\n",
    "save_path_attack = \"../..//record/\" + args.result_file_attack\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f959b510",
   "metadata": {},
   "source": [
    "### Step 2: Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b8b67ac9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:save_path MUST have 'record' in its abspath, and data_path in attack result MUST have 'data' in its path\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "loading...\n",
      "max_num_samples is given, use sample number limit now.\n",
      "subset bd dataset with length:  5000\n",
      "Create visualization dataset with \n",
      " \t Dataset: bd_train \n",
      " \t Number of samples: 5000  \n",
      " \t Selected classes: [0 1 2 3 4 5 6 7 8 9]\n"
     ]
    }
   ],
   "source": [
    "# Load result\n",
    "result_attack = load_attack_result(save_path_attack + \"/attack_result.pt\")\n",
    "selected_classes = np.arange(args.num_classes)\n",
    "\n",
    "# Select classes to visualize\n",
    "if args.num_classes>args.c_sub:\n",
    "    selected_classes = np.delete(selected_classes, args.target_class)\n",
    "    selected_classes = np.random.choice(selected_classes, args.c_sub-1, replace=False)\n",
    "    selected_classes = np.append(selected_classes, args.target_class)\n",
    "\n",
    "# keep the same transforms for train and test dataset for better visualization\n",
    "result_attack[\"clean_train\"].wrap_img_transform = result_attack[\"clean_test\"].wrap_img_transform \n",
    "result_attack[\"bd_train\"].wrap_img_transform = result_attack[\"bd_test\"].wrap_img_transform \n",
    "\n",
    "# Create dataset\n",
    "args.visual_dataset = 'bd_train'\n",
    "if args.visual_dataset == 'mixed':\n",
    "    bd_test_with_trans = result_attack[\"bd_test\"]\n",
    "    visual_dataset = generate_mix_dataset(bd_test_with_trans, args.target_class, args.pratio, selected_classes, max_num_samples=args.n_sub)\n",
    "elif args.visual_dataset == 'clean_train':\n",
    "    clean_train_with_trans = result_attack[\"clean_train\"]\n",
    "    visual_dataset = generate_clean_dataset(clean_train_with_trans, selected_classes, max_num_samples=args.n_sub)\n",
    "elif args.visual_dataset == 'clean_test':\n",
    "    clean_test_with_trans = result_attack[\"clean_test\"]\n",
    "    visual_dataset = generate_clean_dataset(clean_test_with_trans, selected_classes, max_num_samples=args.n_sub)\n",
    "elif args.visual_dataset == 'bd_train':  \n",
    "    bd_train_with_trans = result_attack[\"bd_train\"]\n",
    "    visual_dataset = generate_bd_dataset(bd_train_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)\n",
    "elif args.visual_dataset == 'bd_test':\n",
    "    bd_test_with_trans = result_attack[\"bd_test\"]\n",
    "    visual_dataset = generate_bd_dataset(bd_test_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)\n",
    "else:\n",
    "    assert False, \"Illegal vis_class\"\n",
    "\n",
    "print(f'Create visualization dataset with \\n \\t Dataset: {args.visual_dataset} \\n \\t Number of samples: {len(visual_dataset)}  \\n \\t Selected classes: {selected_classes}')\n",
    "\n",
    "# Create data loader\n",
    "data_loader = torch.utils.data.DataLoader(\n",
    "    visual_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False\n",
    ")\n",
    "\n",
    "# Create denormalization function\n",
    "for trans_t in data_loader.dataset.wrap_img_transform.transforms:\n",
    "    if isinstance(trans_t, transforms.Normalize):\n",
    "        denormalizer = get_dataset_denormalization(trans_t)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67cbfec4",
   "metadata": {},
   "source": [
    "### Step 3: SSIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "39104beb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number Poisoned samples: 489\n",
      "Average SSIM: 0.9929845929145813\n"
     ]
    }
   ],
   "source": [
    "visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))\n",
    "bd_idx = np.where(visual_poison_indicator == 1)[0]\n",
    "\n",
    "from torchmetrics import StructuralSimilarityIndexMeasure\n",
    "ssim = StructuralSimilarityIndexMeasure()\n",
    "ssim_list = []\n",
    "if visual_poison_indicator.sum() > 0:\n",
    "    print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')\n",
    "    # random choose two poisoned samples\n",
    "    start_idx = 0\n",
    "    for i in range(bd_idx.shape[0]):\n",
    "        bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
    "        with temporary_all_clean(visual_dataset):\n",
    "            clean_sample =  denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
    "        ssim_list.append(ssim(bd_sample, clean_sample))        \n",
    "print(f'Average SSIM: {np.mean(ssim_list)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c2b0104",
   "metadata": {},
   "source": [
    "### Step 4: FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "57497927",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number Poisoned samples: 489\n",
      "FID: 0.00030133521067909896\n"
     ]
    }
   ],
   "source": [
    "visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))\n",
    "bd_idx = np.where(visual_poison_indicator == 1)[0]\n",
    "\n",
    "from torchmetrics.image.fid import FrechetInceptionDistance\n",
    "fid = FrechetInceptionDistance(feature=64, normalize = True)\n",
    "if visual_poison_indicator.sum() > 0:\n",
    "    print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')\n",
    "    # random choose two poisoned samples\n",
    "    start_idx = 0\n",
    "    for i in range(bd_idx.shape[0]):\n",
    "        bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
    "        with temporary_all_clean(visual_dataset):\n",
    "            clean_sample =  denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
    "        fid.update(clean_sample, real=True)\n",
    "        fid.update(bd_sample, real=False)\n",
    "    fid_value = fid.compute().numpy()        \n",
    "print(f'FID: {fid_value}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "870cf186",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "6869619afde5ccaa692f7f4d174735a0f86b1f7ceee086952855511b0b6edec0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
