{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "execution": {
     "iopub.execute_input": "2020-09-26T21:11:46.532901Z",
     "iopub.status.busy": "2020-09-26T21:11:46.531461Z",
     "iopub.status.idle": "2020-09-26T21:11:51.199572Z",
     "shell.execute_reply": "2020-09-26T21:11:51.197948Z"
    },
    "papermill": {
     "duration": 4.691383,
     "end_time": "2020-09-26T21:11:51.200081",
     "exception": false,
     "start_time": "2020-09-26T21:11:46.508698",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "from scipy.stats import beta\n",
    "import warnings\n",
    "from src.lesion import perform_lesion_experiment, do_lesion_hypo_tests\n",
    "from src.pointers import DATA_PATHS\n",
    "from src.experiment_tagging import get_model_path\n",
    "from src.utils import bates_quantile\n",
    "from scipy.stats import sem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-26T21:11:51.245147Z",
     "iopub.status.busy": "2020-09-26T21:11:51.243529Z",
     "iopub.status.idle": "2020-09-26T21:11:51.248443Z",
     "shell.execute_reply": "2020-09-26T21:11:51.247070Z"
    },
    "papermill": {
     "duration": 0.023674,
     "end_time": "2020-09-26T21:11:51.248758",
     "exception": false,
     "start_time": "2020-09-26T21:11:51.225084",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "models = [('mnist', 'CNN-MNIST'),\n",
    "          # ('fashion', 'CNN-FASHION'),\n",
    "          # ('cifar10', 'CNN-CIFAR10'),\n",
    "          # ('mnist', 'CNN-MNIST+DROPOUT'), ('cifar10', 'CNN-CIFAR10+DROPOUT'),\n",
    "          # ('fashion', 'CNN-FASHION+DROPOUT')\n",
    "          ]\n",
    "\n",
    "n_clust = 16\n",
    "n_shuffles = 19\n",
    "n_workers = 5\n",
    "n_reps = 5\n",
    "is_unpruned = True\n",
    "results_dir = '/project/nn_clustering/results/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-26T21:11:51.282434Z",
     "iopub.status.busy": "2020-09-26T21:11:51.281493Z",
     "iopub.status.idle": "2020-09-26T22:58:04.687448Z",
     "shell.execute_reply": "2020-09-26T22:58:04.688778Z"
    },
    "papermill": {
     "duration": 6373.432564,
     "end_time": "2020-09-26T22:58:04.689198",
     "exception": false,
     "start_time": "2020-09-26T21:11:51.256634",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 50%|█████     | 1/2 [50:37<50:37, 3037.83s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 2/2 [1:46:13<00:00, 3127.14s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 2/2 [1:46:13<00:00, 3186.68s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>is_unpruned</th>\n",
       "      <th>model_tag</th>\n",
       "      <th>dataset</th>\n",
       "      <th>chi2_p_means</th>\n",
       "      <th>chi2_p_ranges</th>\n",
       "      <th>combined_p_means</th>\n",
       "      <th>effect_ranges</th>\n",
       "      <th>effect_means</th>\n",
       "      <th>combined_p_ranges</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>0.002164</td>\n",
       "      <td>0.892888</td>\n",
       "      <td>0.017104</td>\n",
       "      <td>5.508326e+00</td>\n",
       "      <td>1.000243</td>\n",
       "      <td>0.707797</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>0.264915</td>\n",
       "      <td>0.322700</td>\n",
       "      <td>0.036717</td>\n",
       "      <td>4.033645e+00</td>\n",
       "      <td>1.000161</td>\n",
       "      <td>0.479884</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>0.036334</td>\n",
       "      <td>0.297542</td>\n",
       "      <td>0.013149</td>\n",
       "      <td>2.412596e+00</td>\n",
       "      <td>1.000218</td>\n",
       "      <td>0.094653</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>0.026125</td>\n",
       "      <td>0.561076</td>\n",
       "      <td>0.000428</td>\n",
       "      <td>6.390199e+00</td>\n",
       "      <td>1.000393</td>\n",
       "      <td>0.986759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST</td>\n",
       "      <td>mnist</td>\n",
       "      <td>0.035189</td>\n",
       "      <td>0.877384</td>\n",
       "      <td>0.009793</td>\n",
       "      <td>5.987729e+00</td>\n",
       "      <td>1.000299</td>\n",
       "      <td>0.305394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>0.483188</td>\n",
       "      <td>0.787735</td>\n",
       "      <td>0.041504</td>\n",
       "      <td>inf</td>\n",
       "      <td>1.001231</td>\n",
       "      <td>0.737606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>0.000030</td>\n",
       "      <td>0.282699</td>\n",
       "      <td>0.000167</td>\n",
       "      <td>inf</td>\n",
       "      <td>1.001120</td>\n",
       "      <td>0.808584</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>0.075078</td>\n",
       "      <td>0.223803</td>\n",
       "      <td>0.025246</td>\n",
       "      <td>1.676627e+11</td>\n",
       "      <td>1.001132</td>\n",
       "      <td>0.934899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>0.595642</td>\n",
       "      <td>0.282699</td>\n",
       "      <td>0.011908</td>\n",
       "      <td>4.411689e+11</td>\n",
       "      <td>1.001306</td>\n",
       "      <td>0.580515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-FASHION</td>\n",
       "      <td>fashion</td>\n",
       "      <td>0.752382</td>\n",
       "      <td>0.879894</td>\n",
       "      <td>0.041674</td>\n",
       "      <td>inf</td>\n",
       "      <td>1.001280</td>\n",
       "      <td>0.304067</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   is_unpruned    model_tag  dataset  chi2_p_means  chi2_p_ranges  \\\n",
       "0         True    CNN-MNIST    mnist      0.002164       0.892888   \n",
       "1         True    CNN-MNIST    mnist      0.264915       0.322700   \n",
       "2         True    CNN-MNIST    mnist      0.036334       0.297542   \n",
       "3         True    CNN-MNIST    mnist      0.026125       0.561076   \n",
       "4         True    CNN-MNIST    mnist      0.035189       0.877384   \n",
       "5         True  CNN-FASHION  fashion      0.483188       0.787735   \n",
       "6         True  CNN-FASHION  fashion      0.000030       0.282699   \n",
       "7         True  CNN-FASHION  fashion      0.075078       0.223803   \n",
       "8         True  CNN-FASHION  fashion      0.595642       0.282699   \n",
       "9         True  CNN-FASHION  fashion      0.752382       0.879894   \n",
       "\n",
       "   combined_p_means  effect_ranges  effect_means  combined_p_ranges  \n",
       "0          0.017104   5.508326e+00      1.000243           0.707797  \n",
       "1          0.036717   4.033645e+00      1.000161           0.479884  \n",
       "2          0.013149   2.412596e+00      1.000218           0.094653  \n",
       "3          0.000428   6.390199e+00      1.000393           0.986759  \n",
       "4          0.009793   5.987729e+00      1.000299           0.305394  \n",
       "5          0.041504            inf      1.001231           0.737606  \n",
       "6          0.000167            inf      1.001120           0.808584  \n",
       "7          0.025246   1.676627e+11      1.001132           0.934899  \n",
       "8          0.011908   4.411689e+11      1.001306           0.580515  \n",
       "9          0.041674            inf      1.001280           0.304067  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results = []\n",
    "\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter('ignore')\n",
    "    for dataset_name, tag in tqdm(models):\n",
    "        for use_activations in [False, True]:\n",
    "            for do_local in [False, True]:\n",
    "                paths = get_model_path(tag, filter_='all')[-n_reps:]\n",
    "\n",
    "                net_pkl_results = []\n",
    "                fisher_p_means, chi2_p_means, effect_means = [], [], []\n",
    "                fisher_p_ranges, chi2_p_ranges, effect_ranges = [], [], []\n",
    "\n",
    "                for path in paths:\n",
    "\n",
    "                    (true_results,\n",
    "                     all_random_results,\n",
    "                     metadata,\n",
    "                     evaluation) = perform_lesion_experiment('.' + DATA_PATHS[dataset_name],\n",
    "                                                             path, n_clusters=n_clust,\n",
    "                                                             n_shuffles=n_shuffles,\n",
    "                                                             unpruned=is_unpruned,\n",
    "                                                             activations=use_activations,\n",
    "                                                             local=do_local)\n",
    "                    net_pkl_results.append({'true_results': true_results,\n",
    "                                            'all_random_results': all_random_results,\n",
    "                                            'metadata': metadata,\n",
    "                                            'evaluation': evaluation})\n",
    "                    hypo_results = do_lesion_hypo_tests(evaluation, true_results, all_random_results)\n",
    "\n",
    "                    fisher_p_means.append(hypo_results['fisher_p_means'])\n",
    "                    chi2_p_means.append(hypo_results['chi2_p_means'])\n",
    "                    effect_means.append(hypo_results['effect_factors_means'])\n",
    "                    fisher_p_ranges.append(hypo_results['fisher_p_ranges'])\n",
    "                    chi2_p_ranges.append(hypo_results['chi2_p_ranges'])\n",
    "                    effect_ranges.append(hypo_results['effect_factors_range'])\n",
    "\n",
    "                with open(results_dir + '/lesion_data_' + tag + f'_activations={use_activations}_local={do_local}.pkl', 'wb') as f:\n",
    "                    pickle.dump(net_pkl_results, f)\n",
    "\n",
    "                model_results = {'is_unpruned': is_unpruned,\n",
    "                                 'model_tag': tag,\n",
    "                                 'activations': use_activations,\n",
    "                                 'local': do_local,\n",
    "                                 'fisher_p_means': bates_quantile(np.mean(np.array(fisher_p_means)), n_reps),\n",
    "                                 # 'chi2_p_means': bates_quantile(np.mean(np.array(chi2_p_means)), n_reps),\n",
    "                                 'effect_means': np.mean(np.concatenate(effect_means))*2,\n",
    "                                 'effect_means_sem': sem(np.concatenate(effect_means)*2, axis=None),\n",
    "                                 'fisher_p_ranges': bates_quantile(np.mean(np.array(fisher_p_ranges)), n_reps),\n",
    "                                 # 'chi2_p_ranges': bates_quantile(np.mean(np.array(chi2_p_ranges)), n_reps),\n",
    "                                 'effect_ranges': np.mean(np.concatenate(effect_ranges))*2,\n",
    "                                 'effect_ranges_sem': sem(np.concatenate(effect_ranges)*2, axis=None),}\n",
    "                all_results.append(pd.Series(model_results))\n",
    "\n",
    "result_df = pd.DataFrame(all_results)\n",
    "savepath = '../results/lesion_results_cnn.csv'\n",
    "result_df.to_csv(savepath)\n",
    "result_df\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "papermill": {
   "duration": 6382.26112,
   "end_time": "2020-09-26T22:58:07.706427",
   "environment_variables": {},
   "exception": null,
   "input_path": "./notebooks/lesion_results_cnn.ipynb",
   "output_path": "./notebooks/lesion_results_cnn.ipynb",
   "parameters": {},
   "start_time": "2020-09-26T21:11:45.445307",
   "version": "1.2.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}