{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST with a distractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import torchvision\n",
    "\n",
    "from models.resnet import ResNet18\n",
    "\n",
    "from utils.datasets import get_dataloaders\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "import utils\n",
    "from utils.edm_score import input_gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# avoid type-3 fonts\n",
    "import matplotlib\n",
    "import seaborn as sns\n",
    "\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "\n",
    "sns.set_style(\"white\")\n",
    "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=2)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create an MNIST data set with a single distractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainloader, mnist_testloader = get_dataloaders(\"mnist\", batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_images = iter(mnist_trainloader).__next__()[0].to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image, ImageDraw, ImageFont\n",
    "import random\n",
    "\n",
    "def draw_text_simple(img, text=None):\n",
    "    \"\"\" draws the letter A on an image\"\"\"\n",
    "    img = img.copy()\n",
    "    draw = ImageDraw.Draw(img)\n",
    "    fnt = ImageFont.truetype(\"Pillow/Tests/fonts/FreeMono.ttf\", random.choice([35]))\n",
    "    draw.text((9, 2), \"A\", font=fnt, \n",
    "            stroke_width=1,\n",
    "            fill=\"white\",\n",
    "            stroke_fill=\"white\")\n",
    "    return img\n",
    "\n",
    "# create 28x28 black and white image with pillow\n",
    "img = Image.new('L', (40, 40), color=0)\n",
    "img = draw_text_simple(img, \"AB\")\n",
    "\n",
    "img = np.array(img)\n",
    "plt.imshow(img, cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_distractor(text):\n",
    "    img = Image.new('L', (40, 40), color=0)\n",
    "    img = draw_text_simple(img, text)\n",
    "    img = img.resize((28, 28))\n",
    "    return np.array(img)\n",
    "\n",
    "img = create_distractor(\"AB\")\n",
    "plt.imshow(img, cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random, string\n",
    "\n",
    "def create_distractor_image(mnist_image, text=None):\n",
    "    if text is None:\n",
    "        text = ''.join(random.choice(string.ascii_letters) for _ in range(2))\n",
    "    distractor = torch.Tensor(create_distractor(text)) / 255\n",
    "    image = torch.zeros((56, 28))\n",
    "    on_top = False\n",
    "    if np.random.random() < 0.5: # number on top block\n",
    "        image[:28, :] = mnist_image\n",
    "        image[28:, :] = distractor\n",
    "    else:\n",
    "        image[:28, :] = distractor\n",
    "        image[28:, :] = mnist_image\n",
    "        on_top = True\n",
    "    return image, on_top\n",
    "\n",
    "\n",
    "img2, on_top = create_distractor_image(eval_images[0][0])\n",
    "plt.imshow(img2, cmap='gray')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the minst word distractor data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = torchvision.datasets.MNIST('data/', train=True, transform=torchvision.transforms.ToTensor())\n",
    "testset = torchvision.datasets.MNIST('data/', train=False, transform=torchvision.transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_distractor_mnist_x = torch.zeros((60000, 1, 56, 28))\n",
    "word_distractor_mnist_y = torch.zeros((60000,))\n",
    "dataloader = torch.utils.data.DataLoader(trainset, batch_size=1)\n",
    "for idx, (img, label) in enumerate(iter(dataloader)):\n",
    "    word_distractor_mnist_x[idx] = create_distractor_image(img[0][0])[0]\n",
    "    word_distractor_mnist_y[idx] = label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_distractor_mnist_test_x = torch.zeros((10000, 1, 56, 28))\n",
    "word_distractor_mnist_test_y = torch.zeros((10000,))\n",
    "dataloader = torch.utils.data.DataLoader(testset, batch_size=1)\n",
    "for idx, (img, label) in enumerate(iter(dataloader)):\n",
    "    word_distractor_mnist_test_x[idx] = create_distractor_image(img[0][0])[0]\n",
    "    word_distractor_mnist_test_y[idx] = label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save((word_distractor_mnist_x, word_distractor_mnist_y, word_distractor_mnist_test_x, word_distractor_mnist_test_y), '../datasets/simple_word_distractor_mnist.pt')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and compare the two models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "\n",
    "model_files = ['../saved_models/simple_word_distractor_mnist/resnet18_reg=none_simple_word_distractor_mnist.pt',\n",
    "               '../saved_models/simple_word_distractor_mnist/mnist_simple_word_distractor_adv_robust_l2.pth']\n",
    "\n",
    "models = {}\n",
    "for model_file in model_files:\n",
    "    model = ResNet18(in_channel=1)\n",
    "    model.load_state_dict(torch.load(model_file))\n",
    "    model.eval()\n",
    "    model.to('cpu')\n",
    "    models[os.path.basename(model_file)] = model    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainloader, testloader = get_dataloaders(\"simple_word_distractor_mnist\", batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name, model in models.items():\n",
    "    model.to('cuda')\n",
    "    print(model_name, utils.datasets.test(model, testloader, device))\n",
    "    model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_images = []\n",
    "for img, _ in tqdm.tqdm(testloader):\n",
    "    eval_images.append(img.detach().cpu())\n",
    "eval_images = torch.vstack(eval_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_gradients = {k : [] for k, _ in models.items()}\n",
    "for model_name, model in models.items():\n",
    "    for img, _ in tqdm.tqdm(testloader):\n",
    "        img = img.to(device)\n",
    "        gradient = input_gradient(model, img).detach().cpu()\n",
    "        input_gradients[model_name].append(gradient)\n",
    "input_gradients = {k:torch.cat(v) for k,v in input_gradients.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# scale the lenght of input gradients so that they lie in [-1,1]\n",
    "for model_name, _ in models.items():\n",
    "    for idx in range(input_gradients[model_name].shape[0]):\n",
    "        input_gradients[model_name][idx] = input_gradients[model_name][idx]  / input_gradients[model_name][idx].abs().max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "on_top = [1, 0, 0, 1, 1, 0, 0 ,0 ,0 , 1]\n",
    "__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))\n",
    "\n",
    "for idx in range(10):\n",
    "    img = eval_images[idx, :, :, :].clone().detach().squeeze()\n",
    "    noise = 0.25*torch.randn_like(img) \n",
    "    if on_top[idx]:\n",
    "        img[:28, :] = img[:28, :] + noise[:28, :]\n",
    "    else:\n",
    "        img[28:, :] = img[28:, :] + noise[28:, :]\n",
    "    img = (img * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)\n",
    "    axs[idx].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "on_top = [1, 0, 0, 1, 1, 0, 0 ,0 ,0 , 1]\n",
    "__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))\n",
    "\n",
    "for idx in range(10):\n",
    "    img = eval_images[idx, :, :, :].clone().detach().squeeze()\n",
    "    noise = 0.25*torch.randn_like(img) \n",
    "    if not on_top[idx]:\n",
    "        img[:28, :] = img[:28, :] + noise[:28, :]\n",
    "    else:\n",
    "        img[28:, :] = img[28:, :] + noise[28:, :]\n",
    "    img = (img * 255).clip(0, 255).to(torch.uint8)\n",
    "    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)\n",
    "    axs[idx].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'resnet18_reg=none_simple_word_distractor_mnist.pt'\n",
    "__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))\n",
    "\n",
    "for idx in range(10):\n",
    "    img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)\n",
    "    axs[idx].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'mnist_simple_word_distractor_adv_robust_l2.pth'\n",
    "__, axs = plt.subplots(nrows=1, ncols=10, figsize=(20, 30))\n",
    "\n",
    "for idx in range(10):\n",
    "    img = (input_gradients[model_name][idx, :, :, :] * 127.5 + 128).clip(0, 255).to(torch.uint8)\n",
    "    axs[idx].imshow(img.cpu().numpy().squeeze(), cmap='gray', vmin=0, vmax=255)\n",
    "    axs[idx].axis('off')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare robustness on signal versus distractor part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax_l1(x, y):\n",
    "    return (F.softmax(x, dim=1) - F.softmax(y, dim=1)).abs().sum(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainloader, mnist_testloader = get_dataloaders(\"mnist\", batch_size=1)\n",
    "\n",
    "# create a new evaluation data set where we remember the position of the distractor\n",
    "eval_images = []\n",
    "eval_images_on_top = []\n",
    "for idx, (mnist_images, label) in enumerate(mnist_testloader):\n",
    "    img, on_top = create_distractor_image(mnist_images[0][0])\n",
    "    img = img\n",
    "    eval_images.append(img)\n",
    "    eval_images_on_top.append(on_top)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_noise_std = np.logspace(-2, 1, num=30)\n",
    "data_noise_std = np.insert(data_noise_std, 0, 0)\n",
    "model_results = {}\n",
    "\n",
    "for model_name, model in models.items():\n",
    "    # noise on entire image\n",
    "    all_noise_results = {sigma: [] for sigma in data_noise_std}\n",
    "    for img, on_top in tqdm.tqdm(zip(eval_images, eval_images_on_top)):\n",
    "        for sigma in data_noise_std:\n",
    "            noise = sigma * torch.randn_like(img)\n",
    "            noisy_image = img.clone().detach() + noise\n",
    "            logits = model(noisy_image.to(device).unsqueeze(0).unsqueeze(0))\n",
    "            all_noise_results[sigma].append(logits.detach().cpu())\n",
    "    all_noise_results = {x: torch.vstack(y) for x,y in all_noise_results.items()}\n",
    "\n",
    "    # noise only on signal\n",
    "    signal_noise_results = {sigma: [] for sigma in data_noise_std}\n",
    "    for img, on_top in tqdm.tqdm(zip(eval_images, eval_images_on_top)):\n",
    "        for sigma in data_noise_std:\n",
    "            noise = sigma * torch.randn_like(img)\n",
    "            noisy_image = img.clone().detach()\n",
    "            if not on_top:\n",
    "                noisy_image[:28, :] = noisy_image[:28, :] + noise[:28, :]\n",
    "            else:\n",
    "                noisy_image[28:, :] = noisy_image[28:, :] + noise[28:, :]\n",
    "            logits = model(noisy_image.to(device).unsqueeze(0).unsqueeze(0))\n",
    "            signal_noise_results[sigma].append(logits.detach().cpu())\n",
    "    signal_noise_results = {x: torch.vstack(y) for x,y in signal_noise_results.items()}\n",
    "    \n",
    "    # noise only on distractor\n",
    "    distractor_noise_results = {sigma: [] for sigma in data_noise_std}\n",
    "    for img, on_top in tqdm.tqdm(zip(eval_images, eval_images_on_top)):\n",
    "        for sigma in data_noise_std:\n",
    "            noise = sigma * torch.randn_like(img)\n",
    "            noisy_image = img.clone().detach()\n",
    "            if on_top:\n",
    "                noisy_image[:28, :] = noisy_image[:28, :] + noise[:28, :]\n",
    "            else:\n",
    "                noisy_image[28:, :] = noisy_image[28:, :] + noise[28:, :]\n",
    "            logits = model(noisy_image.to(device).unsqueeze(0).unsqueeze(0))\n",
    "            distractor_noise_results[sigma].append(logits.detach().cpu())\n",
    "    distractor_noise_results = {x: torch.vstack(y) for x,y in distractor_noise_results.items()}\n",
    "\n",
    "\n",
    "    # store results\n",
    "    model_results[model_name] = [all_noise_results, distractor_noise_results, signal_noise_results]\n",
    "\n",
    "    # plot\n",
    "    plot_result = {}\n",
    "    original = all_noise_results[0.0]\n",
    "    for sigma in data_noise_std: \n",
    "        noisy = all_noise_results[sigma]\n",
    "        estimate = softmax_l1(noisy, original)\n",
    "        plot_result[sigma] = estimate.mean().item()\n",
    "    plt.plot(data_noise_std, list(plot_result.values()), label='Noise on Entire Image')\n",
    "\n",
    "    plot_result = {}\n",
    "    original = distractor_noise_results[0.0]\n",
    "    for sigma in data_noise_std: \n",
    "        noisy =  distractor_noise_results[sigma]\n",
    "        estimate = softmax_l1(noisy, original)\n",
    "        plot_result[sigma] = estimate.mean().item()\n",
    "    plt.plot(data_noise_std, list(plot_result.values()), label='Noise on Distractor')\n",
    "\n",
    "    plot_result = {}\n",
    "    original = signal_noise_results[0.0]\n",
    "    for sigma in data_noise_std: \n",
    "        noisy =  signal_noise_results[sigma]\n",
    "        estimate = softmax_l1(noisy, original)\n",
    "        plot_result[sigma] = estimate.mean().item()\n",
    "    plt.plot(data_noise_std, list(plot_result.values()), label='Noise on Signal')\n",
    "\n",
    "    plt.xscale('log')\n",
    "    plt.xlabel(\"Magnitude of Noise Added to Image\")\n",
    "    plt.ylabel(\"L1 Deviation with Original Softmax Score\")\n",
    "    plt.legend(fontsize=6)  \n",
    "    plt.tight_layout()\n",
    "    plt.legend(loc=4)\n",
    "    plt.title(model_name)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model_results, '../data/relative_noise_robustness_results.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_noise_std = np.logspace(-2, 1, num=30)\n",
    "data_noise_std = np.insert(data_noise_std, 0, 0)\n",
    "model_results = torch.load('../data/relative_noise_robustness_results.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=1.)\n",
    "\n",
    "#plt.figure(figsize=(8,6))\n",
    "\n",
    "model_name_1 = 'resnet18_reg=none_simple_word_distractor_mnist.pt'\n",
    "model_name_2 = 'mnist_simple_word_distractor_adv_robust_l2.pth'\n",
    "\n",
    "for model_name in [model_name_1, model_name_2]:\n",
    "    all_noise, distractor_noise, signal_noise = model_results[model_name]\n",
    "    \n",
    "    plot_result = {}\n",
    "    original = all_noise[0.0]\n",
    "    for sigma in data_noise_std: \n",
    "        dn =  distractor_noise[sigma]\n",
    "        sn = signal_noise[sigma]\n",
    "        dn_estimate = softmax_l1(dn, original)\n",
    "        sn_estimate = softmax_l1(sn, original)\n",
    "        plot_result[sigma] = sn_estimate.mean().item() / (dn_estimate.mean().item()+1e-12)\n",
    "    plt.plot(data_noise_std, list(plot_result.values()), 'o--', label= 'Resnet18' if model_name == 'resnet18_reg=none_simple_word_distractor_mnist.pt' else 'Robust Resnet18', ms=6, lw=2)\n",
    "\n",
    "plt.hlines(1, 0, 10, color = 'red', lw=2) # , label=\"Equally Robustness to Signal and Distractor Noise\"\n",
    "plt.xscale('log')\n",
    "#plt.yscale('log')\n",
    "plt.xlabel(\"Noise Level\")\n",
    "plt.ylabel(\"Signal/Distractor\")\n",
    "plt.legend(fontsize=6)  \n",
    "plt.tight_layout()\n",
    "# Put a legend below current axis\n",
    "#plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2),\n",
    "#          fancybox=True, shadow=True, ncol=1)\n",
    "plt.legend()\n",
    "plt.ylim([0,12])\n",
    "plt.yticks([1,4,7,10])\n",
    "plt.title(\"Relative Noise Robustness (MNIST)\")\n",
    "plt.savefig('../figures/relative-noise-robustness.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deeplearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "36d4be3fdecffea904e02ae39f4d4fa6bdc0b9865970412f32cf489e5dafede0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
