{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appended /home/XXX/PycharmProjects/entropy_distance_loss/src to paths\n",
      "Switched to directory /home/XXX/PycharmProjects/entropy_distance_loss\n",
      "%load_ext autoreload\n",
      "%autoreload 2\n",
      "cifar10_no_dropout_surrogates_training.py  \u001b[0m\u001b[01;34mresults\u001b[0m/\n",
      "combined_iq_evaluation.py                  \u001b[01;34mresults_no_dropout\u001b[0m/\n",
      "evaluate_snapshot_robustness.py            \u001b[01;34mresults_no_dropout_15\u001b[0m/\n",
      "export_results_to_csv.ipynb                \u001b[01;34mresults_no_dropout_evaluations\u001b[0m/\n",
      "foolbox_spike.ipynb                        \u001b[01;34mresults_no_dropout_evaluations_15\u001b[0m/\n",
      "\u001b[01;34m__pycache__\u001b[0m/                               simple_iq_evaluation.py\n"
     ]
    }
   ],
   "source": [
    "import XXX.notebook\n",
    "\n",
    "import YYY\n",
    "\n",
    "from experiments.utils.jupyter import results_loader\n",
    "\n",
    "%ls {XXX.notebook.original_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "|                                 | Description                                  |\n",
    "|---------------------------------|----------------------------------------------|\n",
    "| `dropout`                       | Latest with dropout                          |\n",
    "| `w_dropout_weaker_lr`           | Dropout with a less strong LR scheduler      |\n",
    "| `no_dropout`                    | No dropout (but zero-entropy noise injected) |\n",
    "| `no_dropout_no_noise`           | Ditto                                        |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_sources = [\n",
    "    \"no_dropout\",\n",
    "]\n",
    "\n",
    "index = 0\n",
    "\n",
    "data_source = data_sources[index]\n",
    "\n",
    "input_suffix = data_source\n",
    "output_suffix = data_source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments = results_loader.load_YYY_files(f'{XXX.notebook.original_dir}/results_{input_suffix}/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluations = results_loader.load_YYY_files(f'{XXX.notebook.original_dir}/results_{input_suffix}_evaluations/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "82"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(experiments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "246"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(evaluations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "robustness_results = results_loader.filter_dict(evaluations, v=lambda result: \"robustness\" in result.experiment)\n",
    "simple_iq_results = results_loader.filter_dict(evaluations, v=lambda result: \"simple_iq\" in result.experiment)\n",
    "combined_iq_results = results_loader.filter_dict(evaluations, v=lambda result: \"combined_iq\" in result.experiment)\n",
    "\n",
    "robustness_results = results_loader.map_dict(robustness_results, kv=lambda key, result: (result.job_id, result))\n",
    "simple_iq_results = results_loader.map_dict(simple_iq_results, kv=lambda key, result: (result.job_id, result))\n",
    "combined_iq_results = results_loader.map_dict(combined_iq_results, kv=lambda key, result: (result.job_id, result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(82, 82, 82)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(robustness_results), len(simple_iq_results), len(combined_iq_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = results_loader.map_dict(experiments, kv=lambda key, result: (result.job_id, result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A snafu in the result storage\n",
    "Estimate = results_loader.get_any(simple_iq_results).log.epochs[0].training_set.continuous_decoder_uncertainty.__class__\n",
    "\n",
    "# This unwraps an estimate.\n",
    "def get_value(value):\n",
    "    return value.mean if isinstance(value, Estimate) else value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_entropy_from_var(var, capacity=100):\n",
    "    return np.log(var)/2 + capacity*0.5*np.log(2*np.pi*np.e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple IQ Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_iq_data = []\n",
    "\n",
    "for job_id, result in results.items():\n",
    "    regularizer = result.experiment.regularizer\n",
    "    gamma = result.experiment.gamma\n",
    "    inject_noise = result.experiment.inject_noise\n",
    "\n",
    "    if gamma != 0:\n",
    "        dst_regularizers = [regularizer]\n",
    "    else:\n",
    "        dst_regularizers = ['entropy_via_variance_Z', 'mean_squared_Z',\n",
    "        'entropy_via_variance_Z__Y', 'weight_decay'] # 'kraskov_Z__Y', 'kraskov_Z',\n",
    "\n",
    "    # Hack to copy gamma = 0 into each regularizer.\n",
    "    for regularizer in dst_regularizers: \n",
    "        if job_id in simple_iq_results:\n",
    "            for source in (\"train\", \"test\"):\n",
    "                for i, epoch_log in simple_iq_results[job_id].log.epochs.items():\n",
    "                    if source == \"train\":\n",
    "                        epoch = epoch_log.training_set\n",
    "                    elif source == \"test\":\n",
    "                        epoch = epoch_log.test_set\n",
    "                        \n",
    "                    d = {}\n",
    "                    d['epoch'] = i\n",
    "                    d['regularizer'] = regularizer\n",
    "                    d['inject_noise'] = inject_noise\n",
    "                    d['experiment_index'] = job_id\n",
    "\n",
    "                    # It is \\gamma from the paper.\n",
    "                    d['gamma'] = gamma\n",
    "                    d['source'] = source\n",
    "\n",
    "                    d['error'] = 1 - epoch.accuracy\n",
    "                    d['error_p'] = 1 - epoch.correct_prob\n",
    "\n",
    "                    d.update({key: get_value(value) for key, value in epoch._asdict().items()})\n",
    "\n",
    "                    # ['entropy_via_variance_Z', 'mean_squared_Z', 'entropy_via_variance_Z__Y']\n",
    "                    #print(result.actual_name)\n",
    "\n",
    "                    if regularizer == 'mean_squared_Z':\n",
    "                        regularizer_value = d['mean_squared_Z']\n",
    "                    elif regularizer == 'entropy_via_variance_Z':\n",
    "                        regularizer_value = d['mb_H_Z']\n",
    "                    elif regularizer == 'entropy_via_variance_Z__Y':\n",
    "                        regularizer_value = d['mb_H_Z__Y']\n",
    "                    elif regularizer == 'kraskov_Z':\n",
    "                        regularizer_value = d['continuous_H_Z']\n",
    "                    elif regularizer == 'kraskov_Z__Y':\n",
    "                        regularizer_value = d['continuous_H_Z__Y']\n",
    "                    elif regularizer == 'weight_decay':\n",
    "                        regularizer_value = 0.\n",
    "\n",
    "                    d['regularizer_value'] = regularizer_value\n",
    "\n",
    "                    experiment_iq_data.append(d)\n",
    "                \n",
    "                # Imput missing epochs from early out (thresholding if acc is < 20% after 20 epochs or so)\n",
    "                while i < 150:\n",
    "                    i+=5\n",
    "\n",
    "                    d = {}\n",
    "                    d.update(experiment_iq_data[-1])\n",
    "                    d['epoch'] = i        \n",
    "                    experiment_iq_data.append(d)\n",
    "            \n",
    "        if job_id in combined_iq_results:\n",
    "            source = \"combined\"\n",
    "            for i, epoch_log in combined_iq_results[job_id].log.epochs.items():\n",
    "                epoch = epoch_log.evaluation\n",
    "                \n",
    "                d = {}\n",
    "                d['epoch'] = i\n",
    "                d['regularizer'] = regularizer\n",
    "                d['inject_noise'] = inject_noise\n",
    "                d['experiment_index'] = job_id\n",
    "\n",
    "                # It is \\gamma from the paper.\n",
    "                d['gamma'] = gamma\n",
    "                d['source'] = source\n",
    "\n",
    "                d['error'] = 1 - epoch.accuracy\n",
    "                d['error_p'] = 1 - epoch.correct_prob\n",
    "\n",
    "                d.update({key: get_value(value) for key, value in epoch._asdict().items()})\n",
    "\n",
    "                # ['entropy_via_variance_Z', 'mean_squared_Z', 'entropy_via_variance_Z__Y']\n",
    "                #print(result.actual_name)\n",
    "\n",
    "                if regularizer == 'mean_squared_Z':\n",
    "                    regularizer_value = d['mean_squared_Z']\n",
    "                elif regularizer == 'entropy_via_variance_Z':\n",
    "                    regularizer_value = d['mb_H_Z']\n",
    "                elif regularizer == 'entropy_via_variance_Z__Y':\n",
    "                    regularizer_value = d['mb_H_Z__Y']\n",
    "                elif regularizer == 'kraskov_Z':\n",
    "                    regularizer_value = d['continuous_H_Z']\n",
    "                elif regularizer == 'kraskov_Z__Y':\n",
    "                    regularizer_value = d['continuous_H_Z__Y']\n",
    "                elif regularizer == 'weight_decay':\n",
    "                    regularizer_value = 0.\n",
    "\n",
    "                d['regularizer_value'] = regularizer_value\n",
    "\n",
    "                experiment_iq_data.append(d)\n",
    "            \n",
    "            # Imput missing epochs from early out (thresholding if acc is < 20% after 20 epochs or so)\n",
    "            while i < 150:\n",
    "                i+=5\n",
    "\n",
    "                d = {}\n",
    "                d.update(experiment_iq_data[-1])\n",
    "                d['epoch'] = i        \n",
    "                experiment_iq_data.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_records(experiment_iq_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.77827941e-01, 1.62377674e-01, 7.94328235e-02, 2.33572147e-03,\n",
       "       7.07945784e-03, 1.58489319e-02, 5.62341325e-05, 4.83293024e+00,\n",
       "       6.15848211e-02, 7.84759970e-04, 6.30957344e-04, 3.16227766e-03,\n",
       "       1.41253754e-03, 1.12201845e-05, 2.81838293e-04, 1.27427499e-03,\n",
       "       1.99526231e+00, 3.79269019e-04, 2.97635144e-01, 2.97635144e-02,\n",
       "       6.15848211e-05, 4.83293024e-02, 1.27427499e-01, 1.62377674e-03,\n",
       "       1.25892541e-04, 1.12883789e+00, 2.51188643e-05, 1.00000000e-05,\n",
       "       3.98107171e-01, 7.84759970e-03, 1.83298071e-04, 8.91250938e-01,\n",
       "       5.01187234e-06, 1.00000000e-06, 3.35981829e-05, 2.33572147e+00,\n",
       "       3.54813389e-02, 2.63665090e-02, 5.45559478e-01, 6.95192796e-04,\n",
       "       2.06913808e-05, 0.00000000e+00, 2.23872114e-06, 4.28133240e-03,\n",
       "       4.28133240e-05, 1.00000000e+01, 8.85866790e-02, 1.43844989e-02,\n",
       "       6.95192796e-03, 4.46683592e+00, 2.06913808e-04, 8.85866790e-05,\n",
       "       1.00000000e+00, 1.12883789e-04, 2.63665090e-01, 1.83298071e-05,\n",
       "       3.35981829e-03])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.gamma.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>gamma</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4557</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6789</th>\n",
       "      <td>0.000010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7626</th>\n",
       "      <td>0.000018</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3441</th>\n",
       "      <td>0.000034</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1860</th>\n",
       "      <td>0.000062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6882</th>\n",
       "      <td>0.000113</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5859</th>\n",
       "      <td>0.000207</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7533</th>\n",
       "      <td>0.000379</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4278</th>\n",
       "      <td>0.000695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1395</th>\n",
       "      <td>0.001274</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>279</th>\n",
       "      <td>0.002336</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5115</th>\n",
       "      <td>0.004281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2883</th>\n",
       "      <td>0.007848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5580</th>\n",
       "      <td>0.014384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3720</th>\n",
       "      <td>0.026367</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2046</th>\n",
       "      <td>0.048329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5394</th>\n",
       "      <td>0.088587</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>93</th>\n",
       "      <td>0.162378</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1674</th>\n",
       "      <td>0.297635</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3813</th>\n",
       "      <td>0.545559</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6696</th>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         gamma\n",
       "4557  0.000000\n",
       "6789  0.000010\n",
       "7626  0.000018\n",
       "3441  0.000034\n",
       "1860  0.000062\n",
       "6882  0.000113\n",
       "5859  0.000207\n",
       "7533  0.000379\n",
       "4278  0.000695\n",
       "1395  0.001274\n",
       "279   0.002336\n",
       "5115  0.004281\n",
       "2883  0.007848\n",
       "5580  0.014384\n",
       "3720  0.026367\n",
       "2046  0.048329\n",
       "5394  0.088587\n",
       "93    0.162378\n",
       "1674  0.297635\n",
       "3813  0.545559\n",
       "6696  1.000000"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also\n",
    "    display(df[df.regularizer==\"entropy_via_variance_Z\"][[\"gamma\"]].drop_duplicates().sort_values([\"gamma\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['epoch', 'regularizer', 'inject_noise', 'experiment_index', 'gamma',\n",
       "       'source', 'error', 'error_p', 'epoch_duration', 'accuracy',\n",
       "       'correct_prob', 'xe_decoder', 'xe_prediction', 'mean_squared_Z',\n",
       "       'mb_H_Z', 'mb_global_H_Z__X', 'mb_global_H_Z__Y', 'mb_H_mean_Z__X',\n",
       "       'mb_H_Z__X', 'mb_H_Z__Y', 'loss', 'continuous_iq_base',\n",
       "       'continuous_decoder_uncertainty', 'continuous_encoding_entropy',\n",
       "       'continuous_preserved_information', 'continuous_H_Z__X',\n",
       "       'continuous_H_Z__Y', 'continuous_H_Z', 'regularizer_value'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['mean_squared_Z', 'entropy_via_variance_Z', 'weight_decay',\n",
       "       'entropy_via_variance_Z__Y'], dtype=object)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.regularizer.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>inject_noise</th>\n",
       "      <th>source</th>\n",
       "      <th>experiment_index</th>\n",
       "      <th>regularizer</th>\n",
       "      <th>gamma</th>\n",
       "      <th>epoch</th>\n",
       "      <th>variable</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>True</td>\n",
       "      <td>train</td>\n",
       "      <td>15</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>0</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.104060</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>True</td>\n",
       "      <td>train</td>\n",
       "      <td>15</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>5</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.781140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>True</td>\n",
       "      <td>train</td>\n",
       "      <td>15</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>10</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.839440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>True</td>\n",
       "      <td>train</td>\n",
       "      <td>15</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>15</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.904460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>True</td>\n",
       "      <td>train</td>\n",
       "      <td>15</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>20</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.924840</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114571</th>\n",
       "      <td>True</td>\n",
       "      <td>combined</td>\n",
       "      <td>9</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>130</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>-7.039563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114572</th>\n",
       "      <td>True</td>\n",
       "      <td>combined</td>\n",
       "      <td>9</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>135</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>-7.221201</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114573</th>\n",
       "      <td>True</td>\n",
       "      <td>combined</td>\n",
       "      <td>9</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>140</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>-7.405535</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114574</th>\n",
       "      <td>True</td>\n",
       "      <td>combined</td>\n",
       "      <td>9</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>145</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>-7.333575</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114575</th>\n",
       "      <td>True</td>\n",
       "      <td>combined</td>\n",
       "      <td>9</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>150</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>-7.620086</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>114576 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        inject_noise    source  experiment_index     regularizer     gamma  \\\n",
       "0               True     train                15  mean_squared_Z  0.177828   \n",
       "1               True     train                15  mean_squared_Z  0.177828   \n",
       "2               True     train                15  mean_squared_Z  0.177828   \n",
       "3               True     train                15  mean_squared_Z  0.177828   \n",
       "4               True     train                15  mean_squared_Z  0.177828   \n",
       "...              ...       ...               ...             ...       ...   \n",
       "114571          True  combined                 9  mean_squared_Z  0.001413   \n",
       "114572          True  combined                 9  mean_squared_Z  0.001413   \n",
       "114573          True  combined                 9  mean_squared_Z  0.001413   \n",
       "114574          True  combined                 9  mean_squared_Z  0.001413   \n",
       "114575          True  combined                 9  mean_squared_Z  0.001413   \n",
       "\n",
       "        epoch   variable     value  \n",
       "0           0   accuracy  0.104060  \n",
       "1           5   accuracy  0.781140  \n",
       "2          10   accuracy  0.839440  \n",
       "3          15   accuracy  0.904460  \n",
       "4          20   accuracy  0.924840  \n",
       "...       ...        ...       ...  \n",
       "114571    130  mb_H_Z__Y -7.039563  \n",
       "114572    135  mb_H_Z__Y -7.221201  \n",
       "114573    140  mb_H_Z__Y -7.405535  \n",
       "114574    145  mb_H_Z__Y -7.333575  \n",
       "114575    150  mb_H_Z__Y -7.620086  \n",
       "\n",
       "[114576 rows x 8 columns]"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fields = ['accuracy', 'correct_prob', \"error\", \"error_p\", 'loss', 'xe_decoder', 'xe_prediction', 'regularizer_value', 'mean_squared_Z', 'continuous_decoder_uncertainty', 'continuous_H_Z', 'continuous_H_Z__Y', 'mb_H_Z', 'mb_H_Z__Y']\n",
    "\n",
    "dfm = df.melt(id_vars=['inject_noise', 'source', 'experiment_index', 'regularizer','gamma','epoch'], value_vars=fields)\n",
    "dfm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfm.to_csv(f\"./plots/iclr_surrogates_{output_suffix}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>experiment_index</th>\n",
       "      <th>gamma</th>\n",
       "      <th>epoch</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>114576.000000</td>\n",
       "      <td>114576.000000</td>\n",
       "      <td>114576.000000</td>\n",
       "      <td>114576.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>43.227273</td>\n",
       "      <td>0.428403</td>\n",
       "      <td>75.000000</td>\n",
       "      <td>63.924722</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>24.974517</td>\n",
       "      <td>1.371802</td>\n",
       "      <td>44.721555</td>\n",
       "      <td>203.904937</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-257.193276</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>21.750000</td>\n",
       "      <td>0.000031</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>0.121606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>43.500000</td>\n",
       "      <td>0.001518</td>\n",
       "      <td>75.000000</td>\n",
       "      <td>0.928117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>65.250000</td>\n",
       "      <td>0.081721</td>\n",
       "      <td>115.000000</td>\n",
       "      <td>54.825005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>81.000000</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>150.000000</td>\n",
       "      <td>2602.617552</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       experiment_index          gamma          epoch          value\n",
       "count     114576.000000  114576.000000  114576.000000  114576.000000\n",
       "mean          43.227273       0.428403      75.000000      63.924722\n",
       "std           24.974517       1.371802      44.721555     203.904937\n",
       "min            0.000000       0.000000       0.000000    -257.193276\n",
       "25%           21.750000       0.000031      35.000000       0.121606\n",
       "50%           43.500000       0.001518      75.000000       0.928117\n",
       "75%           65.250000       0.081721     115.000000      54.825005\n",
       "max           81.000000      10.000000     150.000000    2602.617552"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dfm.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Robustness results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_robustness_data = []\n",
    "\n",
    "for job_id, result in results.items():\n",
    "    regularizer = result.experiment.regularizer\n",
    "    gamma = result.experiment.gamma\n",
    "    inject_noise = result.experiment.inject_noise\n",
    "\n",
    "    if gamma != 0:\n",
    "        dst_regularizers = [regularizer]\n",
    "    else:\n",
    "        dst_regularizers = ['entropy_via_variance_Z', 'mean_squared_Z',\n",
    "        'entropy_via_variance_Z__Y', 'kraskov_Z__Y', 'kraskov_Z', 'weight_decay']\n",
    "        \n",
    "    epsilons = robustness_results[job_id].config.epsilons\n",
    "    \n",
    "    #print(job_id)\n",
    "\n",
    "    # Hack to copy gamma = 0 into each regularizer.\n",
    "    for regularizer in dst_regularizers:        \n",
    "        robustness_log = robustness_results[job_id].log.epochs[150]\n",
    "        \n",
    "        iqs = simple_iq_results[job_id].log.epochs[150].test_set\n",
    "        \n",
    "        preserved_info = iqs.continuous_preserved_information.mean\n",
    "        xe_decoder = iqs.xe_decoder\n",
    "        \n",
    "        config = dict(inject_noise=inject_noise, regularizer=regularizer, gamma=gamma, preserved_info=preserved_info, xe_decoder=xe_decoder)\n",
    "        \n",
    "        for attack, accuracies in robustness_log.attack_accs._asdict().items():\n",
    "            for epsilon, accuracy in zip(epsilons, accuracies):\n",
    "                row = dict(attack=attack, epsilon=epsilon, accuracy=accuracy)\n",
    "                row.update(config)\n",
    "                \n",
    "                experiment_robustness_data.append(row)\n",
    "                \n",
    "        for epsilon, accuracy in zip(epsilons, robustness_log.robust_accs):\n",
    "            row = dict(attack=\"Robustness\", epsilon=epsilon, accuracy=accuracy)\n",
    "            row.update(config)\n",
    "\n",
    "            experiment_robustness_data.append(row)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>attack</th>\n",
       "      <th>epsilon</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>inject_noise</th>\n",
       "      <th>regularizer</th>\n",
       "      <th>gamma</th>\n",
       "      <th>preserved_info</th>\n",
       "      <th>xe_decoder</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>FGSM</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>0.9233</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>60.213699</td>\n",
       "      <td>0.463161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>FGSM</td>\n",
       "      <td>0.0005</td>\n",
       "      <td>0.8823</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>60.213699</td>\n",
       "      <td>0.463161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>FGSM</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.8489</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>60.213699</td>\n",
       "      <td>0.463161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>FGSM</td>\n",
       "      <td>0.0015</td>\n",
       "      <td>0.8207</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>60.213699</td>\n",
       "      <td>0.463161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>FGSM</td>\n",
       "      <td>0.0020</td>\n",
       "      <td>0.8045</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.177828</td>\n",
       "      <td>60.213699</td>\n",
       "      <td>0.463161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7355</th>\n",
       "      <td>Robustness</td>\n",
       "      <td>0.1000</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>62.733806</td>\n",
       "      <td>0.408646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7356</th>\n",
       "      <td>Robustness</td>\n",
       "      <td>0.2000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>62.733806</td>\n",
       "      <td>0.408646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7357</th>\n",
       "      <td>Robustness</td>\n",
       "      <td>0.3000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>62.733806</td>\n",
       "      <td>0.408646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7358</th>\n",
       "      <td>Robustness</td>\n",
       "      <td>0.5000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>62.733806</td>\n",
       "      <td>0.408646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7359</th>\n",
       "      <td>Robustness</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>True</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.001413</td>\n",
       "      <td>62.733806</td>\n",
       "      <td>0.408646</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>7360 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          attack  epsilon  accuracy  inject_noise     regularizer     gamma  \\\n",
       "0           FGSM   0.0000    0.9233          True  mean_squared_Z  0.177828   \n",
       "1           FGSM   0.0005    0.8823          True  mean_squared_Z  0.177828   \n",
       "2           FGSM   0.0010    0.8489          True  mean_squared_Z  0.177828   \n",
       "3           FGSM   0.0015    0.8207          True  mean_squared_Z  0.177828   \n",
       "4           FGSM   0.0020    0.8045          True  mean_squared_Z  0.177828   \n",
       "...          ...      ...       ...           ...             ...       ...   \n",
       "7355  Robustness   0.1000    0.0001          True  mean_squared_Z  0.001413   \n",
       "7356  Robustness   0.2000    0.0000          True  mean_squared_Z  0.001413   \n",
       "7357  Robustness   0.3000    0.0000          True  mean_squared_Z  0.001413   \n",
       "7358  Robustness   0.5000    0.0000          True  mean_squared_Z  0.001413   \n",
       "7359  Robustness   1.0000    0.0000          True  mean_squared_Z  0.001413   \n",
       "\n",
       "      preserved_info  xe_decoder  \n",
       "0          60.213699    0.463161  \n",
       "1          60.213699    0.463161  \n",
       "2          60.213699    0.463161  \n",
       "3          60.213699    0.463161  \n",
       "4          60.213699    0.463161  \n",
       "...              ...         ...  \n",
       "7355       62.733806    0.408646  \n",
       "7356       62.733806    0.408646  \n",
       "7357       62.733806    0.408646  \n",
       "7358       62.733806    0.408646  \n",
       "7359       62.733806    0.408646  \n",
       "\n",
       "[7360 rows x 8 columns]"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame.from_dict(experiment_robustness_data)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.77827941e-01, 1.62377674e-01, 7.94328235e-02, 2.33572147e-03,\n",
       "       7.07945784e-03, 1.58489319e-02, 5.62341325e-05, 4.83293024e+00,\n",
       "       6.15848211e-02, 7.84759970e-04, 6.30957344e-04, 3.16227766e-03,\n",
       "       1.41253754e-03, 1.12201845e-05, 2.81838293e-04, 1.27427499e-03,\n",
       "       1.99526231e+00, 3.79269019e-04, 2.97635144e-01, 2.97635144e-02,\n",
       "       6.15848211e-05, 4.83293024e-02, 1.27427499e-01, 1.62377674e-03,\n",
       "       1.25892541e-04, 1.12883789e+00, 2.51188643e-05, 1.00000000e-05,\n",
       "       3.98107171e-01, 7.84759970e-03, 1.83298071e-04, 8.91250938e-01,\n",
       "       5.01187234e-06, 1.00000000e-06, 3.35981829e-05, 2.33572147e+00,\n",
       "       3.54813389e-02, 2.63665090e-02, 5.45559478e-01, 6.95192796e-04,\n",
       "       2.06913808e-05, 0.00000000e+00, 2.23872114e-06, 4.28133240e-03,\n",
       "       4.28133240e-05, 1.00000000e+01, 8.85866790e-02, 1.43844989e-02,\n",
       "       6.95192796e-03, 4.46683592e+00, 2.06913808e-04, 8.85866790e-05,\n",
       "       1.00000000e+00, 1.12883789e-04, 2.63665090e-01, 1.83298071e-05,\n",
       "       3.35981829e-03])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.gamma.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(f\"./plots/iclr_surrogates_{output_suffix}_robustness.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.6 64-bit ('uib': conda)",
   "language": "python",
   "name": "python37664bituibconda101b67e03a19488ca1455e98b804290f"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
