{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71765b74",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f593ba32",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from PIL import Image\n",
    "import torchvision.transforms as transforms\n",
    "import traceback\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b07de6ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "\n",
    "def show_cam_on_image(img: np.ndarray,\n",
    "                      mask: np.ndarray,\n",
    "                      use_rgb: bool = False,\n",
    "                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:\n",
    "    \"\"\" This function overlays the cam mask on the image as an heatmap.\n",
    "    By default the heatmap is in BGR format.\n",
    "    :param img: The base image in RGB or BGR format.\n",
    "    :param mask: The cam mask.\n",
    "    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.\n",
    "    :param colormap: The OpenCV colormap to be used.\n",
    "    :returns: The default image with the cam overlay.\n",
    "    \"\"\"\n",
    "    heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)\n",
    "    if use_rgb:\n",
    "        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n",
    "    heatmap = np.float32(heatmap) / 255\n",
    "\n",
    "    if np.max(img) > 1:\n",
    "        raise Exception(\n",
    "            \"The input image should np.float32 in the range [0, 1]\")\n",
    "\n",
    "    cam = heatmap + img\n",
    "    cam = cam / np.max(cam)\n",
    "    return np.uint8(255 * cam)\n",
    "\n",
    "def open_image(filename):\n",
    "    transformer = transforms.Compose([transforms.Resize((224, 224))])\n",
    "    im = Image.open(filename)\n",
    "    im = transformer(im)\n",
    "    im = np.array(im) / 255.0\n",
    "    return im"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceeb5c0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_path_mask(path_masks, dataset_name, model_name, method):\n",
    "    return path_masks / Path('{}_{}_{}/'.format(dataset_name, model_name, method))\n",
    "\n",
    "def load_original_image(path_images, filename):\n",
    "    jpg_name = Path(str(filename)[:-4] + \".jpg\")\n",
    "    x = open_image(path_images / jpg_name)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b47b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods =  [ \"grad_cam\", \"rise\", \"extremal_perturbations\", \"igos_pp\", \"rt_saliency\", \"explainer\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ad6e4f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_masks = Path(\"../src/evaluation/masks/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90336a49",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_filename(filename, dataset_name, model_name, methods):\n",
    "    if dataset_name==\"COCO\":\n",
    "        path_images = Path(\"../datasets/COCO2014/val2014/\")\n",
    "    else:\n",
    "        path_images = Path(\"../datasets/VOC2007/VOCdevkit/VOC2007/JPEGImages/\")\n",
    "    try:\n",
    "        x = load_original_image(path_images, filename)\n",
    "        masks = []\n",
    "        for method in methods: \n",
    "            p = get_path_mask(path_masks, dataset_name, model_name, method)\n",
    "            # try:\n",
    "            npz_name = Path(str(filename)[:-4] + \".npz\")\n",
    "            m_ = np.load(p / npz_name, dataset_name)[\"arr_0\"]\n",
    "            m_ /= m_.max()\n",
    "                \n",
    "            masks.append(m_)\n",
    "            # except:\n",
    "            #     jpg_name = Path(str(filename)[:-4] + \".jpg\")\n",
    "            #     masks.append(load_original_image(p, jpg_name))    \n",
    "    except:\n",
    "        traceback.print_exc()\n",
    "    \n",
    "    \n",
    "    fig = plt.figure(figsize=(15, 4))\n",
    "    n_methods = len(methods)\n",
    "    plt.subplot(1,n_methods+1, 1)\n",
    "    plt.imshow(x, vmin=0, vmax= 1)\n",
    "    plt.axis(\"off\")\n",
    "\n",
    "\n",
    "    for i, (mask,m) in enumerate(zip(masks, methods)):\n",
    "        plt.subplot(1,n_methods+1, i + 2)\n",
    "        plt.imshow(show_cam_on_image(x, mask), vmin=0, vmax= 1)\n",
    "        # plt.imshow( mask, vmin=0, vmax= 1, cmap=plt.cm.gray_r)\n",
    "        plt.axis(\"off\")\n",
    "        \n",
    "    fig.tight_layout()\n",
    "    outfolder = Path(\"ATTENUATED_\" + dataset_name + \"_\" + model_name)\n",
    "    outfolder.mkdir(exist_ok=True, parents=True)\n",
    "    pdf_name = str(filename)[:-4] + \".pdf\"\n",
    "    plt.savefig(outfolder / Path(pdf_name), bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b9c213",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_images = 5\n",
    "dataset_name = \"VOC\"\n",
    "model_name = \"vgg16\"\n",
    "p_explainer = get_path_mask(path_masks, dataset_name, model_name, \"explainer\")\n",
    "image_list = list(enumerate(p_explainer.glob(\"*.png\")))\n",
    "random.shuffle(image_list)\n",
    "count = 0\n",
    "for i, p in image_list:\n",
    "    filename = p.name\n",
    "    plot_filename(filename, dataset_name, model_name, methods)\n",
    "    count += 1\n",
    "    if count>=n_images:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c8a6035",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_gpu",
   "language": "python",
   "name": "pytorch_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
