{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import argparse\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda\n",
    "from tensorflow.python.keras.layers.normalization import BatchNormalization\n",
    "from tensorflow.python.keras import regularizers\n",
    "from tensorflow.python.keras.models import Model, Sequential\n",
    "from tensorflow.python.keras import optimizers\n",
    "from tensorflow.python.keras.callbacks import ModelCheckpoint\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "from utils.explanations import calculate_stability, calculate_robust_astute_sampled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "datatype = 'XOR'\n",
    "run_times = 1\n",
    "prop_points = 0.05\n",
    "calculate = True\n",
    "epsilon_range = np.arange(0.01, 1.1, 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))\n",
    "\n",
    "x_train, _, x_val, _, _, input_shape = data_dict['x_train'], data_dict['y_train'], \\\n",
    "                                       data_dict['x_val'], data_dict['y_val'], \\\n",
    "                                       data_dict['datatype_val'], data_dict['input_shape']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "median_rad = 0.5 * np.median(pdist(x_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_astuteness_file = 'plots/rise_' + datatype + '_astuteness_classifiers.pk'\n",
    "classifiers = ['2layer', '4layer', 'linear', 'svm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Completing Run 1 of 1\n",
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-09-28 22:23:48.878739: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA\n",
      "2021-09-28 22:23:48.901938: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3600000000 Hz\n",
      "2021-09-28 22:23:48.902830: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x55a896724130 executing computations on platform Host. Devices:\n",
      "2021-09-28 22:23:48.902873: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>\n",
      "100%|████████████████████████████████████████| 22/22 [34:32<00:00, 94.19s/it]\n",
      "100%|███████████████████████████████████████| 22/22 [42:11<00:00, 115.07s/it]\n",
      "100%|████████████████████████████████████████| 22/22 [22:57<00:00, 62.59s/it]\n",
      "100%|████████████████████████████████████████| 22/22 [31:24<00:00, 85.66s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAg1ElEQVR4nO3deXhV9b3v8fc3E0kgMiUqo6AHFCgyhUG89lD7qNiDWsVWPWofrcNRi/f0WC1t7XGuVaoePK1V0StYe2+h1+opCNXWerQtwzUhCbMDKEqMSJjEZO/M3/tHQhoxIZuws9cePq/nyQN7r5W9P4vhw+K31/r9zN0REZHElxZ0ABERiQ4VuohIklChi4gkCRW6iEiSUKGLiCSJjKDeOD8/34cNGxbU24uIJKS1a9fudveC9rYFVujDhg2juLg4qLcXEUlIZvZBR9s05CIikiRU6CIiSUKFLiKSJFToIiJJQoUuIpIkOi10M3vGzHaZ2cYOtpuZ/aeZbTWz9WY2MfoxRUSkM5GcoS8CZh5m+7nAiJav64HHjz6WiIgcqU6vQ3f3v5jZsMPscgHwK2+eh3eNmfUxswHu/nG0Qra15pfXkbd/yxF9j+N8nNHEhh4NfJqm6YLjSVaNc9KmetIbgk4iEjs2dCAX3/9y1F83GjcWDQJ2tHlc3vLcFwrdzK6n+SyeoUOHRuGtO1ZlTWzs0cj67HrW92hgd0ZzkZv6PK7826uNjHvbaQo6iEgMbUzb1y2vG41Ct3aea7c23X0BsACgsLCwS9U67aan2n2+vrGedZXrWFWxitUVq9m0ZxOOk5eZx9QBUzlt4GmcNvA0huQN6crbSjeo+stf2PH2v1Dwb/9G/r9cH3QckZgZ002vG41CLwfatuRgoCIKr3tY7s72A9tZXbGa1RWreXPnm4QaQqRbOqcWnMqN42/ktAGn8aX8L5GRFtgMB9KBppoadt57H1knnkj/q68KOo5IUohG0y0F5pjZYmAq8Gl3jZ8DvLX3LRa/tZhVFav4uLr5bYbkDeG8k87jtIGnMeX4KeRl5XXX20uU7FnwFPU7djB00SIsKyvoOCJJodNCN7PfADOAfDMrB+4EMgHc/QlgBfA1YCsQAq7urrAAu8O7+eP2PzJ1wFSuHXuthlESUO3777Pnqac45rzz6DltatBxRJKGBbVIdGFhoXdltsWGpubLITSMkpjcnR3XXEN4w0ZOWrGcjIJ2ZwEVkQ6Y2Vp3L2xvW8LdKZqRlqEyT2Cfvfwy1atWU/Ddf1WZi0RZwhW6JK7Gqio+uf+nZI8eTd9LLw06jkjS0amuxMzun/+cht27GfzLx7D09KDjiCQdnaFLTNRs2cLe535Nn0svIWfs2KDjiCQlFbp0O29qYuddd5Pety/Hfve7QccRSVoqdOl2+3/3O8Lr1nHc928jvXfvoOOIJC0VunSrhn37qHzoYXILCznm/PODjiOS1FTo0q12PfQQjdXVHH/nHZi1N+2PiESLCl26TaikhE9/9wL9r76KHiNGBB1HJOmp0KVbeH09O++8i4yBA8i/8cag44ikBF2HLt1i76//N7Xvvsvgx35BWm5u0HFEUoLO0CXq6nfuZPfPf06vGTPodeaZQccRSRkqdIm6T376AN7UxHE/vl0fhIrEkApdoqrqr3/ls1deIf+GG8gaPDjoOCIpRYUuUdNUU8POe+4la/hw+n27W6fFF5F26ENRiZo9Tz3dsgrRQtK0CpFIzOkMXaKibvt29ixYwDGzZtFz2rSg44ikJBW6RMUnP30A69GD4+Z+P+goIilLhS5HrbGqmqq//Y2+l12mVYhEAqRCl6MWLi2FxkZyteCzSKBU6HLUQkVFkJFB7oQJQUcRSWkqdDlqoeJisseM1i3+IgFToctRaQqHCW/YQM/Jk4OOIpLyVOhyVMLr1kN9PTmFhUFHEUl5KnQ5KqGiIkhLI3fSpKCjiKQ8FboclVBxMdmnnEJ6Xl7QUURSngpduqypro5wWRm5kzXcIhIPVOjSZTUbNuC1teTqA1GRuKBCly4LFRUDkKPxc5G4oEKXLgsVFdFjxAgy+vYNOoqIoEKXLvKGBkKlpRo/F4kjERW6mc00s7fNbKuZ/aCd7b3NbJmZrTOzTWam1Q2SXM3mzXgopPFzkTjSaaGbWTrwGHAuMBq4zMxGH7Lbd4DN7j4OmAE8bGZa4SCJHRw/z9UNRSJxI5Iz9CnAVnd/z93rgMXABYfs40CeNa8I3AvYCzRENanElVBREVnDhmm6XJE4EkmhDwJ2tHlc3vJcW78ARgEVwAbgX9296dAXMrPrzazYzIorKyu7GFmC5o2NhNau1XCLSJyJpNCtnef8kMfnAGXAQGA88AszO+YL3+S+wN0L3b2wQGd2Cav2nXdo+uwzfSAqEmciKfRyYEibx4NpPhNv62rgBW+2FXgfOCU6ESXehIqKAHSGLhJnIin0ImCEmQ1v+aDzUmDpIft8CHwVwMyOA04G3otmUIkfoaJiMgcNInPAgKCjiEgbGZ3t4O4NZjYHeAVIB55x901mdkPL9ieAe4FFZraB5iGaue6+uxtzS0DcnVBxMb3+8R+DjiIih+i00AHcfQWw4pDnnmjz8wrg7OhGk3hUt20bjfv2afxcJA7pTlE5Iho/F4lfKnQ5IqGiYjKOPZbMIUM631lEYkqFLhFzd0JFReROnkzzPWQiEk9U6BKx+g8/pKGyUsMtInFKhS4R+/v4uT4QFYlHKnSJWKioiPR+/cg68cSgo4hIO1ToErFQUTG5hYUaPxeJUyp0iUj9Rx9RX1Gh8XOROKZCl4iEilvmP9f4uUjcUqFLRKqLikjr3ZseI0cGHUVEOqBCl4iEi4rJnTQJS9MfGZF4pb+d0qn6Xbuo++ADLTcnEudU6NIpzd8ikhhU6NKpUHExaT17kj1Ka5aIxDMVunQqVFREzsSJWEZEsy2LSEBU6HJYDXv3Urd1m8bPRRKACl0O6+/Xn2v8XCTeqdDlsELFxVh2NjlfGhN0FBHphApdDitUVEzO+PFYVlbQUUSkEyp06VDjgQPUvvWWbvcXSRAqdOlQaO1acCe3UOPnIolAhS4dChUVY5mZ5Iw7NegoIhIBFbp0KFRcTPapp5KWnR10FBGJgApd2tVYVU3Npk0aPxdJICp0aVe4rAwaGzV+LpJAVOjSrlBREaSnkzthfNBRRCRCKnRpV6i4mOwxY0jr2TPoKCISIRW6fEFTTQ0169dr/FwkwajQ5QvC69bj9fWav0UkwajQ5QtCRUVgRu7EiUFHEZEjEFGhm9lMM3vbzLaa2Q862GeGmZWZ2SYzeyO6MSWWQkVF9Bh1CunHHBN0FBE5Ap0WupmlA48B5wKjgcvMbPQh+/QBfgmc7+5jgG9EP6rEgtfVES4r0/znIgkokjP0KcBWd3/P3euAxcAFh+zzz8AL7v4hgLvvim5MiZXwxo14ba3Gz0USUCSFPgjY0eZxectzbY0E+prZ62a21sy+1d4Lmdn1ZlZsZsWVlZVdSyzdKlTUsqCFztBFEk4khW7tPOeHPM4AJgH/BJwD/LuZjfzCN7kvcPdCdy8sKCg44rDS/UJFRfQY8Q9k9O0bdBQROUKRFHo5MKTN48FARTv7vOzu1e6+G/gLMC46ESVWvKGBcEmJhltEElQkhV4EjDCz4WaWBVwKLD1kn98DZ5hZhpnlAlOBLdGNKt2tZssWmkIhDbeIJKiMznZw9wYzmwO8AqQDz7j7JjO7oWX7E+6+xcxeBtYDTcDT7r6xO4NL9O199ldYZia5U6YEHUVEuqDTQgdw9xXAikOee+KQxz8Dfha9aBJL1WvWcOCll8i/6SYy8vODjiMiXaA7RQWvq2Pn3feQOXQo/a+/Lug4ItJFEZ2hS3Lb88xC6t5/nyFPLdDqRCIJTGfoKa6uvJzdjz9O3jnn0OuMM4KOIyJHQYWewtydT+77CaSnc9wP252iR0QSiAo9hVW99hpVr79OwZw5ZB5/fNBxROQoqdBTVFMoxM6f/IQeI0fS78orgo4jIlGgD0VT1O7HH6eh4mMG/Z+HsMzMoOOISBToDD0F1W7dyp6Fi+g9+yItYiGSRFToKcbd2Xn3PaT37Mmxt94adBwRiSIVeoo5sHQpoaIiCr53i2ZUFEkyKvQU0vjpp3zy4Dxyxo2jz8UXBx1HRKJMH4qmkF3z59O4fz/H/6+nsTT9Wy6SbPS3OkWE169n/+Il9LvyCrJHjQo6joh0AxV6CvDGRnbedTcZ+fnk33xz0HFEpJtoyCUF7Fu8mJrNmxn0yMOk9+oVdBwR6SY6Q09yDZWVVP7HfHpOn07euecGHUdEupEKPcl9Mu9neG0tx/37jzFrb71vEUkWKvQkVr1mDQeWLaP/ddfRY/jwoOOISDdToSep1lWIhgzRKkQiKUIfiiap1lWIFjypVYhEUoTO0JNQ6ypEZ59Nry9/Oeg4IhIjKvQktOvBeVqFSCQFqdCTTFNtLVWvv07fb3yDzAEDgo4jIjGkQk8yNZs24fX15E6ZHHQUEYkxFXqSCZeUAJAzYULASUQk1lToSSZUUkrWCSeQ0a9f0FFEJMZU6EnE3QmXlpKjZeVEUpIKPYnUbd9O47595EzUcItIKlKhJ5FwSSkAuRo/F0lJKvQkEi4rJa13b7JOPDHoKCISABV6EgmVlJI7fryWlxNJURH9zTezmWb2tpltNbMObz80s8lm1mhmWoE4xhr376du2zZdriiSwjotdDNLBx4DzgVGA5eZ2egO9nsQeCXaIaVzobIyAH0gKpLCIjlDnwJsdff33L0OWAxc0M5+NwO/A3ZFMZ9EKFxSChkZ5IwdG3QUEQlIJIU+CNjR5nF5y3OtzGwQcCHwxOFeyMyuN7NiMyuurKw80qxyGOGSErJHjSItJyfoKCISkEgKvb11y/yQx/OBue7eeLgXcvcF7l7o7oUFBQURRpTOeH094Q0byNVwi0hKi2SBi3JgSJvHg4GKQ/YpBBa3rFmZD3zNzBrc/b+iEVIOr2bLFry2lpwJukNUJJVFUuhFwAgzGw58BFwK/HPbHdy9dcFKM1sEvKQyj52QJuQSESIodHdvMLM5NF+9kg484+6bzOyGlu2HHTeX7hcuLSNz0CAyjzs26CgiEqCI1hR19xXAikOea7fI3f2qo48lkXJ3wiUl5E6bFnQUEQmYbilMcPUffURDZSU5E8YHHUVEAqZCT3Dh0pYJuTRlrkjKU6EnuFBJCWk9e9JjxIigo4hIwFToCS5cUkrOuHFYenrQUUQkYCr0BNZYVUXtO+9ohSIRAVToCS1ctg7c9YGoiAAq9IQWLimBtDRyxo0POoqIxAEVegILl5XS4+STSe/VM+goIhIHVOgJyhsaCJetI1fDLSLSQoWeoGrfeYemUEgTcolIKxV6ggq13lCkCblEpJkKPUGFS0rJOPZYMgYODDqKiMQJFXqCCpWWkDNxIi1z0IuIqNATUf3OnTRUfKzhFhH5HBV6Ajo4IZcWtBCRtiKaD13iS6ikFMvJIfuUU4KOIhJ19fX1lJeXU1NTE3SUQGVnZzN48GAyMzMj/h4VegIKl5aSM3YsdgS/0SKJory8nLy8PIYNG5aynxG5O3v27KG8vJzhw4d3/g0tNOSSYJpCIWq2bNFwiyStmpoa+vfvn7JlDmBm9O/f/4j/l6JCTzDh9RugsVEfiEpSO9Iyv+TJ1Vzy5OpuShOMrvyDpkJPMOGylg9Ex48PNoiIxB0VeoIJlZSQ9Q8nkd67d9BRRJLSjh07+MpXvsKoUaMYM2YMjz76KABXXXUVzz//fMDpDk8fiiYQb2oiXLaOY845J+goIkkrIyODhx9+mIkTJ/LZZ58xadIkzjrrrG5/34aGBjIyjq6SVegJpG7bNpoOHNAHopIy7l62ic0VBzrdb/PHzftEMo4+euAx3HnemA63DxgwgAEDBgCQl5fHqFGj+Oijjz63zz333MOyZcsIh8NMnz6dJ598kvfee49vfOMblJSUAPDuu+9y6aWXsnbtWtauXcstt9xCVVUV+fn5LFq0iAEDBjBjxgymT5/OypUrOf/88/ne977Xaf7D0ZBLAgmVaEIukVjavn07paWlTJ069XPPz5kzh6KiIjZu3Eg4HOall17ipJNOonfv3pSVlQGwcOFCrrrqKurr67n55pt5/vnnWbt2Ld/+9re5/fbbW19r//79vPHGG0dd5qAz9IQSLikhvV8/Mk84IegoIjFxuDPptg6emS/5l9Oi9t5VVVXMnj2b+fPnc8wxx3xu23//938zb948QqEQe/fuZcyYMZx33nlce+21LFy4kEceeYQlS5bw5ptv8vbbb7Nx48bWYZvGxsbW/wEAXHLJJVHLrEJPIKGyUnImTEjp63NFYqG+vp7Zs2dz+eWXc9FFF31uW01NDTfddBPFxcUMGTKEu+66q/V68dmzZ3P33Xdz5plnMmnSJPr3709FRQVjxoxh9er2h4N69ozeimMackkQDbt3U//BhxpuEelm7s4111zDqFGjuOWWW76w/WB55+fnU1VV9bkrX7KzsznnnHO48cYbufrqqwE4+eSTqaysbC30+vp6Nm3a1C3ZVegJItQ6IZdWKBLpTitXruS5557jtddeY/z48YwfP54VK1a0bu/Tpw/XXXcdY8eO5etf/zqTJ0/+3PdffvnlmBlnn302AFlZWTz//PPMnTuXcePGMX78eFatWtUt2c3du+WFO1NYWOjFxcWBvHci+mTez9j33HOMLC4irUePoOOIdJstW7YwatSooGN02UMPPcSnn37Kvffee9Sv1d6vhZmtdffC9vbXGHqCCJeUkP2lL6nMReLYhRdeyLZt23jttdcCef+IhlzMbKaZvW1mW83sB+1sv9zM1rd8rTKzcdGPmrqaamup2bSJHI2fi8S1F198kfXr15Ofnx/I+3da6GaWDjwGnAuMBi4zs9GH7PY+8I/ufipwL7Ag2kFTWc2mTXh9Pbm6oUhEDiOSM/QpwFZ3f8/d64DFwAVtd3D3Ve6+r+XhGmBwdGOmtnDLnWe6Q1REDieSQh8E7GjzuLzluY5cA/yhvQ1mdr2ZFZtZcWVlZeQpU1yopJSsE04go3//oKOIxKeF/9T8leIiKfT27mJp99IYM/sKzYU+t73t7r7A3QvdvbCgoCDylCnM3ZtXKNLZuYh0IpJCLweGtHk8GKg4dCczOxV4GrjA3fdEJ57Ubd9O4759+kBUJMYaGxuZMGECs2bNAhJj+txICr0IGGFmw80sC7gUWNp2BzMbCrwAXOnu70Q/ZuoKt07IpRuKRGLp0Ucfjen18A0NDUf9Gp1eh+7uDWY2B3gFSAeecfdNZnZDy/YngDuA/sAvW+YZaejownc5MuGyUtKOOYasE08MOopI7P3hB7BzQ+f77Vzf/GMk4+jHj4VzHzjsLuXl5Sxfvpzbb7+dRx555AvbE3r6XHdf4e4j3f0kd/9Jy3NPtJQ57n6tu/d19/EtXyrzKAmVlJIzYTyWplkaRGLlu9/9LvPmzSOtg793mj5Xjljj/v3UbdtG7/POCzqKSDA6OZNudfDM/OrlR/2WL730EsceeyyTJk3i9ddfb3cfTZ8rRyzU8i+9rnARiZ2VK1eydOlSVqxYQU1NDQcOHOCKK65oXR5O0+dKl4RLSiE9nZxTxwYdRSRl/PSnP6W8vJzt27ezePFizjzzTH7961+3btf0udIl4ZISskeNIi0nJ+goItJC0+e2Q9PnHp7X1/N24WT6XPJNjv/Rj4KOIxIzmj737zR9bpKo2bIFr63V9eciCSTo6XNV6HFq77O/wrKyyC3UFaAiieLFF18M9P01hh6Hqlev5sDy5fS/7joyAppXWUQSjwo9zjTV1bHznnvJHDqU/tdfF3QcEUkgGnKJM3ufeYa6999nyFMLtNycSISufrn5EsGFMxcGnCRYOkOPI3Xl5ex+/AnyzjmHXmecEXQcEUkwKvQ44e58cu99WHo6x/3wC8u2ikgM9erVC4CKigouvvjigNNEToUeJ6r+/Geq3niD/JtvJvP444OOIyLAwIEDu30O9GhMm3uQxtDjQFN1NTt/cj89Ro6k3xWXBx1HJG48+OaDvLX3rU73O7jPwbH0wzml3ynMndLuompfsH37dmbNmsXGjRtZtGgRS5cuJRQKsW3bNi688ELmzZsHwB//+EfuvPNOamtrOemkk1i4cCG9evVqd5pdM4v6tLkH6Qw9Dux+/HEaPv6Y4++6E8vMDDqOiHSgrKyMJUuWsGHDBpYsWcKOHTvYvXs39913H6+++iolJSUUFha2zqHe3jS7B0Vz2tyDdIYesNp332XPomfpPfsi3RUqcohIz6RjdZXLV7/6VXr37g3A6NGj+eCDD9i/fz+bN2/m9NNPB6Curo7TTjsN6HiaXYjutLkHqdAD5O7svPse0nv25Nhbbw06joh0okebS4nT09NpaGjA3TnrrLP4zW9+87l9DzfNLkR32tyDNOQSoE9//3tCxcUU3Po9Mvr2DTqOiHTBtGnTWLlyJVu3bgUgFArxzjvvHHaa3e6iM/SANO7fz655PyNn/Hj6zJ4ddBwR6aKCggIWLVrEZZddRm1tLQD33XcfI0eObJ1md9iwYV+YZrc7aPrcgHx8113s/+3/ZfgLvyP7lFOCjiMSNxJ9+txoOtLpczXkEoDw+vXsX/Jb+l15hcpcRKJGhR5j3tjIx3fdRUZBAfk33xx0HBFJIhpDj7F9v1lM7eYtDPqPR0hvub1YRCQadIYeQ/W7dlE5fz49p08nb+bMoOOISJJRocfQrnk/w2trOf6Of8fMgo4jkjQ+uPJbfHDlt4KOETgVeoxUr17NgZdeov9115E1bFjQcUQkCanQY6B1FaIhQ7QKkYh0G30oGgOtqxAteJK07Oyg44hIklKhd7PWVYjOPpteX/5y0HFEEsrO+++ndkvn0+fWvNW8TyTj6D1GncLxP/pRh9urq6v55je/SXl5OY2Njdx2220sX76c3/72twC8/vrrPPzwwyxbtoxevXrxne98h1dffZW+ffty//338/3vf58PP/yQ+fPnc/7550d4pNGhIZdu4O7Ubt3K3mefpfw7cyA9neN+9MOgY4lIBF5++WUGDhzIunXr2LhxI1//+tdZs2YN1dXVACxZsqR1psTq6mpmzJjB2rVrycvL48c//jF/+tOfePHFF7njjjtinl1n6FHSsG8f1atWUb1yFdUrV9LwyScAZA0bxsD77tUqRCJdcLgz6bYOnpmf8Nyvjvo9x44dy6233srcuXOZNWsWZ5xxBjNnzmTZsmVcfPHFLF++vHVhi6ysLGa2XII8duxYevToQWZmJmPHjmX79u1HneVIRVToZjYTeBRIB5529wcO2W4t278GhICr3L0kylnjSlNdHeGSUqpXrqR61SpqNm8Gd9J696bntGn0PH06PaefTtbgQUFHFZEjMHLkSNauXcuKFSv44Q9/yNlnn80ll1zCY489Rr9+/Zg8eTJ5eXkAZGZmtl6CnJaW1jq9blpaWlSXlotUp4VuZunAY8BZQDlQZGZL3X1zm93OBUa0fE0FHm/5MWm4O3XvvUf1ypVUrVxJ6M0iPByGjAxyxo+j4H/eTM/TTyd7zBgsPT3ouCLSRRUVFfTr148rrriCXr16sWjRIm6//XauueYannrqqW5ZmCJaIjlDnwJsdff3AMxsMXAB0LbQLwB+5c1TN64xsz5mNsDdP4524Kq//o1PHnyg8x2jrOnTAzRUVgLNwyh9LrqInqdPJ3fKFN3CL5JENmzYwG233UZaWhqZmZk8/vjjpKenM2vWLBYtWsSzzz4bdMQOdTp9rpldDMx092tbHl8JTHX3OW32eQl4wN3/1vL4z8Bcdy8+5LWuB64HGDp06KQPPvjgiAOHSkvZuyj2v6Bp2T3ImTRJwygi3UzT5/7dkU6fG8kZenv3qB/6r0Ak++DuC4AF0DwfegTv/QW5EyaQO2FCV75VRCSpRXLZYjkwpM3jwUBFF/YREZFuFEmhFwEjzGy4mWUBlwJLD9lnKfAtazYN+LQ7xs9FJDUEtZJaPOnKr0GnQy7u3mBmc4BXaL5s8Rl332RmN7RsfwJYQfMli1tpvmzx6iNOIiICZGdns2fPHvr375+ys5K6O3v27CH7CKcK0ZqiIhJX6uvrKS8vp6amJugogcrOzmbw4MFkZmZ+7vmj/VBURCRmMjMzGT58eNAxEpLmchERSRIqdBGRJKFCFxFJEoF9KGpmlcCR3iqaD+zuhjjxJBWOEXScyUbHGTsnuHtBexsCK/SuMLPijj7dTRapcIyg40w2Os74oCEXEZEkoUIXEUkSiVboC4IOEAOpcIyg40w2Os44kFBj6CIi0rFEO0MXEZEOqNBFRJJE3BW6mc00s7fNbKuZ/aCd7WZm/9myfb2ZTQwi59GK4Dgvbzm+9Wa2yszGBZHzaHV2nG32m2xmjS0rZCWcSI7TzGaYWZmZbTKzN2KdMRoi+HPb28yWmdm6luNMyJlXzewZM9tlZhs72B6fPeTucfNF8/S824ATgSxgHTD6kH2+BvyB5lWSpgH/L+jc3XSc04G+LT8/N1mPs81+r9E8DfPFQefupt/PPjSvwzu05fGxQefupuP8EfBgy88LgL1AVtDZu3CsXwYmAhs72B6XPRRvZ+itC1K7ex1wcEHqtloXpHb3NUAfMxsQ66BHqdPjdPdV7r6v5eEamleBSjSR/H4C3Az8DtgVy3BRFMlx/jPwgrt/CODuiXiskRynA3nWPJF5L5oLvSG2MY+eu/+F5uwdicseirdCHwTsaPO4vOW5I90n3h3pMVxD89lAoun0OM1sEHAh8EQMc0VbJL+fI4G+Zva6ma01s2/FLF30RHKcvwBG0bwE5QbgX929KTbxYioueyje5kOP2oLUcS7iYzCzr9Bc6P+jWxN1j0iOcz4w190bE3h1mkiOMwOYBHwVyAFWm9kad3+nu8NFUSTHeQ5QBpwJnAT8ycz+6u4HujlbrMVlD8VboafKgtQRHYOZnQo8DZzr7ntilC2aIjnOQmBxS5nnA18zswZ3/6+YJIyOSP/c7nb3aqDazP4CjAMSqdAjOc6rgQe8eaB5q5m9D5wCvBmbiDETlz0Ub0MuqbIgdafHaWZDgReAKxPsLK6tTo/T3Ye7+zB3HwY8D9yUYGUOkf25/T1whpllmFkuMBXYEuOcRyuS4/yQ5v+FYGbHAScD78U0ZWzEZQ/F1Rm6p8iC1BEe5x1Af+CXLWevDR7Hs7y1J8LjTHiRHKe7bzGzl4H1QBPwtLu3e0lcvIrw9/NeYJGZbaB5WGKuuwc93ewRM7PfADOAfDMrB+4EMiG+e0i3/ouIJIl4G3IREZEuUqGLiCQJFbqISJJQoYuIJAkVuohIklChi4gkCRW6iEiS+P9c+JnxXieWPgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "if calculate:\n",
    "    total_astuteness = np.zeros(shape=(run_times, len(classifiers), len(epsilon_range)))\n",
    "    for i in range(run_times):\n",
    "        print('Completing Run ' + str(i + 1) + ' of ' + str(run_times))\n",
    "        for j in range(len(classifiers)):\n",
    "            if classifiers[j] == '2layer':\n",
    "                activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "                model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "                net = Dense(200, activation=activation, name='dense1',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "                net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "                net = Dense(200, activation=activation, name='dense2',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                preds = Dense(2, activation='softmax', name='dense4',\n",
    "                              kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                bbox_model = Model(model_input, preds)\n",
    "                bbox_model.load_weights('models/' + datatype + '_blackbox.hdf5',\n",
    "                                        by_name=True)\n",
    "                pred_model = Model(model_input, preds)\n",
    "\n",
    "            elif classifiers[j] == '4layer':\n",
    "                activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "\n",
    "                model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "                net = Dense(50, activation=activation, name='dense1',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "                net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "                net = Dense(50, activation=activation, name='dense2',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                net = Dense(50, activation=activation, name='dense3',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                net = Dense(50, activation=activation, name='dense4',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                preds = Dense(2, activation='softmax', name='dense5',\n",
    "                              kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                bbox_model = Model(model_input, preds)\n",
    "                bbox_model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',\n",
    "                                        by_name=True)\n",
    "                pred_model = Model(model_input, preds)\n",
    "\n",
    "\n",
    "            elif classifiers[j] == 'linear':\n",
    "                activation = None\n",
    "\n",
    "                model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "\n",
    "                net = Dense(200, activation=activation, name='dense1',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "                net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "\n",
    "                preds = Dense(2, activation='softmax', name='dense4',\n",
    "                              kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                bbox_model = Model(model_input, preds)\n",
    "                bbox_model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',\n",
    "                                        by_name=True)\n",
    "                pred_model = Model(model_input, preds)\n",
    "            elif classifiers[j] == 'svm':\n",
    "                pred_model = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))\n",
    "            fname = 'explained_weights/rise/' + 'rise_' + datatype + '_' + classifiers[j] + '_' + str(\n",
    "                i) + '.gz'\n",
    "            explanations = np.loadtxt(fname, delimiter=',')\n",
    "            if classifiers[j] == 'svm':\n",
    "                for k in tqdm(range(len(epsilon_range))):\n",
    "                    _, total_astuteness[i, j, k], _ = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                                                      explainer=pred_model,\n",
    "                                                                                      explainer_type='rise',\n",
    "                                                                                      explanation_type='attribution',\n",
    "                                                                                      ball_r=median_rad,\n",
    "                                                                                      epsilon=epsilon_range[k],\n",
    "                                                                                      num_points=int(\n",
    "                                                                                          prop_points * len(\n",
    "                                                                                              x_val)),\n",
    "                                                                                      NN=False,\n",
    "                                                                                      data_explanation=explanations)\n",
    "            else:\n",
    "                for k in tqdm(range(len(epsilon_range))):\n",
    "                    _, total_astuteness[i, j, k], _ = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                                                      explainer=pred_model,\n",
    "                                                                                      explainer_type='rise',\n",
    "                                                                                      explanation_type='attribution',\n",
    "                                                                                      ball_r=median_rad,\n",
    "                                                                                      epsilon=epsilon_range[k],\n",
    "                                                                                      num_points=int(\n",
    "                                                                                          prop_points * len(\n",
    "                                                                                              x_val)),\n",
    "                                                                                      NN=True,\n",
    "                                                                                      data_explanation=explanations)\n",
    "    pickle.dump(total_astuteness, open(save_astuteness_file, 'wb'))\n",
    "else:\n",
    "    total_astuteness = pickle.load(open(save_astuteness_file, 'rb'))\n",
    "astuteness_mean = total_astuteness.mean(axis=0)\n",
    "astuteness_std = total_astuteness.std(axis=0)\n",
    "image_name = 'plots/rise_' + datatype + '_astuteness_classifiers.PNG'\n",
    "fig, ax = plt.subplots()\n",
    "for i in range(len(classifiers)):\n",
    "    ax.errorbar(x=epsilon_range, y=astuteness_mean[i, :], yerr=astuteness_std[i, :],\n",
    "                label=classifiers[i])\n",
    "plt.legend()\n",
    "plt.savefig(image_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
