{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-27T00:31:58.568263Z",
     "iopub.status.busy": "2020-09-27T00:31:58.566891Z",
     "iopub.status.idle": "2020-09-27T00:32:03.206194Z",
     "shell.execute_reply": "2020-09-27T00:32:03.204492Z"
    },
    "papermill": {
     "duration": 4.665959,
     "end_time": "2020-09-27T00:32:03.206697",
     "exception": false,
     "start_time": "2020-09-27T00:31:58.540738",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '0,3'\n",
    "from src.lucid import evaluate_visualizations, evaluate_imagenet_visualizations\n",
    "import pandas as pd\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-27T00:32:03.235673Z",
     "iopub.status.busy": "2020-09-27T00:32:03.234995Z",
     "iopub.status.idle": "2020-09-27T00:32:03.238582Z",
     "shell.execute_reply": "2020-09-27T00:32:03.239269Z"
    },
    "papermill": {
     "duration": 0.019278,
     "end_time": "2020-09-27T00:32:03.239486",
     "exception": false,
     "start_time": "2020-09-27T00:32:03.220208",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_tags = ('MNIST+LUCID', 'MNIST+DROPOUT+LUCID', 'CNN-MNIST+LUCID', 'CNN-MNIST+DROPOUT+LUCID',\n",
    "              'CNN-VGG-CIFAR10', 'CNN-VGG-CIFAR10+DROPOUT+L2REG')\n",
    "imagenet_nets = ['vgg16', 'vgg19', 'resnet50']\n",
    "n_reps = 5\n",
    "all_results = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-27T00:32:03.270198Z",
     "iopub.status.busy": "2020-09-27T00:32:03.269344Z",
     "iopub.status.idle": "2020-09-27T00:37:18.175959Z",
     "shell.execute_reply": "2020-09-27T00:37:18.177237Z"
    },
    "papermill": {
     "duration": 314.930513,
     "end_time": "2020-09-27T00:37:18.177683",
     "exception": false,
     "start_time": "2020-09-27T00:32:03.247170",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:205: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:205: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py:240: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Deprecated in favor of operator or tf.math.divide.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py:240: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Deprecated in favor of operator or tf.math.divide.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    }
   ],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter('ignore')\n",
    "\n",
    "    for tag in model_tags:\n",
    "\n",
    "        is_unpruned = 'CNN' in tag\n",
    "        if 'CNN-VGG' in tag:\n",
    "            network = 'CNN-VGG'\n",
    "        elif 'CNN' in tag:\n",
    "            network = 'CNN'\n",
    "        else:\n",
    "            network = 'MLP'\n",
    "\n",
    "        for i in range(n_reps):\n",
    "\n",
    "            results = evaluate_visualizations(tag, i, is_unpruned)\n",
    "            chi2_ps = results['chi2_ps']\n",
    "            combined_ps = results['combined_ps']\n",
    "            effect_factors = results['effect_factors']  # mean of the mean of random results / true result\n",
    "            model_results = {'is_unpruned': is_unpruned, 'model_tag': tag, 'network': network,\n",
    "                             'chi2_categorical_ps_entropy': chi2_ps[0],\n",
    "                             'chi2_categorical_ps_loss': chi2_ps[1],\n",
    "                             'combined_ps_entropy': combined_ps[0],\n",
    "                             'combined_ps_loss': combined_ps[1],\n",
    "                             'effect_factor_entropy': effect_factors[0],\n",
    "                             'effect_factor_loss': effect_factors[1]}\n",
    "\n",
    "            all_results.append(pd.Series(model_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-27T00:37:18.227368Z",
     "iopub.status.busy": "2020-09-27T00:37:18.226702Z",
     "iopub.status.idle": "2020-09-27T00:38:17.712783Z",
     "shell.execute_reply": "2020-09-27T00:38:17.712172Z"
    },
    "papermill": {
     "duration": 59.514978,
     "end_time": "2020-09-27T00:38:17.712918",
     "exception": false,
     "start_time": "2020-09-27T00:37:18.197940",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /root/.local/share/virtualenvs/nn_clustering-Lo7V74L4/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter('ignore')\n",
    "\n",
    "    for net in imagenet_nets:\n",
    "\n",
    "        results = evaluate_imagenet_visualizations(net)\n",
    "        percentiles = results['percentiles']\n",
    "        chi2_ps = results['chi2_ps']\n",
    "        combined_ps = results['combined_ps']\n",
    "        effect_factors = results['effect_factors']  # mean of the mean of random results / true result\n",
    "\n",
    "        model_results = {'is_unpruned': True, 'model_tag': net, 'network': net,\n",
    "                         'chi2_categorical_ps_entropy': chi2_ps[0],\n",
    "                         'chi2_categorical_ps_loss': chi2_ps[1],\n",
    "                         'combined_ps_entropy': combined_ps[0],\n",
    "                         'combined_ps_loss': combined_ps[1],\n",
    "                         'effect_factor_entropy': effect_factors[0],\n",
    "                         'effect_factor_loss': effect_factors[1]}\n",
    "\n",
    "        all_results.append(pd.Series(model_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-27T00:38:17.749200Z",
     "iopub.status.busy": "2020-09-27T00:38:17.748254Z",
     "iopub.status.idle": "2020-09-27T00:38:17.777206Z",
     "shell.execute_reply": "2020-09-27T00:38:17.777637Z"
    },
    "papermill": {
     "duration": 0.055094,
     "end_time": "2020-09-27T00:38:17.777786",
     "exception": false,
     "start_time": "2020-09-27T00:38:17.722692",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>is_unpruned</th>\n",
       "      <th>model_tag</th>\n",
       "      <th>network</th>\n",
       "      <th>chi2_categorical_ps_entropy</th>\n",
       "      <th>chi2_categorical_ps_loss</th>\n",
       "      <th>combined_ps_entropy</th>\n",
       "      <th>combined_ps_loss</th>\n",
       "      <th>effect_factor_entropy</th>\n",
       "      <th>effect_factor_loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>8.343083e-01</td>\n",
       "      <td>9.114125e-01</td>\n",
       "      <td>6.415488e-01</td>\n",
       "      <td>0.575253</td>\n",
       "      <td>5.220406e+30</td>\n",
       "      <td>1.109227</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>3.654041e-01</td>\n",
       "      <td>6.276549e-01</td>\n",
       "      <td>3.768298e-01</td>\n",
       "      <td>0.821081</td>\n",
       "      <td>5.425442e+28</td>\n",
       "      <td>1.191112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>1.842753e-01</td>\n",
       "      <td>8.099154e-01</td>\n",
       "      <td>7.804189e-01</td>\n",
       "      <td>0.337454</td>\n",
       "      <td>4.901449e+26</td>\n",
       "      <td>1.686758</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>3.818729e-02</td>\n",
       "      <td>4.372742e-01</td>\n",
       "      <td>1.949422e-02</td>\n",
       "      <td>0.470037</td>\n",
       "      <td>5.883315e+26</td>\n",
       "      <td>1.061164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>8.043368e-01</td>\n",
       "      <td>8.961869e-01</td>\n",
       "      <td>3.763245e-01</td>\n",
       "      <td>0.619988</td>\n",
       "      <td>1.634509e+21</td>\n",
       "      <td>1.347749</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+DROPOUT+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>3.543540e-01</td>\n",
       "      <td>9.518825e-01</td>\n",
       "      <td>3.052160e-01</td>\n",
       "      <td>0.770845</td>\n",
       "      <td>7.588679e+30</td>\n",
       "      <td>3.284653</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+DROPOUT+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>5.341462e-01</td>\n",
       "      <td>8.099154e-01</td>\n",
       "      <td>3.528122e-01</td>\n",
       "      <td>0.850054</td>\n",
       "      <td>2.384608e+21</td>\n",
       "      <td>3.750950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+DROPOUT+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>1.010572e-02</td>\n",
       "      <td>2.345941e-01</td>\n",
       "      <td>9.639149e-01</td>\n",
       "      <td>0.760806</td>\n",
       "      <td>4.117816e+24</td>\n",
       "      <td>1.451880</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+DROPOUT+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>3.765794e-09</td>\n",
       "      <td>4.846459e-01</td>\n",
       "      <td>5.851108e-04</td>\n",
       "      <td>0.681073</td>\n",
       "      <td>1.940843e+28</td>\n",
       "      <td>19.092426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>False</td>\n",
       "      <td>MNIST+DROPOUT+LUCID</td>\n",
       "      <td>MLP</td>\n",
       "      <td>9.745547e-01</td>\n",
       "      <td>7.110172e-01</td>\n",
       "      <td>6.139546e-01</td>\n",
       "      <td>0.495174</td>\n",
       "      <td>3.232920e+31</td>\n",
       "      <td>1.281552</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>1.346864e-01</td>\n",
       "      <td>9.093598e-02</td>\n",
       "      <td>7.814156e-01</td>\n",
       "      <td>0.726158</td>\n",
       "      <td>3.451579e+00</td>\n",
       "      <td>2.174174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>8.609222e-01</td>\n",
       "      <td>4.399437e-02</td>\n",
       "      <td>2.437335e-01</td>\n",
       "      <td>0.296673</td>\n",
       "      <td>7.409971e+02</td>\n",
       "      <td>1.014918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>1.346864e-01</td>\n",
       "      <td>9.957112e-01</td>\n",
       "      <td>4.969971e-01</td>\n",
       "      <td>0.517611</td>\n",
       "      <td>2.174070e+01</td>\n",
       "      <td>1.244136</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>1.034010e-01</td>\n",
       "      <td>4.505644e-01</td>\n",
       "      <td>5.043146e-01</td>\n",
       "      <td>0.671475</td>\n",
       "      <td>6.729382e+02</td>\n",
       "      <td>1.201458</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>7.399183e-01</td>\n",
       "      <td>3.278541e-01</td>\n",
       "      <td>8.754545e-01</td>\n",
       "      <td>0.790018</td>\n",
       "      <td>1.423348e+01</td>\n",
       "      <td>1.329242</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+DROPOUT+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>5.009344e-01</td>\n",
       "      <td>1.951631e-01</td>\n",
       "      <td>5.105824e-01</td>\n",
       "      <td>0.249073</td>\n",
       "      <td>3.208313e+00</td>\n",
       "      <td>1.239347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+DROPOUT+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>2.429858e-01</td>\n",
       "      <td>6.890190e-01</td>\n",
       "      <td>3.159778e-01</td>\n",
       "      <td>0.399179</td>\n",
       "      <td>2.881206e+01</td>\n",
       "      <td>1.503284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+DROPOUT+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>8.543181e-01</td>\n",
       "      <td>6.248720e-01</td>\n",
       "      <td>8.998283e-01</td>\n",
       "      <td>0.400648</td>\n",
       "      <td>2.736860e+00</td>\n",
       "      <td>1.188363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+DROPOUT+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>5.714647e-02</td>\n",
       "      <td>8.119933e-01</td>\n",
       "      <td>1.578825e-01</td>\n",
       "      <td>0.094943</td>\n",
       "      <td>2.081008e+01</td>\n",
       "      <td>1.138360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-MNIST+DROPOUT+LUCID</td>\n",
       "      <td>CNN</td>\n",
       "      <td>3.950156e-01</td>\n",
       "      <td>3.006789e-01</td>\n",
       "      <td>9.571449e-01</td>\n",
       "      <td>0.636037</td>\n",
       "      <td>1.083563e+00</td>\n",
       "      <td>1.500785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>1.292553e-23</td>\n",
       "      <td>3.751238e-11</td>\n",
       "      <td>2.000798e-07</td>\n",
       "      <td>0.000345</td>\n",
       "      <td>2.268583e+01</td>\n",
       "      <td>6.396876</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>2.309317e-25</td>\n",
       "      <td>5.190681e-17</td>\n",
       "      <td>3.084317e-08</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>7.282656e-01</td>\n",
       "      <td>12.137091</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>2.309317e-25</td>\n",
       "      <td>5.695479e-07</td>\n",
       "      <td>6.182423e-08</td>\n",
       "      <td>0.011986</td>\n",
       "      <td>1.974278e-01</td>\n",
       "      <td>67.001663</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>1.226882e-11</td>\n",
       "      <td>8.973292e-09</td>\n",
       "      <td>1.488822e-04</td>\n",
       "      <td>0.001746</td>\n",
       "      <td>1.520406e+15</td>\n",
       "      <td>40.483116</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>1.923616e-20</td>\n",
       "      <td>2.192059e-13</td>\n",
       "      <td>3.062468e-07</td>\n",
       "      <td>0.000185</td>\n",
       "      <td>1.613116e+17</td>\n",
       "      <td>24.390331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10+DROPOUT+L2REG</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>1.168184e-01</td>\n",
       "      <td>1.733085e-09</td>\n",
       "      <td>9.254234e-01</td>\n",
       "      <td>0.027421</td>\n",
       "      <td>6.543447e+04</td>\n",
       "      <td>4.286005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10+DROPOUT+L2REG</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>6.158542e-02</td>\n",
       "      <td>3.560205e-05</td>\n",
       "      <td>9.669103e-01</td>\n",
       "      <td>0.033528</td>\n",
       "      <td>5.249682e+03</td>\n",
       "      <td>5.030357</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10+DROPOUT+L2REG</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>8.168641e-02</td>\n",
       "      <td>3.192441e-05</td>\n",
       "      <td>8.908880e-01</td>\n",
       "      <td>0.065585</td>\n",
       "      <td>8.581334e+04</td>\n",
       "      <td>7.648377</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10+DROPOUT+L2REG</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>5.047996e-01</td>\n",
       "      <td>5.161830e-05</td>\n",
       "      <td>5.411119e-01</td>\n",
       "      <td>0.034970</td>\n",
       "      <td>2.358279e+03</td>\n",
       "      <td>5.617299</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>True</td>\n",
       "      <td>CNN-VGG-CIFAR10+DROPOUT+L2REG</td>\n",
       "      <td>CNN-VGG</td>\n",
       "      <td>6.158542e-02</td>\n",
       "      <td>6.982450e-08</td>\n",
       "      <td>7.777845e-01</td>\n",
       "      <td>0.064055</td>\n",
       "      <td>6.387778e+07</td>\n",
       "      <td>8.267369</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>True</td>\n",
       "      <td>vgg16</td>\n",
       "      <td>vgg16</td>\n",
       "      <td>2.810185e-01</td>\n",
       "      <td>1.405777e-07</td>\n",
       "      <td>1.181684e-01</td>\n",
       "      <td>0.000254</td>\n",
       "      <td>1.009512e+00</td>\n",
       "      <td>0.928339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>True</td>\n",
       "      <td>vgg19</td>\n",
       "      <td>vgg19</td>\n",
       "      <td>8.881366e-01</td>\n",
       "      <td>5.692416e-04</td>\n",
       "      <td>6.329974e-01</td>\n",
       "      <td>0.002895</td>\n",
       "      <td>9.987466e-01</td>\n",
       "      <td>0.947881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>True</td>\n",
       "      <td>resnet50</td>\n",
       "      <td>resnet50</td>\n",
       "      <td>8.119933e-01</td>\n",
       "      <td>4.372742e-01</td>\n",
       "      <td>4.351022e-01</td>\n",
       "      <td>0.354857</td>\n",
       "      <td>1.001712e+00</td>\n",
       "      <td>0.991791</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    is_unpruned                      model_tag   network  \\\n",
       "0         False                    MNIST+LUCID       MLP   \n",
       "1         False                    MNIST+LUCID       MLP   \n",
       "2         False                    MNIST+LUCID       MLP   \n",
       "3         False                    MNIST+LUCID       MLP   \n",
       "4         False                    MNIST+LUCID       MLP   \n",
       "5         False            MNIST+DROPOUT+LUCID       MLP   \n",
       "6         False            MNIST+DROPOUT+LUCID       MLP   \n",
       "7         False            MNIST+DROPOUT+LUCID       MLP   \n",
       "8         False            MNIST+DROPOUT+LUCID       MLP   \n",
       "9         False            MNIST+DROPOUT+LUCID       MLP   \n",
       "10         True                CNN-MNIST+LUCID       CNN   \n",
       "11         True                CNN-MNIST+LUCID       CNN   \n",
       "12         True                CNN-MNIST+LUCID       CNN   \n",
       "13         True                CNN-MNIST+LUCID       CNN   \n",
       "14         True                CNN-MNIST+LUCID       CNN   \n",
       "15         True        CNN-MNIST+DROPOUT+LUCID       CNN   \n",
       "16         True        CNN-MNIST+DROPOUT+LUCID       CNN   \n",
       "17         True        CNN-MNIST+DROPOUT+LUCID       CNN   \n",
       "18         True        CNN-MNIST+DROPOUT+LUCID       CNN   \n",
       "19         True        CNN-MNIST+DROPOUT+LUCID       CNN   \n",
       "20         True                CNN-VGG-CIFAR10   CNN-VGG   \n",
       "21         True                CNN-VGG-CIFAR10   CNN-VGG   \n",
       "22         True                CNN-VGG-CIFAR10   CNN-VGG   \n",
       "23         True                CNN-VGG-CIFAR10   CNN-VGG   \n",
       "24         True                CNN-VGG-CIFAR10   CNN-VGG   \n",
       "25         True  CNN-VGG-CIFAR10+DROPOUT+L2REG   CNN-VGG   \n",
       "26         True  CNN-VGG-CIFAR10+DROPOUT+L2REG   CNN-VGG   \n",
       "27         True  CNN-VGG-CIFAR10+DROPOUT+L2REG   CNN-VGG   \n",
       "28         True  CNN-VGG-CIFAR10+DROPOUT+L2REG   CNN-VGG   \n",
       "29         True  CNN-VGG-CIFAR10+DROPOUT+L2REG   CNN-VGG   \n",
       "30         True                          vgg16     vgg16   \n",
       "31         True                          vgg19     vgg19   \n",
       "32         True                       resnet50  resnet50   \n",
       "\n",
       "    chi2_categorical_ps_entropy  chi2_categorical_ps_loss  \\\n",
       "0                  8.343083e-01              9.114125e-01   \n",
       "1                  3.654041e-01              6.276549e-01   \n",
       "2                  1.842753e-01              8.099154e-01   \n",
       "3                  3.818729e-02              4.372742e-01   \n",
       "4                  8.043368e-01              8.961869e-01   \n",
       "5                  3.543540e-01              9.518825e-01   \n",
       "6                  5.341462e-01              8.099154e-01   \n",
       "7                  1.010572e-02              2.345941e-01   \n",
       "8                  3.765794e-09              4.846459e-01   \n",
       "9                  9.745547e-01              7.110172e-01   \n",
       "10                 1.346864e-01              9.093598e-02   \n",
       "11                 8.609222e-01              4.399437e-02   \n",
       "12                 1.346864e-01              9.957112e-01   \n",
       "13                 1.034010e-01              4.505644e-01   \n",
       "14                 7.399183e-01              3.278541e-01   \n",
       "15                 5.009344e-01              1.951631e-01   \n",
       "16                 2.429858e-01              6.890190e-01   \n",
       "17                 8.543181e-01              6.248720e-01   \n",
       "18                 5.714647e-02              8.119933e-01   \n",
       "19                 3.950156e-01              3.006789e-01   \n",
       "20                 1.292553e-23              3.751238e-11   \n",
       "21                 2.309317e-25              5.190681e-17   \n",
       "22                 2.309317e-25              5.695479e-07   \n",
       "23                 1.226882e-11              8.973292e-09   \n",
       "24                 1.923616e-20              2.192059e-13   \n",
       "25                 1.168184e-01              1.733085e-09   \n",
       "26                 6.158542e-02              3.560205e-05   \n",
       "27                 8.168641e-02              3.192441e-05   \n",
       "28                 5.047996e-01              5.161830e-05   \n",
       "29                 6.158542e-02              6.982450e-08   \n",
       "30                 2.810185e-01              1.405777e-07   \n",
       "31                 8.881366e-01              5.692416e-04   \n",
       "32                 8.119933e-01              4.372742e-01   \n",
       "\n",
       "    combined_ps_entropy  combined_ps_loss  effect_factor_entropy  \\\n",
       "0          6.415488e-01          0.575253           5.220406e+30   \n",
       "1          3.768298e-01          0.821081           5.425442e+28   \n",
       "2          7.804189e-01          0.337454           4.901449e+26   \n",
       "3          1.949422e-02          0.470037           5.883315e+26   \n",
       "4          3.763245e-01          0.619988           1.634509e+21   \n",
       "5          3.052160e-01          0.770845           7.588679e+30   \n",
       "6          3.528122e-01          0.850054           2.384608e+21   \n",
       "7          9.639149e-01          0.760806           4.117816e+24   \n",
       "8          5.851108e-04          0.681073           1.940843e+28   \n",
       "9          6.139546e-01          0.495174           3.232920e+31   \n",
       "10         7.814156e-01          0.726158           3.451579e+00   \n",
       "11         2.437335e-01          0.296673           7.409971e+02   \n",
       "12         4.969971e-01          0.517611           2.174070e+01   \n",
       "13         5.043146e-01          0.671475           6.729382e+02   \n",
       "14         8.754545e-01          0.790018           1.423348e+01   \n",
       "15         5.105824e-01          0.249073           3.208313e+00   \n",
       "16         3.159778e-01          0.399179           2.881206e+01   \n",
       "17         8.998283e-01          0.400648           2.736860e+00   \n",
       "18         1.578825e-01          0.094943           2.081008e+01   \n",
       "19         9.571449e-01          0.636037           1.083563e+00   \n",
       "20         2.000798e-07          0.000345           2.268583e+01   \n",
       "21         3.084317e-08          0.000001           7.282656e-01   \n",
       "22         6.182423e-08          0.011986           1.974278e-01   \n",
       "23         1.488822e-04          0.001746           1.520406e+15   \n",
       "24         3.062468e-07          0.000185           1.613116e+17   \n",
       "25         9.254234e-01          0.027421           6.543447e+04   \n",
       "26         9.669103e-01          0.033528           5.249682e+03   \n",
       "27         8.908880e-01          0.065585           8.581334e+04   \n",
       "28         5.411119e-01          0.034970           2.358279e+03   \n",
       "29         7.777845e-01          0.064055           6.387778e+07   \n",
       "30         1.181684e-01          0.000254           1.009512e+00   \n",
       "31         6.329974e-01          0.002895           9.987466e-01   \n",
       "32         4.351022e-01          0.354857           1.001712e+00   \n",
       "\n",
       "    effect_factor_loss  \n",
       "0             1.109227  \n",
       "1             1.191112  \n",
       "2             1.686758  \n",
       "3             1.061164  \n",
       "4             1.347749  \n",
       "5             3.284653  \n",
       "6             3.750950  \n",
       "7             1.451880  \n",
       "8            19.092426  \n",
       "9             1.281552  \n",
       "10            2.174174  \n",
       "11            1.014918  \n",
       "12            1.244136  \n",
       "13            1.201458  \n",
       "14            1.329242  \n",
       "15            1.239347  \n",
       "16            1.503284  \n",
       "17            1.188363  \n",
       "18            1.138360  \n",
       "19            1.500785  \n",
       "20            6.396876  \n",
       "21           12.137091  \n",
       "22           67.001663  \n",
       "23           40.483116  \n",
       "24           24.390331  \n",
       "25            4.286005  \n",
       "26            5.030357  \n",
       "27            7.648377  \n",
       "28            5.617299  \n",
       "29            8.267369  \n",
       "30            0.928339  \n",
       "31            0.947881  \n",
       "32            0.991791  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df = pd.DataFrame(all_results)\n",
    "savepath = '../results/lucid_results_all.csv'\n",
    "result_df.to_csv(savepath)\n",
    "result_df\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "papermill": {
   "duration": 384.412618,
   "end_time": "2020-09-27T00:38:22.028950",
   "environment_variables": {},
   "exception": null,
   "input_path": "./notebooks/lucid_results_all.ipynb",
   "output_path": "./notebooks/lucid_results_all.ipynb",
   "parameters": {},
   "start_time": "2020-09-27T00:31:57.616332",
   "version": "1.2.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}