{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d42e27b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "from umfavi.experiments.utils import load_experiment_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cf376e58",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lander = load_experiment_data(Path(\"../experiments/sweep_lander_21012026_val-loss\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a2039a8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['experiment_id', 'config_hash', 'seed', 'status', 'worker_id',\n",
       "       'started_at', 'completed_at', 'error_message', 'best_model_path',\n",
       "       'best_epoch', 'wandb_run_id', 'final.regret', 'final.mean_rew',\n",
       "       'config.env_id', 'config.vis_every_n_epochs', 'config.subsample_factor',\n",
       "       'config.step_offset', 'config.obs_transform', 'config.act_transform',\n",
       "       'config.reward_domain', 'config.encoder_hidden_sizes',\n",
       "       'config.log_wandb', 'config.num_epochs', 'config.batch_size',\n",
       "       'config.gamma', 'config.retrain_verbose', 'config.retrain_pbar',\n",
       "       'config.log_every_n_steps', 'config.val_every_n_epochs',\n",
       "       'config.n_regret_samples', 'config.optimal_policy_path',\n",
       "       'config.use_importance_weights', 'config.lr',\n",
       "       'config.use_imitation_learning', 'config.td_error_weight',\n",
       "       'config.kl_weight', 'config.n_pref_samples', 'config.n_pref_episodes',\n",
       "       'config.pref_seg_len', 'config.pref_trajectory_rationality',\n",
       "       'config.pref_rationality', 'config.pref_policy_path',\n",
       "       'config.n_rating_samples', 'config.n_rating_episodes',\n",
       "       'config.rating_seg_len', 'config.rating_trajectory_rationality',\n",
       "       'config.rating_policy_path', 'config.n_demo_samples',\n",
       "       'config.min_reward_demo', 'epochs', 'num_epochs_completed',\n",
       "       'series.spearman_pref', 'series.spearman_rating', 'feedback_combo',\n",
       "       'config.demo_policy_path', 'series.spearman_demo'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_lander.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "64d9dd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lander.loc[df_lander[\"config.use_imitation_learning\"], \"feedback_combo\"] = \"imitation\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "21527c02",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lander = df_lander[df_lander[\"feedback_combo\"] != \"imitation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "036e8cf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_grouped = df_lander.groupby([\"config.kl_weight\", \"config.td_error_weight\", \"feedback_combo\"]).agg({\"final.mean_rew\": \"mean\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "a1e338e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>final.mean_rew</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>config.kl_weight</th>\n",
       "      <th>config.td_error_weight</th>\n",
       "      <th>feedback_combo</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"12\" valign=\"top\">0.5</th>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.5</th>\n",
       "      <th>demo+pref</th>\n",
       "      <td>-241.102244</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo+rating</th>\n",
       "      <td>-227.542529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo_only</th>\n",
       "      <td>-123.125147</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref+rating</th>\n",
       "      <td>-528.716200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref_only</th>\n",
       "      <td>-632.496851</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rating_only</th>\n",
       "      <td>-333.433511</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">1.0</th>\n",
       "      <th>demo+pref</th>\n",
       "      <td>-614.513789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo+rating</th>\n",
       "      <td>-624.046465</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo_only</th>\n",
       "      <td>-543.790767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref+rating</th>\n",
       "      <td>-476.261088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref_only</th>\n",
       "      <td>-484.395194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rating_only</th>\n",
       "      <td>-420.797296</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"12\" valign=\"top\">1.0</th>\n",
       "      <th rowspan=\"6\" valign=\"top\">0.5</th>\n",
       "      <th>demo+pref</th>\n",
       "      <td>-205.020079</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo+rating</th>\n",
       "      <td>-261.578355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo_only</th>\n",
       "      <td>-105.000588</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref+rating</th>\n",
       "      <td>-495.114421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref_only</th>\n",
       "      <td>-434.133365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rating_only</th>\n",
       "      <td>-324.491543</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">1.0</th>\n",
       "      <th>demo+pref</th>\n",
       "      <td>-173.789678</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo+rating</th>\n",
       "      <td>-198.842467</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>demo_only</th>\n",
       "      <td>-232.989421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref+rating</th>\n",
       "      <td>-457.032101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pref_only</th>\n",
       "      <td>-514.172508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rating_only</th>\n",
       "      <td>-360.243092</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                        final.mean_rew\n",
       "config.kl_weight config.td_error_weight feedback_combo                \n",
       "0.5              0.5                    demo+pref          -241.102244\n",
       "                                        demo+rating        -227.542529\n",
       "                                        demo_only          -123.125147\n",
       "                                        pref+rating        -528.716200\n",
       "                                        pref_only          -632.496851\n",
       "                                        rating_only        -333.433511\n",
       "                 1.0                    demo+pref          -614.513789\n",
       "                                        demo+rating        -624.046465\n",
       "                                        demo_only          -543.790767\n",
       "                                        pref+rating        -476.261088\n",
       "                                        pref_only          -484.395194\n",
       "                                        rating_only        -420.797296\n",
       "1.0              0.5                    demo+pref          -205.020079\n",
       "                                        demo+rating        -261.578355\n",
       "                                        demo_only          -105.000588\n",
       "                                        pref+rating        -495.114421\n",
       "                                        pref_only          -434.133365\n",
       "                                        rating_only        -324.491543\n",
       "                 1.0                    demo+pref          -173.789678\n",
       "                                        demo+rating        -198.842467\n",
       "                                        demo_only          -232.989421\n",
       "                                        pref+rating        -457.032101\n",
       "                                        pref_only          -514.172508\n",
       "                                        rating_only        -360.243092"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_grouped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "a2b9150d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_be406_row0_col0 {\n",
       "  background-color: #58c765;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row0_col1 {\n",
       "  background-color: #67cc5c;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row0_col2 {\n",
       "  background-color: #eae51a;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row0_col3 {\n",
       "  background-color: #414287;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row0_col4 {\n",
       "  background-color: #440154;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row0_col5 {\n",
       "  background-color: #1fa188;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row1_col0 {\n",
       "  background-color: #470d60;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row1_col1 {\n",
       "  background-color: #46075a;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row1_col2 {\n",
       "  background-color: #443a83;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row1_col3 {\n",
       "  background-color: #355e8d;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row1_col4 {\n",
       "  background-color: #375a8c;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row1_col5 {\n",
       "  background-color: #2a788e;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row2_col0 {\n",
       "  background-color: #81d34d;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row2_col1 {\n",
       "  background-color: #46c06f;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row2_col2 {\n",
       "  background-color: #fde725;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row2_col3 {\n",
       "  background-color: #3a548c;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row2_col4 {\n",
       "  background-color: #2c728e;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row2_col5 {\n",
       "  background-color: #20a486;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row3_col0 {\n",
       "  background-color: #a8db34;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row3_col1 {\n",
       "  background-color: #89d548;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row3_col2 {\n",
       "  background-color: #60ca60;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_be406_row3_col3 {\n",
       "  background-color: #31688e;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row3_col4 {\n",
       "  background-color: #3e4a89;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_be406_row3_col5 {\n",
       "  background-color: #1f948c;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_be406\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_be406_level0_col0\" class=\"col_heading level0 col0\" colspan=\"6\">final.mean_rew</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"index_name level1\" >feedback_combo</th>\n",
       "      <th id=\"T_be406_level1_col0\" class=\"col_heading level1 col0\" >demo+pref</th>\n",
       "      <th id=\"T_be406_level1_col1\" class=\"col_heading level1 col1\" >demo+rating</th>\n",
       "      <th id=\"T_be406_level1_col2\" class=\"col_heading level1 col2\" >demo_only</th>\n",
       "      <th id=\"T_be406_level1_col3\" class=\"col_heading level1 col3\" >pref+rating</th>\n",
       "      <th id=\"T_be406_level1_col4\" class=\"col_heading level1 col4\" >pref_only</th>\n",
       "      <th id=\"T_be406_level1_col5\" class=\"col_heading level1 col5\" >rating_only</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >config.kl_weight</th>\n",
       "      <th class=\"index_name level1\" >config.td_error_weight</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "      <th class=\"blank col4\" >&nbsp;</th>\n",
       "      <th class=\"blank col5\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_be406_level0_row0\" class=\"row_heading level0 row0\" rowspan=\"2\">0.500000</th>\n",
       "      <th id=\"T_be406_level1_row0\" class=\"row_heading level1 row0\" >0.500000</th>\n",
       "      <td id=\"T_be406_row0_col0\" class=\"data row0 col0\" >-241.1022</td>\n",
       "      <td id=\"T_be406_row0_col1\" class=\"data row0 col1\" >-227.5425</td>\n",
       "      <td id=\"T_be406_row0_col2\" class=\"data row0 col2\" >-123.1251</td>\n",
       "      <td id=\"T_be406_row0_col3\" class=\"data row0 col3\" >-528.7162</td>\n",
       "      <td id=\"T_be406_row0_col4\" class=\"data row0 col4\" >-632.4969</td>\n",
       "      <td id=\"T_be406_row0_col5\" class=\"data row0 col5\" >-333.4335</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_be406_level1_row1\" class=\"row_heading level1 row1\" >1.000000</th>\n",
       "      <td id=\"T_be406_row1_col0\" class=\"data row1 col0\" >-614.5138</td>\n",
       "      <td id=\"T_be406_row1_col1\" class=\"data row1 col1\" >-624.0465</td>\n",
       "      <td id=\"T_be406_row1_col2\" class=\"data row1 col2\" >-543.7908</td>\n",
       "      <td id=\"T_be406_row1_col3\" class=\"data row1 col3\" >-476.2611</td>\n",
       "      <td id=\"T_be406_row1_col4\" class=\"data row1 col4\" >-484.3952</td>\n",
       "      <td id=\"T_be406_row1_col5\" class=\"data row1 col5\" >-420.7973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_be406_level0_row2\" class=\"row_heading level0 row2\" rowspan=\"2\">1.000000</th>\n",
       "      <th id=\"T_be406_level1_row2\" class=\"row_heading level1 row2\" >0.500000</th>\n",
       "      <td id=\"T_be406_row2_col0\" class=\"data row2 col0\" >-205.0201</td>\n",
       "      <td id=\"T_be406_row2_col1\" class=\"data row2 col1\" >-261.5784</td>\n",
       "      <td id=\"T_be406_row2_col2\" class=\"data row2 col2\" >-105.0006</td>\n",
       "      <td id=\"T_be406_row2_col3\" class=\"data row2 col3\" >-495.1144</td>\n",
       "      <td id=\"T_be406_row2_col4\" class=\"data row2 col4\" >-434.1334</td>\n",
       "      <td id=\"T_be406_row2_col5\" class=\"data row2 col5\" >-324.4915</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_be406_level1_row3\" class=\"row_heading level1 row3\" >1.000000</th>\n",
       "      <td id=\"T_be406_row3_col0\" class=\"data row3 col0\" >-173.7897</td>\n",
       "      <td id=\"T_be406_row3_col1\" class=\"data row3 col1\" >-198.8425</td>\n",
       "      <td id=\"T_be406_row3_col2\" class=\"data row3 col2\" >-232.9894</td>\n",
       "      <td id=\"T_be406_row3_col3\" class=\"data row3 col3\" >-457.0321</td>\n",
       "      <td id=\"T_be406_row3_col4\" class=\"data row3 col4\" >-514.1725</td>\n",
       "      <td id=\"T_be406_row3_col5\" class=\"data row3 col5\" >-360.2431</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x12ae49510>"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Unstack to get feedback combo as columns\n",
    "table = df_grouped.unstack(level='feedback_combo')\n",
    "\n",
    "# Style with background gradient\n",
    "table.style.background_gradient(cmap='viridis', axis=None).format('{:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd9adf50",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "umfavi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
