{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appended /home/XXX/PycharmProjects/entropy_distance_loss/src to paths\n",
      "Switched to directory /home/XXX/PycharmProjects/entropy_distance_loss\n",
      "%load_ext autoreload\n",
      "%autoreload 2\n",
      "cifar10_save_model.py                       \u001b[0m\u001b[01;34mresults_more_dropout\u001b[0m/\n",
      "cifar10_surrogates_batchsize.py             \u001b[01;34mresults_no_dropout\u001b[0m/\n",
      "cifar10_surrogates_more_dropout.py          \u001b[01;34mresults_no_dropout_batchsize_256\u001b[0m/\n",
      "cifar10_surrogates_no_dropout_batchsize.py  \u001b[01;34mresults_no_dropout_no_noise\u001b[0m/\n",
      "cifar10_surrogates_no_dropout_no_noise.py   \u001b[01;34mresults_no_noise\u001b[0m/\n",
      "cifar10_surrogates_no_dropout.py            \u001b[01;34mresults_save_model\u001b[0m/\n",
      "cifar10_surrogates_no_noise.py              \u001b[01;34mresults_strong_lr_scheduler\u001b[0m/\n",
      "cifar10_surrogates.py                       \u001b[01;34mresults_w_dropout_weaker_lr\u001b[0m/\n",
      "export_results_to_csv.ipynb                 visualization-gamma.ipynb\n",
      "\u001b[01;34mresults_batchsize_256\u001b[0m/                      visualization.ipynb\n",
      "\u001b[01;34mresults_dropout\u001b[0m/\n"
     ]
    }
   ],
   "source": [
    "import XXX.notebook\n",
    "\n",
    "import YYY\n",
    "\n",
    "from experiments.utils.jupyter import results_loader\n",
    "\n",
    "%ls {XXX.notebook.original_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "|                                 | Description                                  |\n",
    "|---------------------------------|----------------------------------------------|\n",
    "| `dropout`                       | Latest with dropout                          |\n",
    "| `w_dropout_weaker_lr`           | Dropout with a less strong LR scheduler      |\n",
    "| `no_dropout`                    | No dropout (but zero-entropy noise injected) |\n",
    "| `no_dropout_no_noise`           | Ditto                                        |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_sources = [\n",
    "    \"dropout\", \n",
    "    #\"w_dropout_weaker_lr\",\n",
    "    \"more_dropout\",\n",
    "    \"no_noise\",\n",
    "    \"no_dropout\",\n",
    "    \"no_dropout_no_noise\",\n",
    "    \"no_dropout_batchsize_256\"]\n",
    "\n",
    "index = 1\n",
    "\n",
    "data_source = data_sources[index]\n",
    "\n",
    "input_suffix = data_source\n",
    "output_suffix = data_source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "store = results_loader.load_YYY_files(f'{XXX.notebook.original_dir}/results_{input_suffix}/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "31"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_store = results_loader.filter_dict(store, v=lambda result: result.timestamp >= 1583070930)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "31"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(filtered_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A snafu in the result storage\n",
    "Estimate = results_loader.get_any(filtered_store).log.train_eval_epochs[0].continuous_decoder_uncertainty.__class__\n",
    "\n",
    "# This unwraps an estimate.\n",
    "def get_value(value):\n",
    "    return value.mean if isinstance(value, Estimate) else value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_entropy_from_var(var, capacity=100):\n",
    "    return np.log(var)/2 + capacity*0.5*np.log(2*np.pi*np.e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_data = []\n",
    "\n",
    "for ri, result in enumerate(filtered_store.values()):\n",
    "    regularizer = result.experiment.regularizer\n",
    "    gamma = result.experiment.gamma\n",
    "\n",
    "    if gamma != 0:\n",
    "        dst_regularizers = [regularizer]\n",
    "    else:\n",
    "        dst_regularizers = ['entropy_via_variance_Z', 'mean_squared_Z',\n",
    "        'entropy_via_variance_Z__Y']\n",
    "\n",
    "    # Hack to copy gamma = 0 into each regularizer.\n",
    "    for regularizer in dst_regularizers:\n",
    "        for source, epochs in ((\"train\", result.log.train_eval_epochs), (\"test\", result.log.test_epochs)):\n",
    "            for i, epoch in enumerate(epochs):\n",
    "                d = {}\n",
    "                d['epoch'] = i\n",
    "                d['regularizer'] = regularizer\n",
    "                d['experiment_index'] = ri\n",
    "                    \n",
    "                # It is \\gamma from the paper.\n",
    "                d['gamma'] = gamma\n",
    "                d['source'] = source\n",
    "                \n",
    "                d['error'] = 1 - epoch.accuracy\n",
    "                d['error_p'] = 1 - epoch.correct_prob\n",
    "                \n",
    "                d.update({key: get_value(value) for key, value in epoch._asdict().items()})\n",
    "                \n",
    "                # ['entropy_via_variance_Z', 'mean_squared_Z', 'entropy_via_variance_Z__Y']\n",
    "                #print(result.actual_name)\n",
    "                \n",
    "                if regularizer == 'mean_squared_Z':\n",
    "                    regularizer_value = d['mean_squared_Z']\n",
    "                elif regularizer == 'entropy_via_variance_Z':\n",
    "                    regularizer_value = d['mb_H_Z']\n",
    "                elif regularizer == 'entropy_via_variance_Z__Y':\n",
    "                    regularizer_value = d['mb_H_Z__Y']\n",
    "                d['regularizer_value'] = regularizer_value\n",
    "                \n",
    "                experiment_data.append(d)\n",
    "                \n",
    "        # Imput missing epochs from early out (thresholding if acc is < 20% after 20 epochs or so)\n",
    "        while i < 100:\n",
    "            i+=1\n",
    "            \n",
    "            d = {}\n",
    "            d.update(experiment_data[-1])\n",
    "            d['epoch'] = i        \n",
    "            experiment_data.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_records(experiment_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2.78255940e-01, 2.15443469e-02, 1.00000000e-03, 7.74263683e-02,\n",
       "       1.00000000e-06, 1.00000000e-05, 7.94328235e-02, 2.15443469e-04,\n",
       "       3.98107171e-01, 5.99484250e-03, 2.15443469e+00, 2.51188643e-05,\n",
       "       2.15443469e-02, 1.58489319e-02, 3.16227766e-03, 4.64158883e-04,\n",
       "       1.00000000e+00, 1.25892541e-04, 6.30957344e-04, 1.99526231e+00,\n",
       "       4.64158883e-01, 5.01187234e-06, 1.66810054e-03, 4.64158883e-03,\n",
       "       1.00000000e-01, 4.64158883e-05, 0.00000000e+00, 1.29154967e-04,\n",
       "       3.59381366e-05, 1.00000000e+01])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.gamma.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>gamma</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>8154</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1510</th>\n",
       "      <td>0.000010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9362</th>\n",
       "      <td>0.000036</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9060</th>\n",
       "      <td>0.000129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4530</th>\n",
       "      <td>0.000464</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6644</th>\n",
       "      <td>0.001668</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2718</th>\n",
       "      <td>0.005995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>302</th>\n",
       "      <td>0.021544</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>906</th>\n",
       "      <td>0.077426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.278256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4832</th>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         gamma\n",
       "8154  0.000000\n",
       "1510  0.000010\n",
       "9362  0.000036\n",
       "9060  0.000129\n",
       "4530  0.000464\n",
       "6644  0.001668\n",
       "2718  0.005995\n",
       "302   0.021544\n",
       "906   0.077426\n",
       "0     0.278256\n",
       "4832  1.000000"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also\n",
    "    display(df[df.regularizer==\"entropy_via_variance_Z\"][[\"gamma\"]].drop_duplicates().sort_values([\"gamma\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['epoch', 'regularizer', 'experiment_index', 'gamma', 'source', 'error',\n",
       "       'error_p', 'epoch_duration', 'accuracy', 'correct_prob', 'xe_decoder',\n",
       "       'xe_prediction', 'mean_squared_Z', 'mb_H_Z', 'mb_global_H_Z__X',\n",
       "       'mb_global_H_Z__Y', 'mb_H_mean_Z__X', 'mb_H_Z__X', 'mb_H_Z__Y', 'loss',\n",
       "       'continuous_iq_base', 'continuous_decoder_uncertainty',\n",
       "       'continuous_encoding_entropy', 'continuous_preserved_information',\n",
       "       'continuous_H_Z__X', 'continuous_H_Z__Y', 'continuous_H_Z',\n",
       "       'regularizer_value'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>experiment_index</th>\n",
       "      <th>regularizer</th>\n",
       "      <th>gamma</th>\n",
       "      <th>epoch</th>\n",
       "      <th>variable</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>entropy_via_variance_Z</td>\n",
       "      <td>0.278256</td>\n",
       "      <td>0</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.102120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>entropy_via_variance_Z</td>\n",
       "      <td>0.278256</td>\n",
       "      <td>1</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.097980</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>entropy_via_variance_Z</td>\n",
       "      <td>0.278256</td>\n",
       "      <td>2</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.101860</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>entropy_via_variance_Z</td>\n",
       "      <td>0.278256</td>\n",
       "      <td>3</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.100420</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>entropy_via_variance_Z</td>\n",
       "      <td>0.278256</td>\n",
       "      <td>4</td>\n",
       "      <td>accuracy</td>\n",
       "      <td>0.102000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139519</th>\n",
       "      <td>test</td>\n",
       "      <td>30</td>\n",
       "      <td>entropy_via_variance_Z__Y</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>146</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>4.600474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139520</th>\n",
       "      <td>test</td>\n",
       "      <td>30</td>\n",
       "      <td>entropy_via_variance_Z__Y</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>147</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>5.689423</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139521</th>\n",
       "      <td>test</td>\n",
       "      <td>30</td>\n",
       "      <td>entropy_via_variance_Z__Y</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>148</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>4.629138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139522</th>\n",
       "      <td>test</td>\n",
       "      <td>30</td>\n",
       "      <td>entropy_via_variance_Z__Y</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>149</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>5.039477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139523</th>\n",
       "      <td>test</td>\n",
       "      <td>30</td>\n",
       "      <td>entropy_via_variance_Z__Y</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>150</td>\n",
       "      <td>mb_H_Z__Y</td>\n",
       "      <td>4.504184</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>139524 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       source  experiment_index                regularizer      gamma  epoch  \\\n",
       "0       train                 0     entropy_via_variance_Z   0.278256      0   \n",
       "1       train                 0     entropy_via_variance_Z   0.278256      1   \n",
       "2       train                 0     entropy_via_variance_Z   0.278256      2   \n",
       "3       train                 0     entropy_via_variance_Z   0.278256      3   \n",
       "4       train                 0     entropy_via_variance_Z   0.278256      4   \n",
       "...       ...               ...                        ...        ...    ...   \n",
       "139519   test                30  entropy_via_variance_Z__Y  10.000000    146   \n",
       "139520   test                30  entropy_via_variance_Z__Y  10.000000    147   \n",
       "139521   test                30  entropy_via_variance_Z__Y  10.000000    148   \n",
       "139522   test                30  entropy_via_variance_Z__Y  10.000000    149   \n",
       "139523   test                30  entropy_via_variance_Z__Y  10.000000    150   \n",
       "\n",
       "         variable     value  \n",
       "0        accuracy  0.102120  \n",
       "1        accuracy  0.097980  \n",
       "2        accuracy  0.101860  \n",
       "3        accuracy  0.100420  \n",
       "4        accuracy  0.102000  \n",
       "...           ...       ...  \n",
       "139519  mb_H_Z__Y  4.600474  \n",
       "139520  mb_H_Z__Y  5.689423  \n",
       "139521  mb_H_Z__Y  4.629138  \n",
       "139522  mb_H_Z__Y  5.039477  \n",
       "139523  mb_H_Z__Y  4.504184  \n",
       "\n",
       "[139524 rows x 7 columns]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fields = ['accuracy', 'correct_prob', \"error\", \"error_p\", 'loss', 'xe_decoder', 'xe_prediction', 'regularizer_value', 'mean_squared_Z', 'continuous_decoder_uncertainty', 'continuous_H_Z', 'continuous_H_Z__Y', 'mb_H_Z', 'mb_H_Z__Y']\n",
    "\n",
    "dfm = df.melt(id_vars=['source', 'experiment_index', 'regularizer','gamma','epoch'], value_vars=fields)\n",
    "dfm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfm.to_csv(f\"./plots/surrogates_{output_suffix}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.6 64-bit ('uib': conda)",
   "language": "python",
   "name": "python37664bituibconda101b67e03a19488ca1455e98b804290f"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
