{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7380c37d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "import PIL\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.datasets import test\n",
    "\n",
    "from utils.edm_score import input_gradient\n",
    "\n",
    "from robustness import model_utils, datasets # https://github.com/MadryLab/robustness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae5dfe52",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "batch_size = 32\n",
    "num_workers = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc79e3fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training and validation loader, inlcuding the standard data transform\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], \n",
    "                                 std=[0.229, 0.224, 0.225])\n",
    "\n",
    "imagenet_val_unnormalized = torchvision.datasets.ImageNet('/scratch_local/datasets/ImageNet2012', \n",
    "                                                         split='val',            \n",
    "                                                         transform = transforms.Compose([\n",
    "                                                            transforms.Resize(256),\n",
    "                                                            transforms.CenterCrop(224),\n",
    "                                                            transforms.ToTensor(),\n",
    "                                                        ]))\n",
    "\n",
    "val_loader_unnormalized = torch.utils.data.DataLoader(imagenet_val_unnormalized, batch_size=batch_size, shuffle=False, # for madry pre-trained models\n",
    "                                                      num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82766ae1",
   "metadata": {},
   "source": [
    "## Validation Accuracy of pre-trained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f31d9558",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# load the models from https://github.com/microsoft/robust-models-transfer\n",
    "l2_epsilons = ['0', '0.01', '0.03', '0.05', '0.1', '0.25', '0.5', '1', '3', '5']\n",
    "l2_robust_models = [f'../saved_models/imagenet_robust/resnet18_l2_eps{eps}.ckpt' for eps in  l2_epsilons]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963ff612",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Wrapper(torch.nn.Module):\n",
    "    def __init__(self, wrapped):\n",
    "        super().__init__()\n",
    "        self.wrapped = wrapped\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.wrapped(x)\n",
    "        # insert fancy logic here\n",
    "        return out[0]\n",
    "  \n",
    "    def __getattr__(self, name):\n",
    "        try:\n",
    "                return super().__getattr__(name)\n",
    "        except AttributeError:\n",
    "            if name == \"wrapped\":\n",
    "                raise AttributeError()\n",
    "            return getattr(self.wrapped, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23ab414f",
   "metadata": {},
   "outputs": [],
   "source": [
    "imagenet_val_64x64 = torchvision.datasets.ImageNet('/scratch_local/datasets/ImageNet2012', \n",
    "                                             split='val',            \n",
    "                                             transform = transforms.Compose([\n",
    "                                                transforms.Resize(256),\n",
    "                                                transforms.CenterCrop(224),\n",
    "                                                transforms.Resize(64),\n",
    "                                                transforms.Resize(224),\n",
    "                                                transforms.ToTensor(),\n",
    "                                                normalize,\n",
    "                                            ]))\n",
    "\n",
    "val_loader_64x64 = torch.utils.data.DataLoader(imagenet_val_64x64, batch_size=batch_size, shuffle=False,\n",
    "                                         num_workers=num_workers, pin_memory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ac5d6a5",
   "metadata": {},
   "source": [
    "test(resnet50_model, val_loader_64x64, device) # 56.69"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f644a383",
   "metadata": {},
   "source": [
    "## Load the robust models and compute some input gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "959562dd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "imgnet_ds = datasets.ImageNet('/scratch_local/datasets/ImageNet2012')\n",
    "models = {model_file: Wrapper(model_utils.make_and_restore_model(arch='resnet18', dataset=imgnet_ds, resume_path = model_file)[0]) for model_file in l2_robust_models}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa72f328",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, model in models.items():\n",
    "    model.eval()\n",
    "    model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "919d6ca8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "input_gradients = {k : [] for k, _ in models.items()}\n",
    "for model_name, model in models.items():\n",
    "    model.to(device)\n",
    "    for img, _ in tqdm.tqdm(val_loader_unnormalized):\n",
    "        img = img.to(device)\n",
    "        gradient = input_gradient(model, img).detach().cpu()\n",
    "        input_gradients[model_name].append(gradient)\n",
    "        break\n",
    "    model.to('cpu')\n",
    "input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}\n",
    "\n",
    "# scale the input gradients so that they lie in [-1,1]\n",
    "for model_name, _ in models.items():\n",
    "    for idx in range(input_gradients[model_name].shape[0]):\n",
    "        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f21d91e",
   "metadata": {},
   "source": [
    "## Load the diffusion model and compute some scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c66c636",
   "metadata": {},
   "outputs": [],
   "source": [
    "import diffusion_model\n",
    "\n",
    "diffusion = diffusion_model.Diffusion(f'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3c8e126",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loader_unnormalized_bs_1 = torch.utils.data.DataLoader(imagenet_val_unnormalized, batch_size=1, shuffle=True, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56cdebd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 1.2\n",
    "\n",
    "diffusion.to(device)\n",
    "images = []\n",
    "scores = []\n",
    "for idx, (img, label) in tqdm.tqdm(enumerate(val_loader_unnormalized_bs_1)):\n",
    "    images.append(img.clone().detach())\n",
    "    img  = transforms.Resize((64,64), antialias=True)(img)\n",
    "    score = diffusion.get_score(img.to(device), sigma, class_labels=label)\n",
    "    scores.append(score.detach().cpu())\n",
    "    if idx > 12:\n",
    "        break\n",
    "images = torch.vstack(images)\n",
    "scores = torch.vstack(scores)\n",
    "diffusion.to('cpu')\n",
    "\n",
    "for idx in range(scores.shape[0]): # score to [-1, 1]\n",
    "    scores[idx] = scores[idx] / scores[idx].abs().max()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9e5305c",
   "metadata": {},
   "source": [
    "## Model gradients and score, the standard plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43a39045",
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 3+len(models)\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=10, figsize=(2 * nrows, 25))\n",
    "\n",
    "for idx in range(10):\n",
    "    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    \n",
    "    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "    # different models\n",
    "    for model_idx, (model_name, _) in enumerate(models.items()):\n",
    "        img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "        axs[2+model_idx, 0].set_ylabel(model_name[-8:-4])\n",
    "\n",
    "    # random baseline\n",
    "    img = torch.randn_like(images[idx, :, :, :])\n",
    "    img = (img / img.abs().max() * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[nrows-1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    axs[nrows-1, 0].set_ylabel('Random Normal')\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(10):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "axs[1,0].set_ylabel('Score')\n",
    "axs[0,0].set_ylabel('Image')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc53ab99",
   "metadata": {},
   "source": [
    "## Calculate Input Gradients and Scores for 1000 images from the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6723eda9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "torch.manual_seed(0)\n",
    "random.seed(0)\n",
    "\n",
    "images = []\n",
    "scores = []\n",
    "input_gradients = {k : [] for k, _ in models.items()}\n",
    "for idx, (img, label) in tqdm.tqdm(enumerate(val_loader_unnormalized_bs_1)):\n",
    "    # image\n",
    "    images.append(img.detach().cpu())\n",
    "    img = img.to(device)\n",
    "    # input gradient, for all models\n",
    "    for model_name, model in models.items():\n",
    "        model.to(device)\n",
    "        assert img.grad is None\n",
    "        ig = input_gradient(model, img).detach().cpu()\n",
    "        input_gradients[model_name].append(ig)\n",
    "        model.to('cpu')\n",
    "    # score\n",
    "    diffusion.to(device)\n",
    "    img  = transforms.Resize((64,64), antialias=True)(images[-1])\n",
    "    score = diffusion.get_score(img.to(device), sigma, class_labels=label)\n",
    "    scores.append(score.detach().cpu())\n",
    "    diffusion.to('cpu')\n",
    "    if idx >= 1000:\n",
    "        break\n",
    "images = torch.vstack(images)\n",
    "input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}\n",
    "scores = torch.vstack(scores)\n",
    "\n",
    "# scale the lenght of score and input gradients so that they lie in [-1,1]\n",
    "for model_name, _ in models.items():\n",
    "    for idx in range(input_gradients[model_name].shape[0]):\n",
    "        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()\n",
    "for idx in range(scores.shape[0]): # score to [-1, 1]\n",
    "    scores[idx] = scores[idx] / scores[idx].abs().max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba15c0ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save((images, scores, input_gradients), '../datasets/imagenet_resnet18_img_score_gradients.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6426c7f",
   "metadata": {},
   "source": [
    "## visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15cc6ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 2+len(models)\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=10, figsize=(20, 24))\n",
    "\n",
    "for idx in range(10):\n",
    "    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    \n",
    "    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "    # different models\n",
    "    for model_idx, (model_name, _) in enumerate(models.items()):\n",
    "        img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "        axs[2+model_idx, 0].set_ylabel(model_name[15:20])\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(10):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=0.5)\n",
    "plt.savefig('../figures/imagenet-big-appendix.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7a15cdc",
   "metadata": {},
   "source": [
    "## LPIPS similarity metric for input gradients and score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b841d5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "images, scores, input_gradients = torch.load('../datasets/imagenet_resnet18_img_score_gradients.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b242ad59",
   "metadata": {},
   "outputs": [],
   "source": [
    "import lpips\n",
    "\n",
    "# init lpips \n",
    "torch.hub.set_dir(\"../tmp/.cache/torchhub\") # set hub to writeable directory\n",
    "loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9385dc5a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = []\n",
    "for model_name, _ in models.items():\n",
    "    if model_name == '../saved_models/imagenet_robust/resnet18_l2_eps0.ckpt': # cant't plot 0 in log plot\n",
    "        continue     \n",
    "    print(model_name)\n",
    "    distances = []\n",
    "    for img_idx in range(100): # for all images\n",
    "\n",
    "        score = scores[img_idx]\n",
    "        ig = input_gradients[model_name][img_idx]\n",
    "        # bilinear downsize for the input gradient to have the same size as the score \n",
    "        ig = torchvision.transforms.Resize(size=(64,64))(ig) \n",
    "        distance = loss_fn_alex(ig, score)\n",
    "        distances.append(distance.item())        \n",
    "    print(np.mean(distances))\n",
    "    results.append(np.mean(distances))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef4439e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# avoid type-3 fonts\n",
    "import matplotlib\n",
    "import seaborn as sns\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "\n",
    "sns.set_style(\"white\")\n",
    "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=1.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d75ae0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilons = ['0.01', '0.03', '0.05', '0.1', '0.25', '0.5', '1', '3', '5']\n",
    "accuracies = np.array([69.90, 69.24, 69.15, 68.77, 67.43, 65.49, 62.32, 53.12, 45.59]) / 100 # https://github.com/microsoft/robust-models-transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73999d67",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,6))\n",
    "ax1 = plt.gca()\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax1.plot(epsilons, accuracies, 'o--', ms=10, color='#1f77b4', label='Accuracy')\n",
    "ax1.set_ylabel('Accuracy', color='#1f77b4')\n",
    "ax1.tick_params(axis='y', colors='#1f77b4')\n",
    "\n",
    "\n",
    "ax2.plot(epsilons, [1-x for x in results], 'o--', ms=10, color='#ff7f0e')\n",
    "ax2.set_ylabel('1-LPIPS', color='#ff7f0e')\n",
    "ax2.tick_params(axis='y', colors='#ff7f0e')\n",
    "\n",
    "plt.title('ImageNet')\n",
    "\n",
    "#ax1.get_xaxis().get_major_formatter().labelOnlyBase = False\n",
    "#ax1.set_xscale('log')\n",
    "#ax1.set_xticks([float(x) for x in epsilons])\n",
    "\n",
    "ax1.set_xlabel('Adversarial Perturbation Budget (Epsilon)')\n",
    "plt.savefig('../figures/imagenet_lpips.pdf')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5255ca93",
   "metadata": {},
   "source": [
    "## Page 2 Figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aeeb048",
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 5\n",
    "ncols = 5\n",
    "img_offset = 15\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10))\n",
    "\n",
    "for idx in range(ncols):\n",
    "    # image\n",
    "    img = images[img_offset+idx, :, :, :]\n",
    "    img = torchvision.transforms.Resize(size=(64,64))(img) \n",
    "    img = (img * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # score\n",
    "    img = (scores[img_offset+idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # models\n",
    "    for model_idx, model_name in enumerate(['../saved_models/imagenet_robust/resnet18_l2_eps0.ckpt',\n",
    "                                            '../saved_models/imagenet_robust/resnet18_l2_eps5.ckpt',\n",
    "                                            '../saved_models/imagenet_robust/resnet18_l2_eps5.ckpt']):\n",
    "        ig = input_gradients[model_name][img_offset+idx, :, :, :]\n",
    "        ig = torchvision.transforms.Resize(size=(64,64))(ig) \n",
    "        img = (ig * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "        axs[2+model_idx, 0].set_ylabel(model_name[13:-3])\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(ncols):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=.4)\n",
    "plt.savefig('../figures/cifar10-page-2.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b55a261",
   "metadata": {},
   "source": [
    "## Page 7 Figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2acc4893",
   "metadata": {},
   "outputs": [],
   "source": [
    "values = []\n",
    "for i in range(20):\n",
    "    x = np.argmin(distances)\n",
    "    values.append(x)\n",
    "    distances[x] = 1000\n",
    "print(values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bd28cb5",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = '../saved_models/imagenet_robust/resnet18_l2_eps5.ckpt'\n",
    "for idx in range(25):\n",
    "    print(idx)\n",
    "    img = images[idx]\n",
    "    img = torchvision.transforms.Resize(size=(64,64))(img) \n",
    "    plt.imshow((img * 255).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    plt.imshow((scores[idx] * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    ig = input_gradients[model_name][idx] \n",
    "    ig = torchvision.transforms.Resize(size=(64,64))(ig) \n",
    "    plt.imshow((ig * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47dd8f26",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = '../saved_models/imagenet_robust/resnet18_l2_eps5.ckpt'\n",
    "for idx in values:\n",
    "    print(idx)\n",
    "    img = images[idx]\n",
    "    img = torchvision.transforms.Resize(size=(64,64))(img) \n",
    "    plt.imshow((img * 255).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    plt.imshow((scores[idx] * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    ig = input_gradients[model_name][idx] \n",
    "    ig = torchvision.transforms.Resize(size=(64,64))(ig) \n",
    "    plt.imshow((ig * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34b21698",
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 3\n",
    "ncols = 4\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 6))\n",
    "\n",
    "for icol, idx in enumerate([89, 93, 9, 8]):\n",
    "    # image\n",
    "    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # score\n",
    "    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # models\n",
    "    img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[2, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(ncols):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=0.4)\n",
    "plt.savefig('imagenet-gradients-big.png')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
