{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appended /home/XXX/PycharmProjects/entropy_distance_loss/src to paths\n",
      "Switched to directory /home/XXX/PycharmProjects/entropy_distance_loss\n",
      "%load_ext autoreload\n",
      "%autoreload 2\n",
      "export_results_to_csv.ipynb  \u001b[0m\u001b[01;34mresults_no_zero_too_many_splits\u001b[0m/\n",
      "imagenette_surrogates.py     \u001b[01;34mresults_rebuttal\u001b[0m/\n",
      "\u001b[01;34mresults_combined_iq\u001b[0m/\n"
     ]
    }
   ],
   "source": [
    "import XXX.notebook\n",
    "\n",
    "import YYY\n",
    "\n",
    "from experiments.utils.jupyter import results_loader\n",
    "\n",
    "%ls {XXX.notebook.original_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "|                                 | Description                                  |\n",
    "|---------------------------------|----------------------------------------------|\n",
    "| `dropout`                       | Latest with dropout                          |\n",
    "| `w_dropout_weaker_lr`           | Dropout with a less strong LR scheduler      |\n",
    "| `no_dropout`                    | No dropout (but zero-entropy noise injected) |\n",
    "| `no_dropout_no_noise`           | Ditto                                        |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_sources = [\n",
    "    \"rebuttal\",\n",
    "    \"no_zero_too_many_splits\"]\n",
    "\n",
    "index = 0\n",
    "\n",
    "data_source = data_sources[index]\n",
    "\n",
    "input_suffix = data_source\n",
    "output_suffix = data_source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "store = results_loader.load_YYY_files(f'{XXX.notebook.original_dir}/results_{input_suffix}/')\n",
    "combined_iq_store = results_loader.load_YYY_files(f'{XXX.notebook.original_dir}/results_combined_iq/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(17, 17)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(store), len(combined_iq_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_store = results_loader.filter_dict(store, v=lambda result: result.timestamp >= 1583070930)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "17"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(filtered_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A snafu in the result storage\n",
    "Estimate = results_loader.get_any(filtered_store).log.train_eval_epochs[0].continuous_decoder_uncertainty.__class__\n",
    "\n",
    "# This unwraps an estimate.\n",
    "def get_value(value):\n",
    "    return value.mean if isinstance(value, Estimate) else value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_entropy_from_var(var, capacity=100):\n",
    "    return np.log(var)/2 + capacity*0.5*np.log(2*np.pi*np.e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for ri, result in enumerate(filtered_store.values()):\n",
    "    regularizer = result.experiment.regularizer\n",
    "    gamma = result.experiment.gamma\n",
    "\n",
    "    if gamma != 0:\n",
    "        dst_regularizers = [regularizer]\n",
    "    else:\n",
    "        # dst_regularizers = ['entropy_via_variance_Z', 'mean_squared_Z',\n",
    "        # 'entropy_via_variance_Z__Y']\n",
    "        dst_regularizers = ['mean_squared_Z']\n",
    "\n",
    "    # Hack to copy gamma = 0 into each regularizer.\n",
    "    for regularizer in dst_regularizers:\n",
    "        for source, epochs in ((\"train\", result.log.train_eval_epochs), (\"test\", result.log.test_epochs)):\n",
    "            for i, epoch in enumerate(epochs):\n",
    "                d = {}\n",
    "                d['epoch'] = i\n",
    "                d['regularizer'] = regularizer\n",
    "                d['experiment_index'] = ri\n",
    "                    \n",
    "                # It is \\gamma from the paper.\n",
    "                d['gamma'] = gamma\n",
    "                d['source'] = source\n",
    "                \n",
    "                d['error'] = 1 - epoch.accuracy\n",
    "                d['error_p'] = 1 - epoch.correct_prob\n",
    "                \n",
    "                d.update({key: get_value(value) for key, value in epoch._asdict().items()})\n",
    "                \n",
    "                # ['entropy_via_variance_Z', 'mean_squared_Z', 'entropy_via_variance_Z__Y']\n",
    "                #print(result.actual_name)\n",
    "                \n",
    "                if regularizer == 'mean_squared_Z':\n",
    "                    regularizer_value = d['mean_squared_Z']\n",
    "                elif regularizer == 'entropy_via_variance_Z':\n",
    "                    regularizer_value = d['mb_H_Z']\n",
    "                elif regularizer == 'entropy_via_variance_Z__Y':\n",
    "                    regularizer_value = d['mb_H_Z__Y']\n",
    "                d['regularizer_value'] = regularizer_value\n",
    "                \n",
    "                experiment_data.append(d)\n",
    "                \n",
    "            # Imput missing epochs from early out (thresholding if acc is < 20% after 20 epochs or so)\n",
    "            while i < 100:\n",
    "                i+=1\n",
    "\n",
    "                d = {}\n",
    "                d.update(experiment_data[-1])\n",
    "                d['epoch'] = i        \n",
    "                experiment_data.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "for ri, result in enumerate(combined_iq_store.values()):\n",
    "    regularizer = result.experiment.regularizer\n",
    "    gamma = result.experiment.gamma\n",
    "\n",
    "    if gamma != 0:\n",
    "        dst_regularizers = [regularizer]\n",
    "    else:\n",
    "        # dst_regularizers = ['entropy_via_variance_Z', 'mean_squared_Z',\n",
    "        # 'entropy_via_variance_Z__Y']\n",
    "        dst_regularizers = ['mean_squared_Z']\n",
    "\n",
    "    # Hack to copy gamma = 0 into each regularizer.\n",
    "    for regularizer in dst_regularizers:\n",
    "        source = \"combined\"\n",
    "        for i, combined_eval in result.log.combined_evals.items():\n",
    "            epoch = combined_eval.evaluation\n",
    "       \n",
    "            d = {}\n",
    "            d['epoch'] = i\n",
    "            d['regularizer'] = regularizer\n",
    "            d['experiment_index'] = ri\n",
    "\n",
    "            # It is \\gamma from the paper.\n",
    "            d['gamma'] = gamma\n",
    "            d['source'] = source\n",
    "\n",
    "            d['error'] = 1 - epoch.accuracy\n",
    "            d['error_p'] = 1 - epoch.correct_prob\n",
    "\n",
    "            d.update({key: get_value(value) for key, value in epoch._asdict().items()})\n",
    "\n",
    "            # ['entropy_via_variance_Z', 'mean_squared_Z', 'entropy_via_variance_Z__Y']\n",
    "            #print(result.actual_name)\n",
    "\n",
    "            if regularizer == 'mean_squared_Z':\n",
    "                regularizer_value = d['mean_squared_Z']\n",
    "            elif regularizer == 'entropy_via_variance_Z':\n",
    "                regularizer_value = d['mb_H_Z']\n",
    "            elif regularizer == 'entropy_via_variance_Z__Y':\n",
    "                regularizer_value = d['mb_H_Z__Y']\n",
    "            d['regularizer_value'] = regularizer_value\n",
    "\n",
    "            experiment_data.append(d)\n",
    "                \n",
    "        # Imput missing epochs from early out (thresholding if acc is < 20% after 20 epochs or so)\n",
    "        while i < 100:\n",
    "            i+=1\n",
    "\n",
    "            d = {}\n",
    "            d.update(experiment_data[-1])\n",
    "            d['epoch'] = i        \n",
    "            experiment_data.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_records(experiment_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.00000000e-06, 6.49381632e-02, 1.77827941e-01, 8.65964323e-03,\n",
       "       1.33352143e+00, 4.21696503e-04, 2.05352503e-05, 1.53992653e-04,\n",
       "       0.00000000e+00, 2.73841963e-06, 3.65174127e+00, 7.49894209e-06,\n",
       "       2.37137371e-02, 3.16227766e-03, 5.62341325e-05, 4.86967525e-01,\n",
       "       1.15478198e-03])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.gamma.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>gamma</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2416</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.000001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2718</th>\n",
       "      <td>0.000003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3322</th>\n",
       "      <td>0.000007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1812</th>\n",
       "      <td>0.000021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4228</th>\n",
       "      <td>0.000056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2114</th>\n",
       "      <td>0.000154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1510</th>\n",
       "      <td>0.000422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4832</th>\n",
       "      <td>0.001155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3926</th>\n",
       "      <td>0.003162</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>906</th>\n",
       "      <td>0.008660</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3624</th>\n",
       "      <td>0.023714</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>302</th>\n",
       "      <td>0.064938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>604</th>\n",
       "      <td>0.177828</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4530</th>\n",
       "      <td>0.486968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1208</th>\n",
       "      <td>1.333521</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3020</th>\n",
       "      <td>3.651741</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         gamma\n",
       "2416  0.000000\n",
       "0     0.000001\n",
       "2718  0.000003\n",
       "3322  0.000007\n",
       "1812  0.000021\n",
       "4228  0.000056\n",
       "2114  0.000154\n",
       "1510  0.000422\n",
       "4832  0.001155\n",
       "3926  0.003162\n",
       "906   0.008660\n",
       "3624  0.023714\n",
       "302   0.064938\n",
       "604   0.177828\n",
       "4530  0.486968\n",
       "1208  1.333521\n",
       "3020  3.651741"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also\n",
    "    display(df[df.regularizer==\"mean_squared_Z\"][[\"gamma\"]].drop_duplicates().sort_values([\"gamma\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['epoch', 'regularizer', 'experiment_index', 'gamma', 'source', 'error',\n",
       "       'error_p', 'epoch_duration', 'accuracy', 'correct_prob', 'xe_decoder',\n",
       "       'xe_prediction', 'mean_squared_Z', 'mb_H_Z', 'mb_global_H_Z__X',\n",
       "       'mb_global_H_Z__Y', 'mb_H_mean_Z__X', 'mb_H_Z__X', 'mb_H_Z__Y', 'loss',\n",
       "       'continuous_iq_base', 'continuous_decoder_uncertainty',\n",
       "       'continuous_encoding_entropy', 'continuous_preserved_information',\n",
       "       'continuous_H_Z__X', 'continuous_H_Z__Y', 'continuous_H_Z',\n",
       "       'regularizer_value'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>experiment_index</th>\n",
       "      <th>regularizer</th>\n",
       "      <th>gamma</th>\n",
       "      <th>epoch</th>\n",
       "      <th>variable</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>0</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.097683</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>1</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.427083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>2</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.548895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>3</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.534864</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>4</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.585247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>74489</th>\n",
       "      <td>combined</td>\n",
       "      <td>16</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>60</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>338.677424</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>74490</th>\n",
       "      <td>combined</td>\n",
       "      <td>16</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>70</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>351.236511</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>74491</th>\n",
       "      <td>combined</td>\n",
       "      <td>16</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>80</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>355.939681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>74492</th>\n",
       "      <td>combined</td>\n",
       "      <td>16</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>90</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>356.062206</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>74493</th>\n",
       "      <td>combined</td>\n",
       "      <td>16</td>\n",
       "      <td>mean_squared_Z</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>100</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>363.564311</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>74494 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         source  experiment_index     regularizer     gamma  epoch   variable  \\\n",
       "0         train                 0  mean_squared_Z  0.000001      0   accuracy   \n",
       "1         train                 0  mean_squared_Z  0.000001      1   accuracy   \n",
       "2         train                 0  mean_squared_Z  0.000001      2   accuracy   \n",
       "3         train                 0  mean_squared_Z  0.000001      3   accuracy   \n",
       "4         train                 0  mean_squared_Z  0.000001      4   accuracy   \n",
       "...         ...               ...             ...       ...    ...        ...   \n",
       "74489  combined                16  mean_squared_Z  0.000003     60  mb_H_Z__Y   \n",
       "74490  combined                16  mean_squared_Z  0.000003     70  mb_H_Z__Y   \n",
       "74491  combined                16  mean_squared_Z  0.000003     80  mb_H_Z__Y   \n",
       "74492  combined                16  mean_squared_Z  0.000003     90  mb_H_Z__Y   \n",
       "74493  combined                16  mean_squared_Z  0.000003    100  mb_H_Z__Y   \n",
       "\n",
       "            value  \n",
       "0        0.097683  \n",
       "1        0.427083  \n",
       "2        0.548895  \n",
       "3        0.534864  \n",
       "4        0.585247  \n",
       "...           ...  \n",
       "74489  338.677424  \n",
       "74490  351.236511  \n",
       "74491  355.939681  \n",
       "74492  356.062206  \n",
       "74493  363.564311  \n",
       "\n",
       "[74494 rows x 7 columns]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fields = ['accuracy', 'correct_prob', \"error\", \"error_p\", 'loss', 'xe_decoder', 'xe_prediction', 'regularizer_value', 'mean_squared_Z', 'continuous_decoder_uncertainty', 'continuous_H_Z', 'continuous_H_Z__Y', 'mb_H_Z', 'mb_H_Z__Y']\n",
    "\n",
    "dfm = df.melt(id_vars=['source', 'experiment_index', 'regularizer','gamma','epoch'], value_vars=fields)\n",
    "dfm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfm.to_csv(f\"./plots/imagenette_surrogates_{output_suffix}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.6 64-bit ('uib': conda)",
   "language": "python",
   "name": "python37664bituibconda101b67e03a19488ca1455e98b804290f"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
