{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "91b2b4cb-d0bb-43ea-90d0-47bb1b537f88",
   "metadata": {},
   "source": [
    "# Initialize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "320af0a3-c02e-421b-85e6-d9f6353ab0ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import copy \n",
    "from collections import namedtuple\n",
    "import math\n",
    "import numpy as np\n",
    "npa = np.array\n",
    "import pandas as pd \n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "sys.path.insert(0, '[PATH_TO_CODE]')\n",
    "sys.path.insert(0, '[PATH_TO_CODE]/models')\n",
    "\n",
    "from utils import *\n",
    "import foolbox as fb\n",
    "from autoattack import AutoAttack\n",
    "import resnet_cifar\n",
    "import matplotlib.pyplot as plt \n",
    "import matplotlib as mpl \n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "\n",
    "import seaborn as sns\n",
    "sns.set_style(\"white\")\n",
    "sns.set_style(\"ticks\", {\"xtick.major.size\": 14, \"ytick.major.size\": 14})\n",
    "sns.set_context(\"paper\")\n",
    "mpl.rcParams['axes.linewidth']=2.5\n",
    "mpl.rcParams['ytick.major.width']=2.5\n",
    "mpl.rcParams['xtick.major.width']=2.5\n",
    "\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d0306ed1-8245-45c8-8a91-325ac49659aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "ROOT_PATH = '[PATH_TO_DATA]'\n",
    "C10_DATA_PATH = os.path.join(ROOT_PATH, 'cifar10')\n",
    "\n",
    "cifar10 = get_cifar10_test_loader(batch_size=100, data_path=C10_DATA_PATH, norm=False, shuffle=True)\n",
    "\n",
    "c10_images, c10_labels = [], []\n",
    "for im, l in cifar10:\n",
    "    c10_images.append(im)\n",
    "    c10_labels.append(l)    \n",
    "c10_images = torch.cat(c10_images, dim=0)\n",
    "c10_labels = torch.cat(c10_labels, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4426ce59-0512-475d-808e-4c5b5f04a205",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'cifar10'\n",
    "arch = 'resnet18'\n",
    "epoch = 199\n",
    "pool_type = 'mean'\n",
    "max_num_pools = 1\n",
    "noise_std = 0.1\n",
    "exp_name = '' \n",
    "\n",
    "Args = namedtuple('nt', ['dataset', 'arch', 'pool_type', 'max_num_pools', 'noise_std'])\n",
    "args = Args(dataset=dataset, arch=arch, pool_type=pool_type, max_num_pools=max_num_pools, noise_std=noise_std)\n",
    "model = get_model(args)\n",
    "\n",
    "state_dict = torch.load('[PATH_TO_CHECKPOINT]', map_location=device)\n",
    "\n",
    "model.load_state_dict(state_dict)\n",
    "model.to(device)\n",
    "model.eval();\n",
    "                        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8907df70-11a8-424e-bd51-dd16c22b34b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_closest_factors(num): \n",
    "    num_root = int(math.sqrt(num))\n",
    "    while num % num_root != 0: \n",
    "        num_root -= 1\n",
    "    return num_root, int(num / num_root)\n",
    "\n",
    "\n",
    "def get_activation(name, activation):        \n",
    "    def hook(model, input, output):\n",
    "        activation[name] = output.detach()\n",
    "    return hook\n",
    "\n",
    "def normalize(img, uint=True): \n",
    "    if uint: \n",
    "        tmp = img - img.min()\n",
    "        tmp /= tmp.max()\n",
    "        return  tmp\n",
    "    else:\n",
    "        tmp = (img - img.mean()) / img.std()\n",
    "        return tmp\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c86c6d7a-b66d-4edf-bb41-17e3b4287aac",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Robustness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "eafd4c7c-a00e-4fb6-b60d-2a81be17eb2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[67.75]\n"
     ]
    }
   ],
   "source": [
    "epsilon = 0.031\n",
    "fmodel = fb.PyTorchModel(model, bounds=(0,1), device=device)\n",
    "\n",
    "advs, success = [], []\n",
    "attack = fb.attacks.LinfPGD(steps=20, abs_stepsize=1/255.)\n",
    "for images, labels in cifar10:  \n",
    "    _, current_advs, current_success = attack(fmodel, images.to(device), labels.to(device), epsilons=[epsilon])\n",
    "    advs.append(current_advs)\n",
    "    success.append(current_success)\n",
    "adv_images = torch.cat([torch.stack(ad) for ad in advs], dim=1).cpu().numpy()\n",
    "success = torch.cat(success, dim=-1)\n",
    "attack_success = success.cpu().numpy()\n",
    "accuracy = 100 * (1 - attack_success.mean(-1))\n",
    "print(accuracy)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0b2bc3d-99de-4b57-b94d-1b388547dd4e",
   "metadata": {},
   "source": [
    "## AutoAttack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15a34095-c937-46c8-aee4-a2314bf159b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AutoAttack\n",
    "# apgd-ce, apgd-t, fab-t, square\n",
    "epsilon = 8/255.\n",
    "adversary = AutoAttack(model, norm='Linf', eps=epsilon, version='standard', verbose=True)  # standard, rand\n",
    "# adversary.attacks_to_run = ['square']\n",
    "# adversary.square.n_restarts = 5\n",
    "# adversary.square.n_queries = 10000\n",
    "x_adv, y_adv = adversary.run_standard_evaluation(c10_images, c10_labels, bs=500, return_labels=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22750942-1854-4ea7-8c03-daf63b2e05cb",
   "metadata": {},
   "source": [
    "# Topography"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "f866a5a7-4eac-4010-8f48-da7e460f0d43",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAARTklEQVR4nO3deXiW5ZXH8TvJ+5KQjQQQQti3WFBAEKzaWqwLrYJ2XCodFRdwBZQ6aqsdoc5VaouCy7gvdRnAEeuMiKhFi4rjFAsuGdGCENbEsARIgOxr/+ifc35Pr+d1KofO9/PnOdd575ckh+e6nvt5zp3W0dHREQC4k364vwAAG80JOEVzAk7RnIBTNCfgVCIqmZaW9nV9D+D/LbVhwpUTcIrmBJyiOQGnaE7AKZoTcIrmBJyK3EqJ8kbx62Z87YmnypplWz404x+WfkfWXNtrnMyNP+4NM/7ghtWyZvXWc8z4dT1PlDVnjrfXCSGE+R99YK+z+SxZM/YDu+aulkJZU56eK3PTvtXbjM9ZVitrBmbXy9zu9qQZv22C/n7/cPxGMz7l2l6ypqm4SeYuntTdjM9Ju07WjOxUYMZ3ZayXNTPrX5G528VaxRE7jNvzN5jx+TXv6iKBKyfgFM0JOEVzAk7RnIBTNCfgVMp3a0ffZN+563JMuaxZd+/RZty+h/sX98z6F5nbfvxmM37SVQNkjbqPO++ffi5rdog7kSGEcOrMobHWCSGEi1bbd6A7Ttgka/aV9oz4RFtVWZbMFY3Qd0ordneOv9b2V834/5TdJGtyqnNir1ObqJG57W2NZjw9LbU/8305FWa8obWTrMmNyMXFlRNwiuYEnKI5AadoTsApmhNwiuYEnEp5K2XtE3Zfd/5pV1mzfot9uz3KlqcPyVxa30Fm/Petz8dep+zRKplLzusvc6/VLoy9Vs0h+8nppia9hfHFXR/HXie9UucaBmbK3MBK+8H3KPnN+WZ82xr9cx3QEX+d3KT+e9jeYW+lNCWa9QfaJSGEEPIz6sx4TXq7rNmfbLMT+h0EiSsn4BTNCThFcwJO0ZyAUzQn4FRa1PmcTHwH/vaY+A4cYWhOwCmaE3CK5gScojkBp2hOwKmUH3yf0vubZvw7u34kaxbl7TXjq2p+KWsuLBkhc3cnLzfjy9fr2Tk3ts804+cefZyseTT/apl7pdR+WH16y1RZM2DcBWb8nStmy5qt6fZU9xBCOO36o8x4l8v1iwYzz5wgcxUt9sPbz12ZLWtOHfKOGZ912xhZkzVon8yddZr9UsOTycdkzQmjh5nxnCv0pPoh00fKXNVk+1SD0jf0SwgjjrXXKvrDDFmjcOUEnKI5AadoTsApmhNwiuYEnKI5AadS3kp54mJ7zP7escWypuySbWZ8VcQ6zx17lczVX1Bixk+Zs19/oH2CQ1g0yN6WCSGEhssHy9zJs2vshD5ZITycdpcZP9hfD7Tp9L58eUga16oPBC4ubJC5zWV6vpCS2FNqxnce+rb+Dgf0kRDKxtZPZW5AeZEZH7THPoj3r1m49Bkznhv0z2dkh96aiYsrJ+AUzQk4RXMCTtGcgFM0J+BUyndrH53/gRk/Z7p+SHxV1xfsxB69zquv64NrJ5Qca8Z/9eV/6Q8Ulq7U60wcc5zM3VX+duy1Nu2y78oWf2lPTQ8hhFfnr4i9TsO6Mpmr2HqCzO18KeJWs3D1mMlm/GDERPWMpvjXhmu+MVzmtlbbk9gbdsSfLB9CCBcfZ6+17CP9B7v20xRGuwtcOQGnaE7AKZoTcIrmBJyiOQGnaE7AKY5jAA4zjmMAjjA0J+AUzQk4RXMCTtGcgFMpP/h+abE9AmP87n+UNYvyqsz4qpq5subCEj32YV7iMjO+fIOe+D7r/3ji+9JSewr6jJYrZc2A46aY8asLH5A1axprZO6V1fZ09JIhC2RN6HqMTGV26mLG171/kqy5NHG7GS/qrf8ezrxDT7H/3tXdzPii5M9kTW4ix4yP+sUkWTPwllEy90jOD8x4S4Ye8TKxeLQZH7JhnqxRuHICTtGcgFM0J+AUzQk4RXMCTtGcgFMpb6U8eYk98b0qYuL7pku3m/Hoie/TZE5NfP/O7IiJ71vscKoT3781+4CdiBjDM2TPRWZ8xE/0VPdNzxXoDxTaO+lDY8ed/l2Z215hH3Ic3tdr7cyz5+qUjBsqa5rTIwYMCauyPpG5SW2nmfGO9NRe4NiQX2HGS1r1xPc+Y3qID4u/PldOwCmaE3CK5gScojkBp2hOwCmaE3DqKxzH8EczPmmGPuz2vcIUjmN4LYXjGCq/zuMYVsZea2/4zIx/9Pr3ZM2nn94fe51kvdg3CiF88U6pzDVX21teUY6q7WrGE9uqZU1il/0WSZRxzfog3ILcDHuditbY64QQwrEt4vt16K2Zz9/TR2DExZUTcIrmBJyiOQGnaE7AKZoTcIqJ78BhxsR34AhDcwJO0ZyAUzQn4BTNCThFcwJOpfzg+/gC+wiFeUUnyJpLkrvM+OZ19rEKIYRw7ICHZO7mMXbdy53bZM2yxfZcnWED58ua2WP1cQy/FSc/vLzQPtIghBBGDV9uxo/qpY9IqOupf1Wrn+9rxqfm/Ies+bBYPwyeVTjQjK9Zo3+3K7v9xoy/UJ8ray6Ze7rMnXqz/YD7hsH66I7yKnsm0eD7LpA1g6bZxyeEEMJnw39oxtcdLJc154+3j37IXHyHrFG4cgJO0ZyAUzQn4BTNCThFcwJOpXy39plp9sGiOUP1tPUzFow145sj1vn1aHs6egghDJi404zvFofJhhDCMhGfd/TFsmbwxEqZK181xIy/LCtCKMy374YOGKYPk11fYx88HOUPvcU0+hBCbn9957U5J/6fxYKmUjPet+RmWVPTuVPsdR6p0ucDXNvbPrC4JTO1a9CKuj+Z8XFd9XiVxPCClNaycOUEnKI5AadoTsApmhNwiuYEnKI5AadS3krJK99hxvcN1NsYPdKTsddJ5On/Pw405Zvx5pbYy4SkmBYeQggHavVD7Aer4y+W3m7PjKktr9PfYb/9846SaBSH4IYQ2sv1VPem9PrYaw2uH2bGuzfqB+zTKuNfG85NGyNz3YcVmPG6Wn0ocZQTEvbWVn6G/je1bYi/5aVw5QScojkBp2hOwCmaE3CK5gScojkBpziOATjMOI4BOMLQnIBTNCfgFM0JOEVzAk6l/OD7o/fYD3y3Fulp66/VNJnx392gHyyfdP1BmcsrsWfQlDVXy5q1P+0Ve51csU4IIWxusmcmrb2tWNacPNyeg1OZpx9Ur8/TP9c9v7fnLN0d9HT0F3rpB/3T0uwp7R9V3iBrNo5+0Yzfvv49WfPNxESZ+0ntWWb8wGV6iv1zrz5oxiePmCJrer43TeZu6W9P4O+Taf8dhxDC3sZMMz53x+eyRuHKCThFcwJO0ZyAUzQn4BTNCThFcwJOpbyV0i9pz5lpydQzdfIb9Rh75eQiffhqskezGa/blfe1rBNCCPW77DlGUSq62HNmMjL1Z2XkxH8J4anB+pZ/yDpaptrzu9kJfSpFmLX5TTPeloxYZ4w+EiKIUxeWrFggS0bl2d877wz7wOQQQgh6pyeMyrZnBZUU6a21vgVFZpytFODvCM0JOEVzAk7RnIBTNCfgFM0JOJXyVkpxz91mfGdeZ1lzZh87Z7/P8BcD++ijARpy7VvdI7tlyxp1snXUOvVinai11DohhJCZZm+ZdGrX2zntTY0Rn2hLtLbrXJ3eZmlp1W/HKIX1/c14J/FvDSGEzuX2yeRRuh8aJXMlA8SbNnv1VliUmtosM96RoX92G3fG/zcpXDkBp2hOwCmaE3CK5gScojkBp5j4DhxmTHwHjjA0J+AUzQk4RXMCTtGcgFM0J+BUyg++P3a3/TB4c5F+SPy1A/bD2ytuKJA137+2RuZyj7ZnuexoOyBr1txqH8dw8c/0w8yZ/eRuU6gMDWZ8xXV6bs2d0+2jFZZuXS9rwjf0A+Sl9/Uz49vfsOf6hBDC8qe/LXP9J9q/p0lXdJU1m5e9b68zVf/sho6yv3cIIZy90n6Qfv3vPpA1ibfsn9Guz5Ky5pQVQ2Vu61J7wFBy3lZZ8+N19s/8pdrFskbhygk4RXMCTtGcgFM0J+AUzQk49RUmvteZ8aiJ710aUpj4XqzvUiaL7PETzXviT3wf3ElP8e7cRd/JPVSt65QXPl9txluCXqdjv74Lrlzz78fLXNmflshcesvI2Gu98649vX1Rhn2gbQghFFbaY0AiHdQ1D977iBlvzNB/k1Eya+0xL69stdcJIYSCLuKQ49r463PlBJyiOQGnaE7AKZoTcIrmBJyiOQGnvsLE911mfGe+nrY+IYWJ70P66knsdTn2beuRhXrqvJz43k9PVG/N01scwzvibweMGD7GjFftrpY1TVn6/9EtIn7+2fq7Ldv7fZk7anR3M75pqSwJfUfZLxuc/NYFsmbg2GKZe1O8A9CQrn+3U06/1ow3tukH359453GZa26xt/4Sh+yXJ0IIoa5bhczFxZUTcIrmBJyiOQGnaE7AKZoTcIrmBJziOAbgMOM4BuAIQ3MCTtGcgFM0J+AUzQk4lfKD790uWG7G88pekjVJ8fz4pg3PyprcCxfJXP6OlWY8o0HPHSpf94AZn1qtHwQvm/K5zM2YesiMTz6/RtaMmjTMjGc3iPkzIYR+xfqB6iUL7ZcDVux7Qtbc8cP7ZW7ylZvM+C1T9CyeN//4mRn/8TW3ypqD6Xrie8Un9gPpH61aK2tq3x5of1Zarqy55E79ckD1gyvM+Ce/2CNr5iTsVyver/ytrFG4cgJO0ZyAUzQn4BTNCThFcwJO0ZyAUylvpVz5un08wMoeejugI8+eOxSl9yf6tnV7lv0dkomNsde58cvrZW7vvQtkrrixh8joLaWnHp9kxm+doQ9YvXx6b5lbstDe+sjcUyJrqjr0HJxtF42zE1OekzWVB3qa8dZ2vY3RlhH/2Iy0Q3qG0OyH/tmMN0YcERKlPsv+fkta5smarGz7CIdUcOUEnKI5AadoTsApmhNwiuYEnEr5bu01J88345XbDsqaLzK6xF7n7fP0HczL3rTv1u5J7xp7nfaIkSw9mmfJXCJRLjIRLwDUjzDj9/7qYVnTrXGzzIVwixlNtOj/e5+9f67MZZfZ09sfCvpubY/QYMavm/G8rOmoypC5mz+2/74aM/T09lnn2TVrG/QBx2sWPy1zLTvtO821zUfJmrrEfpmLiysn4BTNCThFcwJO0ZyAUzQn4BTNCTjFxHfgMGPiO3CEoTkBp2hOwCmaE3CK5gScojkBp1J+K6VwwjNmvGD327Im2WzHN65fKGvOzrtG5vYNWmfGe2+aJmv+s/4qM37mgPGypqGnftNg7kT7mIRT51TLmsFjf2DGC5L6yIW2nvptn9Kl9gyhx9fukDWP3XChzC3+zQ1mfPgxU2TNg0vtYykeuXOyrPnlz/VbLuefZ7/58ekbH8uaL5cPMuOFp+jfxYk/so9wCCGEipn20Q+v7T9G1jz63lNmvLTiRlmjcOUEnKI5AadoTsApmhNwiuYEnEr5bu05b9t3MNf1sef6hBBCyNXT25XiJnGLN4RQ32ZP8t42WB+4G+wbvOGFX98kS6Y+PFPmFn9mz86J0q/hSzPeklUna9Iz7buhUfrk2+uEEEJzJ/1SQ01e/Bce7ljx32a8KKFPAPi3+vh/fg3teuL7ra/OMePZ79p3x/+ayhJ7Kv6/Lpgta1qzmSEE/N2jOQGnaE7AKZoTcIrmBJyiOQGnUt5KOXeUfXRA08FsWVOWKIi9zowX7aMGQgghmW8/DF61XB8a+9113c14ea8iWTP3nmdlrnODfVjqky9OkDUPzH/CTnTfp9ep1dsbQ5ecYcazq/Xv4sn77pO5zKpGmVNO2jnajM+YJ/6tIYT6Wr3NslTE61pzZM3N0+3jGJp663XW6Gf5w5a37G2btPZ8WVOXlXJL/S9cOQGnaE7AKZoTcIrmBJyiOQGnmPgOHGZMfAeOMDQn4BTNCThFcwJO0ZyAUzQn4FTkU7oRuywA/sa4cgJO0ZyAUzQn4BTNCThFcwJO0ZyAU38GMJEZDRlyNVMAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "w=model.conv1.conv.weight.detach().cpu()\n",
    "x=torch.nn.functional.normalize(w, dim=(1))\n",
    "kw, kh = resnet_cifar.get_closest_factors(x.shape[0])\n",
    "img=torchvision.utils.make_grid(x,nrow=kh, ncol=kw,padding=1, normalize=True)\n",
    "plt.imshow(img.permute(1,2,0))\n",
    "plt.axis('off')\n",
    "plt.savefig(f\"figs/c10_weightplot_{pool_type}_{max_num_pools}_{noise_std}.pdf\", format='pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c226bea-8223-443f-b2ca-b3cfc3dbf302",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
