{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ad572f07-0e37-4908-9f78-9b0f2e9145b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/RAID5/DataStorage/xiongren/py38/lib/python3.8/site-packages/mmcv/__init__.py:20: UserWarning: On January 1, 2023, MMCV will release v2.0.0, in which it will remove components related to the training process and add a data transformation module. In addition, it will rename the package names mmcv to mmcv-lite and mmcv-full to mmcv. See https://github.com/open-mmlab/mmcv/blob/master/docs/en/compatibility.md for more details.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current working directory: /RAID5/DataStorage/xiongren/github/InputIBA\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/RAID5/DataStorage/xiongren/py38/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/RAID5/DataStorage/xiongren/py38/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "import sys\n",
    "try:\n",
    "    import input_iba\n",
    "except ModuleNotFoundError:\n",
    "    sys.path.insert(0, '..')\n",
    "    import input_iba\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "from input_iba.models import build_attributor\n",
    "from input_iba.datasets import build_dataset, build_pipeline\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import mmcv\n",
    "\n",
    "\n",
    "from input_iba.datasets import build_dataset\n",
    "from input_iba.evaluation import VisionSensitivityN\n",
    "from input_iba.models import build_classifiers\n",
    "from input_iba.utils import get_valid_set\n",
    "import matplotlib.colors as colors\n",
    "import cv2\n",
    "\n",
    "\n",
    "import os\n",
    "os.chdir('..')\n",
    "print(f'Current working directory: {os.getcwd()}')\n",
    "cfg = mmcv.Config.fromfile('configs/vgg_imagenet1.py')\n",
    "device = 'cuda:0'\n",
    "\n",
    "\n",
    "est_set = build_dataset(cfg.data['estimation'])\n",
    "est_loader = DataLoader(est_set, **cfg.data['data_loader'])\n",
    "attributor = build_attributor(cfg.attributor, default_args=dict(device=device))\n",
    "attributor.estimate(est_loader, cfg.estimation_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b8e5bfc-e996-4c01-9dbc-275fe88f501f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, Normalize\n",
    "from PIL import Image\n",
    "image_size = 224\n",
    "    \n",
    "\n",
    "img_airliner  = Compose([\n",
    "    Resize(image_size), ToTensor(),  \n",
    "    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])(Image.open(\"3.jpg\"))\n",
    "\n",
    "target_air = 404"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "557c5be9-906b-432a-9904-f0466e77f35c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inputIBA_att(attributor,input_tensor,device,target):\n",
    "    attributor.make_attribution(input_tensor,\n",
    "                            target,\n",
    "                            attribution_cfg=cfg.attribution_cfg)\n",
    "\n",
    "    ###INPUT_IBA\n",
    "    cap = attributor.buffer['input_mask']\n",
    "    return cap\n",
    "cap_air = inputIBA_att(attributor,img_airliner.to(device),device,target_air)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8a9992e-842f-49f4-a128-87db19905e5b",
   "metadata": {},
   "outputs": [],
   "source": "import torch\nimport torchvision\nimport numpy as np\n\n# Assume mask, image, pretrained_vgg16, device, target are already defined\npretrained_vgg16 = torchvision.models.vgg16(pretrained=True)\npretrained_vgg16.eval()\n\ntotal_pixels = 3 * 244 * 244\nn_values = range(1, total_pixels, 1000)  # From 1 to total pixels, every 1000\n\n# Method name and corresponding mask\nmethod_name = 'InputIBA'  # Or other method name\nmask = cap_air  # Single mask, not a list\n\nimage = img_airliner\ntarget = target_air\nconfs = []\nmask = cap_air\n# If mask is 2D, expand to 3 channels\nif len(mask.shape) == 2:\n    mask = np.stack([mask, mask, mask], 0)\n\nfor n in n_values:\n    top_indices = np.argsort(mask.flatten())[-n:]\n    m_result = np.zeros_like(mask)\n    np.put(m_result, top_indices, 1)  # Mark top N pixels in m_result\n    \n    m_result_tensor = torch.tensor(m_result, dtype=torch.float)\n    ib_image1 = m_result_tensor * image + (1 - m_result_tensor) * torch.randn_like(image)\n    \n    # Calculate confidence score\n    conf = torch.softmax(pretrained_vgg16(ib_image1[None]), dim=1)[0, target].item()\n    confs.append([n, conf])\n\n# Store results\nres_id = {method_name: confs}"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "150eb124-04c8-4fa5-8589-ed6973b5e2c3",
   "metadata": {},
   "outputs": [],
   "source": "import matplotlib.pyplot as plt\n\n# Assume confs has been calculated, format is [[n1, conf1], [n2, conf2], ...]\n\n# Extract n values and confidence scores\nn_list = [item[0] for item in confs]\nconf_list = [item[1] for item in confs]\n\n# Plot line chart\nplt.figure(figsize=(10, 6))\nplt.plot(n_list, conf_list, marker='o', markersize=3, linewidth=1.5, label=method_name)\n\nplt.xlabel('Number of Top Pixels (n)', fontsize=12)\nplt.ylabel('Confidence Score', fontsize=12)\nplt.title(f'Confidence vs. Number of Top Pixels - {method_name}', fontsize=14)\nplt.legend()\nplt.grid(True, linestyle='--', alpha=0.7)\n\nplt.tight_layout()\nplt.show()"
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8835a5fe-f6cb-41de-985a-7cffdca9216c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8395055338177264"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.integrate import trapezoid\n",
    "\n",
    "\n",
    "ins_scores =  np.array(confs)[:,1].tolist() \n",
    "auc_dic = trapezoid(\n",
    "        ins_scores, dx=1. / len(ins_scores))\n",
    "auc_dic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d737b1-4a5c-4907-b782-8358cf63874b",
   "metadata": {},
   "outputs": [],
   "source": "confs_del = []\nif len(mask.shape) == 2:\n    mask = np.stack([mask,mask,mask],0)\nfor n in n_values:\n    top_indices = np.argsort(mask.flatten())[-n:]\n    m_result = np.ones_like(mask)\n    np.put(m_result, top_indices, 0)  # Mark top N pixels in m_result\n\n    m_result_tensor = torch.tensor(m_result, dtype=torch.float)\n    ib_image1 = m_result_tensor * image + (1 - m_result_tensor) * torch.randn_like(image)\n    \n    conf_d = torch.softmax(pretrained_vgg16(ib_image1[None]), dim=1)[0, target].item()\n    confs_del.append([n,conf_d])"
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4bba05c0-7217-41ab-978a-fe267593d88a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.008348894398324886"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "del_scores =  np.array(confs_del)[:,1].tolist() \n",
    "auc_dic_del = trapezoid(\n",
    "        del_scores, dx=1. / len(ins_scores))\n",
    "auc_dic_del"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a975e62d-e480-45d0-ba91-e2a1f021097a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8311566394194015"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "auc_dic - auc_dic_del"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feb38007-9b1b-4153-ba35-d48670a80ba6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}