{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "execution": {
     "iopub.execute_input": "2020-09-26T21:09:21.130953Z",
     "iopub.status.busy": "2020-09-26T21:09:21.129574Z",
     "iopub.status.idle": "2020-09-26T21:09:25.691696Z",
     "shell.execute_reply": "2020-09-26T21:09:25.690136Z"
    },
    "papermill": {
     "duration": 4.585278,
     "end_time": "2020-09-26T21:09:25.692159",
     "exception": false,
     "start_time": "2020-09-26T21:09:21.106881",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "sys.path.append('..')\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "from scipy.stats import beta\n",
    "import warnings\n",
    "from src.lesion import perform_lesion_experiment, do_lesion_hypo_tests\n",
    "from src.pointers import DATA_PATHS\n",
    "from src.experiment_tagging import get_model_path\n",
    "from src.utils import bates_quantile\n",
    "from scipy.stats import sem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "models = [('mnist', 'MNIST'),\n",
    "          # ('fashion', 'FASHION'),\n",
    "          # ('poly', 'POLY'),\n",
    "          # ('cifar10', 'CIFAR10'),\n",
    "          # ('mnist', 'MNIST+DROPOUT'), ('cifar10', 'CIFAR10+DROPOUT'),\n",
    "          # ('fashion', 'FASHION+DROPOUT'),\n",
    "          ]\n",
    "\n",
    "n_clust = 16\n",
    "n_shuffles = 19\n",
    "n_workers = 5\n",
    "n_reps = 5\n",
    "is_unpruned = True\n",
    "results_dir = '/project/nn_clustering/results/'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-26T21:09:25.753190Z",
     "iopub.status.busy": "2020-09-26T21:09:25.752000Z",
     "iopub.status.idle": "2020-09-26T22:45:30.394813Z",
     "shell.execute_reply": "2020-09-26T22:45:30.396186Z"
    },
    "papermill": {
     "duration": 5764.669788,
     "end_time": "2020-09-26T22:45:30.396747",
     "exception": false,
     "start_time": "2020-09-26T21:09:25.726959",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/3 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 33%|███▎      | 1/3 [30:17<1:00:35, 1817.97s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 67%|██████▋   | 2/3 [1:04:01<31:19, 1879.75s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 3/3 [1:36:04<00:00, 1892.64s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 3/3 [1:36:04<00:00, 1921.53s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>is_unpruned</th>\n",
       "      <th>model_tag</th>\n",
       "      <th>dataset</th>\n",
       "      <th>chi2_p_means</th>\n",
       "      <th>chi2_p_ranges</th>\n",
       "      <th>combined_p_means</th>\n",
       "      <th>effect_ranges</th>\n",
       "      <th>effect_means</th>\n",
       "      <th>combined_p_ranges</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>2.616959e-12</td>\n",
       "      <td>1.972133e-01</td>\n",
       "      <td>3.723754e-07</td>\n",
       "      <td>5.141354e+00</td>\n",
       "      <td>1.054148</td>\n",
       "      <td>7.724322e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>4.397937e-14</td>\n",
       "      <td>3.850862e-04</td>\n",
       "      <td>1.341372e-08</td>\n",
       "      <td>9.705171e+00</td>\n",
       "      <td>1.033126</td>\n",
       "      <td>9.846316e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>1.467892e-16</td>\n",
       "      <td>4.110047e-03</td>\n",
       "      <td>1.194488e-08</td>\n",
       "      <td>1.009853e+01</td>\n",
       "      <td>1.024116</td>\n",
       "      <td>9.677028e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>8.338514e-13</td>\n",
       "      <td>8.775936e-02</td>\n",
       "      <td>5.802777e-08</td>\n",
       "      <td>3.035395e+00</td>\n",
       "      <td>1.023603</td>\n",
       "      <td>5.111085e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>1.985740e-17</td>\n",
       "      <td>3.470003e-02</td>\n",
       "      <td>2.653024e-09</td>\n",
       "      <td>1.107928e+01</td>\n",
       "      <td>1.025438</td>\n",
       "      <td>9.762319e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>1.759935e-11</td>\n",
       "      <td>8.556220e-09</td>\n",
       "      <td>7.730191e-06</td>\n",
       "      <td>inf</td>\n",
       "      <td>1.202914</td>\n",
       "      <td>8.615767e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>3.306690e-11</td>\n",
       "      <td>1.504932e-03</td>\n",
       "      <td>9.302282e-05</td>\n",
       "      <td>7.300478e+11</td>\n",
       "      <td>1.090869</td>\n",
       "      <td>2.198356e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>4.910232e-22</td>\n",
       "      <td>3.408331e-02</td>\n",
       "      <td>1.627973e-07</td>\n",
       "      <td>3.978180e+11</td>\n",
       "      <td>1.147517</td>\n",
       "      <td>7.137569e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>4.660709e-18</td>\n",
       "      <td>3.776269e-02</td>\n",
       "      <td>5.270228e-06</td>\n",
       "      <td>inf</td>\n",
       "      <td>1.126919</td>\n",
       "      <td>1.453712e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>False</td>\n",
       "      <td>FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>9.311098e-08</td>\n",
       "      <td>2.586911e-03</td>\n",
       "      <td>4.420465e-05</td>\n",
       "      <td>inf</td>\n",
       "      <td>1.153860</td>\n",
       "      <td>5.131659e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>poly</td>\n",
       "      <td>2.910050e-07</td>\n",
       "      <td>3.804319e-06</td>\n",
       "      <td>3.740229e-02</td>\n",
       "      <td>8.105610e-01</td>\n",
       "      <td>13.022902</td>\n",
       "      <td>1.185158e-05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>poly</td>\n",
       "      <td>2.018893e-06</td>\n",
       "      <td>3.528865e-16</td>\n",
       "      <td>5.168376e-02</td>\n",
       "      <td>8.173905e-01</td>\n",
       "      <td>4.388534</td>\n",
       "      <td>1.057573e-06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>poly</td>\n",
       "      <td>7.539385e-06</td>\n",
       "      <td>9.311152e-12</td>\n",
       "      <td>2.276826e-03</td>\n",
       "      <td>8.645244e-01</td>\n",
       "      <td>7.511498</td>\n",
       "      <td>6.353361e-07</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>poly</td>\n",
       "      <td>2.214766e-04</td>\n",
       "      <td>6.197747e-11</td>\n",
       "      <td>8.247474e-01</td>\n",
       "      <td>9.289072e-01</td>\n",
       "      <td>18.803897</td>\n",
       "      <td>3.259691e-05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>False</td>\n",
       "      <td>POLY</td>\n",
       "      <td>poly</td>\n",
       "      <td>7.101178e-04</td>\n",
       "      <td>1.039481e-15</td>\n",
       "      <td>3.457820e-01</td>\n",
       "      <td>1.197040e+00</td>\n",
       "      <td>7.763134</td>\n",
       "      <td>9.403570e-07</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    is_unpruned model_tag  dataset  chi2_p_means  chi2_p_ranges  \\\n",
       "0         False     MNIST    mnist  2.616959e-12   1.972133e-01   \n",
       "1         False     MNIST    mnist  4.397937e-14   3.850862e-04   \n",
       "2         False     MNIST    mnist  1.467892e-16   4.110047e-03   \n",
       "3         False     MNIST    mnist  8.338514e-13   8.775936e-02   \n",
       "4         False     MNIST    mnist  1.985740e-17   3.470003e-02   \n",
       "5         False   FASHION  fashion  1.759935e-11   8.556220e-09   \n",
       "6         False   FASHION  fashion  3.306690e-11   1.504932e-03   \n",
       "7         False   FASHION  fashion  4.910232e-22   3.408331e-02   \n",
       "8         False   FASHION  fashion  4.660709e-18   3.776269e-02   \n",
       "9         False   FASHION  fashion  9.311098e-08   2.586911e-03   \n",
       "10        False      POLY     poly  2.910050e-07   3.804319e-06   \n",
       "11        False      POLY     poly  2.018893e-06   3.528865e-16   \n",
       "12        False      POLY     poly  7.539385e-06   9.311152e-12   \n",
       "13        False      POLY     poly  2.214766e-04   6.197747e-11   \n",
       "14        False      POLY     poly  7.101178e-04   1.039481e-15   \n",
       "\n",
       "    combined_p_means  effect_ranges  effect_means  combined_p_ranges  \n",
       "0       3.723754e-07   5.141354e+00      1.054148       7.724322e-01  \n",
       "1       1.341372e-08   9.705171e+00      1.033126       9.846316e-01  \n",
       "2       1.194488e-08   1.009853e+01      1.024116       9.677028e-01  \n",
       "3       5.802777e-08   3.035395e+00      1.023603       5.111085e-01  \n",
       "4       2.653024e-09   1.107928e+01      1.025438       9.762319e-01  \n",
       "5       7.730191e-06            inf      1.202914       8.615767e-01  \n",
       "6       9.302282e-05   7.300478e+11      1.090869       2.198356e-01  \n",
       "7       1.627973e-07   3.978180e+11      1.147517       7.137569e-01  \n",
       "8       5.270228e-06            inf      1.126919       1.453712e-01  \n",
       "9       4.420465e-05            inf      1.153860       5.131659e-01  \n",
       "10      3.740229e-02   8.105610e-01     13.022902       1.185158e-05  \n",
       "11      5.168376e-02   8.173905e-01      4.388534       1.057573e-06  \n",
       "12      2.276826e-03   8.645244e-01      7.511498       6.353361e-07  \n",
       "13      8.247474e-01   9.289072e-01     18.803897       3.259691e-05  \n",
       "14      3.457820e-01   1.197040e+00      7.763134       9.403570e-07  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results = []\n",
    "all_pkl_results = {}\n",
    "\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter('ignore')\n",
    "    for dataset_name, tag in tqdm(models):\n",
    "        for use_activations in [False, True]:\n",
    "            for do_local in [False, True]:\n",
    "\n",
    "                net_pkl_results = []\n",
    "                paths = get_model_path(tag, filter_='all')[-n_reps:]\n",
    "                fisher_p_means, chi2_p_means, effect_means = [], [], []\n",
    "                fisher_p_ranges, chi2_p_ranges, effect_ranges = [], [], []\n",
    "                for path in paths:\n",
    "\n",
    "                    (true_results,\n",
    "                     all_random_results,\n",
    "                     metadata,\n",
    "                     evaluation) = perform_lesion_experiment('.' + DATA_PATHS[dataset_name], path, n_clusters=n_clust,\n",
    "                                                             n_shuffles=n_shuffles, unpruned=is_unpruned,\n",
    "                                                             activations=use_activations, local=do_local)\n",
    "                    net_pkl_results.append({'true_results': true_results,\n",
    "                                            'all_random_results': all_random_results,\n",
    "                                            'metadata': metadata,\n",
    "                                            'evaluation': evaluation})\n",
    "                    hypo_results = do_lesion_hypo_tests(evaluation, true_results, all_random_results)\n",
    "\n",
    "                    fisher_p_means.append(hypo_results['fisher_p_means'])\n",
    "                    chi2_p_means.append(hypo_results['chi2_p_means'])\n",
    "                    effect_means.append(hypo_results['effect_factors_means'])\n",
    "                    fisher_p_ranges.append(hypo_results['fisher_p_ranges'])\n",
    "                    chi2_p_ranges.append(hypo_results['chi2_p_ranges'])\n",
    "                    effect_ranges.append(hypo_results['effect_factors_range'])\n",
    "                with open(results_dir + '/lesion_data_' + tag + f'_activations={use_activations}_local={do_local}.pkl', 'wb') as f:\n",
    "                    pickle.dump(net_pkl_results, f)\n",
    "\n",
    "                model_results = {'is_unpruned': is_unpruned,\n",
    "                                 'model_tag': tag,\n",
    "                                 'activations': use_activations,\n",
    "                                 'local': do_local,\n",
    "                                 'fisher_p_means': bates_quantile(np.mean(np.array(fisher_p_means)), n_reps),\n",
    "                                 # 'chi2_p_means': bates_quantile(np.mean(np.array(chi2_p_means)), n_reps),\n",
    "                                 'effect_means': np.mean(np.concatenate(effect_means))*2,\n",
    "                                 'effect_means_sem': sem(np.concatenate(effect_means)*2),\n",
    "                                 'fisher_p_ranges': bates_quantile(np.mean(np.array(fisher_p_ranges)), n_reps),\n",
    "                                 # 'chi2_p_ranges': bates_quantile(np.mean(np.array(chi2_p_ranges)), n_reps),\n",
    "                                 'effect_ranges': np.mean(np.concatenate(effect_ranges))*2,\n",
    "                                 'effect_ranges_sem': sem(np.concatenate(effect_ranges)*2),}\n",
    "                all_results.append(pd.Series(model_results))\n",
    "\n",
    "result_df = pd.DataFrame(all_results)\n",
    "savepath = '../results/lesion_results_mlp.csv'\n",
    "result_df.to_csv(savepath)\n",
    "result_df\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "papermill": {
   "duration": 5773.022779,
   "end_time": "2020-09-26T22:45:33.109329",
   "environment_variables": {},
   "exception": null,
   "input_path": "./notebooks/lesion_results_mlp.ipynb",
   "output_path": "./notebooks/lesion_results_mlp.ipynb",
   "parameters": {},
   "start_time": "2020-09-26T21:09:20.086550",
   "version": "1.2.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}