{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "execution": {
     "iopub.execute_input": "2020-09-03T04:02:00.715360Z",
     "iopub.status.busy": "2020-09-03T04:02:00.714492Z",
     "iopub.status.idle": "2020-09-03T04:02:28.050988Z",
     "shell.execute_reply": "2020-09-03T04:02:28.052370Z"
    },
    "papermill": {
     "duration": 27.357068,
     "end_time": "2020-09-03T04:02:28.052701",
     "exception": false,
     "start_time": "2020-09-03T04:02:00.695633",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "from scipy.stats import beta\n",
    "import warnings\n",
    "from src.lesion import perform_lesion_experiment, do_lesion_hypo_tests\n",
    "from src.pointers import DATA_PATHS\n",
    "from src.experiment_tagging import get_model_path\n",
    "from src.utils import bates_quantile\n",
    "from scipy.stats import sem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-03T04:02:28.076453Z",
     "iopub.status.busy": "2020-09-03T04:02:28.075430Z",
     "iopub.status.idle": "2020-09-03T04:02:28.080359Z",
     "shell.execute_reply": "2020-09-03T04:02:28.081139Z"
    },
    "papermill": {
     "duration": 0.020965,
     "end_time": "2020-09-03T04:02:28.081383",
     "exception": false,
     "start_time": "2020-09-03T04:02:28.060418",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "models = [# 'CNN-VGG-CIFAR10',\n",
    "          'CNN-VGG-CIFAR10+DROPOUT+L2REG',\n",
    "          # 'CNN-VGG-CIFAR10+L1REG', 'CNN-VGG-CIFAR10+L2REG',\n",
    "          # 'CNN-VGG-CIFAR10+DROPOUT', , 'CNN-VGG-CIFAR10+MOD-INIT'\n",
    "          ]\n",
    "\n",
    "n_clust = 16\n",
    "n_shuffles = 19\n",
    "n_workers = 5\n",
    "n_reps = 5\n",
    "is_unpruned = True\n",
    "dataset_name = 'cifar10_full'\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '1,3'\n",
    "results_dir = '/project/nn_clustering/results/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-03T04:02:28.105788Z",
     "iopub.status.busy": "2020-09-03T04:02:28.104840Z",
     "iopub.status.idle": "2020-09-04T01:55:10.680675Z",
     "shell.execute_reply": "2020-09-04T01:55:10.681365Z"
    },
    "papermill": {
     "duration": 78762.594806,
     "end_time": "2020-09-04T01:55:10.681770",
     "exception": false,
     "start_time": "2020-09-03T04:02:28.086964",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/7 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 14%|█▍        | 1/7 [1:26:46<8:40:38, 5206.45s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 29%|██▊       | 2/7 [2:52:34<7:12:24, 5188.83s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 43%|████▎     | 3/7 [3:53:28<5:15:14, 4728.53s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 57%|█████▋    | 4/7 [8:10:36<6:36:54, 7938.21s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 71%|███████▏  | 5/7 [14:18:05<6:45:43, 12171.61s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 86%|████████▌ | 6/7 [20:34:54<4:15:02, 15302.65s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 7/7 [21:52:42<00:00, 12112.36s/it]  "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 7/7 [21:52:42<00:00, 11251.78s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>is_unpruned</th>\n",
       "      <th>model_tag</th>\n",
       "      <th>network</th>\n",
       "      <th>dataset</th>\n",
       "      <th>chi2_p_means</th>\n",
       "      <th>chi2_p_stds</th>\n",
       "      <th>combined_p_means</th>\n",
       "      <th>combined_p_stds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>MLP</td>\n",
       "      <td>mnist</td>\n",
       "      <td>1.362690e-20</td>\n",
       "      <td>1.217738e-01</td>\n",
       "      <td>3.436946e-09</td>\n",
       "      <td>8.425660e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>MLP</td>\n",
       "      <td>mnist</td>\n",
       "      <td>6.569656e-07</td>\n",
       "      <td>5.160085e-02</td>\n",
       "      <td>1.241211e-05</td>\n",
       "      <td>8.630156e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>MLP</td>\n",
       "      <td>mnist</td>\n",
       "      <td>1.133536e-14</td>\n",
       "      <td>7.853003e-07</td>\n",
       "      <td>3.308756e-07</td>\n",
       "      <td>5.022651e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>MLP</td>\n",
       "      <td>mnist</td>\n",
       "      <td>1.047182e-13</td>\n",
       "      <td>1.002984e-02</td>\n",
       "      <td>3.284628e-09</td>\n",
       "      <td>9.939614e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>MLP</td>\n",
       "      <td>mnist</td>\n",
       "      <td>8.391256e-24</td>\n",
       "      <td>2.443453e-05</td>\n",
       "      <td>9.473283e-08</td>\n",
       "      <td>7.941233e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>False</td>\n",
       "      <td>CIFAR10</td>\n",
       "      <td>MLP</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>2.560037e-06</td>\n",
       "      <td>8.638070e-14</td>\n",
       "      <td>7.879100e-04</td>\n",
       "      <td>2.156128e-04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>False</td>\n",
       "      <td>CIFAR10</td>\n",
       "      <td>MLP</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>4.018413e-10</td>\n",
       "      <td>9.516042e-07</td>\n",
       "      <td>5.360029e-05</td>\n",
       "      <td>6.479042e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>False</td>\n",
       "      <td>CIFAR10</td>\n",
       "      <td>MLP</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>1.152081e-10</td>\n",
       "      <td>9.981256e-10</td>\n",
       "      <td>3.392675e-05</td>\n",
       "      <td>2.438592e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>False</td>\n",
       "      <td>CIFAR10</td>\n",
       "      <td>MLP</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>5.154715e-08</td>\n",
       "      <td>2.327078e-02</td>\n",
       "      <td>1.441507e-04</td>\n",
       "      <td>9.100498e-02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>False</td>\n",
       "      <td>CIFAR10</td>\n",
       "      <td>MLP</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>2.214766e-04</td>\n",
       "      <td>8.127249e-05</td>\n",
       "      <td>5.106075e-03</td>\n",
       "      <td>1.424048e-02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>MLP</td>\n",
       "      <td>fashion</td>\n",
       "      <td>8.638070e-14</td>\n",
       "      <td>1.170203e-02</td>\n",
       "      <td>4.189157e-07</td>\n",
       "      <td>4.430452e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>MLP</td>\n",
       "      <td>fashion</td>\n",
       "      <td>5.321419e-04</td>\n",
       "      <td>8.690679e-02</td>\n",
       "      <td>1.137731e-03</td>\n",
       "      <td>7.887562e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>MLP</td>\n",
       "      <td>fashion</td>\n",
       "      <td>1.253848e-10</td>\n",
       "      <td>1.328975e-04</td>\n",
       "      <td>4.913787e-06</td>\n",
       "      <td>5.829153e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>MLP</td>\n",
       "      <td>fashion</td>\n",
       "      <td>7.452428e-10</td>\n",
       "      <td>2.960862e-06</td>\n",
       "      <td>1.705334e-04</td>\n",
       "      <td>2.504779e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>MLP</td>\n",
       "      <td>fashion</td>\n",
       "      <td>3.357147e-17</td>\n",
       "      <td>2.971277e-02</td>\n",
       "      <td>7.282496e-06</td>\n",
       "      <td>7.988556e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>CNN</td>\n",
       "      <td>mnist</td>\n",
       "      <td>2.049312e-01</td>\n",
       "      <td>9.096757e-01</td>\n",
       "      <td>7.887406e-03</td>\n",
       "      <td>3.973505e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>CNN</td>\n",
       "      <td>mnist</td>\n",
       "      <td>2.826985e-01</td>\n",
       "      <td>4.270683e-01</td>\n",
       "      <td>1.291942e-01</td>\n",
       "      <td>4.462957e-02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>CNN</td>\n",
       "      <td>mnist</td>\n",
       "      <td>1.229532e-01</td>\n",
       "      <td>9.485461e-01</td>\n",
       "      <td>9.080628e-03</td>\n",
       "      <td>9.438898e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>CNN</td>\n",
       "      <td>mnist</td>\n",
       "      <td>4.827084e-03</td>\n",
       "      <td>1.375269e-02</td>\n",
       "      <td>1.939270e-02</td>\n",
       "      <td>4.702843e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>CNN</td>\n",
       "      <td>mnist</td>\n",
       "      <td>1.150873e-01</td>\n",
       "      <td>3.632654e-01</td>\n",
       "      <td>2.580437e-03</td>\n",
       "      <td>4.386698e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-CIFAR10</td>\n",
       "      <td>CNN</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>1.989727e-07</td>\n",
       "      <td>2.652537e-02</td>\n",
       "      <td>1.988050e-06</td>\n",
       "      <td>9.299938e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-CIFAR10</td>\n",
       "      <td>CNN</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>1.346172e-04</td>\n",
       "      <td>3.227004e-01</td>\n",
       "      <td>1.543318e-05</td>\n",
       "      <td>6.927857e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-CIFAR10</td>\n",
       "      <td>CNN</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>1.445778e-05</td>\n",
       "      <td>6.399512e-01</td>\n",
       "      <td>1.717826e-05</td>\n",
       "      <td>9.176365e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-CIFAR10</td>\n",
       "      <td>CNN</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>5.861737e-06</td>\n",
       "      <td>3.782186e-02</td>\n",
       "      <td>4.191547e-06</td>\n",
       "      <td>8.997839e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-CIFAR10</td>\n",
       "      <td>CNN</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>6.895551e-03</td>\n",
       "      <td>1.904097e-01</td>\n",
       "      <td>1.583224e-03</td>\n",
       "      <td>6.935081e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>CNN</td>\n",
       "      <td>fashion</td>\n",
       "      <td>9.016207e-01</td>\n",
       "      <td>6.399512e-01</td>\n",
       "      <td>4.241003e-02</td>\n",
       "      <td>4.473141e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>CNN</td>\n",
       "      <td>fashion</td>\n",
       "      <td>3.405106e-01</td>\n",
       "      <td>6.399512e-01</td>\n",
       "      <td>3.100053e-02</td>\n",
       "      <td>6.154523e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>CNN</td>\n",
       "      <td>fashion</td>\n",
       "      <td>2.649152e-01</td>\n",
       "      <td>8.773840e-01</td>\n",
       "      <td>1.571878e-02</td>\n",
       "      <td>3.460211e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>CNN</td>\n",
       "      <td>fashion</td>\n",
       "      <td>4.360552e-01</td>\n",
       "      <td>5.951581e-01</td>\n",
       "      <td>2.679142e-02</td>\n",
       "      <td>7.623858e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>CNN</td>\n",
       "      <td>fashion</td>\n",
       "      <td>6.985366e-02</td>\n",
       "      <td>9.891754e-01</td>\n",
       "      <td>3.093995e-02</td>\n",
       "      <td>2.887871e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>MLP</td>\n",
       "      <td>poly</td>\n",
       "      <td>5.183663e-06</td>\n",
       "      <td>8.556220e-09</td>\n",
       "      <td>1.784343e-03</td>\n",
       "      <td>1.355244e-05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>MLP</td>\n",
       "      <td>poly</td>\n",
       "      <td>1.508944e-07</td>\n",
       "      <td>1.992414e-11</td>\n",
       "      <td>3.329188e-04</td>\n",
       "      <td>7.338914e-07</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>MLP</td>\n",
       "      <td>poly</td>\n",
       "      <td>3.683811e-04</td>\n",
       "      <td>2.540871e-09</td>\n",
       "      <td>4.434288e-02</td>\n",
       "      <td>1.296825e-06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>MLP</td>\n",
       "      <td>poly</td>\n",
       "      <td>2.361847e-07</td>\n",
       "      <td>1.463393e-15</td>\n",
       "      <td>1.663042e-03</td>\n",
       "      <td>1.142244e-08</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>MLP</td>\n",
       "      <td>poly</td>\n",
       "      <td>5.388052e-10</td>\n",
       "      <td>4.312578e-15</td>\n",
       "      <td>6.668100e-04</td>\n",
       "      <td>1.164792e-06</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    is_unpruned    model_tag network  dataset  chi2_p_means   chi2_p_stds  \\\n",
       "0         False        MNIST     MLP    mnist  1.362690e-20  1.217738e-01   \n",
       "1         False        MNIST     MLP    mnist  6.569656e-07  5.160085e-02   \n",
       "2         False        MNIST     MLP    mnist  1.133536e-14  7.853003e-07   \n",
       "3         False        MNIST     MLP    mnist  1.047182e-13  1.002984e-02   \n",
       "4         False        MNIST     MLP    mnist  8.391256e-24  2.443453e-05   \n",
       "5         False      CIFAR10     MLP  cifar10  2.560037e-06  8.638070e-14   \n",
       "6         False      CIFAR10     MLP  cifar10  4.018413e-10  9.516042e-07   \n",
       "7         False      CIFAR10     MLP  cifar10  1.152081e-10  9.981256e-10   \n",
       "8         False      CIFAR10     MLP  cifar10  5.154715e-08  2.327078e-02   \n",
       "9         False      CIFAR10     MLP  cifar10  2.214766e-04  8.127249e-05   \n",
       "10        False      FASHION     MLP  fashion  8.638070e-14  1.170203e-02   \n",
       "11        False      FASHION     MLP  fashion  5.321419e-04  8.690679e-02   \n",
       "12        False      FASHION     MLP  fashion  1.253848e-10  1.328975e-04   \n",
       "13        False      FASHION     MLP  fashion  7.452428e-10  2.960862e-06   \n",
       "14        False      FASHION     MLP  fashion  3.357147e-17  2.971277e-02   \n",
       "15         True    CNN-MNIST     CNN    mnist  2.049312e-01  9.096757e-01   \n",
       "16         True    CNN-MNIST     CNN    mnist  2.826985e-01  4.270683e-01   \n",
       "17         True    CNN-MNIST     CNN    mnist  1.229532e-01  9.485461e-01   \n",
       "18         True    CNN-MNIST     CNN    mnist  4.827084e-03  1.375269e-02   \n",
       "19         True    CNN-MNIST     CNN    mnist  1.150873e-01  3.632654e-01   \n",
       "20         True  CNN-CIFAR10     CNN  cifar10  1.989727e-07  2.652537e-02   \n",
       "21         True  CNN-CIFAR10     CNN  cifar10  1.346172e-04  3.227004e-01   \n",
       "22         True  CNN-CIFAR10     CNN  cifar10  1.445778e-05  6.399512e-01   \n",
       "23         True  CNN-CIFAR10     CNN  cifar10  5.861737e-06  3.782186e-02   \n",
       "24         True  CNN-CIFAR10     CNN  cifar10  6.895551e-03  1.904097e-01   \n",
       "25         True  CNN-FASHION     CNN  fashion  9.016207e-01  6.399512e-01   \n",
       "26         True  CNN-FASHION     CNN  fashion  3.405106e-01  6.399512e-01   \n",
       "27         True  CNN-FASHION     CNN  fashion  2.649152e-01  8.773840e-01   \n",
       "28         True  CNN-FASHION     CNN  fashion  4.360552e-01  5.951581e-01   \n",
       "29         True  CNN-FASHION     CNN  fashion  6.985366e-02  9.891754e-01   \n",
       "30        False         POLY     MLP     poly  5.183663e-06  8.556220e-09   \n",
       "31        False         POLY     MLP     poly  1.508944e-07  1.992414e-11   \n",
       "32        False         POLY     MLP     poly  3.683811e-04  2.540871e-09   \n",
       "33        False         POLY     MLP     poly  2.361847e-07  1.463393e-15   \n",
       "34        False         POLY     MLP     poly  5.388052e-10  4.312578e-15   \n",
       "\n",
       "    combined_p_means  combined_p_stds  \n",
       "0       3.436946e-09     8.425660e-01  \n",
       "1       1.241211e-05     8.630156e-01  \n",
       "2       3.308756e-07     5.022651e-01  \n",
       "3       3.284628e-09     9.939614e-01  \n",
       "4       9.473283e-08     7.941233e-01  \n",
       "5       7.879100e-04     2.156128e-04  \n",
       "6       5.360029e-05     6.479042e-01  \n",
       "7       3.392675e-05     2.438592e-01  \n",
       "8       1.441507e-04     9.100498e-02  \n",
       "9       5.106075e-03     1.424048e-02  \n",
       "10      4.189157e-07     4.430452e-01  \n",
       "11      1.137731e-03     7.887562e-01  \n",
       "12      4.913787e-06     5.829153e-01  \n",
       "13      1.705334e-04     2.504779e-01  \n",
       "14      7.282496e-06     7.988556e-01  \n",
       "15      7.887406e-03     3.973505e-01  \n",
       "16      1.291942e-01     4.462957e-02  \n",
       "17      9.080628e-03     9.438898e-01  \n",
       "18      1.939270e-02     4.702843e-01  \n",
       "19      2.580437e-03     4.386698e-01  \n",
       "20      1.988050e-06     9.299938e-01  \n",
       "21      1.543318e-05     6.927857e-01  \n",
       "22      1.717826e-05     9.176365e-01  \n",
       "23      4.191547e-06     8.997839e-01  \n",
       "24      1.583224e-03     6.935081e-01  \n",
       "25      4.241003e-02     4.473141e-01  \n",
       "26      3.100053e-02     6.154523e-01  \n",
       "27      1.571878e-02     3.460211e-01  \n",
       "28      2.679142e-02     7.623858e-01  \n",
       "29      3.093995e-02     2.887871e-01  \n",
       "30      1.784343e-03     1.355244e-05  \n",
       "31      3.329188e-04     7.338914e-07  \n",
       "32      4.434288e-02     1.296825e-06  \n",
       "33      1.663042e-03     1.142244e-08  \n",
       "34      6.668100e-04     1.164792e-06  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results = []\n",
    "\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter('ignore')\n",
    "    for tag in tqdm(models):\n",
    "        for use_activations in [False, True]:\n",
    "            for do_local in [False, True]:\n",
    "                paths = get_model_path(tag, filter_='all')[-n_reps:]\n",
    "\n",
    "                net_pkl_results = []\n",
    "\n",
    "                fisher_p_means, chi2_p_means, effect_means = [], [], []\n",
    "                fisher_p_ranges, chi2_p_ranges, effect_ranges = [], [], []\n",
    "\n",
    "                for path in paths:\n",
    "\n",
    "                    (true_results,\n",
    "                     all_random_results,\n",
    "                     metadata,\n",
    "                     evaluation) = perform_lesion_experiment('.' + DATA_PATHS[dataset_name],\n",
    "                                                             path, n_clusters=n_clust,\n",
    "                                                             n_shuffles=n_shuffles,\n",
    "                                                             unpruned=is_unpruned,\n",
    "                                                             depth=3,\n",
    "                                                             n_side=32,\n",
    "                                                             activations=use_activations,\n",
    "                                                             local=do_local)\n",
    "                    net_pkl_results.append({'true_results': true_results,\n",
    "                                            'all_random_results': all_random_results,\n",
    "                                            'metadata': metadata,\n",
    "                                            'evaluation': evaluation})\n",
    "                    hypo_results = do_lesion_hypo_tests(evaluation, true_results, all_random_results)\n",
    "\n",
    "                    fisher_p_means.append(hypo_results['fisher_p_means'])\n",
    "                    chi2_p_means.append(hypo_results['chi2_p_means'])\n",
    "                    effect_means.append(hypo_results['effect_factors_means'])\n",
    "                    fisher_p_ranges.append(hypo_results['fisher_p_ranges'])\n",
    "                    chi2_p_ranges.append(hypo_results['chi2_p_ranges'])\n",
    "                    effect_ranges.append(hypo_results['effect_factors_range'])\n",
    "                print(tag, use_activations, do_local)\n",
    "\n",
    "                with open(results_dir + '/lesion_data_' + tag + f'_activations={use_activations}_local={do_local}.pkl', 'wb') as f:\n",
    "                    pickle.dump(net_pkl_results, f)\n",
    "\n",
    "                model_results = {'is_unpruned': is_unpruned,\n",
    "                                 'model_tag': tag,\n",
    "                                 'activations': use_activations,\n",
    "                                 'local': do_local,\n",
    "                                 'fisher_p_means': bates_quantile(np.mean(np.array(fisher_p_means)), n_reps),\n",
    "                                 # 'chi2_p_means': bates_quantile(np.mean(np.array(chi2_p_means)), n_reps),\n",
    "                                 'effect_means': np.mean(np.concatenate(effect_means))*2,\n",
    "                                 'effect_means_sem': sem(np.concatenate(effect_means)*2, axis=None),\n",
    "                                 'fisher_p_ranges': bates_quantile(np.mean(np.array(fisher_p_ranges)), n_reps),\n",
    "                                 # 'chi2_p_ranges': bates_quantile(np.mean(np.array(chi2_p_ranges)), n_reps),\n",
    "                                 'effect_ranges': np.mean(np.concatenate(effect_ranges))*2,\n",
    "                                 'effect_ranges_sem': sem(np.concatenate(effect_ranges)*2, axis=None),}\n",
    "                all_results.append(pd.Series(model_results))\n",
    "\n",
    "result_df = pd.DataFrame(all_results)\n",
    "savepath = '../results/lesion_results_cnn_vgg.csv'\n",
    "result_df.to_csv(savepath)\n",
    "result_df\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "papermill": {
   "duration": 78799.116494,
   "end_time": "2020-09-04T01:55:15.612084",
   "environment_variables": {},
   "exception": null,
   "input_path": "./notebooks/lesion_tables.ipynb",
   "output_path": "./notebooks/lesion_tables.ipynb",
   "parameters": {},
   "start_time": "2020-09-03T04:01:56.495590",
   "version": "1.2.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}