{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the alignment of the score with the input gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "\n",
    "from utils.datasets import get_dataloaders\n",
    "from models.resnet import ResNet18\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "from utils.edm_score import cifar10_score, input_gradient_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_reg_const(model_name, const_name):\n",
    "    magic_string = f'{const_name}='\n",
    "    pos = model_name.find(magic_string) \n",
    "    pos = pos + len(magic_string)\n",
    "    reg_const = model_name[pos:pos+model_name[pos:].find('_')]\n",
    "    return float(reg_const)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gnorm regularized models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../saved_models/gnorm_c10/'  \n",
    "\n",
    "gnorms = []\n",
    "model_files = {}\n",
    "for file in os.listdir(path): # load one representative model for each regularization constant\n",
    "    if file[-3:] != '.pt':\n",
    "        continue\n",
    "    \n",
    "    reg_const = get_reg_const(file, 'gnorm_const')\n",
    "    if not reg_const in gnorms:\n",
    "        gnorms.append(reg_const)\n",
    "        model_files[reg_const] = os.path.join(path, file)\n",
    "\n",
    "gnorms.sort()\n",
    "model_files = [model_files[x] for x in gnorms]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Randomized Smoothing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../saved_models/randomized_smoothing_c10/'  \n",
    "\n",
    "gnorms = []\n",
    "model_files = {}\n",
    "for file in os.listdir(path): # load one representative model for each regularization constant\n",
    "    if file[-3:] != '.pt':\n",
    "        continue\n",
    "    \n",
    "    reg_const = get_reg_const(file, 'noise_level')\n",
    "    if not reg_const in gnorms:\n",
    "        gnorms.append(reg_const)\n",
    "        model_files[reg_const] = os.path.join(path, file)\n",
    "\n",
    "gnorms.sort()\n",
    "model_files = [model_files[x] for x in gnorms]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Smothness Penalty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../saved_models/smooth_c10/'   \n",
    "gnorms = []\n",
    "model_files = {}\n",
    "for file in os.listdir(path): # load one representative model for each regularization constant\n",
    "    if file[-3:] != '.pt':\n",
    "        continue\n",
    "    \n",
    "    reg_const = get_reg_const(file, 'gnorm_const')\n",
    "    if not reg_const in gnorms:\n",
    "        gnorms.append(reg_const)\n",
    "        model_files[reg_const] = os.path.join(path, file)\n",
    "\n",
    "gnorms.sort()\n",
    "model_files = [model_files[x] for x in gnorms]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Models for the plot on page 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_files = ['../saved_models/standard_c10/resnet18_reg=none_cifar10.pt',\n",
    "               '../saved_models/gnorm_c10/full/resnet18_reg=gnorm_const=0.07278953843983153_cifar10.pt',\n",
    "               '../saved_models/gnorm_c10/full/resnet18_reg=gnorm_const=57.361525104486816_cifar10.pt']"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the selected models files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {}\n",
    "for model_file in model_files:\n",
    "    model = ResNet18()\n",
    "    model.load_state_dict(torch.load(model_file))\n",
    "    model.eval()\n",
    "    model.to('cpu')\n",
    "    models[os.path.basename(model_file)] = model    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compute input gradients and scores for the entire test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainloader, testloader = get_dataloaders(\"cifar10\", batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 0.5\n",
    "\n",
    "input_gradients = {k : [] for k, _ in models.items()}\n",
    "for model_name, model in models.items():\n",
    "    for img, _ in tqdm.tqdm(testloader):\n",
    "        img = img.to(device)\n",
    "        gradient = input_gradient_sum(model, img, device=device).detach().cpu()\n",
    "        input_gradients[model_name].append(gradient)\n",
    "input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}\n",
    "\n",
    "images = []\n",
    "scores = []\n",
    "for img, _ in tqdm.tqdm(testloader):\n",
    "    img = img.to(device)\n",
    "    images.append(img.detach().cpu())\n",
    "    scores.append(cifar10_score(img, sigma, device=device).detach().cpu())\n",
    "images = torch.vstack(images)\n",
    "scores = torch.vstack(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# scale the lenght of score and input gradients so that they lie in [-1,1]\n",
    "for idx in range(scores.shape[0]):\n",
    "    scores[idx] = scores[idx] / scores[idx].abs().max()\n",
    "for model_name, _ in models.items():\n",
    "    for idx in range(input_gradients[model_name].shape[0]):\n",
    "        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's first make a big visulization of all the different images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we have 30 different models, make the plot for 15\n",
    "representative_models = {k: v for k, v in models.items() if get_reg_const(k, 'gnorm_const') in gnorms[0::2]}\n",
    "#representative_models = {k: v for k, v in models.items() if get_reg_const(k, 'noise_level') in gnorms[0::2]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 2+len(representative_models)\n",
    "ncols = 12\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 28))\n",
    "\n",
    "for idx in range(ncols):\n",
    "    # image\n",
    "    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # score\n",
    "    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # models\n",
    "    for model_idx, (model_name, _) in enumerate(representative_models.items()):\n",
    "        img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "        axs[2+model_idx, 0].set_ylabel(model_name[13:-3])\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(ncols):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=0.5)\n",
    "plt.savefig('../figures/cifar10-gradients-big.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The plot on page 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_models = {k: v for k, v in models.items() if get_reg_const(k, 'gnorm_const') in gnorms[0::3]}\n",
    "plot_models.pop('resnet18_reg=gnorm_const=0.0014873521072935117_cifar10.pt')\n",
    "\n",
    "img_idx = 0\n",
    "nrows = 1\n",
    "ncols = 11\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2*ncols, 2))\n",
    "\n",
    "# image\n",
    "img = (images[img_idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "axs[0].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "# score\n",
    "img = (scores[img_idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "axs[1].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "# models\n",
    "for model_idx, (model_name, _) in enumerate(plot_models.items()):\n",
    "    img = (input_gradients[model_name][img_idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[2+model_idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "for idx in range(ncols):\n",
    "    axs[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=.5)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 5\n",
    "ncols = 5\n",
    "img_offset = 10\n",
    "\n",
    "for idx in range(ncols):\n",
    "    # image\n",
    "    img = (images[img_offset+idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    plt.figure(figsize=(4,4))\n",
    "    plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "    plt.axis('off')\n",
    "    plt.tight_layout(pad=0.)\n",
    "    plt.savefig(f'../figures/fig2-cifar10-img-{idx}.png')\n",
    "    plt.show()\n",
    "\n",
    "    # score\n",
    "    img = (scores[img_offset+idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    plt.figure(figsize=(4,4))\n",
    "    plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "    plt.axis('off')\n",
    "    plt.tight_layout(pad=0.)\n",
    "    plt.savefig(f'../figures/fig2-cifar10-img-{idx}-score.png')\n",
    "    plt.show()\n",
    "\n",
    "    # models\n",
    "    for model_idx, (model_name, _) in enumerate(models.items()):\n",
    "        plt.figure(figsize=(4,4))\n",
    "        img = (input_gradients[model_name][img_offset+idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "        plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "        plt.axis('off')\n",
    "        plt.tight_layout(pad=0.)\n",
    "        plt.savefig(f'../figures/fig2-cifar10-img-{idx}-model-{model_idx}.png')\n",
    "        plt.show()\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(ncols):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=1.0)\n",
    "plt.savefig('../figures/cifar10-page-2.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The plot on page 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lpips\n",
    "\n",
    "loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distances = []\n",
    "model_name = 'resnet18_reg=gnorm_const=0.07278953843983153_cifar10.pt'\n",
    "for img_idx in range(scores.shape[0]): # for all images\n",
    "    score = scores[img_idx]\n",
    "    ig = input_gradients[model_name][img_idx]\n",
    "    distance = loss_fn_alex(ig, score)\n",
    "    distances.append(distance.item())        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "values = []\n",
    "for i in range(20):\n",
    "    x = np.argmin(distances)\n",
    "    values.append(x)\n",
    "    distances[x] = 1000\n",
    "print(values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in values:\n",
    "    # image\n",
    "    print(idx)\n",
    "    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    plt.imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    plt.imshow((scores[idx] * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    plt.imshow((input_gradients[model_name][idx] * 127.5 + 128).clip(0, 255).to(torch.uint8).cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 3\n",
    "ncols = 4\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 6))\n",
    "\n",
    "for icol, idx in enumerate([7908, 2649, 6748, 821]):\n",
    "    # image\n",
    "    img = (images[idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # score\n",
    "    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # models\n",
    "    img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[2, icol].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(ncols):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=0.75)\n",
    "plt.savefig('../figures/cifar10-gradients-big.png')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deeplearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "36d4be3fdecffea904e02ae39f4d4fa6bdc0b9865970412f32cf489e5dafede0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
