{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-04T17:53:34.695019Z",
     "iopub.status.busy": "2021-07-04T17:53:34.694591Z",
     "iopub.status.idle": "2021-07-04T17:53:43.582847Z",
     "shell.execute_reply": "2021-07-04T17:53:43.582358Z"
    },
    "papermill": {
     "duration": 8.895851,
     "end_time": "2021-07-04T17:53:43.582999",
     "exception": false,
     "start_time": "2021-07-04T17:53:34.687148",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'\n",
    "from src.lucid import evaluate_visualizations, evaluate_imagenet_visualizations\n",
    "from src.visualization import run_spectral_cluster, run_activations_cluster\n",
    "from src.experiment_tagging import get_model_path\n",
    "from src.utils import benjamini_hochberg, bates_quantile\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import warnings\n",
    "from scipy.stats import sem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-04T17:53:43.593475Z",
     "iopub.status.busy": "2021-07-04T17:53:43.593035Z",
     "iopub.status.idle": "2021-07-04T17:53:43.595024Z",
     "shell.execute_reply": "2021-07-04T17:53:43.594670Z"
    },
    "papermill": {
     "duration": 0.008564,
     "end_time": "2021-07-04T17:53:43.595126",
     "exception": false,
     "start_time": "2021-07-04T17:53:43.586562",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_tags = ('MNIST+LUCID',  # 'MNIST+DROPOUT+LUCID',\n",
    "              'CNN-MNIST+LUCID',  # 'CNN-MNIST+DROPOUT+LUCID',\n",
    "              'CNN-VGG-CIFAR10', 'CNN-VGG-CIFAR10+DROPOUT+L2REG',)\n",
    "              # 'CNN-VGG-UNTRAINED')\n",
    "imagenet_nets = ['vgg16', 'vgg19']\n",
    "n_reps = 5\n",
    "all_results = []\n",
    "n_clust = 16\n",
    "fisher_pvalues = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-04T17:53:43.611472Z",
     "iopub.status.busy": "2021-07-04T17:53:43.609800Z",
     "iopub.status.idle": "2021-07-04T23:00:32.412417Z",
     "shell.execute_reply": "2021-07-04T23:00:32.412791Z"
    },
    "papermill": {
     "duration": 18408.814585,
     "end_time": "2021-07-04T23:00:32.412935",
     "exception": false,
     "start_time": "2021-07-04T17:53:43.598350",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False True 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID False True 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True True 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST+LUCID True True 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False True 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID False True 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True True 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/init_ops.py:97: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-MNIST+LUCID True True 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False True 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 False True 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True True 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10 True True 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False True 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG False True 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True False 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True False 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True False 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True False 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True False 4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True True 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True True 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True True 2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True True 3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN-VGG-CIFAR10+DROPOUT+L2REG True True 4\n"
     ]
    }
   ],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter('ignore')\n",
    "    for tag in model_tags:\n",
    "        for use_activations in [False, True]:\n",
    "            for local in [False, True]:\n",
    "\n",
    "                model_dirs = get_model_path(tag, filter_='all')[-n_reps:]\n",
    "                is_unpruned = True\n",
    "                if 'CNN-VGG' in tag:\n",
    "                    network = 'CNN-VGG'\n",
    "                elif 'CNN' in tag:\n",
    "                    network = 'CNN'\n",
    "                else:\n",
    "                    network = 'MLP'\n",
    "\n",
    "                fisher_ps_score, chi2_ps_score, effect_factor_score = [], [], []\n",
    "                fisher_ps_entropy, chi2_ps_entropy, effect_factor_entropy = [], [], []\n",
    "                fisher_ps_dispersion, chi2_ps_dispersion, effect_factor_dispersion = [], [], []\n",
    "\n",
    "                for i in range(n_reps):\n",
    "                    # if not is_unpruned:\n",
    "                    #     weight_path = str(next(model_dirs[i].glob('*-pruned-weights.pckl')))\n",
    "                    #     activations_path = str(next(model_dirs[i].glob('*-unpruned-activations.pckl')))\n",
    "                    #     activations_mask_path = str(next(model_dirs[i].glob('*-unpruned-activations_mask.pckl')))\n",
    "                    # else:\n",
    "                    #     weight_path = str(next(model_dirs[i].glob('*-unpruned-weights.pckl')))\n",
    "                    #     activations_path = str(next(model_dirs[i].glob('*-unpruned-activations.pckl')))\n",
    "                    #     activations_mask_path = str(next(model_dirs[i].glob('*-unpruned-activations_mask.pckl')))\n",
    "\n",
    "                    # if use_activations:\n",
    "                    #     labels, _ = run_activations_cluster(activations_path, activations_mask_path, n_clusters=n_clust)\n",
    "                    # else:\n",
    "                    #     labels, _ = run_spectral_cluster(weight_path, n_clusters=n_clust, with_shuffle=False)\n",
    "\n",
    "                    # tag_sfx = '_activations' if use_activations else ''\n",
    "                    tag_sfx = f'_acts={use_activations}_local={local}_'\n",
    "                    results = evaluate_visualizations(tag, tag_sfx, i, is_unpruned)\n",
    "                    print(tag, use_activations, local, i)\n",
    "\n",
    "                    fisher_ps = results['fisher_ps']\n",
    "                    chi2_ps = results['chi2_ps']\n",
    "                    effect_factors = results['effect_factors']  # mean of the mean of random results / true result\n",
    "\n",
    "                    fisher_ps_score.append(fisher_ps[1])\n",
    "                    fisher_ps_entropy.append(fisher_ps[0])\n",
    "                    fisher_ps_dispersion.append(fisher_ps[2])\n",
    "                    chi2_ps_score.append(chi2_ps[1])\n",
    "                    chi2_ps_entropy.append(chi2_ps[0])\n",
    "                    chi2_ps_dispersion.append(chi2_ps[2])\n",
    "                    effect_factor_score.append(effect_factors[1])\n",
    "                    effect_factor_entropy.append(effect_factors[0])\n",
    "                    effect_factor_dispersion.append(effect_factors[2])\n",
    "\n",
    "                model_results = {'model_tag': tag,\n",
    "                                 'network': network,\n",
    "                                 'activations': use_activations,\n",
    "                                 'local': local,\n",
    "                                 'is_unpruned': is_unpruned,\n",
    "                                 'fisher_ps_score': bates_quantile(np.mean(np.array(fisher_ps_score)), n_reps),\n",
    "                                 # 'chi2_ps_score': bates_quantile(np.mean(np.array(chi2_ps_score)), n_reps),\n",
    "                                 'effect_factor_score': np.mean(np.concatenate(effect_factor_score))*2,\n",
    "                                 'effect_factor_score_sem': sem(np.concatenate(effect_factor_score).flatten()*2),\n",
    "                                 'fisher_ps_entropy': bates_quantile(np.mean(np.array(fisher_ps_entropy)), n_reps),\n",
    "                                 # 'chi2_ps_entropy': bates_quantile(np.mean(np.array(chi2_ps_entropy)), n_reps),\n",
    "                                 'effect_factor_entropy': np.mean(np.concatenate(effect_factor_entropy))*2,\n",
    "                                 'effect_factor_entropy_sem': sem(np.concatenate(effect_factor_entropy).flatten()*2),}\n",
    "                                 # 'fisher_ps_dispersion': bates_quantile(np.mean(np.array(fisher_ps_dispersion)), n_reps),\n",
    "                                 # 'chi2_ps_dispersion': bates_quantile(np.mean(np.array(chi2_ps_dispersion)), n_reps),\n",
    "                                 # 'effect_factor_dispersion': np.mean(np.array(effect_factor_dispersion)),\n",
    "                                 # 'cov_quartiles': results['cov_quartiles']}\n",
    "                fisher_pvalues.append(model_results['fisher_ps_score'])\n",
    "                fisher_pvalues.append(model_results['fisher_ps_entropy'])\n",
    "                all_results.append(pd.Series(model_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-04T23:00:32.487712Z",
     "iopub.status.busy": "2021-07-04T23:00:32.487253Z",
     "iopub.status.idle": "2021-07-05T01:36:26.032151Z",
     "shell.execute_reply": "2021-07-05T01:36:26.032488Z"
    },
    "papermill": {
     "duration": 9353.601467,
     "end_time": "2021-07-05T01:36:26.032619",
     "exception": false,
     "start_time": "2021-07-04T23:00:32.431152",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg16 False False\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg16 False True\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg16 True False\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg16 True True\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg19 False False\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg19 False True\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg19 True False\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg19 True True\n"
     ]
    }
   ],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter('ignore')\n",
    "    for net in imagenet_nets:\n",
    "        for use_activations in [False, True]:\n",
    "            for local in [False, True]:\n",
    "\n",
    "                results = evaluate_imagenet_visualizations(net, use_activations, local=local)\n",
    "                percentiles = results['percentiles']\n",
    "                chi2_ps = results['chi2_ps']\n",
    "                fisher_ps = results['fisher_ps']\n",
    "                effect_factors = results['effect_factors']  # mean of the mean of random results / true result\n",
    "                model_results = {'is_unpruned': True, 'model_tag': net,\n",
    "                                 'activations': use_activations,\n",
    "                                 'local': local,\n",
    "                                 'fisher_ps_score': fisher_ps[1],\n",
    "                                 # 'chi2_ps_score': chi2_ps[1],\n",
    "                                 'effect_factor_score': np.mean(effect_factors[1]*2),\n",
    "                                 'effect_factor_score_sem': sem(effect_factors[1]*2, axis=None),\n",
    "                                 'fisher_ps_entropy': fisher_ps[0],\n",
    "                                 # 'chi2_ps_entropy': chi2_ps[0],\n",
    "                                 'effect_factor_entropy': np.mean(effect_factors[0]*2),\n",
    "                                 'effect_factor_entropy_sem': sem(effect_factors[0]*2, axis=None),}\n",
    "                                 # 'fisher_ps_dispersion': fisher_ps[2],\n",
    "                                 # 'chi2_ps_dispersion': chi2_ps[2],\n",
    "                                 # 'effect_factor_dispersion': effect_factors[2],\n",
    "                                 # 'cov_quartiles': results['cov_quartiles']}\n",
    "                all_results.append(pd.Series(model_results))\n",
    "                fisher_pvalues.append(model_results['fisher_ps_score'])\n",
    "                fisher_pvalues.append(model_results['fisher_ps_entropy'])\n",
    "                print(net, use_activations, local)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-05T01:36:26.086239Z",
     "iopub.status.busy": "2021-07-05T01:36:26.085742Z",
     "iopub.status.idle": "2021-07-05T01:36:26.093667Z",
     "shell.execute_reply": "2021-07-05T01:36:26.093323Z"
    },
    "papermill": {
     "duration": 0.041513,
     "end_time": "2021-07-05T01:36:26.093760",
     "exception": false,
     "start_time": "2021-07-05T01:36:26.052247",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cricial_pvalue: 0.024620756852149532\n"
     ]
    }
   ],
   "source": [
    "result_df = pd.DataFrame(all_results)\n",
    "savepath = '../results/lucid_results_all.csv'\n",
    "result_df.to_csv(savepath)\n",
    "result_df\n",
    "\n",
    "critical_p = benjamini_hochberg(fisher_pvalues)\n",
    "print(f'cricial_pvalue: {critical_p}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "papermill": {
   "duration": 27778.697113,
   "end_time": "2021-07-05T01:36:32.569771",
   "environment_variables": {},
   "exception": null,
   "input_path": "./notebooks/lucid_results_all.ipynb",
   "output_path": "./notebooks/lucid_results_all.ipynb",
   "parameters": {},
   "start_time": "2021-07-04T17:53:33.872658",
   "version": "1.2.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}