{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Use**: Identifies the top-k hyper-parameters that yield the highest Auto-Labeling coverage (Coverage-Mean), while ensuring that the coverage remains within the error threshold defined by Auto-Labeling error (eps). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import itertools\n",
    "sys.path.append('../')\n",
    "sys.path.append('../../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Options\n",
    "root_pfx = \"../../../outputs/\"\n",
    "file_identifier = \"mnist_lenet_calib_eval_hyp_2_std_xent_tbal_eval_hyp/tbal_train_time_search_mnist_lenet_calib_eval_hyp_2_std_xent_tbal_eval_hyp__01-25-2024__16-23-53\"\n",
    "eps_percent = 0.05 * 100\n",
    "\n",
    "\n",
    "# Read xlsx file\n",
    "df = pd.read_excel(f\"{root_pfx}{file_identifier}.xlsx\", sheet_name=0)\n",
    "df = df.drop(columns=['Unnamed: 0'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Filter records in dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retain all records with Auto-Labeling-Err-Mean <= eps\n",
    "df_1 = df.query(\"`Auto-Labeling-Err-Mean` <= @eps_percent\")\n",
    "\n",
    "# Sort by col: Coverage-Mean in descending order, and then by col: calib_conf in ascending order\n",
    "df_1 = df_1.sort_values([\"Coverage-Mean\", \"calib_conf\"], ascending = [False, True])\n",
    "\n",
    "# Change all calib_conf with NaN to \"None\"\n",
    "df_1['calib_conf'] = df['calib_conf'].fillna(\"None\")\n",
    "\n",
    "# Sanity checks\n",
    "assert df_1.query(\"`Auto-Labeling-Err-Mean` > @eps_percent\").empty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>calib_conf</th>\n",
       "      <th>training_conf</th>\n",
       "      <th>C_1</th>\n",
       "      <th>N_t</th>\n",
       "      <th>N_v</th>\n",
       "      <th>N_hyp_v</th>\n",
       "      <th>Auto-Labeling-Err-Mean</th>\n",
       "      <th>Coverage-Mean</th>\n",
       "      <th>Avg-ECE-Val-Mean</th>\n",
       "      <th>Auto-Labeling-Err-Std</th>\n",
       "      <th>...</th>\n",
       "      <th>max_epochs</th>\n",
       "      <th>method</th>\n",
       "      <th>momentum</th>\n",
       "      <th>num_runs</th>\n",
       "      <th>optimizer</th>\n",
       "      <th>query_batch_frac</th>\n",
       "      <th>rank_target</th>\n",
       "      <th>rank_weight</th>\n",
       "      <th>seed_frac</th>\n",
       "      <th>weight_decay</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.4312</td>\n",
       "      <td>56.44</td>\n",
       "      <td>6.0722</td>\n",
       "      <td>2.1461</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.5818</td>\n",
       "      <td>56.40</td>\n",
       "      <td>6.0660</td>\n",
       "      <td>3.0198</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.3675</td>\n",
       "      <td>55.76</td>\n",
       "      <td>4.4521</td>\n",
       "      <td>1.7586</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.2372</td>\n",
       "      <td>55.64</td>\n",
       "      <td>4.1498</td>\n",
       "      <td>1.6806</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.0315</td>\n",
       "      <td>54.56</td>\n",
       "      <td>3.8795</td>\n",
       "      <td>1.0393</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.8115</td>\n",
       "      <td>54.08</td>\n",
       "      <td>8.3615</td>\n",
       "      <td>1.6278</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.0654</td>\n",
       "      <td>54.04</td>\n",
       "      <td>5.0408</td>\n",
       "      <td>1.9636</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>3.5651</td>\n",
       "      <td>53.68</td>\n",
       "      <td>3.1165</td>\n",
       "      <td>1.3038</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>3.5188</td>\n",
       "      <td>53.28</td>\n",
       "      <td>5.7080</td>\n",
       "      <td>1.6277</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>3.5960</td>\n",
       "      <td>50.96</td>\n",
       "      <td>3.5022</td>\n",
       "      <td>2.1684</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.1968</td>\n",
       "      <td>48.20</td>\n",
       "      <td>8.4083</td>\n",
       "      <td>1.8236</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>3.6767</td>\n",
       "      <td>45.76</td>\n",
       "      <td>9.1505</td>\n",
       "      <td>1.0801</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.9116</td>\n",
       "      <td>32.52</td>\n",
       "      <td>15.3833</td>\n",
       "      <td>3.0198</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.3681</td>\n",
       "      <td>32.48</td>\n",
       "      <td>16.0310</td>\n",
       "      <td>2.2533</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>3.4268</td>\n",
       "      <td>17.84</td>\n",
       "      <td>10.7908</td>\n",
       "      <td>3.5767</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>3.2746</td>\n",
       "      <td>15.44</td>\n",
       "      <td>14.1374</td>\n",
       "      <td>1.1967</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.5591</td>\n",
       "      <td>15.32</td>\n",
       "      <td>14.4885</td>\n",
       "      <td>0.6448</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.3182</td>\n",
       "      <td>15.24</td>\n",
       "      <td>14.1636</td>\n",
       "      <td>0.8821</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.5801</td>\n",
       "      <td>15.24</td>\n",
       "      <td>13.0013</td>\n",
       "      <td>1.1101</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.9014</td>\n",
       "      <td>15.20</td>\n",
       "      <td>13.5733</td>\n",
       "      <td>0.4239</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.0842</td>\n",
       "      <td>14.76</td>\n",
       "      <td>13.0589</td>\n",
       "      <td>0.8703</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.7233</td>\n",
       "      <td>10.84</td>\n",
       "      <td>15.1670</td>\n",
       "      <td>2.8228</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.9</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.1370</td>\n",
       "      <td>10.68</td>\n",
       "      <td>15.0018</td>\n",
       "      <td>2.5850</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>2.3367</td>\n",
       "      <td>10.44</td>\n",
       "      <td>15.2674</td>\n",
       "      <td>1.9881</td>\n",
       "      <td>...</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>24 rows × 25 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   calib_conf training_conf   C_1  N_t  N_v  N_hyp_v  Auto-Labeling-Err-Mean  \\\n",
       "13       None           crl  0.25  500  500      500                  4.4312   \n",
       "22       None           crl  0.25  500  500      500                  4.5818   \n",
       "25       None           crl  0.25  500  500      500                  4.3675   \n",
       "3        None           crl  0.25  500  500      500                  4.2372   \n",
       "12       None           crl  0.25  500  500      500                  4.0315   \n",
       "15       None           crl  0.25  500  500      500                  4.8115   \n",
       "21       None           crl  0.25  500  500      500                  4.0654   \n",
       "7        None           crl  0.25  500  500      500                  3.5651   \n",
       "4        None           crl  0.25  500  500      500                  3.5188   \n",
       "16       None           crl  0.25  500  500      500                  3.5960   \n",
       "6        None           crl  0.25  500  500      500                  4.1968   \n",
       "24       None           crl  0.25  500  500      500                  3.6767   \n",
       "14       None           crl  0.25  500  500      500                  4.9116   \n",
       "5        None           crl  0.25  500  500      500                  4.3681   \n",
       "8        None           crl  0.25  500  500      500                  3.4268   \n",
       "18       None           crl  0.25  500  500      500                  3.2746   \n",
       "9        None           crl  0.25  500  500      500                  2.5591   \n",
       "0        None           crl  0.25  500  500      500                  2.3182   \n",
       "10       None           crl  0.25  500  500      500                  2.5801   \n",
       "19       None           crl  0.25  500  500      500                  2.9014   \n",
       "1        None           crl  0.25  500  500      500                  2.0842   \n",
       "20       None           crl  0.25  500  500      500                  2.7233   \n",
       "11       None           crl  0.25  500  500      500                  2.1370   \n",
       "2        None           crl  0.25  500  500      500                  2.3367   \n",
       "\n",
       "    Coverage-Mean  Avg-ECE-Val-Mean  Auto-Labeling-Err-Std  ...  max_epochs  \\\n",
       "13          56.44            6.0722                 2.1461  ...          50   \n",
       "22          56.40            6.0660                 3.0198  ...          50   \n",
       "25          55.76            4.4521                 1.7586  ...          50   \n",
       "3           55.64            4.1498                 1.6806  ...          50   \n",
       "12          54.56            3.8795                 1.0393  ...          50   \n",
       "15          54.08            8.3615                 1.6278  ...          50   \n",
       "21          54.04            5.0408                 1.9636  ...          50   \n",
       "7           53.68            3.1165                 1.3038  ...          50   \n",
       "4           53.28            5.7080                 1.6277  ...          50   \n",
       "16          50.96            3.5022                 2.1684  ...          50   \n",
       "6           48.20            8.4083                 1.8236  ...          50   \n",
       "24          45.76            9.1505                 1.0801  ...          50   \n",
       "14          32.52           15.3833                 3.0198  ...          50   \n",
       "5           32.48           16.0310                 2.2533  ...          50   \n",
       "8           17.84           10.7908                 3.5767  ...          50   \n",
       "18          15.44           14.1374                 1.1967  ...          50   \n",
       "9           15.32           14.4885                 0.6448  ...          50   \n",
       "0           15.24           14.1636                 0.8821  ...          50   \n",
       "10          15.24           13.0013                 1.1101  ...          50   \n",
       "19          15.20           13.5733                 0.4239  ...          50   \n",
       "1           14.76           13.0589                 0.8703  ...          50   \n",
       "20          10.84           15.1670                 2.8228  ...          50   \n",
       "11          10.68           15.0018                 2.5850  ...          50   \n",
       "2           10.44           15.2674                 1.9881  ...          50   \n",
       "\n",
       "    method  momentum  num_runs  optimizer  query_batch_frac rank_target  \\\n",
       "13    tbal       0.9         5        sgd              0.04     softmax   \n",
       "22    tbal       0.9         5        sgd              0.04     softmax   \n",
       "25    tbal       0.9         5        sgd              0.04     softmax   \n",
       "3     tbal       0.9         5        sgd              0.04     softmax   \n",
       "12    tbal       0.9         5        sgd              0.04     softmax   \n",
       "15    tbal       0.9         5        sgd              0.04     softmax   \n",
       "21    tbal       0.9         5        sgd              0.04     softmax   \n",
       "7     tbal       0.9         5        sgd              0.04     softmax   \n",
       "4     tbal       0.9         5        sgd              0.04     softmax   \n",
       "16    tbal       0.9         5        sgd              0.04     softmax   \n",
       "6     tbal       0.9         5        sgd              0.04     softmax   \n",
       "24    tbal       0.9         5        sgd              0.04     softmax   \n",
       "14    tbal       0.9         5        sgd              0.04     softmax   \n",
       "5     tbal       0.9         5        sgd              0.04     softmax   \n",
       "8     tbal       0.9         5        sgd              0.04     softmax   \n",
       "18    tbal       0.9         5        sgd              0.04     softmax   \n",
       "9     tbal       0.9         5        sgd              0.04     softmax   \n",
       "0     tbal       0.9         5        sgd              0.04     softmax   \n",
       "10    tbal       0.9         5        sgd              0.04     softmax   \n",
       "19    tbal       0.9         5        sgd              0.04     softmax   \n",
       "1     tbal       0.9         5        sgd              0.04     softmax   \n",
       "20    tbal       0.9         5        sgd              0.04     softmax   \n",
       "11    tbal       0.9         5        sgd              0.04     softmax   \n",
       "2     tbal       0.9         5        sgd              0.04     softmax   \n",
       "\n",
       "    rank_weight  seed_frac weight_decay  \n",
       "13          0.8        0.4        0.010  \n",
       "22          0.9        0.4        0.010  \n",
       "25          0.9        0.4        0.010  \n",
       "3           0.7        0.4        0.001  \n",
       "12          0.8        0.4        0.001  \n",
       "15          0.8        0.4        0.001  \n",
       "21          0.9        0.4        0.001  \n",
       "7           0.7        0.4        0.010  \n",
       "4           0.7        0.4        0.010  \n",
       "16          0.8        0.4        0.010  \n",
       "6           0.7        0.4        0.001  \n",
       "24          0.9        0.4        0.001  \n",
       "14          0.8        0.4        0.100  \n",
       "5           0.7        0.4        0.100  \n",
       "8           0.7        0.4        0.100  \n",
       "18          0.9        0.4        0.001  \n",
       "9           0.8        0.4        0.001  \n",
       "0           0.7        0.4        0.001  \n",
       "10          0.8        0.4        0.010  \n",
       "19          0.9        0.4        0.010  \n",
       "1           0.7        0.4        0.010  \n",
       "20          0.9        0.4        0.100  \n",
       "11          0.8        0.4        0.100  \n",
       "2           0.7        0.4        0.100  \n",
       "\n",
       "[24 rows x 25 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set display option to show all columns\n",
    "pd.set_option('display.max_columns', None)\n",
    "\n",
    "train_time_methods = ['std_cross_entropy', 'crl', 'fmfp', 'squentropy']\n",
    "\n",
    "# Find best training-time hyper-parameters\n",
    "def print_top_k_train_time_methods(methods, methods_col_name, top_k):\n",
    "    for method in methods:\n",
    "        df_filtered = df_1[df_1[f'{methods_col_name}'] == method]\n",
    "        df_filtered = df_filtered[df_filtered['calib_conf'] == 'None']\n",
    "        if top_k == 1:\n",
    "            # Retain ties for top-1\n",
    "            max_value = df_filtered['Coverage-Mean'].max()\n",
    "            df_2 = df_filtered[df_filtered['Coverage-Mean'] == max_value]\n",
    "        else:\n",
    "            df_2 = df_filtered.head(top_k)\n",
    "        display(df_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>calib_conf</th>\n",
       "      <th>training_conf</th>\n",
       "      <th>C_1</th>\n",
       "      <th>N_t</th>\n",
       "      <th>N_v</th>\n",
       "      <th>N_hyp_v</th>\n",
       "      <th>Auto-Labeling-Err-Mean</th>\n",
       "      <th>Coverage-Mean</th>\n",
       "      <th>Avg-ECE-Val-Mean</th>\n",
       "      <th>Auto-Labeling-Err-Std</th>\n",
       "      <th>Coverage-Std</th>\n",
       "      <th>Avg-ECE-Val-Std</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>eps</th>\n",
       "      <th>learning_rate</th>\n",
       "      <th>max_epochs</th>\n",
       "      <th>method</th>\n",
       "      <th>momentum</th>\n",
       "      <th>num_runs</th>\n",
       "      <th>optimizer</th>\n",
       "      <th>query_batch_frac</th>\n",
       "      <th>rank_target</th>\n",
       "      <th>rank_weight</th>\n",
       "      <th>seed_frac</th>\n",
       "      <th>weight_decay</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: [calib_conf, training_conf, C_1, N_t, N_v, N_hyp_v, Auto-Labeling-Err-Mean, Coverage-Mean, Avg-ECE-Val-Mean, Auto-Labeling-Err-Std, Coverage-Std, Avg-ECE-Val-Std, batch_size, eps, learning_rate, max_epochs, method, momentum, num_runs, optimizer, query_batch_frac, rank_target, rank_weight, seed_frac, weight_decay]\n",
       "Index: []"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>calib_conf</th>\n",
       "      <th>training_conf</th>\n",
       "      <th>C_1</th>\n",
       "      <th>N_t</th>\n",
       "      <th>N_v</th>\n",
       "      <th>N_hyp_v</th>\n",
       "      <th>Auto-Labeling-Err-Mean</th>\n",
       "      <th>Coverage-Mean</th>\n",
       "      <th>Avg-ECE-Val-Mean</th>\n",
       "      <th>Auto-Labeling-Err-Std</th>\n",
       "      <th>Coverage-Std</th>\n",
       "      <th>Avg-ECE-Val-Std</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>eps</th>\n",
       "      <th>learning_rate</th>\n",
       "      <th>max_epochs</th>\n",
       "      <th>method</th>\n",
       "      <th>momentum</th>\n",
       "      <th>num_runs</th>\n",
       "      <th>optimizer</th>\n",
       "      <th>query_batch_frac</th>\n",
       "      <th>rank_target</th>\n",
       "      <th>rank_weight</th>\n",
       "      <th>seed_frac</th>\n",
       "      <th>weight_decay</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>None</td>\n",
       "      <td>crl</td>\n",
       "      <td>0.25</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>500</td>\n",
       "      <td>4.4312</td>\n",
       "      <td>56.44</td>\n",
       "      <td>6.0722</td>\n",
       "      <td>2.1461</td>\n",
       "      <td>5.518</td>\n",
       "      <td>0.3404</td>\n",
       "      <td>32</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.01</td>\n",
       "      <td>50</td>\n",
       "      <td>tbal</td>\n",
       "      <td>0.9</td>\n",
       "      <td>5</td>\n",
       "      <td>sgd</td>\n",
       "      <td>0.04</td>\n",
       "      <td>softmax</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   calib_conf training_conf   C_1  N_t  N_v  N_hyp_v  Auto-Labeling-Err-Mean  \\\n",
       "13       None           crl  0.25  500  500      500                  4.4312   \n",
       "\n",
       "    Coverage-Mean  Avg-ECE-Val-Mean  Auto-Labeling-Err-Std  Coverage-Std  \\\n",
       "13          56.44            6.0722                 2.1461         5.518   \n",
       "\n",
       "    Avg-ECE-Val-Std  batch_size   eps  learning_rate  max_epochs method  \\\n",
       "13           0.3404          32  0.05           0.01          50   tbal   \n",
       "\n",
       "    momentum  num_runs optimizer  query_batch_frac rank_target  rank_weight  \\\n",
       "13       0.9         5       sgd              0.04     softmax          0.8   \n",
       "\n",
       "    seed_frac  weight_decay  \n",
       "13        0.4          0.01  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>calib_conf</th>\n",
       "      <th>training_conf</th>\n",
       "      <th>C_1</th>\n",
       "      <th>N_t</th>\n",
       "      <th>N_v</th>\n",
       "      <th>N_hyp_v</th>\n",
       "      <th>Auto-Labeling-Err-Mean</th>\n",
       "      <th>Coverage-Mean</th>\n",
       "      <th>Avg-ECE-Val-Mean</th>\n",
       "      <th>Auto-Labeling-Err-Std</th>\n",
       "      <th>Coverage-Std</th>\n",
       "      <th>Avg-ECE-Val-Std</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>eps</th>\n",
       "      <th>learning_rate</th>\n",
       "      <th>max_epochs</th>\n",
       "      <th>method</th>\n",
       "      <th>momentum</th>\n",
       "      <th>num_runs</th>\n",
       "      <th>optimizer</th>\n",
       "      <th>query_batch_frac</th>\n",
       "      <th>rank_target</th>\n",
       "      <th>rank_weight</th>\n",
       "      <th>seed_frac</th>\n",
       "      <th>weight_decay</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: [calib_conf, training_conf, C_1, N_t, N_v, N_hyp_v, Auto-Labeling-Err-Mean, Coverage-Mean, Avg-ECE-Val-Mean, Auto-Labeling-Err-Std, Coverage-Std, Avg-ECE-Val-Std, batch_size, eps, learning_rate, max_epochs, method, momentum, num_runs, optimizer, query_batch_frac, rank_target, rank_weight, seed_frac, weight_decay]\n",
       "Index: []"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>calib_conf</th>\n",
       "      <th>training_conf</th>\n",
       "      <th>C_1</th>\n",
       "      <th>N_t</th>\n",
       "      <th>N_v</th>\n",
       "      <th>N_hyp_v</th>\n",
       "      <th>Auto-Labeling-Err-Mean</th>\n",
       "      <th>Coverage-Mean</th>\n",
       "      <th>Avg-ECE-Val-Mean</th>\n",
       "      <th>Auto-Labeling-Err-Std</th>\n",
       "      <th>Coverage-Std</th>\n",
       "      <th>Avg-ECE-Val-Std</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>eps</th>\n",
       "      <th>learning_rate</th>\n",
       "      <th>max_epochs</th>\n",
       "      <th>method</th>\n",
       "      <th>momentum</th>\n",
       "      <th>num_runs</th>\n",
       "      <th>optimizer</th>\n",
       "      <th>query_batch_frac</th>\n",
       "      <th>rank_target</th>\n",
       "      <th>rank_weight</th>\n",
       "      <th>seed_frac</th>\n",
       "      <th>weight_decay</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: [calib_conf, training_conf, C_1, N_t, N_v, N_hyp_v, Auto-Labeling-Err-Mean, Coverage-Mean, Avg-ECE-Val-Mean, Auto-Labeling-Err-Std, Coverage-Std, Avg-ECE-Val-Std, batch_size, eps, learning_rate, max_epochs, method, momentum, num_runs, optimizer, query_batch_frac, rank_target, rank_weight, seed_frac, weight_decay]\n",
       "Index: []"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# top k occurences for each method\n",
    "top_k = 1 \n",
    "\n",
    "print_top_k_train_time_methods(train_time_methods, \"training_conf\", top_k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-kernel-tbal",
   "language": "python",
   "name": "ml-kernel-tbal"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
