{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "410fcbd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6015dc3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from evaluation.eval_utils.compute_masks import vedaldi2019\n",
    "from evaluation.eval_utils.compute_scores import segmented_generator, get_model_and_data\n",
    "from torchray.utils import get_device\n",
    "from torchray.attribution.rise import rise\n",
    "from torchray.attribution.grad_cam import grad_cam\n",
    "from torchray.attribution.guided_backprop import guided_backprop\n",
    "from PIL import Image\n",
    "\n",
    "from utils.image_utils import get_unnormalized_image\n",
    "from models.explainer_classifier import ExplainerClassifierModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0de578f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "\n",
    "def show_cam_on_image(img: np.ndarray,\n",
    "                      mask: np.ndarray,\n",
    "                      use_rgb: bool = False,\n",
    "                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:\n",
    "    \"\"\" This function overlays the cam mask on the image as an heatmap.\n",
    "    By default the heatmap is in BGR format.\n",
    "    :param img: The base image in RGB or BGR format.\n",
    "    :param mask: The cam mask.\n",
    "    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.\n",
    "    :param colormap: The OpenCV colormap to be used.\n",
    "    :returns: The default image with the cam overlay.\n",
    "    \"\"\"\n",
    "    heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)\n",
    "    if use_rgb:\n",
    "        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n",
    "    heatmap = np.float32(heatmap) / 255\n",
    "\n",
    "    if np.max(img) > 1:\n",
    "        raise Exception(\n",
    "            \"The input image should np.float32 in the range [0, 1]\")\n",
    "\n",
    "    cam = heatmap + img\n",
    "    cam = cam / np.max(cam)\n",
    "    return np.uint8(255 * cam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1531662",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_masks = Path(\"../src/evaluation/masks/\")\n",
    "data_path = Path(\"../datasets/VOC2007/\")\n",
    "dataset_name = \"VOC\"\n",
    "model_name = \"vgg16\"\n",
    "model_path = \"../src/checkpoints/pretrained_classifiers/vgg16_voc.ckpt\"\n",
    "path_segmentation = Path('../datasets/VOC2007/VOCdevkit/VOC2007/SegmentationClass/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dc9e63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "voc_classes = [\"aeroplane\", \"bicycle\", \"bird\", \"boat\", \"bottle\", \n",
    "               \"bus\", \"car\", \"cat\", \"chair\", \"cow\", \"diningtable\", \"dog\", \"horse\", \n",
    "               \"motorbike\", \"person\", \"pottedplant\", \"sheep\", \"sofa\", \"train\", \"tvmonitor\" ]\n",
    "d_classes = {i: e for i,e in enumerate(voc_classes)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d27aa64",
   "metadata": {},
   "outputs": [],
   "source": [
    "def open_segmentation_mask(segmentation_filename):\n",
    "    import torchvision.transforms as transforms\n",
    "    from PIL import Image\n",
    "\n",
    "    transformer = transforms.Compose([transforms.Resize((224, 224))])\n",
    "    mask = Image.open(segmentation_filename).convert('L')\n",
    "    mask = transformer(mask)\n",
    "    mask = np.array(mask) / 255.0\n",
    "    # mask[mask > 0] = 1\n",
    "    return mask\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6307903d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_torch_image(x):\n",
    "    img = x.detach().cpu().numpy().squeeze()\n",
    "    if len(img.shape)==2:\n",
    "        \n",
    "        plt.imshow(img, vmin=0, vmax=1)\n",
    "    else:\n",
    "        plt.imshow(np.transpose(img, (1,2,0)),  vmin=0, vmax=1)\n",
    "    plt.axis(\"off\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19bed087",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, data_module = get_model_and_data(data_path, dataset_name, model_name, model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ecf7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = []\n",
    "for s in segmented_generator(data_module, path_segmentation):\n",
    "    x, category_id, filename = s\n",
    "    imgs.append(get_unnormalized_image(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5406bdc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,20))\n",
    "for i in range(10):\n",
    "    for j in range(10):\n",
    "        plt.subplot(10,10, 10*i+j+1)\n",
    "        plot_torch_image(imgs[10*i+j+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d5cf8a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,20))\n",
    "for i in range(10):\n",
    "    for j in range(10):\n",
    "        plt.subplot(10,10, 10*i+j+1)\n",
    "        plot_torch_image(imgs[100+10*i+j+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7537375",
   "metadata": {},
   "outputs": [],
   "source": [
    "# positive examples\n",
    "# c = 54\n",
    "# c = 145\n",
    "\n",
    "# c = 4\n",
    "# c = 77\n",
    "\n",
    "# negative example (horses and people)\n",
    "# c = 6\n",
    "# c = 165\n",
    "# c = 175\n",
    "# c = 67\n",
    "\n",
    "# mixed result\n",
    "# c = 21\n",
    "# c = 189\n",
    "\n",
    "\n",
    "# c = 192\n",
    "\n",
    "c = 35\n",
    "# c = 53\n",
    "\n",
    "for s in segmented_generator(data_module, path_segmentation):\n",
    "    x, category_id, filename = s\n",
    "    c -= 1\n",
    "    if c<0:\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee8eaf76",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63556637",
   "metadata": {},
   "outputs": [],
   "source": [
    "[d_classes[e] for e in category_id]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71662d86",
   "metadata": {},
   "outputs": [],
   "source": [
    "seg_mask = open_segmentation_mask(path_segmentation / filename)\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.subplot(121)\n",
    "plot_torch_image(get_unnormalized_image(x))\n",
    "plt.colorbar()\n",
    "plt.subplot(122)\n",
    "plt.imshow(seg_mask)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb3fc6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.hist(seg_mask.flatten(), 100);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "083098de",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [1, 7, 11, 13, 14]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd5926e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = len(voc_classes)\n",
    "explainer_classifier = ExplainerClassifierModel(num_classes=num_classes)\n",
    "explainer_classifier = explainer_classifier.load_from_checkpoint(\n",
    "            \"../src/checkpoints/explainer_vgg16_voc.ckpt\",num_classes=num_classes)\n",
    "explainer = explainer_classifier.explainer\n",
    "device = get_device()\n",
    "model = model.to(device)\n",
    "explainer = explainer.to(device)\n",
    "explainer.freeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e109273",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5059c49",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_ci = model.forward(x).sigmoid().detach().cpu().numpy().squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dec5377",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Explainer\n",
    "explainer_masks = explainer.forward(x).sigmoid().detach().cpu().numpy().squeeze()#[classes]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23972e89",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_classes = np.unique(category_id) # explainer_masks.shape\n",
    "\n",
    "f,a = plt.subplots(3,7,figsize=(40, 20))\n",
    "\n",
    "for ci, c in enumerate(explainer_masks):\n",
    "    if ci == 0:\n",
    "        a[0][0].imshow(np.transpose(get_unnormalized_image(x[0,...].cpu()), (1,2,0)))\n",
    "        a[0][0].axis(\"off\")\n",
    "    \n",
    "    y_pos = int((ci+1)/7)\n",
    "    x_pos = (ci+1)%7\n",
    "    a[y_pos][x_pos].imshow(c)\n",
    "    a[y_pos][x_pos].axis(\"off\")\n",
    "    if (ci in true_classes):\n",
    "        sub_title = f\"*{ci}: {d_classes[ci]}: {p_ci[ci]:.2f}\"\n",
    "    else: \n",
    "        sub_title = f\"{ci}: {d_classes[ci]}: {p_ci[ci]:.2f}\"\n",
    "\n",
    "    a[y_pos][x_pos].set_title(sub_title, fontsize=12)\n",
    "    # print(f\"{ci}: {d_classes[ci]}: {p_ci[ci]:.2f}\")\n",
    "    \n",
    "#plt.tight_layout()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2b1d95f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extremal perturbation\n",
    "vedaldi_masks = []\n",
    "for c in classes:\n",
    "    vedaldi_masks.append(vedaldi2019(model, x, c).detach().cpu().numpy().squeeze())\n",
    "vedaldi_masks = np.array(vedaldi_masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "101832bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grad CAM\n",
    "# gcam = GradCAM(model=model, target_layer=model.feature_extractor[-1])\n",
    "# gradcam_masks = []\n",
    "# for c in classes:\n",
    "#     gradcam_masks.append(cam(input_tensor=input_tensor, target_category=c).detach().cpu().numpy().squeeze())\n",
    "# gradcam_masks = np.array(gradcam_masks)\n",
    "\n",
    "gradcam_masks = []\n",
    "for c in classes:\n",
    "    gradcam_masks.append(grad_cam(model, x, c, saliency_layer=model.feature_extractor[-1], resize=True).detach().cpu().numpy().squeeze())\n",
    "gradcam_masks = np.array(gradcam_masks)\n",
    "gradcam_masks = gradcam_masks- np.min(gradcam_masks)\n",
    "gradcam_masks = gradcam_masks/np.max(gradcam_masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "101d9401",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RISE\n",
    "# rise_masks = []\n",
    "# for c in classes:\n",
    "#     rise_masks.append(rise(model, x, target=c).detach().cpu().numpy().squeeze())\n",
    "# rise_masks = np.array(rise_masks)\n",
    "rise_masks = []\n",
    "segmentations = rise(model, x).detach().cpu().numpy().squeeze()\n",
    "for c in classes:\n",
    "    class_mask = segmentations[c]\n",
    "    class_mask = class_mask-np.amin(class_mask)\n",
    "    class_mask = class_mask/np.amax(class_mask)\n",
    "    rise_masks.append(class_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3fac80d",
   "metadata": {},
   "outputs": [],
   "source": [
    "gbackprop_masks = []\n",
    "for c in classes:\n",
    "    gbackprop_masks.append(guided_backprop(model, x, c,resize=True).detach().cpu().numpy().squeeze())\n",
    "gbackprop_masks = np.array(gbackprop_masks)\n",
    "# gbackprop_masks = gbackprop_masks- np.min(gbackprop_masks)\n",
    "# gbackprop_masks = gbackprop_masks/np.max(gbackprop_masks)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f64255e",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_id, [d_classes[e] for e in category_id]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aaa20b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer_masks = np.take(explainer_masks, classes, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df944f96",
   "metadata": {},
   "outputs": [],
   "source": [
    "im_ori = np.transpose(get_unnormalized_image(x).detach().cpu().numpy().squeeze(), (1,2,0))\n",
    "n_classes = len(classes)\n",
    "\n",
    "all_masks = [explainer_masks, gradcam_masks, rise_masks, vedaldi_masks]\n",
    "n_methods = len(all_masks)\n",
    "\n",
    "plt.figure(figsize=(20, 10))\n",
    "for i, masks in enumerate(all_masks):\n",
    "    for j, c in enumerate(classes):\n",
    "        plt.subplot(n_methods,n_classes,i*n_classes+j+1)\n",
    "        im = show_cam_on_image(im_ori, masks[j])\n",
    "        plt.imshow(im, vmin=0, vmax=1)\n",
    "        plt.title(d_classes[classes[j]])\n",
    "        plt.axis(\"off\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2cc5d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 4))\n",
    "cmax = np.max(gbackprop_masks)\n",
    "cmin = 0\n",
    "for j, c in enumerate(classes):\n",
    "    plt.subplot(1,n_classes,j+1)\n",
    "    plt.imshow(gbackprop_masks[j], vmin = 0, vmax =cmax, cmap=plt.cm.gray_r)\n",
    "    plt.title(d_classes[classes[j]])\n",
    "    plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "470eee0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the figures\n",
    "c_list =  [d_classes[e] for e in classes]\n",
    "im_ori = np.transpose(get_unnormalized_image(x).detach().cpu().numpy().squeeze(), (1,2,0))\n",
    "\n",
    "save_folder = Path(\"negative\")\n",
    "\n",
    "def save_fig(masks, im_ori, c_list, base_name, save_folder):\n",
    "    save_folder.mkdir(exist_ok=True, parents=True)\n",
    "    for mask, c in zip(masks, c_list):\n",
    "        im = Image.fromarray(show_cam_on_image(im_ori, mask))\n",
    "        im.save( save_folder / Path(base_name + \"_\" + c + \".png\"))\n",
    "\n",
    "save_fig(explainer_masks, im_ori, c_list, \"ours\", save_folder)\n",
    "save_fig(vedaldi_masks, im_ori, c_list, \"vedaldi\", save_folder)\n",
    "save_fig(gradcam_masks, im_ori, c_list, \"gradcam\", save_folder)\n",
    "cmax = np.max(gbackprop_masks)\n",
    "save_fig(gbackprop_masks/cmax, im_ori, c_list, \"guided_backprop\", save_folder)\n",
    "\n",
    "\n",
    "# cmax = np.max(gbackprop_masks)\n",
    "# for j, c in enumerate(c_list):\n",
    "#     im = Image.fromarray(((1-gbackprop_masks[j]/cmax)*255).astype(np.uint8))\n",
    "#     im.save( save_folder / Path(\"guided_backprop\" + \"_\" + c + \".png\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcf31faa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_gpu",
   "language": "python",
   "name": "pytorch_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
