{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a43ba62b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import argparse\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda\n",
    "from tensorflow.python.keras.layers.normalization import BatchNormalization\n",
    "from tensorflow.python.keras import regularizers\n",
    "from tensorflow.python.keras.models import Model, Sequential\n",
    "from tensorflow.python.keras import optimizers\n",
    "from tensorflow.python.keras.callbacks import ModelCheckpoint\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "from utils.explanations import calculate_stability, calculate_robust_astute_sampled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d4afd323",
   "metadata": {},
   "outputs": [],
   "source": [
    "datatype = 'switch'\n",
    "run_times = 1\n",
    "prop_points = 0.05\n",
    "calculate = True\n",
    "epsilon_range = np.arange(0.01, 1.1, 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a44f3887",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))\n",
    "\n",
    "x_train, _, x_val, _, _, input_shape = data_dict['x_train'], data_dict['y_train'], \\\n",
    "                                       data_dict['x_val'], data_dict['y_val'], \\\n",
    "                                       data_dict['datatype_val'], data_dict['input_shape']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "592d2b9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "median_rad = 0.5 * np.median(pdist(x_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e698ae31",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_astuteness_file = 'plots/rise_' + datatype + '_astuteness_classifiers.pk'\n",
    "classifiers = ['2layer', '4layer', 'linear', 'svm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b120909c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Completing Run 1 of 1\n",
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-09-28 22:23:45.613670: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA\n",
      "2021-09-28 22:23:45.637962: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3600000000 Hz\n",
      "2021-09-28 22:23:45.638903: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x5652000066b0 executing computations on platform Host. Devices:\n",
      "2021-09-28 22:23:45.638937: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>\n",
      "100%|███████████████████████████████████████| 22/22 [37:49<00:00, 103.18s/it]\n",
      "100%|███████████████████████████████████████| 22/22 [44:48<00:00, 122.18s/it]\n",
      "100%|████████████████████████████████████████| 22/22 [22:23<00:00, 61.06s/it]\n",
      "100%|████████████████████████████████████████| 22/22 [20:58<00:00, 57.23s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAwHUlEQVR4nO3deXxU5b348c+TyWRfCFlYDAgiICAQkUVZLGKtaBW1el1qa0WrP7V6b68bdnOvva7V3tuqtBds7a3aequAoFWvBSZBKlsIBEvYIYRkkhCSzEyWWZ7fH5PEGLLMhEnOmZPv+/XiBTNzzsz3QPjmyfc8z/dRWmuEEEJEvxijAxBCCBEZktCFEMIiJKELIYRFSEIXQgiLkIQuhBAWEWvUB2dlZelRo0YZ9fFCCBGVtmzZUqW1zu7sNcMS+qhRo9i8ebNRHy+EEFFJKXWoq9ek5CKEEBYhCV0IISxCEroQQliEJHQhhLAISehCCGERPSZ0pdQypZRTKbWzi9eVUupXSqm9SqkipdS0yIcphBCiJ6GM0F8HFnbz+qXA2JZfdwCvnHpYQgghwtXjPHSt9Xql1KhuDrkS+IMO9uHdqJQapJQaprU+FqkgjeRp9rG25DB/3vM2vkBzWOcqHSDTs59EX20fRSf628jd9WQ4w/s6EKKjpKnTWfTDlyP+vpFYWHQacKTd49KW505K6EqpOwiO4hk5cmQEPjryAgFNcVkd6/dU4thTyZZDNeiUTSQOfwcArVVI76PQKFp6zcudCku4eGuA+R8F/00DBsciotuOwOd98r6RSOidZbhOd83QWi8FlgJMnz7dNDtrHKttwLGnCseeKvL3VFLj8QIwcVgat84dTX38P1l1BLZ8Zwtxtrju36x0M3z0Mzi8ATLHwtcfg7O+CSq0bwTCnOo++ICjH99PyoUXkvufv0LFGrbIWljApD5630h8VZYCI9o9zgXKIvC+fcbT7OMf+4+3jMKr2Ot0AZCdGs+FZ+Vwwdhs5pyZRXZqPABPfPYXBicM7j6ZV++D/3scdq2A5By4/Jdwzs1gk//40c5VUMDRh5aQOG0ap/3yRUnmwrQi8ZW5ErhHKfUWMAuoNWv9/ONdFSzLP8DmQ8fx+jXxsTHMOiOTG2aMYO7YLMYPSUV1MpKu8FQwJGlI52/qqoR1z8CW5WCLh/k/gvPvgfiUPr4a0R8aduyg9N5/Jf6MMxjxym+ISUgwOiQhutRjQldKvQnMB7KUUqXAo4AdQGv9KrAGuAzYC3iAxX0V7Knw+gPc93Yh6Ul2bp0zmnljs5k+KoMEu63Hc50e58kJvdkNn/0aCl4GbwOcewvMfxhScvrmAkS/a9q/nyO330FsZiYjfrsUW1qa0SEJ0a1QZrnc2MPrGvhBxCLqI4VHTlDf5OPZa6dw6eRhYZ1b4a5gStaU4AO/Dwr/CH//BbjKYcIVcNGjkDW2D6IWRvEeO8bh274PsbGM/O/fYc+Rb9TC/AZMMXDtbie2GMWcsVlhndfkb6KmqYacpBz45xr45DGo2g0jZsF1f4CRs/omYGEYX00Nh79/O4H6ek5/4w/EmXRGlhAdDZiEvq6kknNHZpCWYA/rPKfHCcCQrf8Dh7ZD5plw/f/IzBWLCrjdHLnzTrxHjjDid78lYcIEo0MSImQDIqFX1jex82gdD14yPuxzK9wVAAxxlsA3X4Bp3wNbeN8URHTQzc2U/uu/0bhjJ7m/epnkmTONDkmIsAyIhO7YUwnA18Z1umtTt9pG6EOmwozvRzQuYR46EKDs4R/hLihg2M+fIvXrXzc6JCHCNiDWMK4rqSQrJY6Jw8KfpVDh3AHAkPGLIh2WMAmtNRU/f5q6NWvIeeB+Bl1zjdEhCdErlk/o/oBmfUklF4zNJiYm/Jp3xdHPSQ4ESD772j6ITphB1SuvUPM//8PgxYsZfNttRocjRK9ZPqHvOFpLjcfL18aHX24BcNbsZYiKh9QuFhaJqFbz5ptU/eo/Sb/qKnIefKDThWVCRAvLJ/R1uytRCuaN7UVCd35Bhb+BISlDIx+YMFzdBx9Q/sSTpMyfz7Ann0DFWP6/g7A4y38FrytxMiV3EIOTe2iq1Znid6mItZGTOTHygQlDuTds+Gp/FrvMXBLRz9IJ/YSnmcIjJ3o1uwXAV/wuVbZYhgw6PcKRCSM17trFkXvuJX706GB/lsREo0MSIiIsndAde6oI6N5NV8T5BdU1+/Arum7MJaKO3+Wm9N//HVtqKiN+91vpzyIsxdLz0NeVVJKeaCdvxKDwTy5+F2dLm1RJ6Nagtab88cfxHinl9D/8XvqzCMux7Ahda826kkrmjc3C1ovpihS/R8WwswEYkiwJ3Qpq31tB3apVZP3gbpKmTzc6HCEizrIJ/Ytj9VTWN/W63ELVbiqGBm+G5iTJSC7aNe0/QPmTT5I0YwZZd95pdDhC9AnLJvS1JcEl+71K6MXvgoqhIn0o9hg7GfEZEY5O9KdAUxNH77+fmLg4hj//HMrWcw98IaKRZWvo63ZXMmFYGjlpvdhhpvg9OH0OFd56cpJyZLFJlHM+9zxNX3xB7iu/wT5EymfCuiw5Qq9v9LLlUA3ze7M6tKXcwsQrO9+pSESV+k8/peaPf2Tw924m9cILjQ5HiD5lyYS+YV81voDufbkFBRMWBfcSlRuiUct77BjHfvRjEiZOJPv++40OR4g+Z8mEvq6kkpT4WKaN7EXtu/g9GDUXnZIjI/Qopn0+jj74INrr5bQXXyAmrhcrhYWIMpZL6Fpr1u2uZPaYTOJiw7y8duWW2qZamvxNktCjVNVvXqFh8xaGPvYocaNGGR2OEP3Ccgl9X6WLoycaetddsUO5BWTKYjRy/+Nzql59lfSrriJ9kfSxFwOH5RL62t29352otdxC6pC2hC419Ojiq6mh7MEHiRs5kqE/+6nR4QjRryyX0NeVVHJmTgq5GUnhndiu3AJ8mdCl5BI1tNYc+9GP8dfUcNovXyQmOdnokIToV5ZK6A3Nfv5x4Pgpz26B4F6iMSqGrMSsyAYp+kzNG2/gWruWnIceImHCBKPDEaLfWSqhb9xfTbMvcMrlFoAKdwVZCVnExlh27ZWlNOwspuK550lZsICM79xkdDhCGMJSCX1dSSUJ9hhmjh4c3okdyi0QLLnIDdHo4He5OXr/fcRmZjLs50/Jyl4xYFlq+LmupJLzz8gkwR5mr44O5RYIllxOT5ONLaJB+RNftsSNzZC+O2LgsswI/VC1mwNV7oiUWyBYcpEbouZ34r33qFspLXGFAAsl9PUlLdMVx4dZJumk3OLxeqhvacwlzKtp/wHKn5CWuEK0skxCX7u7kpGDkxiVGeZ0xU7KLTIHPToc+/GPpSWuEO1YIqE3+fxs2FfN18Zlh39DrLNyi8xBNz1vWRkNhYVk3n67tMQVokVICV0ptVAptVsptVcp9XAnr6crpVYppbYrpYqVUosjH2rXNh+socHrD79dbiflFgjeEAVJ6Gbmys8HIOVrFxgciRDm0WNCV0rZgF8DlwITgRuVUhM7HPYDYJfWeiowH3hBKdVv7e3WlVQSZ4vhvDMywzuxk3ILBG+IgvRxMTO3I5/YYcOIGzPG6FCEMI1QRugzgb1a6/1a62bgLeDKDsdoIFUF6x0pwHHAF9FIu7FudyUzRmeQHB/mLMyWnYnal1sgWHJJj08nIbYXux2JPqe9XtyffUbK3Lky51yIdkJJ6KcBR9o9Lm15rr3/AiYAZcAO4N+01oGOb6SUukMptVkptbmysrKXIX9V2YkGdlfUhz9dsbXcMumqk16q8MiURTNrKCwk4HKRfME8o0MRwlRCSeidDYF0h8eXAIXAcCAP+C+lVNpJJ2m9VGs9XWs9PTu7F/PFO9E2XXFcmOWRLsotECy5SLnFvFyOfIiNJfm884wORQhTCSWhlwIj2j3OJTgSb28x8FcdtBc4AJwVmRC7t66kkmHpCYwbkhLeiV2UWwDZqcjkXPkOkvLysKWmGh2KEKYSSkLfBIxVSo1uudF5A7CywzGHgYsAlFJDgPHA/kgG2hmvP0D+nqrwpyt2U27x+r1UN1bLHHST8lVW0rTrC5LnSblFiI56vIuotfYppe4B/gbYgGVa62Kl1J0tr78KPAm8rpTaQbBEs0RrXdWHcQNQeOQE9U2+8Ovn3ZRbKhuCJRwZoZuTq6AAgJR5cw2ORAjzCWlaiNZ6DbCmw3OvtvtzGfCNyIbWs7W7ndhiFLPPDLNneTflFllUZG7u9Q5s2VnES79zIU4S1StF15VUMm3kINIT7aGf1E25BWQOuplpvx93QQEpc2S6ohCdidqEXlnfxM6jdcwPtxlXN+UWkD4uZta4cyf+2lqSpdwiRKeiNqE79vRyM+huyi0QTOiJsYmk2mUGhdm4HPmgFMmzZxsdihCmFLUJfV1JJVkpcUwcdtJ09671UG6BL/ugy4/05uN2OEiYMlk2sRCiC1GZ0P0BzfqSSi4Ym01MTBiJt4dyC8gcdLPy1dTQUFREylyZrihEV6Iyoe84WkuNx8vXwu2u+M81cPrsLsst0LLsX+rnpuPesAG0JkWW+wvRpahM6Ot2V6IUzBsbRkJv9oCzOJjQuxDQASo9lTLDxYTcjnxs6ekknH220aEIYVrRmdBLnEzJHcTg5DA69Dp3gQ7AsKldHnK88Tg+7ZOSi8noQABXfj7Jc+bIzkRCdCPqEvoJTzOFR06EP7vl2Pbg70OndHmIzEE3p6bdu/FXVclyfyF6EHUJ3bGnioDuxXTFY9shYRAMGtnlITIH3ZxcjpbdiebOMTgSIcwt6hL6eWdk8sw1k5mamx7eieVFMGwKdDMdUZb9m5N7/XriJ04gNkItl4WwqqhL6Nmp8Vw/YySxtjBC93uhYle35RYITlmMjYllcMLgU4xSRIq/vh5PYaFMVxQiBFGX0Hulcjf4m2BYXreHVbgryEnMIUYNjL+WaODeuBF8PumuKEQIBkbmKi8K/j6s+xF6hUd2KjIbtyOfmORkEvPyjA5FCNMbGAn92HawJ0Hmmd0e5vQ45YaoiWitceU7SJ59PsoeRkdNIQaoAZLQi2DI2RDT9RxmrbVsDm0yzfv34ys7RrLUz4UIifUTeiAA5Tt6LLfUe+tp8DVIycVEXOsdgOxOJESorJ/Qaw5Ac323K0Thy0VFUnIxD7fDQdyZY7APH250KEJEBesn9BBWiILMQTebgMeDZ9Mmma4oRBgGRkKPsUNO93tQOj1OQBK6WXg2bUJ7vbI7kRBhsH5CLy+CnLMgNr7bwyrcFSgU2YmyGtEMXI58VEICSdOnGx2KEFHD2gld6+AMl6Hd188hWHIZnDAYu02mx5mBy7GepFkziYnv/huxEOJL1k7odWXgqerxhijIxhZm0nzoEN5Dh0mZd4HRoQgRVayd0ENcIQqyStRMXPkt3RWlfi5EWKyd0I9tB1RwUVEPZC9R83A78rGPHEnc6acbHYoQUcXiCb0ouNw/PqXbwxp9jdQ21TI0eWg/BSa6Emhuxv2Pf5AyV0bnQoTL2gm9tQd6D1qnLErJxXgNW7agGxpkuqIQvWDdhO45DrVHQr4hCjIH3Qxc6x0ou53kWbOMDkWIqGPdhB7iClGAcnc5ICN0M3DnO0iaMZ2YpCSjQxEi6lg/oYcwQpdVoubgPXaMpj17pbuiEL0UUkJXSi1USu1WSu1VSj3cxTHzlVKFSqlipdS6yIbZC+VFkD4CknreTq7CU0FqXCpJdhkVGkmmKwpxamJ7OkApZQN+DVwMlAKblFIrtda72h0zCPgNsFBrfVgpZXzt4lhRSOUWkCmLZuF25BM7dChxZ3a/EYkQonOhjNBnAnu11vu11s3AW8CVHY75NvBXrfVhAK21M7JhhqnJBdV7Qyq3QLCPiyR0Y2mvF/eGDaTMm4tSyuhwhIhKoST004Aj7R6XtjzX3jggQym1Vim1RSl1c2dvpJS6Qym1WSm1ubKysncRh6JiJ6BDmrIIskrUDBq2byfgcpE8T+rnQvRWKAm9s+GS7vA4FjgX+CZwCfAzpdS4k07SeqnWerrWenp2dh92NQzjhqg34KWqoUr6uBjM5cgHm43k8883OhQholaPNXSCI/IR7R7nAmWdHFOltXYDbqXUemAqUBKRKMN1rAiSsiB1WI+HVjdUo9FScjGY2+Eg8Zw8bKmpRociRNQKZYS+CRirlBqtlIoDbgBWdjhmBTBPKRWrlEoCZgFfRDbUMJRvD5ZbQqjFti4qkpLLqXEVFFC3Zg2+mpqwz/VVVdG4a5fsTiTEKepxhK619iml7gH+BtiAZVrrYqXUnS2vv6q1/kIp9SFQBASA32mtd/Zl4F3yNYHzC5j99ZAOb9tLVEbovaabmzl6778S8HhAKRImTiR5zhyS58wh6Zw8VFxct+e7CwoAZLm/EKcolJILWus1wJoOz73a4fFzwHORC62XnF9AwBfylEVZ9n/qPFu3EfB4yL7/PvD5cBUUUL1sGdVLl6KSkkiaMZ2UlgQfd8YZJ81ica13YMvKImFC99sECiG6F1JCjyph3BCF4Bz0eFs86fHpfRiUtbkc68FuJ+PGb2NLSSbrrrvwu9x4Pv8cd0EB7oICKtatByB26FCS58wmZc4cks4/H1taGu6CAlK+9jVUjHUXLgvRH6yX0MuLIC4VMkaHdHjrHHSZ+9x7bkc+Seeeiy0lue05W0oyqQsuJHXBhQA0lx7FvaEAd8EG6j/+hNr//SsoRdwZZ+A/cUKmKwoRAdZL6MdaWuaGONqTOeinxltRQVNJCTkPPtDtcXG5pxF33XVkXHcd2u+nsbgYd0EBroIC7KePJGXunH6KWAjrslZCD/iDi4qmfS/kUyo8FeTl5PVdTBbnbum/Ek5DLWWzkThlColTppB11119FZoQA461ipbVe8HrCXmFqNYap8cpI/RT4HLkE5uTQ/y4sUaHIsSAZ62EHuYN0ZqmGrwBr8xw6SXt8+HesIFk6b8ihClYL6Hb4iHrpK4DnWqdgz40SfYS7Y2GoiICdXWkzLvA6FCEEFgtoZcXwZBJYLOHdLjsJXpqXA5HsP/KbOm/IoQZWCehax0coYdYP4d2i4qkMVevuB35JE6dii0tzehQhBBYKaGfOAyNtSGvEIXgXqI2ZSMzIbMPA7MmX3U1jTt3yu5CQpiIdRJ62w3RvJBPcXqcZCVmYYux9U1MFtbWf0UaaglhGtZJ6OVFoGwwZGLIp1R4KqTc0ksuRz62wYNJmBT637cQom9ZJ6EfK4Ls8WBPDPkU2Uu0d3QggDs/n+S5c6T/ihAmYp3/jce2h1U/h5YRuiT0sDUW78JfU0OK9F8RwlSskdDrK8BVHtYMF1ezC7fXLVMWe8Gd7wClSJ4j/VeEMBNrJPTyouDvIa4QhS/noMsIPXyu9Q4SJk0idvBgo0MRQrRjjYTeOsNl6OSQTyn3lAMyBz1c/tpaGrZvJ+UCKbcIYTbWSOjlRcH+5wmhb1Ihq0R7x/3ZZxAIyHRFIUzIGgk9zBWi8GUfF0no4XE5HMSkpZE4JfSfhoQQ/SP6E3rDCag52KsZLoMTBhNvi++TsKxIa43bkU/y7NmoWGu10hfCCqI/oZfvCP4exgpRQPqg90JTSQk+p1OW+wthUhZI6K0zXGQOel9zOxyALPcXwqyiP6EfK4LUYZAS3mhbRujhcznyiR8/HvsQ+XsTwowskNDDXyHa5G/ieONxGaGHwe9y49m6VcotQphYdCf0Zg9U7Q673NK2qEjmoIfM84+N4PVKuUUIE4vuhO7cBToQ1gpRkDnoveFyOIhJSiJp2jlGhyKE6EJ0J/S2FaK9m4Mue4mGpnW6YtL556Pi4owORwjRhehO6OVFkDAIBo0M6zQZoYen+cBBvEePSv1cCJOL7oTeukJUqbBOq/BUkGxPJiUupY8CsxZ3fut0RUnoQphZ9CZ0vxcqdoVdbgGZgx4u13oHcaNHE5eba3QoQohuRG9Cr9wN/qawV4hCMKFLuSU0gcZGPJs2kSzlFiFML6SErpRaqJTarZTaq5R6uJvjZiil/EqpayMXYhd6uUIUgjdFZYQeGs+mTeimJlLmXWB0KEKIHvSY0JVSNuDXwKXAROBGpdRJOwO3HPcM8LdIB9mpY0VgT4LMM8M6zR/wU9VQJSP0ELkcDlR8PEkzphsdihCiB6GM0GcCe7XW+7XWzcBbwJWdHHcv8L+AM4Lxde3YdhhyNsTYwjqturEav/YzNFmmLIbC7cgnaeZMYhISjA5FCNGDUBL6acCRdo9LW55ro5Q6DbgaeLW7N1JK3aGU2qyU2lxZWRlurF8KBIJdFntZbgHZei4UzaWlNB84INMVhYgSoST0zuYE6g6PXwKWaK393b2R1nqp1nq61np6dnZ2iCF2ouYANNeHvUIUZA56OKS7ohDRJZRdCkqBEe0e5wJlHY6ZDrylgvPBs4DLlFI+rfV7kQjyJL1cIQqyl2g4XI587Lm5xI0eZXQoQogQhJLQNwFjlVKjgaPADcC32x+gtR7d+mel1OvA+32WzCE4wyXGDjkTwj7V6XFij7GTEZ/RB4FZh25uxr1xI+lXLkKFuXBLCGGMHhO61tqnlLqH4OwVG7BMa12slLqz5fVu6+Z94th2yDkLYsPfPq51Drokqe55tm5DezykzJNyixDRIqSNIbXWa4A1HZ7rNJFrrW859bC6DSY4ZXHcwl6dLnPQQ+NyrAe7naSZs4wORQgRouhbKVpXBp6qXt0QhWDJRRJ6z9yOfJKmTcOWkmx0KEKIEEVfQj+FFaJa62AfF7kh2i1vRQVNJSWkXCDlFiGiSfQl9EGnw5wfBhcVham2qZYmf5NMWeyBOz8fkOmKQkSbkGropjJkIlz8eK9OrfDIoqJQuBz5xObkED9urNGhCCHCEH0j9FPQltCl5NIl7fPh3rCB5HlzZSaQEFFmQCX0ts2hZYTepYaiIgJ1dTJdUYgoNKASeoWnghgVQ2ZiptGhmJbL4QCbjeTZs40ORQgRpoGV0N0VZCZkYo+xGx2Kabkd+SROnYotLc3oUIQQYRpQCV3moHfPV11N486d0l1RiCg1oBJ6mbtMboh2w7VeuisKEc0GTEKva67jYO1Bzhp8ltGhmFbdqpXYc3NJOHuS0aEIIXphwCT0osoiNJq8nDyjQzElb3k57s82kr5IuisKEa0GTELf5tyGTdmYkhV+y4CBoHbVKtCa9CsXGR2KEKKXBkxC3+7czriMcSTZk4wOxXS01tSuWEHiOecQd/rpRocjhOilAZHQfQEfRVVFUm7pQuOuXTTv3Uf6lZ3t/S2EiBbR18ulF3bX7KbB18A5OecYHYop1a5YgbLbSbu0dz3mhYgkr9dLaWkpjY2NRodiqISEBHJzc7HbQ183MyASeqGzEEASeie010vd+6tJWbAAW3q60eEIQWlpKampqYwaNWrA3qDXWlNdXU1paSmjR4/u+YQWA6LkUugsZEjSEIYmDzU6FNNx5efjP35cyi3CNBobG8nMzBywyRxAKUVmZmbYP6UMiIS+zblNRuddqF2xEltGhqwOFaYSbjK//rXPuP61z/ooGmP05hua5RN6ubucCk+F3BDthL+uDtenn5L2zW+iwqjTCSHMyfIJfZtzG4Ak9E7UffghurlZyi1CtHPkyBEuvPBCJkyYwKRJk3j55ZcBuOWWW3jnnXcMjq57lr8pus25jcTYRMZnjDc6FNOpXbGSuDFjZKm/EO3ExsbywgsvMG3aNOrr6zn33HO5+OKL+/xzfT4fsbGnlpItn9ALnYVMyZpCbIzlLzUszYcP07BlC9n33Tegbz4Jc3t8VTG7yup6PG7XseAxodTRJw5P49Eruh7EDBs2jGHDhgGQmprKhAkTOHr06FeOeeKJJ1i1ahUNDQ3Mnj2b1157jf379/Mv//IvbN26FYA9e/Zwww03sGXLFrZs2cJ9992Hy+UiKyuL119/nWHDhjF//nxmz55NQUEBixYt4v777+8x/u5YuuTi8XooqSlhas5Uo0MxndqVq0Ap0q+43OhQhDCtgwcPsm3bNmbNmvWV5++55x42bdrEzp07aWho4P3332fMmDGkp6dTWFgIwPLly7nlllvwer3ce++9vPPOO2zZsoVbb72Vn/zkJ23vdeLECdatW3fKyRwsPkIvqirCr/0yw6WD1qX+SefNwt4yEhHCjLobSbfXOjJ/+/+dH7HPdrlcXHPNNbz00kukddjw5e9//zvPPvssHo+H48ePM2nSJK644gq+//3vs3z5cl588UXefvttPv/8c3bv3s3OnTvbyjZ+v7/tJwCA66+/PmIxWzqhFzoLUSimZEtDrvYatm3De+QIWXffbXQoQpiS1+vlmmuu4aabbuJb3/rWV15rbGzk7rvvZvPmzYwYMYLHHnusbb74Nddcw+OPP86CBQs499xzyczMpKysjEmTJvHZZ52Xg5KTkyMWt6VLLoXOQsYMGkNanGyn1l7teytQiYmk9sONHiGijdaa2267jQkTJnDfffed9Hpr8s7KysLlcn1l5ktCQgKXXHIJd911F4sXLwZg/PjxVFZWtiV0r9dLcXFxn8Ru2YTuD/jZXrldyi0dBJqaqPvwQ1Iv/jq2lMiNDISwioKCAt544w0+/fRT8vLyyMvLY82aNW2vDxo0iNtvv53Jkydz1VVXMWPGjK+cf9NNN6GU4hvf+AYAcXFxvPPOOyxZsoSpU6eSl5fHhg0b+iR2y5Zc9tXuw+V1SULvwPX3tQTq6mTuubCUSNbO586di9b6pOcvu+yytj8/9dRTPPXUU52en5+fz6233orNZmt7Li8vj/Xr15907Nq1a0894HYsm9BbG3LlZecZGofZ1K5YQWxODsnnnWd0KEJYztVXX82+ffv49NNPDfn8kEouSqmFSqndSqm9SqmHO3n9JqVUUcuvDUopw+cJbnNuIzMhk9zUXKNDMQ1fdTUuh4P0RVeg2o0ehBCR8e6771JUVERWVpYhn99jQldK2YBfA5cCE4EblVITOxx2APia1noK8CSwNNKBhqvQWcg5OefIopl26lavAZ+PtEWyzZwQVhTKCH0msFdrvV9r3Qy8BXylAKu13qC1rml5uBEwdFhc1VBFqatU+rd0ULtiBfETJ5AwbpzRoQgh+kAoCf004Ei7x6Utz3XlNuCDzl5QSt2hlNqslNpcWVkZepRhkoZcJ2vau5fG4mIGyc1QYUXLvxn8NcCFktA7q1mcfAsYUEpdSDChL+nsda31Uq31dK319Ozs7NCjDFOhs5C4mDgmDu5YGRq4alesBJuNtG/KF70QVhVKQi8FRrR7nAuUdTxIKTUF+B1wpda6OjLh9U6hs5Czs87GbpMe3wDa76d21SpS5s4l1qCbNUJEG7/fzznnnMPllwf7HUVD+9xQEvomYKxSarRSKg64AVjZ/gCl1Ejgr8B3tdYlkQ8zdI2+RnYd3yXllnY8n3+Or7yc9Kuk3CJEqF5++WUmTJjQb5/n8/lO+T16nIeutfYppe4B/gbYgGVa62Kl1J0tr78KPAJkAr9pmVXi01pPP+XoeqG4uhhfwCcLitqpfW8FMamppFx4odGhCBGeDx6G8h09H1deFPw9lDr60Mlw6X90e0hpaSmrV6/mJz/5CS+++OJJr0d1+1yt9Rqt9Tit9Rit9c9bnnu1JZmjtf6+1jpDa53X8suQZA5f3hCdmm34VHhTCLjd1H38MWkLLyEmIcHocISICj/84Q959tlniYnpPEVK+9x+UugsZFTaKDISMowOxRTqP/kE7fHIUn8RnXoYSbdpHZkvXn3KH/n++++Tk5PDueee2+XSfGmf2w+01hRWFrJgxAKjQzGN2hUrsefmkjhtmtGhCBEVCgoKWLlyJWvWrKGxsZG6ujq+853vtG0PJ+1z+8mBugPUNtXKDdEW3ooK3J99RvqiRagufnQUQnzVL37xC0pLSzl48CBvvfUWCxYs4I9//GPb69I+t5+0NeSShA5A3apVoDXpV8pSfyEixcztc1VnbSL7w/Tp0/XmzZsj+p6PFDzCp0c+xXG9Y8D3cNFas/+KK7ClpjHqzT8ZHY4QIfviiy/6dbpgpD3//PPU1tby5JNPnvJ7dfZ3oZTa0tXEE0vV0Lc5t5GXnTfgkzlA465dNO/dx9DHHjM6FCEGDKPb51omodc01nCw7iBXnimzOSDYiEvZ7aRdutDoUIQYMN59911DP98yNfTtldsBZEERoL1e6lavIeXCC7GlpxsdjhCin1gmoW9zbiM2JpZJmZOMDsVwroIC/NXVstRfiAHGMgm90FnIxMETSYiV1ZC1763AlpFByty5RociRL9Y/OFiFn+42OgwDGeJhO71eymuLpbpikBDcTH1H31E+lVXoeLijA5HCNGPLJHQdx3fRZO/acAndO33U/7Y49gGDybrrjuNDkeIqJWSkgJAWVkZ1157rcHRhM4SCb11QdFAvyF64i9/oXHHDoYseQhbWprR4QgR9YYPH97nPdAj0Ta3lSWmLRY6C8lNySUrceBu3uCrrsb54i9JmjWLtJaG/EJEu2c+f4Z/Hv9nj8e1HhNKHf2swWexZGanm6qd5ODBg1x++eXs3LmT119/nZUrV+LxeNi3bx9XX301zz77LAAfffQRjz76KE1NTYwZM4bly5eTkpLSaZtdpVTE2+a2ivoRutY6uKBogJdbnM89T6ChgaGP/EwWVgnRRwoLC3n77bfZsWMHb7/9NkeOHKGqqoqnnnqKTz75hK1btzJ9+vS2HuqdtdltFcm2ua2ifoRe6iqlurF6QJdbPJs2Ufvee2TecQfxY8YYHY4QERPqSLp1ZL584fK+DIeLLrqI9Ja1HRMnTuTQoUOcOHGCXbt2MWfOHACam5s5//zzga7b7EJk2+a2ivqEPtAbcmmvl2OPP459+HC5ESpEH4uPj2/7s81mw+fzobXm4osv5s033/zKsd212YXIts1tFfUll23ObaTYUxiTPjBHpsd//3ua9+5jyE9/SkxiotHhCDHgnHfeeRQUFLB3714APB4PJSUl3bbZ7SvRP0KvLGRq9lRsMTajQ+l33rIyKn/9G1IWLCB1gewXKoQRsrOzef3117nxxhtpamoC4KmnnmLcuHFtbXZHjRp1UpvdvhDV7XPrmuuY++Zc7s67mzunDrxyQ+m99+Jy5DNm9fvYTzvN6HCEiIhob58bSeG2z43qkktRZREaPSDr5/Vr11L/8Sdk3X23JHMhBBDlCb3QWYhN2ZiSNcXoUPpVoKGBiiefIm7MGDJv+Z7R4QghTCKqa+iFzkLGZYwjyZ5kdCj9quq11/AePcrI3/9e+rUIIdpE7QjdF/BRVFU04MotTfsPUP3fy0hbdAXJs2YaHY4QwkSiNqGX1JTQ4GsYUAuKtNaUP/EEMQkJDHnoIaPDEcI0Dn33Zg5992ajwzBc1JZctjm3AeZuyOUtK6N62XJs6ekMvuV72FJTT+n96lavwbNxI0Me+RmxWQO3b40QonNRm9ALnYUMSRrC0OShRodyEn9tLVVLl1Lzxh9Ba7TXS82f/kTW3XeTcf11vap7++vrqXjmP0g4+2wy+mDJsBAi+kVtyaWwstB0o/NAUxPVy5az9xuXcHzZctIuu4wxf/uQUe+8Q/z48VT8/Ofsu/wK6j74gHDn/1e+/Cv8VdUMffRRlG3gLaISQvQsKkfo5e5yyt3l5E3KMzoUAHQgQN3q1VT+8iW8ZWUkz5tHzgP3kzB+PAD24cMZuXwZ7vx8nM89z9F/v4+EZcvJefABkmf2fGOzobiYmj/9iYwbbyRx8tl9fTlCmEb500/T9EXP7XMb/xk8JpQ6evyEsxj64x93+brb7ea6666jtLQUv9/Pgw8+yOrVq/nzn/8MwNq1a3nhhRdYtWoVKSkp/OAHP+CTTz4hIyODp59+moceeojDhw/z0ksvsWjRohCvNDKicoTeWj83wwwX94YNHLj2WsoefIiYQemMXPbfjPzt0rZk3kopRcq8eYx+968Me/ppfJWVHL75exy58y6a9uzp8v2130/5409gy8gg+4f/1teXI8SA9+GHHzJ8+HC2b9/Ozp07ueqqq9i4cSNutxuAt99+u61TotvtZv78+WzZsoXU1FR++tOf8vHHH/Puu+/yyCOP9HvsUTlCL3QWkhibyPiM8T0f3Eca//lPnM+/gDs/H/vw4Qx/7jnSvnkZKqb775HKZmPQt64m7bJLOf7GG1S/tpT9V15F+reuJvvee7EPGfKV40/85R0ai4oY/uwzsguRGHC6G0m31zoyP/2NP5zyZ06ePJkHHniAJUuWcPnllzNv3jwWLlzIqlWruPbaa1m9enXbxhZxcXEsXLiw7bz4+HjsdjuTJ0/m4MGDpxxLuEIaoSulFiqldiul9iqlHu7kdaWU+lXL60VKqWmRD/VL25zbmJI1hdiY/v9+5C0ro2zJwxy4+ls07NhBzpIlnPHhB6RfcXmPyby9mIQEsm6/nTEff8Tg736X2hUr2XfJQpy/fAl/fT3QugvRiyTNnElaSw9lIUTfGjduHFu2bGHy5Mn86Ec/4oknnuD666/nz3/+M59++ikzZswgtWXGmt1ub9tQJiYmpq29bkxMTES3lgtVjxlIKWUDfg1cCkwEblRKTexw2KXA2JZfdwCvRDjONh6vh5KaEqbmTO2rj+iUv7aWiueeY9/CS6n74AMyb7uVMz/6G5mLbyHmFFZrxmZkMORHDzPmgzWkXnQR1a+9xr5vXMLxP7xBxTPPBHchevQR2YVIiH5SVlZGUlIS3/nOd3jggQfYunUr8+fPZ+vWrfz2t7/tk40pIiWUIe5MYK/Wej+AUuot4EpgV7tjrgT+oINTNzYqpQYppYZprY9FOuBda/7Es0ubGJb8Lvvsf4v023fJV15BwO0m/coryf7Xe7EPHx7R94/LzeW0F55n8OLFOJ97joqnnwaQXYiE6Gc7duzgwQcfJCYmBrvdziuvvILNZuPyyy/n9ddf5/e//73RIXapx/a5SqlrgYVa6++3PP4uMEtrfU+7Y94H/kNrnd/y+P+AJVrrzR3e6w6CI3hGjhx57qFDh8IOuPD/3ubw737N5KzJ2GPsYZ/fWzHJyQy++bsknHVWn3+W1hq3w4G7oIDsH/5QNq4QA4q0z/1SuO1zQxmhd/azfsfvAqEcg9Z6KbAUgv3QQ/jsk+RddD15F5n3R55IUEqRcsEFpFxwgdGhCCGiSCh38UqBEe0e5wJlvThGCCFEHwoloW8CxiqlRiul4oAbgJUdjlkJ3Nwy2+U8oLYv6udCiIHBqJ3UzKQ3fwc9lly01j6l1D3A3wAbsExrXayUurPl9VeBNcBlwF7AAywOOxIhhAASEhKorq4mMzNzwM7u0lpTXV1NQkJCWOdF9Z6iQgjr8Xq9lJaW0tjYaHQohkpISCA3Nxe7/auTP071pqgQQvQbu93O6NGjjQ4jKkVlLxchhBAnk4QuhBAWIQldCCEswrCbokqpSiDcpaJZQFUfhGMmA+EaQa7TauQ6+8/pWuvszl4wLKH3hlJqc1d3d61iIFwjyHVajVynOUjJRQghLEISuhBCWES0JfSlRgfQDwbCNYJcp9XIdZpAVNXQhRBCdC3aRuhCCCG6IAldCCEswnQJ3WwbUveVEK7zppbrK1JKbVBK9e8mqhHS03W2O26GUsrfskNW1AnlOpVS85VShUqpYqXUuv6OMRJC+LpNV0qtUkptb7nOqOy8qpRappRyKqV2dvG6OfOQ1to0vwi2590HnAHEAduBiR2OuQz4gOAuSecB/zA67j66ztlARsufL7XqdbY77lOCbZivNTruPvr3HERwH96RLY9zjI67j67zx8AzLX/OBo4DcUbH3otrvQCYBuzs4nVT5iGzjdDbNqTWWjcDrRtSt9e2IbXWeiMwSCk1rL8DPUU9XqfWeoPWuqbl4UaCu0BFm1D+PQHuBf4XcPZncBEUynV+G/ir1vowgNY6Gq81lOvUQKoKNjJPIZjQff0b5qnTWq8nGHtXTJmHzJbQTwOOtHtc2vJcuMeYXbjXcBvB0UC06fE6lVKnAVcDr/ZjXJEWyr/nOCBDKbVWKbVFKXVzv0UXOaFc538BEwhuQbkD+DetdaB/wutXpsxDZuuHHrENqU0u5GtQSl1IMKHP7dOI+kYo1/kSsERr7Y/i3WlCuc5Y4FzgIiAR+EwptVFrXdLXwUVQKNd5CVAILADGAB8rpRxa67o+jq2/mTIPmS2hD5QNqUO6BqXUFOB3wKVa6+p+ii2SQrnO6cBbLck8C7hMKeXTWr/XLxFGRqhft1VaazfgVkqtB6YC0ZTQQ7nOxcB/6GChea9S6gBwFvB5/4TYb0yZh8xWchkoG1L3eJ1KqZHAX4HvRtkorr0er1NrPVprPUprPQp4B7g7ypI5hPZ1uwKYp5SKVUolAbOAL/o5zlMVynUeJvhTCEqpIcB4YH+/Rtk/TJmHTDVC1wNkQ+oQr/MRIBP4Tcvo1adN3OWtMyFeZ9QL5Tq11l8opT4EioAA8DutdadT4swqxH/PJ4HXlVI7CJYllmitjW43Gzal1JvAfCBLKVUKPArYwdx5SJb+CyGERZit5CKEEKKXJKELIYRFSEIXQgiLkIQuhBAWIQldCCEsQhK6EEJYhCR0IYSwiP8PWxmiDpzfsycAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "if calculate:\n",
    "    total_astuteness = np.zeros(shape=(run_times, len(classifiers), len(epsilon_range)))\n",
    "    for i in range(run_times):\n",
    "        print('Completing Run ' + str(i + 1) + ' of ' + str(run_times))\n",
    "        for j in range(len(classifiers)):\n",
    "            if classifiers[j] == '2layer':\n",
    "                activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "                model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "                net = Dense(200, activation=activation, name='dense1',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "                net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "                net = Dense(200, activation=activation, name='dense2',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                preds = Dense(2, activation='softmax', name='dense4',\n",
    "                              kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                bbox_model = Model(model_input, preds)\n",
    "                bbox_model.load_weights('models/' + datatype + '_blackbox.hdf5',\n",
    "                                        by_name=True)\n",
    "                pred_model = Model(model_input, preds)\n",
    "\n",
    "            elif classifiers[j] == '4layer':\n",
    "                activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "\n",
    "                model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "                net = Dense(50, activation=activation, name='dense1',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "                net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "                net = Dense(50, activation=activation, name='dense2',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                net = Dense(50, activation=activation, name='dense3',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                net = Dense(50, activation=activation, name='dense4',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                net = BatchNormalization()(net)\n",
    "                preds = Dense(2, activation='softmax', name='dense5',\n",
    "                              kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                bbox_model = Model(model_input, preds)\n",
    "                bbox_model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',\n",
    "                                        by_name=True)\n",
    "                pred_model = Model(model_input, preds)\n",
    "\n",
    "\n",
    "            elif classifiers[j] == 'linear':\n",
    "                activation = None\n",
    "\n",
    "                model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "\n",
    "                net = Dense(200, activation=activation, name='dense1',\n",
    "                            kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "                net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "\n",
    "                preds = Dense(2, activation='softmax', name='dense4',\n",
    "                              kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "                bbox_model = Model(model_input, preds)\n",
    "                bbox_model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',\n",
    "                                        by_name=True)\n",
    "                pred_model = Model(model_input, preds)\n",
    "            elif classifiers[j] == 'svm':\n",
    "                pred_model = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))\n",
    "            fname = 'explained_weights/rise/' + 'rise_' + datatype + '_' + classifiers[j] + '_' + str(\n",
    "                i) + '.gz'\n",
    "            explanations = np.loadtxt(fname, delimiter=',')\n",
    "            if classifiers[j] == 'svm':\n",
    "                for k in tqdm(range(len(epsilon_range))):\n",
    "                    _, total_astuteness[i, j, k], _ = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                                                      explainer=pred_model,\n",
    "                                                                                      explainer_type='rise',\n",
    "                                                                                      explanation_type='attribution',\n",
    "                                                                                      ball_r=median_rad,\n",
    "                                                                                      epsilon=epsilon_range[k],\n",
    "                                                                                      num_points=int(\n",
    "                                                                                          prop_points * len(\n",
    "                                                                                              x_val)),\n",
    "                                                                                      NN=False,\n",
    "                                                                                      data_explanation=explanations)\n",
    "            else:\n",
    "                for k in tqdm(range(len(epsilon_range))):\n",
    "                    _, total_astuteness[i, j, k], _ = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                                                      explainer=pred_model,\n",
    "                                                                                      explainer_type='rise',\n",
    "                                                                                      explanation_type='attribution',\n",
    "                                                                                      ball_r=median_rad,\n",
    "                                                                                      epsilon=epsilon_range[k],\n",
    "                                                                                      num_points=int(\n",
    "                                                                                          prop_points * len(\n",
    "                                                                                              x_val)),\n",
    "                                                                                      NN=True,\n",
    "                                                                                      data_explanation=explanations)\n",
    "    pickle.dump(total_astuteness, open(save_astuteness_file, 'wb'))\n",
    "else:\n",
    "    total_astuteness = pickle.load(open(save_astuteness_file, 'rb'))\n",
    "astuteness_mean = total_astuteness.mean(axis=0)\n",
    "astuteness_std = total_astuteness.std(axis=0)\n",
    "image_name = 'plots/rise_' + datatype + '_astuteness_classifiers.PNG'\n",
    "fig, ax = plt.subplots()\n",
    "for i in range(len(classifiers)):\n",
    "    ax.errorbar(x=epsilon_range, y=astuteness_mean[i, :], yerr=astuteness_std[i, :],\n",
    "                label=classifiers[i])\n",
    "plt.legend()\n",
    "plt.savefig(image_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
