{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "source": [
    "from src.model_defs import mnist_cnn_4layer\n",
    "from auto_LiRPA import BoundedModule, BoundedTensor\n",
    "import torch\n",
    "import torchvision\n",
    "from autoattack import AutoAttack\n",
    "import random\n",
    "from src.attack_pgd import attack_pgd"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "source": [
    "device=\"cuda\""
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "source": [
    "adv_data = torch.load(\"madrycnn_diversity_pgd.torch\")\n",
    "adv_data.shape"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "torch.Size([100, 100, 28, 28])"
      ]
     },
     "metadata": {},
     "execution_count": 25
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "source": [
    "relus = []\n",
    "for m in model.modules():\n",
    "    if isinstance(m, torch.nn.ReLU):\n",
    "        relus.append([None,None,None])"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "source": [
    "N = 100\n",
    "for n in range(N):\n",
    "    x_adv = adv_data[n].unsqueeze(1)\n",
    "    x_adv.to(device)\n",
    "\n",
    "    x = x_adv\n",
    "    \n",
    "    relu_idx = 0\n",
    "    for module in model.children():\n",
    "        if isinstance(module, torch.nn.ReLU):\n",
    "            if relus[relu_idx][0] is None:\n",
    "                relus[relu_idx] = [torch.zeros_like(x_adv.flatten()), torch.zeros_like(x_adv.flatten()), torch.zeros_like(x_adv.flatten())]\n",
    "            relus[relu_idx][0] += x_adv.flatten() > 0\n",
    "            relus[relu_idx][1] += x_adv.flatten() == 0\n",
    "            relus[relu_idx][2] += x_adv.flatten() < 0\n",
    "\n",
    "            relu_idx += 1\n",
    "        x_adv = module(x_adv)\n",
    "\n",
    "    "
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "source": [
    "import matplotlib.pyplot as plt"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "source": [
    "\n",
    "for i in range(len(relus)):\n",
    "    relu_sum = sum(relus[i])+3\n",
    "    relus[i] = [list(((r+1)/relu_sum).detach().cpu().numpy()) for r in relus[i]]"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "source": [
    "import numpy as np"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "source": [
    "plt.hist(relus[0][0])"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(array([124436.,  26268.,  14220.,   8954.,   7316.,   6978.,  10185.,\n",
       "         12488.,  17229.,  85526.]),\n",
       " array([0.00970874, 0.10679612, 0.2038835 , 0.30097088, 0.39805827,\n",
       "        0.49514565, 0.592233  , 0.6893204 , 0.78640777, 0.88349515,\n",
       "        0.98058254], dtype=float32),\n",
       " <BarContainer object of 10 artists>)"
      ]
     },
     "metadata": {},
     "execution_count": 35
    },
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUF0lEQVR4nO3dcazd5X3f8fdndkNJM4gBg5hNZzq8toAaNXjEa7cqm6fgJFXNJJCctcXKLFllrMumSQt00pAaWQJtGh3aoEKBYVgUsNy0eMtoapll2VQwuTRpwFDKXcjAw8M3NaMsFXQm3/1xnisd31w/9/oc33u5vu+XdHR+5/t7nt95Htm6n/v8fuf8bqoKSZJO5S8s9QAkSe9tBoUkqcugkCR1GRSSpC6DQpLUtXqpB3CmXXTRRbVhw4alHoYkLSvPPPPMd6tq7Wz7zrqg2LBhAxMTE0s9DElaVpL8z1Pt89STJKnLoJAkdRkUkqQug0KS1GVQSJK6DApJUpdBIUnqMigkSV0GhSSp66z7Zva4Ntz65SV53+/c8ckleV9JmosrCklSl0EhSeoyKCRJXXMGRZIHkhxL8txQ7V8m+aMk30ry20k+OLTvtiSTSV5Mct1Q/Zokz7Z9dydJq5+T5NFWP5Rkw1CfHUleao8dZ2rSkqT5m8+K4kFg64zaAeDqqvop4I+B2wCSXAlsB65qfe5Jsqr1uRfYBWxsj+lj7gTeqKorgLuAO9uxLgBuBz4CXAvcnmTN6U9RkjSOOYOiqr4GHJ9R+72qOtFePgWsb9vbgEeq6p2qehmYBK5NcilwXlU9WVUFPARcP9RnT9veB2xpq43rgANVdbyq3mAQTjMDS5K0wM7ENYq/DzzettcBrw7tO9Jq69r2zPpJfVr4vAlc2DnWD0iyK8lEkompqamxJiNJOtlYQZHknwMngC9Ml2ZpVp36qH1OLlbdV1WbqmrT2rWz/iU/SdKIRg6KdnH554FfbKeTYPBb/2VDzdYDr7X6+lnqJ/VJsho4n8GprlMdS5K0iEYKiiRbgc8Cv1BVfza0az+wvX2S6XIGF62frqqjwFtJNrfrDzcBjw31mf5E0w3AEy14vgJ8LMmadhH7Y60mSVpEc97CI8kXgY8CFyU5wuCTSLcB5wAH2qdcn6qqX6mqw0n2As8zOCV1S1W92w51M4NPUJ3L4JrG9HWN+4GHk0wyWElsB6iq40k+B3y9tfv1qjrporokaeHNGRRV9alZyvd32u8Gds9SnwCunqX+NnDjKY71APDAXGOUJC0cv5ktSeoyKCRJXQaFJKnLoJAkdRkUkqQug0KS1GVQSJK6DApJUpdBIUnqMigkSV0GhSSpy6CQJHUZFJKkLoNCktRlUEiSugwKSVKXQSFJ6jIoJEldBoUkqcugkCR1GRSSpC6DQpLUZVBIkroMCklS15xBkeSBJMeSPDdUuyDJgSQvtec1Q/tuSzKZ5MUk1w3Vr0nybNt3d5K0+jlJHm31Q0k2DPXZ0d7jpSQ7ztisJUnzNp8VxYPA1hm1W4GDVbURONhek+RKYDtwVetzT5JVrc+9wC5gY3tMH3Mn8EZVXQHcBdzZjnUBcDvwEeBa4PbhQJIkLY45g6KqvgYcn1HeBuxp23uA64fqj1TVO1X1MjAJXJvkUuC8qnqyqgp4aEaf6WPtA7a01cZ1wIGqOl5VbwAH+MHAkiQtsFGvUVxSVUcB2vPFrb4OeHWo3ZFWW9e2Z9ZP6lNVJ4A3gQs7x/oBSXYlmUgyMTU1NeKUJEmzOdMXszNLrTr1UfucXKy6r6o2VdWmtWvXzmugkqT5GTUoXm+nk2jPx1r9CHDZULv1wGutvn6W+kl9kqwGzmdwqutUx5IkLaJRg2I/MP0ppB3AY0P17e2TTJczuGj9dDs99VaSze36w00z+kwf6wbgiXYd4yvAx5KsaRexP9ZqkqRFtHquBkm+CHwUuCjJEQafRLoD2JtkJ/AKcCNAVR1Oshd4HjgB3FJV77ZD3czgE1TnAo+3B8D9wMNJJhmsJLa3Yx1P8jng663dr1fVzIvqkqQFNmdQVNWnTrFryyna7wZ2z1KfAK6epf42LWhm2fcA8MBcY5QkLRy/mS1J6jIoJEldBoUkqcugkCR1GRSSpC6DQpLUZVBIkroMCklSl0EhSeoyKCRJXQaFJKnLoJAkdRkUkqQug0KS1GVQSJK6DApJUpdBIUnqMigkSV1z/ilUSdLp2XDrl5fkfb9zxycX5LiuKCRJXQaFJKnLoJAkdRkUkqQug0KS1GVQSJK6xgqKJP8kyeEkzyX5YpIfTnJBkgNJXmrPa4ba35ZkMsmLSa4bql+T5Nm27+4kafVzkjza6oeSbBhnvJKk0zdyUCRZB/wjYFNVXQ2sArYDtwIHq2ojcLC9JsmVbf9VwFbgniSr2uHuBXYBG9tja6vvBN6oqiuAu4A7Rx2vJGk04556Wg2cm2Q18H7gNWAbsKft3wNc37a3AY9U1TtV9TIwCVyb5FLgvKp6sqoKeGhGn+lj7QO2TK82JEmLY+SgqKr/Bfwr4BXgKPBmVf0ecElVHW1tjgIXty7rgFeHDnGk1da17Zn1k/pU1QngTeDCmWNJsivJRJKJqampUackSZrFOKee1jD4jf9y4C8BP5Lkl3pdZqlVp97rc3Kh6r6q2lRVm9auXdsfuCTptIxz6unvAC9X1VRV/T/gS8DPAK+300m052Ot/RHgsqH+6xmcqjrStmfWT+rTTm+dDxwfY8ySpNM0TlC8AmxO8v523WAL8AKwH9jR2uwAHmvb+4Ht7ZNMlzO4aP10Oz31VpLN7Tg3zegzfawbgCfadQxJ0iIZ+e6xVXUoyT7gD4ATwDeA+4APAHuT7GQQJje29oeT7AWeb+1vqap32+FuBh4EzgUebw+A+4GHk0wyWElsH3W8kqTRjHWb8aq6Hbh9RvkdBquL2drvBnbPUp8Arp6l/jYtaCRJS8NvZkuSugwKSVKXQSFJ6jIoJEldBoUkqcugkCR1GRSSpC6DQpLUZVBIkroMCklSl0EhSeoyKCRJXQaFJKnLoJAkdRkUkqQug0KS1GVQSJK6DApJUpdBIUnqMigkSV0GhSSpy6CQJHUZFJKkLoNCktQ1VlAk+WCSfUn+KMkLSf56kguSHEjyUnteM9T+tiSTSV5Mct1Q/Zokz7Z9dydJq5+T5NFWP5RkwzjjlSSdvnFXFP8G+N2q+gngQ8ALwK3AwaraCBxsr0lyJbAduArYCtyTZFU7zr3ALmBje2xt9Z3AG1V1BXAXcOeY45UknaaRgyLJecDPAfcDVNWfV9X/AbYBe1qzPcD1bXsb8EhVvVNVLwOTwLVJLgXOq6onq6qAh2b0mT7WPmDL9GpDkrQ4xllR/BgwBfz7JN9I8vkkPwJcUlVHAdrzxa39OuDVof5HWm1d255ZP6lPVZ0A3gQunDmQJLuSTCSZmJqaGmNKkqSZxgmK1cCHgXur6qeB79FOM53CbCuB6tR7fU4uVN1XVZuqatPatWv7o5YknZZxguIIcKSqDrXX+xgEx+vtdBLt+dhQ+8uG+q8HXmv19bPUT+qTZDVwPnB8jDFLkk7TyEFRVf8beDXJj7fSFuB5YD+wo9V2AI+17f3A9vZJpssZXLR+up2eeivJ5nb94aYZfaaPdQPwRLuOIUlaJKvH7P+rwBeSvA/4NvBpBuGzN8lO4BXgRoCqOpxkL4MwOQHcUlXvtuPcDDwInAs83h4wuFD+cJJJBiuJ7WOOV5J0msYKiqr6JrBpll1bTtF+N7B7lvoEcPUs9bdpQSNJWhp+M1uS1GVQSJK6DApJUpdBIUnqMigkSV0GhSSpy6CQJHUZFJKkLoNCktRlUEiSugwKSVKXQSFJ6jIoJEldBoUkqcugkCR1GRSSpC6DQpLUZVBIkroMCklSl0EhSeoyKCRJXQaFJKnLoJAkdRkUkqSusYMiyaok30jyn9rrC5IcSPJSe14z1Pa2JJNJXkxy3VD9miTPtn13J0mrn5Pk0VY/lGTDuOOVJJ2eM7Gi+AzwwtDrW4GDVbURONhek+RKYDtwFbAVuCfJqtbnXmAXsLE9trb6TuCNqroCuAu48wyMV5J0GsYKiiTrgU8Cnx8qbwP2tO09wPVD9Ueq6p2qehmYBK5NcilwXlU9WVUFPDSjz/Sx9gFbplcbkqTFMe6K4jeAfwZ8f6h2SVUdBWjPF7f6OuDVoXZHWm1d255ZP6lPVZ0A3gQunDmIJLuSTCSZmJqaGnNKkqRhIwdFkp8HjlXVM/PtMkutOvVen5MLVfdV1aaq2rR27dp5DkeSNB+rx+j7s8AvJPkE8MPAeUn+A/B6kkur6mg7rXSstT8CXDbUfz3wWquvn6U+3OdIktXA+cDxMcYsSTpNI68oquq2qlpfVRsYXKR+oqp+CdgP7GjNdgCPte39wPb2SabLGVy0frqdnnoryeZ2/eGmGX2mj3VDe48fWFFIkhbOOCuKU7kD2JtkJ/AKcCNAVR1Oshd4HjgB3FJV77Y+NwMPAucCj7cHwP3Aw0kmGawkti/AeCVJHWckKKrqq8BX2/afAFtO0W43sHuW+gRw9Sz1t2lBI0laGn4zW5LUZVBIkroMCklSl0EhSeoyKCRJXQaFJKnLoJAkdRkUkqQug0KS1GVQSJK6DApJUpdBIUnqMigkSV0GhSSpy6CQJHUZFJKkLoNCktS1EH8KVSPYcOuXl+y9v3PHJ5fsvSW997mikCR1GRSSpC6DQpLUZVBIkroMCklSl0EhSeoaOSiSXJbkvyR5IcnhJJ9p9QuSHEjyUnteM9TntiSTSV5Mct1Q/Zokz7Z9dydJq5+T5NFWP5RkwxhzlSSNYJwVxQngn1bVTwKbgVuSXAncChysqo3Awfaatm87cBWwFbgnyap2rHuBXcDG9tja6juBN6rqCuAu4M4xxitJGsHIQVFVR6vqD9r2W8ALwDpgG7CnNdsDXN+2twGPVNU7VfUyMAlcm+RS4LyqerKqCnhoRp/pY+0DtkyvNiRJi+OMXKNop4R+GjgEXFJVR2EQJsDFrdk64NWhbkdabV3bnlk/qU9VnQDeBC48E2OWJM3P2EGR5APAbwH/uKr+tNd0llp16r0+M8ewK8lEkompqam5hixJOg1jBUWSH2IQEl+oqi+18uvtdBLt+VirHwEuG+q+Hnit1dfPUj+pT5LVwPnA8ZnjqKr7qmpTVW1au3btOFOSJM0wzqeeAtwPvFBV/3po135gR9veATw2VN/ePsl0OYOL1k+301NvJdncjnnTjD7Tx7oBeKJdx5AkLZJx7h77s8AvA88m+War/RpwB7A3yU7gFeBGgKo6nGQv8DyDT0zdUlXvtn43Aw8C5wKPtwcMgujhJJMMVhLbxxivJGkEIwdFVf13Zr+GALDlFH12A7tnqU8AV89Sf5sWNJJ0Opby1v1nG7+ZLUnqMigkSV0GhSSpy6CQJHX5N7O1ZBf9/Fvd0vLgikKS1GVQSJK6PPUkaUH5fYblzxWFJKnLoJAkdRkUkqQur1FIK4DXCTQOVxSSpC5XFFoyftFPWh5cUUiSugwKSVKXp5604nhhVzo9rigkSV0GhSSpy6CQJHUZFJKkLoNCktRlUEiSugwKSVKXQSFJ6loWQZFka5IXk0wmuXWpxyNJK8l7PiiSrAL+HfBx4ErgU0muXNpRSdLK8Z4PCuBaYLKqvl1Vfw48Amxb4jFJ0oqxHO71tA54dej1EeAjww2S7AJ2tZf/N8mLp/keFwHfHXmEy9dKnTc495U497N+3rnzlLvmM/e/fKodyyEoMkutTnpRdR9w38hvkExU1aZR+y9XK3Xe4NxX4txX6rxh/Lkvh1NPR4DLhl6vB15borFI0oqzHILi68DGJJcneR+wHdi/xGOSpBXjPX/qqapOJPmHwFeAVcADVXX4DL/NyKetlrmVOm9w7ivRSp03jDn3VNXcrSRJK9ZyOPUkSVpCBoUkqWvFBMVctwHJwN1t/7eSfHgpxrkQ5jH3X2xz/laS30/yoaUY50KY7+1fkvy1JO8muWExx7dQ5jPvJB9N8s0kh5P818Ue40KZx//385P8xyR/2Ob+6aUY55mW5IEkx5I8d4r9o/+Mq6qz/sHgIvj/AH4MeB/wh8CVM9p8Anicwfc2NgOHlnrcizj3nwHWtO2Pr6S5D7V7AvjPwA1LPe5F+jf/IPA88KPt9cVLPe5FnPuvAXe27bXAceB9Sz32MzD3nwM+DDx3iv0j/4xbKSuK+dwGZBvwUA08BXwwyaWLPdAFMOfcq+r3q+qN9vIpBt9VORvM9/Yvvwr8FnBsMQe3gOYz778HfKmqXgGoqpU09wL+YpIAH2AQFCcWd5hnXlV9jcFcTmXkn3ErJShmuw3IuhHaLEenO6+dDH7rOBvMOfck64C/C/zmIo5roc3n3/yvAmuSfDXJM0luWrTRLaz5zP3fAj/J4Iu7zwKfqarvL87wltTIP+Pe89+jOEPmvA3IPNssR/OeV5K/xSAo/saCjmjxzGfuvwF8tqreHfyCeVaYz7xXA9cAW4BzgSeTPFVVf7zQg1tg85n7dcA3gb8N/BXgQJL/VlV/usBjW2oj/4xbKUExn9uAnK23CpnXvJL8FPB54ONV9SeLNLaFNp+5bwIeaSFxEfCJJCeq6ncWZYQLY77/379bVd8Dvpfka8CHgOUeFPOZ+6eBO2pw4n4yycvATwBPL84Ql8zIP+NWyqmn+dwGZD9wU/tkwGbgzao6utgDXQBzzj3JjwJfAn75LPiNcticc6+qy6tqQ1VtAPYB/2CZhwTM7//7Y8DfTLI6yfsZ3JH5hUUe50KYz9xfYbCSIsklwI8D317UUS6NkX/GrYgVRZ3iNiBJfqXt/00Gn3j5BDAJ/BmD3zqWvXnO/V8AFwL3tN+sT9RZcJfNec79rDOfeVfVC0l+F/gW8H3g81U168cql5N5/pt/DngwybMMTsd8tqqW/e3Hk3wR+ChwUZIjwO3AD8H4P+O8hYckqWulnHqSJI3IoJAkdRkUkqQug0KS1GVQSJK6DApJUpdBIUnq+v9zADgYPXkCJAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.7.10",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.10 64-bit ('torch': conda)"
  },
  "interpreter": {
   "hash": "55559f2be42a5b6cb796781e868e42274b92524fa645acedffc2257e2d2820dc"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}