{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5e82e0d2-4c7f-47bb-9d60-79ab7bbcfed3",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "94c4742a-51ab-4429-a0e9-7cc9903cd2b9",
   "metadata": {},
   "source": [
    "# Devide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "432e23d8-23b7-428f-93f1-1581bbcbb21e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "DEVICE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b98b193-8cee-489d-a5ba-a7e270653eb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c0d2678-d24c-4581-8dde-b5a8db77dd54",
   "metadata": {},
   "source": [
    "# Directories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1262ade4-c761-4708-8fee-d996da1b177f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be23ff4c-90d8-4d64-8c30-da3077bb01ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT_DIR:    Path = Path('')\n",
    "DATA_DIR:    Path = ROOT_DIR/'datasets'/'imagenet2012'\n",
    "RESULTS_DIR: Path = ROOT_DIR/'dave_attribution'/'results'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51388ebf-aabb-4bdb-95f1-40b6827c7d08",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(RESULTS_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "925f3d08-e2af-4f6f-bb71-504260f2188e",
   "metadata": {},
   "source": [
    "# Core model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b846777-4550-4350-bf1d-9e7b709edd6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdceb032-2461-410a-8ee6-34f07471c05d",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME: str = 'vit_base_patch16_224'  # {'deit3_base_patch16_224.fb_in1k', deit_base_patch16_224, vit_base_patch16_224}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88ee8560-5428-4d9a-85d1-0a8dd3fd6c09",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = timm.create_model(\n",
    "    MODEL_NAME,\n",
    "    pretrained=True,\n",
    ").eval().to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9941eee0-cc5c-4d77-a4cb-d07fafafed81",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c83abcd-8981-442d-85ca-8b6ae2ec7fa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import ImageFolder\n",
    "from utils.data import extract_subset, get_input_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9cb2d73-4221-4f1e-8af3-190010176243",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVKIT_PATH = DATA_DIR/'ILSVRC2012_devkit_t12/data/meta.mat'\n",
    "VAL_FOLD_DIR: Path = DATA_DIR/'val_sorted'\n",
    "\n",
    "SELECTED_CLASSES = [\n",
    "    'ostrich',\n",
    "    'golden retriever',\n",
    "    'panda',\n",
    "    'red panda',\n",
    "    'sports car',\n",
    "    'penguin',\n",
    "    'soccer ball',\n",
    "    'chickadee',\n",
    "    'airliner',\n",
    "]\n",
    "\n",
    "BATCH_SIZE: int = 10\n",
    "NUM_WORKERS: int = os.cpu_count() // 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a0b397-fb47-4f4e-b955-8a536f870b03",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = get_input_transform(MODEL_NAME, model, img_size=224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "822ebee9-a57c-40e7-b5b1-3dc0cb0ef6e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = ImageFolder(\n",
    "    VAL_FOLD_DIR,\n",
    "    transform=transforms,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2f2eeb4-057f-4225-8dc9-ac141b1b3657",
   "metadata": {},
   "source": [
    "Extract selected subset (OPTIONAL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9fd2b13-e5db-4b19-8943-48c2b9b79908",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = extract_subset(\n",
    "    dataset=dataset,\n",
    "    classes=SELECTED_CLASSES,\n",
    "    devkit_path=DEVKIT_PATH,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd503c36-a7fa-4682-9277-f4cdf05f7a4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader = DataLoader(\n",
    "    dataset,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17da7453-a6fe-481f-b7b4-a17050084c26",
   "metadata": {},
   "source": [
    "Iterate data_loader (OPTIONAL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38e90cf5-b93b-4d4b-a261-b21c2971b017",
   "metadata": {},
   "outputs": [],
   "source": [
    "for x, _ in tqdm(data_loader, total=len(data_loader)):\n",
    "    del x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e326d7dd-03d1-4252-a5c9-dcabd94aa7c2",
   "metadata": {},
   "source": [
    "# Explainer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81a6991a-732f-4389-8aae-44aa7ebc3b69",
   "metadata": {},
   "source": [
    "### Augmentation CFG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93642458-cb35-46ec-a16e-c888ec0ff052",
   "metadata": {},
   "outputs": [],
   "source": [
    "from augment import AugCFG, DiffAugment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9edaf61-4230-4700-8c7d-6e889169e009",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = (224, 224)\n",
    "\n",
    "H_FLIP_PROB = 0.5\n",
    "\n",
    "AFFINE_PROB            = 1.0\n",
    "AFFINE_ROTATE_RANGE    = (-15.0, 15.0)\n",
    "AFFINE_TRANSLATE_RANGE = (0.1, 0.1)\n",
    "AFFINE_SCALE_RANGE     = (0.9, 1.1)\n",
    "\n",
    "CROP_SCALE_RANGE = (0.9, 1.1)\n",
    "CROP_RATIO_RANGE = (0.9, 1.1)\n",
    "CROP_PAD         = 4\n",
    "\n",
    "MAX_RESIZE_RATIO = 0.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb45acf-b044-44a5-a12e-a51f8d4752b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "aug_cfg = AugCFG(\n",
    "    image_size=IMAGE_SIZE,\n",
    "    h_flip_prob=H_FLIP_PROB,\n",
    "    affine_prob=AFFINE_PROB,\n",
    "    affine_rotate_range=AFFINE_ROTATE_RANGE,\n",
    "    affine_translate_range=AFFINE_TRANSLATE_RANGE,\n",
    "    affine_scale_range=AFFINE_SCALE_RANGE,\n",
    "    crop_scale_range=CROP_SCALE_RANGE,\n",
    "    crop_ratio_range=CROP_RATIO_RANGE,\n",
    "    crop_pad=CROP_PAD,\n",
    "    max_resize_ratio=MAX_RESIZE_RATIO,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f698f9b3-71bb-44ef-82c2-886be6dff787",
   "metadata": {},
   "source": [
    "### Postproc CFG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16638468-c662-4794-9969-31fc2d1249b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.post_processing import PostProcessing, PostprocCFG, GaussianCFG, BilateralCFG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e213ca4-b2de-4f23-99ca-f7e615c5648e",
   "metadata": {},
   "outputs": [],
   "source": [
    "GAUSS_KERNEL_SIZE: int   = 11\n",
    "GAUSS_SIGMA:       float = 7.0\n",
    "\n",
    "BILAT_KERNEL_SIZE:   int   = 5\n",
    "BILAT_SIGMA_SPATIAL: float = 0.5\n",
    "BILAT_SIGMA_RANGE:   float = 0.05\n",
    "EPS:                 float = 1e-8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3682877c-7c93-47db-84c3-fa3f64d9cd26",
   "metadata": {},
   "outputs": [],
   "source": [
    "postproc_cfg = PostprocCFG(\n",
    "    gauss=GaussianCFG(\n",
    "        kernel_size=GAUSS_KERNEL_SIZE,\n",
    "        sgm=GAUSS_SIGMA,\n",
    "    ),\n",
    "    bilat=BilateralCFG(\n",
    "        kernel_size=BILAT_KERNEL_SIZE,\n",
    "        sgm_spatial=BILAT_SIGMA_SPATIAL,\n",
    "        sgm_range=BILAT_SIGMA_RANGE,\n",
    "    ),\n",
    "    eps=EPS,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9db01a89-be21-49e1-84d7-3b8d75f63cad",
   "metadata": {},
   "source": [
    "### Attribution CFG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf80cb28-7690-4a96-a7c0-905244f08cc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from explainer import DAVEExplainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df9595a4-875f-40bb-96bc-48765d015709",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_STEPS: int = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3534c20c-327f-43e4-84e6-7c3ab9efb18e",
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = DAVEExplainer(\n",
    "    vit_model=model,\n",
    "    aug_cfg=aug_cfg,\n",
    "    postproc_cfg=postproc_cfg,\n",
    "    device=DEVICE,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b58afdf-09da-4896-ab5a-7bf784a0ead4",
   "metadata": {},
   "source": [
    "# Attribute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a03a5ffc-1857-44fd-b9c3-0d216a96f7c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "from utils.data import IMAGENET_MEAN, IMAGENET_STD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57754027-52c6-4309-b802-f9a112dc54d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "global_idx = 0\n",
    "\n",
    "for step, (x, y) in tqdm(\n",
    "    enumerate(data_loader),\n",
    "    desc=\"Attributing...\",\n",
    "    total=len(data_loader),\n",
    "):\n",
    "    N = x.shape[0]\n",
    "    x, y = x.to(DEVICE), y.to(DEVICE)\n",
    "    y = y.unsqueeze(1)\n",
    "    result = explainer.attribute_supervised(\n",
    "        x=x,\n",
    "        y=y,\n",
    "        num_steps=NUM_STEPS,\n",
    "        clamp=False,\n",
    "    )\n",
    "    a_m = result['attribution']\n",
    "\n",
    "    mu  = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(x.device)\n",
    "    sgm = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(x.device)\n",
    "    x = (x * sgm + mu).detach().cpu()\n",
    "\n",
    "    for i in range(N):\n",
    "        fig = plt.figure(figsize=(15, 7), dpi=200)\n",
    "        gs = gridspec.GridSpec(\n",
    "            2, 2,\n",
    "            height_ratios=[4, 1],\n",
    "            hspace=0.25\n",
    "        )\n",
    "\n",
    "        img = x[i].permute(1, 2, 0).detach().cpu()\n",
    "        a_m_i = a_m[i]\n",
    "\n",
    "        ax1 = fig.add_subplot(gs[0, 0])\n",
    "        ax1.imshow(img)\n",
    "        ax1.set_aspect(\"equal\", adjustable=\"box\")\n",
    "        ax1.axis(\"off\")\n",
    "\n",
    "        ax2 = fig.add_subplot(gs[0, 1])\n",
    "        ax2.imshow(a_m_i, cmap=\"magma\", vmin=a_m_i.min(), vmax=a_m_i.max())\n",
    "        ax2.set_aspect(\"equal\", adjustable=\"box\")\n",
    "        ax2.axis(\"off\")\n",
    "\n",
    "        plt.show()\n",
    "        plt.close()\n",
    "\n",
    "    if global_idx >= 500:\n",
    "        break\n",
    "    global_idx += 1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80fe1d31-9478-4cc3-bf8e-9b52a39f8467",
   "metadata": {},
   "source": [
    "# Single sample sanity check!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dcb0f33-74a1-4673-96ee-7786396cc511",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "img = Image.open('test_img.png').convert(\"RGB\")\n",
    "x = transform(img).unsqueeze(0)\n",
    "y = torch.zeros((1,)) + 19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61db9e09-9a23-4edb-b9eb-9bb778f8dfa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = x.shape[0]\n",
    "x, y = x.to(DEVICE), y.to(DEVICE)\n",
    "y = y.unsqueeze(1)\n",
    "result = explainer.attribute_supervised(\n",
    "    x=x,\n",
    "    y=y,\n",
    "    num_steps=500,\n",
    "    clamp=False,\n",
    ")\n",
    "a_m = result['attribution']\n",
    "\n",
    "mu  = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(x.device)\n",
    "sgm = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(x.device)\n",
    "x = (x * sgm + mu).detach().cpu()\n",
    "\n",
    "for i in range(N):\n",
    "    fig = plt.figure(figsize=(15, 7), dpi=200)\n",
    "    gs = gridspec.GridSpec(\n",
    "        2, 2,\n",
    "        height_ratios=[4, 1],\n",
    "        hspace=0.25\n",
    "    )\n",
    "\n",
    "    img = x[i].permute(1, 2, 0).detach().cpu()\n",
    "    a_m_i = a_m[i]\n",
    "\n",
    "    ax1 = fig.add_subplot(gs[0, 0])\n",
    "    ax1.imshow(img)\n",
    "    ax1.set_aspect(\"equal\", adjustable=\"box\")\n",
    "    ax1.axis(\"off\")\n",
    "\n",
    "    ax2 = fig.add_subplot(gs[0, 1])\n",
    "    ax2.imshow(a_m_i, cmap=\"magma\", vmin=a_m_i.min(), vmax=a_m_i.max())\n",
    "    ax2.set_aspect(\"equal\", adjustable=\"box\")\n",
    "    ax2.axis(\"off\")\n",
    "\n",
    "    plt.show()\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46244775-f5fe-4119-9b16-d0b66ef468f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
