{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcbea600",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "import PIL\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.edm_score import input_gradient\n",
    "\n",
    "from torchvision.models import resnet18\n",
    "\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d296be4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# avoid type-3 fonts\n",
    "import matplotlib\n",
    "import seaborn as sns\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "\n",
    "sns.set_style(\"white\")\n",
    "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40841337",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63ee03b8",
   "metadata": {},
   "source": [
    "## load the data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7796652",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_images, val_labels = torch.load('../datasets/imagenet-64x64-val.pth')\n",
    "valset = torch.utils.data.TensorDataset(val_images, val_labels)\n",
    "valloader = torch.utils.data.DataLoader(valset, batch_size=512, shuffle=True, num_workers=8)\n",
    "valloader_single_images = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=True, num_workers=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "975ea4d4",
   "metadata": {},
   "source": [
    "## load pre-trained networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "323571d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../saved_models/imagenet_robust/imagenet64x64/'\n",
    "models = {}\n",
    "model_names = [#'resnet18_l2_eps0.0.pth',\n",
    "               'resnet18_l2_eps0.01.pth',\n",
    "               'resnet18_l2_eps0.1.pth',\n",
    "               'resnet18_l2_eps5.0.pth',\n",
    "               'resnet18_l2_eps10.0.pth',\n",
    "               'resnet18_l2_eps20.0.pth',\n",
    "               'resnet18_l2_eps50.0.pth',\n",
    "               'resnet18_l2_eps100.0.pth',\n",
    "               'resnet18_l2_eps200.0.pth',\n",
    "               'resnet18_l2_eps500.0.pth',\n",
    "               'resnet18_l2_eps2500.0.pth',\n",
    "               'resnet18_l2_eps5000.0.pth']\n",
    "\n",
    "for file in model_names:\n",
    "    pos = file.find('eps') \n",
    "    pos = pos + len('eps')\n",
    "    eps = file[pos:pos+file[pos:].find('.pth')]\n",
    "    eps = float(eps)\n",
    "    model = resnet18()\n",
    "    state_dict = torch.load(os.path.join(path, file))\n",
    "    # remove `module.` from distributed training\n",
    "    if 'module.conv1.weight' in state_dict.keys():\n",
    "        cleaned_state_dict = OrderedDict()\n",
    "        for k, v in state_dict.items():\n",
    "            name = k[7:]\n",
    "            cleaned_state_dict[name] = v\n",
    "        state_dict = cleaned_state_dict\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.to('cpu')\n",
    "    model.eval()\n",
    "    models[file] = model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79aef7b9",
   "metadata": {},
   "source": [
    "## accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df713166",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilons =  [0.01, 0.1, 5, 10, 20, 50, 100, 200, 500, 2500, 5000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f1df54",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracies = {}\n",
    "for model_name, model in models.items():\n",
    "    model.to(device)\n",
    "    val_loss = 0\n",
    "    val_zero_one_loss = 0\n",
    "    for img, label in tqdm.tqdm(valloader):\n",
    "        img = img / 255\n",
    "        img, label = img.to(device), label.to(device)\n",
    "        pred = model(img)\n",
    "        val_zero_one_loss += (pred.softmax(dim=1).argmax(dim=1) != label).sum().item()\n",
    "    accuracies[model_name] = 1-val_zero_one_loss / len(valloader.dataset)\n",
    "    print(f'{model_name} Val Acc. {accuracies[model_name]}')\n",
    "    model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3fda52d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(epsilons, accuracies.values(), 'o--')\n",
    "plt.xscale('log')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b3aeb8c",
   "metadata": {},
   "source": [
    "## load the diffusion model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9c5ccac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import diffusion_model\n",
    "\n",
    "sigma = 1.2\n",
    "diffusion = diffusion_model.Diffusion(f'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737ddffa",
   "metadata": {},
   "source": [
    "## input gradients and scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bb678bd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "N_images = 1000\n",
    "\n",
    "images = []\n",
    "scores = []\n",
    "input_gradients = {k : [] for k, _ in models.items()}\n",
    "\n",
    "for idx, (img, label) in tqdm.tqdm(enumerate(valloader_single_images)):\n",
    "    img = img / 255\n",
    "    images.append(img.detach().cpu())\n",
    "    img, label = img.to(device), label.to(device)\n",
    "    # input gradient, for all models\n",
    "    for model_name, model in models.items():\n",
    "        model.to(device)\n",
    "        assert img.grad is None\n",
    "        ig = input_gradient(model, img).detach().cpu()\n",
    "        input_gradients[model_name].append(ig)\n",
    "        model.to('cpu')\n",
    "    # score\n",
    "    diffusion.to(device)\n",
    "    score = diffusion.get_score(img.to(device), sigma, class_labels=label)\n",
    "    scores.append(score.detach().cpu())\n",
    "    diffusion.to('cpu')\n",
    "    if idx >= N_images:\n",
    "        break\n",
    "images = torch.vstack(images)\n",
    "input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}\n",
    "scores = torch.vstack(scores)        \n",
    "\n",
    "# scale the lenght of score and input gradients so that they lie in [-1,1]\n",
    "for model_name, _ in models.items():\n",
    "    for idx in range(input_gradients[model_name].shape[0]):\n",
    "        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()\n",
    "for idx in range(scores.shape[0]): # score to [-1, 1]\n",
    "    scores[idx] = scores[idx] / scores[idx].abs().max()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fa80b96",
   "metadata": {},
   "source": [
    "## visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e1abb3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = 2+len(models)\n",
    "__, axs = plt.subplots(nrows=nrows, ncols=10, figsize=(20, 26))\n",
    "\n",
    "for idx in range(10):\n",
    "    img = (images[idx, :, :, :]*255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "    \n",
    "    img = (scores[idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "    # different models\n",
    "    for model_idx, (model_name, _) in enumerate(models.items()):\n",
    "        img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "        axs[2+model_idx, idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "for ax in axs:\n",
    "    for idx in range(10):\n",
    "        ax[idx].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=0.5)\n",
    "plt.savefig('../figures/imagenet64x64-gradients-big.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "210c3891",
   "metadata": {},
   "source": [
    "## LPIPS metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f396bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import lpips\n",
    "\n",
    "# init lpips \n",
    "torch.hub.set_dir(\"../tmp/.cache/torchhub\") # set hub to writeable directory\n",
    "loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b765325c",
   "metadata": {},
   "outputs": [],
   "source": [
    "lpips_results = []\n",
    "for model_name, _ in models.items():\n",
    "    print(model_name)\n",
    "    distances = []\n",
    "    for img_idx in range(N_images): # for all images\n",
    "        score = scores[img_idx]\n",
    "        ig = input_gradients[model_name][img_idx]\n",
    "        # bilinear downsize for the input gradient to have the same size as the score \n",
    "        ig = torchvision.transforms.Resize(size=(64,64))(ig) \n",
    "        distance = loss_fn_alex(ig, score)\n",
    "        distances.append(distance.item())        \n",
    "    print(np.mean(distances))\n",
    "    lpips_results.append(np.mean(distances))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93008668",
   "metadata": {},
   "source": [
    "## plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11d1c8bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,6))\n",
    "ax1 = plt.gca()\n",
    "ax2 = ax1.twinx()\n",
    "\n",
    "ax1.plot(epsilons, accuracies.values(), 'o--', ms=10, color='#1f77b4', label='Accuracy')\n",
    "ax1.set_ylabel('Accuracy', color='#1f77b4')\n",
    "ax1.tick_params(axis='y', colors='#1f77b4')\n",
    "\n",
    "ax2.plot(epsilons, [1-x for x in lpips_results], 'o--', ms=10, color='#ff7f0e')\n",
    "ax2.set_ylabel('1-LPIPS', color='#ff7f0e')\n",
    "ax2.tick_params(axis='y', colors='#ff7f0e')\n",
    "\n",
    "plt.title('ImageNet-64x64')\n",
    "plt.xscale('log')\n",
    "ax1.set_xlabel('Adversarial Perturbation Budget (Epsilon)')\n",
    "plt.savefig('../figures/imagenet64x64_lpips.pdf')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98e4a828",
   "metadata": {},
   "source": [
    "## plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eadaca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_models = {k: v for k, v in models.items()}\n",
    "plot_models.pop('resnet18_l2_eps10.0.pth')\n",
    "plot_models.pop('resnet18_l2_eps100.0.pth')\n",
    "plot_models.pop('resnet18_l2_eps200.0.pth')\n",
    "\n",
    "for img_idx in [24]:\n",
    "    nrows = 1\n",
    "    ncols = 11\n",
    "    __, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2*ncols, 2))\n",
    "\n",
    "    # image\n",
    "    img = (images[img_idx, :, :, :] * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[0].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # score\n",
    "    img = (scores[img_idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[1].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)), vmin=0, vmax=255)\n",
    "\n",
    "    # models\n",
    "    for model_idx, (model_name, _) in enumerate(plot_models.items()):\n",
    "        img = (input_gradients[model_name][img_idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "        axs[2+model_idx].imshow(img.cpu().numpy().squeeze().transpose((1, 2, 0)))\n",
    "\n",
    "    for idx in range(ncols):\n",
    "        axs[idx].axis('off')\n",
    "    \n",
    "    plt.tight_layout(pad=.2)\n",
    "    plt.savefig('../figures/imagenet-example.pdf')\n",
    "    plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
