{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-24T20:38:00.543717Z",
     "start_time": "2020-03-24T20:38:00.489316Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appended /home/XXX/PycharmProjects/entropy_distance_loss/src to paths\n",
      "Switched to directory /home/XXX/PycharmProjects/entropy_distance_loss\n",
      "%load_ext autoreload\n",
      "%autoreload 2\n"
     ]
    }
   ],
   "source": [
    "import XXX.notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-24T20:38:03.002798Z",
     "start_time": "2020-03-24T20:38:02.836299Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-24T20:38:01.531646Z",
     "start_time": "2020-03-24T20:38:00.634401Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0m\u001b[01;34mbaseline_cifar10\u001b[0m/              continuous_bounds.ipynb\n",
      "\u001b[01;34mbaseline_permutation_mnist\u001b[0m/    \u001b[01;34mdifferent_categorical_decoders\u001b[0m/\n",
      "\u001b[01;34mbounds_categorical_cifar10\u001b[0m/    \u001b[01;34mdvib_comparison\u001b[0m/\n",
      "\u001b[01;34mbounds_categorical_mnist\u001b[0m/      \u001b[01;34mentropy_minimization_and_noise\u001b[0m/\n",
      "\u001b[01;34mbounds_cifar10\u001b[0m/                \u001b[01;34mimagenet_measure_regularization\u001b[0m/\n",
      "\u001b[01;34mbounds_mnist\u001b[0m/                  __init__.py\n",
      "categorical_bounds.ipynb       \u001b[01;34mmeasure_regularization\u001b[0m/\n",
      "\u001b[01;34mcategorical_cifar2x5\u001b[0m/          \u001b[01;34msurrogates\u001b[0m/\n",
      "\u001b[01;34mcifar10_resnet_beta_ib_plots\u001b[0m/  \u001b[01;34munused\u001b[0m/\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import statistics\n",
    "\n",
    "from XXX import YYY\n",
    "\n",
    "from experiments.utils.jupyter import results_loader\n",
    "\n",
    "%ls {XXX.notebook.original_dir}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-24T20:38:02.719121Z",
     "start_time": "2020-03-24T20:38:01.533602Z"
    }
   },
   "outputs": [],
   "source": [
    "mnist_store = results_loader.load_YYY_files(f'{XXX.notebook.original_dir}/bounds_mnist/results/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10_store = results_loader.load_YYY_files(f'{XXX.notebook.original_dir}/bounds_cifar10/results/')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10_deterministic_store = results_loader.load_YYY_files(f\"{XXX.notebook.original_dir}/bounds_cifar10/results_deterministic/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10_deterministic_no_bn_store = results_loader.load_YYY_files(f\"{XXX.notebook.original_dir}/bounds_cifar10/results_deterministic_no_bn/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-24T20:38:02.733683Z",
     "start_time": "2020-03-24T20:38:02.721114Z"
    }
   },
   "outputs": [],
   "source": [
    "pmnist_filtered_results = results_loader.filter_dict(mnist_store, k=lambda name: \"permutation\" in name)\n",
    "mnist_filtered_results = results_loader.filter_dict(mnist_store, k=lambda name: \"permutation\" not in name)\n",
    "cifar10_filtered_results = results_loader.filter_dict(cifar10_store, v=lambda result: True)\n",
    "cifar10_deterministic_filtered_results = results_loader.filter_dict(cifar10_deterministic_store, v=lambda result: True)\n",
    "cifar10_deterministic_no_bn_filtered_results = results_loader.filter_dict(cifar10_deterministic_no_bn_store, v=lambda result: True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-24T20:38:02.754990Z",
     "start_time": "2020-03-24T20:38:02.735380Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(16, 20, 16, 16, 16)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(mnist_filtered_results), len(pmnist_filtered_results), len(cifar10_filtered_results), len(cifar10_deterministic_filtered_results), len(cifar10_deterministic_no_bn_filtered_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_data = []\n",
    "\n",
    "for dataset, results in ((\"mnist\", mnist_filtered_results), (\"pmnist\", pmnist_filtered_results), (\"cifar10\", cifar10_filtered_results), (\"det_cifar10\", cifar10_deterministic_filtered_results), (\"det_no_bn_cifar10\", cifar10_deterministic_no_bn_filtered_results)):\n",
    "\n",
    "    for ri, result in enumerate(results.values()):\n",
    "        if \"categorical\" in result.actual_name:\n",
    "            objective = \"decoder_uncertainty\"\n",
    "        elif 'prediction' in result.config.cross_entropy_type:\n",
    "            objective = \"prediction\"\n",
    "        else:\n",
    "            objective = \"decoder\"\n",
    "            \n",
    "        for (source, epochs) in ((\"train\", result.log.training_epochs), (\"test\", result.log.test_epochs)):\n",
    "            for i, epoch in enumerate(epochs):\n",
    "                d = epoch._asdict()\n",
    "                d['error'] = 1 - d['accuracy']\n",
    "                d['error_p'] = 1 - d['correct_prob']\n",
    "                d['error_bound_decoder_xe'] = 1 - np.exp(-d['xe_decoder'])\n",
    "                d['epoch'] = i\n",
    "                d['objective'] = objective\n",
    "                d['experiment_index'] = ri\n",
    "                d['source'] = source\n",
    "                d['dataset'] = dataset\n",
    "                experiment_data.append(d)\n",
    "\n",
    "df=pd.DataFrame(experiment_data)\n",
    "dfm = df.melt(id_vars=['dataset', 'source', 'experiment_index', 'epoch', 'objective'], value_vars=['error','error_p', 'xe_decoder', 'xe_prediction', 'error_bound_decoder_xe'])\n",
    "dfm.to_csv(f\"plots/bounds_continuous.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>source</th>\n",
       "      <th>experiment_index</th>\n",
       "      <th>epoch</th>\n",
       "      <th>objective</th>\n",
       "      <th>variable</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>mnist</td>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>decoder</td>\n",
       "      <td>error</td>\n",
       "      <td>0.069933</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>mnist</td>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>decoder</td>\n",
       "      <td>error</td>\n",
       "      <td>0.025917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>mnist</td>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>decoder</td>\n",
       "      <td>error</td>\n",
       "      <td>0.015900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>mnist</td>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>decoder</td>\n",
       "      <td>error</td>\n",
       "      <td>0.011050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>mnist</td>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>decoder</td>\n",
       "      <td>error</td>\n",
       "      <td>0.008233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83995</th>\n",
       "      <td>det_no_bn_cifar10</td>\n",
       "      <td>test</td>\n",
       "      <td>15</td>\n",
       "      <td>95</td>\n",
       "      <td>prediction</td>\n",
       "      <td>error_bound_decoder_xe</td>\n",
       "      <td>0.973538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83996</th>\n",
       "      <td>det_no_bn_cifar10</td>\n",
       "      <td>test</td>\n",
       "      <td>15</td>\n",
       "      <td>96</td>\n",
       "      <td>prediction</td>\n",
       "      <td>error_bound_decoder_xe</td>\n",
       "      <td>0.974008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83997</th>\n",
       "      <td>det_no_bn_cifar10</td>\n",
       "      <td>test</td>\n",
       "      <td>15</td>\n",
       "      <td>97</td>\n",
       "      <td>prediction</td>\n",
       "      <td>error_bound_decoder_xe</td>\n",
       "      <td>0.973573</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83998</th>\n",
       "      <td>det_no_bn_cifar10</td>\n",
       "      <td>test</td>\n",
       "      <td>15</td>\n",
       "      <td>98</td>\n",
       "      <td>prediction</td>\n",
       "      <td>error_bound_decoder_xe</td>\n",
       "      <td>0.973142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83999</th>\n",
       "      <td>det_no_bn_cifar10</td>\n",
       "      <td>test</td>\n",
       "      <td>15</td>\n",
       "      <td>99</td>\n",
       "      <td>prediction</td>\n",
       "      <td>error_bound_decoder_xe</td>\n",
       "      <td>0.973907</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>84000 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                 dataset source  experiment_index  epoch   objective  \\\n",
       "0                  mnist  train                 0      0     decoder   \n",
       "1                  mnist  train                 0      1     decoder   \n",
       "2                  mnist  train                 0      2     decoder   \n",
       "3                  mnist  train                 0      3     decoder   \n",
       "4                  mnist  train                 0      4     decoder   \n",
       "...                  ...    ...               ...    ...         ...   \n",
       "83995  det_no_bn_cifar10   test                15     95  prediction   \n",
       "83996  det_no_bn_cifar10   test                15     96  prediction   \n",
       "83997  det_no_bn_cifar10   test                15     97  prediction   \n",
       "83998  det_no_bn_cifar10   test                15     98  prediction   \n",
       "83999  det_no_bn_cifar10   test                15     99  prediction   \n",
       "\n",
       "                     variable     value  \n",
       "0                       error  0.069933  \n",
       "1                       error  0.025917  \n",
       "2                       error  0.015900  \n",
       "3                       error  0.011050  \n",
       "4                       error  0.008233  \n",
       "...                       ...       ...  \n",
       "83995  error_bound_decoder_xe  0.973538  \n",
       "83996  error_bound_decoder_xe  0.974008  \n",
       "83997  error_bound_decoder_xe  0.973573  \n",
       "83998  error_bound_decoder_xe  0.973142  \n",
       "83999  error_bound_decoder_xe  0.973907  \n",
       "\n",
       "[84000 rows x 7 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dfm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "objective   dataset          \n",
       "decoder     cifar10              0.936715\n",
       "            det_cifar10          0.821622\n",
       "            det_no_bn_cifar10    0.742425\n",
       "            mnist                0.991457\n",
       "            pmnist               0.986266\n",
       "prediction  cifar10              0.936150\n",
       "            det_cifar10          0.803050\n",
       "            det_no_bn_cifar10    0.728347\n",
       "            mnist                0.991572\n",
       "            pmnist               0.986402\n",
       "Name: accuracy, dtype: float64"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = df\n",
    "data = data[data.source == \"test\"]\n",
    "data = data[data.epoch >= 95]\n",
    "data.groupby([\"dataset\",]).accuracy.mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.6 64-bit ('uib': conda)",
   "language": "python",
   "name": "python37664bituibconda101b67e03a19488ca1455e98b804290f"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": false,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
