{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8553b165",
   "metadata": {},
   "source": [
    "# figuring out the best stop params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ecfd207",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from pathlib import Path\n",
    "from umfavi.experiments.utils import load_experiment_data, compute_aggregated_metric\n",
    "import seaborn as sns\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6408560b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_grid_cliff = load_experiment_data(Path(\"../experiments/sweep_grid_cliff_18-01-2026\"))\n",
    "df_grid_sparse = load_experiment_data(Path(\"../experiments/sweep_grid_sparse_18-01-2026\"))\n",
    "df_grid_trap = load_experiment_data(Path(\"../experiments/sweep_grid_trap_18-01-2026\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "44cdd7de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge all dataframes\n",
    "df_full = pd.concat([df_grid_cliff, df_grid_sparse, df_grid_trap])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a5d4682f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['experiment_id', 'config_hash', 'seed', 'status', 'worker_id',\n",
       "       'started_at', 'completed_at', 'error_message', 'best_model_path',\n",
       "       'wandb_run_id', 'config.env_id', 'config.grid_size',\n",
       "       'config.vis_every_n_epochs', 'config.subsample_factor',\n",
       "       'config.step_offset', 'config.obs_transform', 'config.act_transform',\n",
       "       'config.reward_domain', 'config.kl_weight', 'config.td_error_weight',\n",
       "       'config.encoder_hidden_sizes', 'config.log_wandb', 'config.num_epochs',\n",
       "       'config.lr', 'config.batch_size', 'config.gamma',\n",
       "       'config.early_stop_metric', 'config.retrain_verbose',\n",
       "       'config.retrain_pbar', 'config.log_every_n_steps',\n",
       "       'config.val_every_n_epochs', 'config.early_stop_patience',\n",
       "       'config.use_importance_weights', 'config.n_pref_samples',\n",
       "       'config.n_pref_episodes', 'config.pref_seg_len',\n",
       "       'config.pref_trajectory_rationality', 'config.pref_rationality',\n",
       "       'config.n_rating_samples', 'config.n_rating_episodes',\n",
       "       'config.rating_seg_len', 'config.rating_trajectory_rationality',\n",
       "       'config.n_demo_samples', 'config.demo_rationality',\n",
       "       'config.n_stop_samples', 'config.n_stop_episodes',\n",
       "       'config.stop_seg_len', 'config.stop_regret_percentile',\n",
       "       'config.stop_regret_discount', 'config.stop_c',\n",
       "       'config.stop_trajectory_rationality', 'epochs', 'num_epochs_completed',\n",
       "       'series.spearman_demo', 'series.regret', 'series.spearman_pref',\n",
       "       'series.spearman_stop', 'series.spearman_rating',\n",
       "       'series.epic_distance', 'feedback_combo'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_full.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0e4170c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_full[\"min_regret\"] = df_grid_cliff[\"min_regret\"] = compute_aggregated_metric(df_grid_cliff, \"regret\", \"min\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "525f03b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_fed61_row0_col0 {\n",
       "  background-color: #21918c;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row0_col1 {\n",
       "  background-color: #2c718e;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row0_col2 {\n",
       "  background-color: #404688;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row1_col0 {\n",
       "  background-color: #481467;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row1_col1, #T_fed61_row1_col2 {\n",
       "  background-color: #440154;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row2_col0 {\n",
       "  background-color: #c2df23;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_fed61_row2_col1 {\n",
       "  background-color: #73d056;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_fed61_row2_col2 {\n",
       "  background-color: #25ac82;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row3_col0 {\n",
       "  background-color: #27ad81;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row3_col1 {\n",
       "  background-color: #3c4f8a;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row3_col2 {\n",
       "  background-color: #482979;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_fed61_row4_col0 {\n",
       "  background-color: #fde725;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_fed61_row4_col1, #T_fed61_row5_col0 {\n",
       "  background-color: #efe51c;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_fed61_row4_col2 {\n",
       "  background-color: #d2e21b;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_fed61_row5_col1 {\n",
       "  background-color: #aadc32;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_fed61_row5_col2 {\n",
       "  background-color: #60ca60;\n",
       "  color: #000000;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_fed61\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"index_name level0\" >config.stop_regret_percentile</th>\n",
       "      <th id=\"T_fed61_level0_col0\" class=\"col_heading level0 col0\" >50.000000</th>\n",
       "      <th id=\"T_fed61_level0_col1\" class=\"col_heading level0 col1\" >75.000000</th>\n",
       "      <th id=\"T_fed61_level0_col2\" class=\"col_heading level0 col2\" >90.000000</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >config.stop_c</th>\n",
       "      <th class=\"index_name level1\" >config.stop_regret_discount</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_fed61_level0_row0\" class=\"row_heading level0 row0\" rowspan=\"2\">0.500000</th>\n",
       "      <th id=\"T_fed61_level1_row0\" class=\"row_heading level1 row0\" >0.100000</th>\n",
       "      <td id=\"T_fed61_row0_col0\" class=\"data row0 col0\" >6.2448</td>\n",
       "      <td id=\"T_fed61_row0_col1\" class=\"data row0 col1\" >7.3858</td>\n",
       "      <td id=\"T_fed61_row0_col2\" class=\"data row0 col2\" >8.8727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_fed61_level1_row1\" class=\"row_heading level1 row1\" >0.500000</th>\n",
       "      <td id=\"T_fed61_row1_col0\" class=\"data row1 col0\" >10.2568</td>\n",
       "      <td id=\"T_fed61_row1_col1\" class=\"data row1 col1\" >10.7216</td>\n",
       "      <td id=\"T_fed61_row1_col2\" class=\"data row1 col2\" >10.7126</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_fed61_level0_row2\" class=\"row_heading level0 row2\" rowspan=\"2\">1.000000</th>\n",
       "      <th id=\"T_fed61_level1_row2\" class=\"row_heading level1 row2\" >0.100000</th>\n",
       "      <td id=\"T_fed61_row2_col0\" class=\"data row2 col0\" >2.6085</td>\n",
       "      <td id=\"T_fed61_row2_col1\" class=\"data row2 col1\" >3.7070</td>\n",
       "      <td id=\"T_fed61_row2_col2\" class=\"data row2 col2\" >5.2235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_fed61_level1_row3\" class=\"row_heading level1 row3\" >0.500000</th>\n",
       "      <td id=\"T_fed61_row3_col0\" class=\"data row3 col0\" >5.1777</td>\n",
       "      <td id=\"T_fed61_row3_col1\" class=\"data row3 col1\" >8.5733</td>\n",
       "      <td id=\"T_fed61_row3_col2\" class=\"data row3 col2\" >9.7032</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_fed61_level0_row4\" class=\"row_heading level0 row4\" rowspan=\"2\">2.000000</th>\n",
       "      <th id=\"T_fed61_level1_row4\" class=\"row_heading level1 row4\" >0.100000</th>\n",
       "      <td id=\"T_fed61_row4_col0\" class=\"data row4 col0\" >1.8010</td>\n",
       "      <td id=\"T_fed61_row4_col1\" class=\"data row4 col1\" >2.0390</td>\n",
       "      <td id=\"T_fed61_row4_col2\" class=\"data row4 col2\" >2.4020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_fed61_level1_row5\" class=\"row_heading level1 row5\" >0.500000</th>\n",
       "      <td id=\"T_fed61_row5_col0\" class=\"data row5 col0\" >2.0327</td>\n",
       "      <td id=\"T_fed61_row5_col1\" class=\"data row5 col1\" >2.9196</td>\n",
       "      <td id=\"T_fed61_row5_col2\" class=\"data row5 col2\" >3.9928</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x35d5f5490>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Aggregate over all other conditions\n",
    "agg_df = df_full.groupby([\n",
    "    'config.stop_c', \n",
    "    'config.stop_regret_discount',\n",
    "    'config.stop_regret_percentile'\n",
    "])['min_regret'].mean()\n",
    "\n",
    "# Unstack to get percentile as columns, keep (stop_c, discount) as MultiIndex rows\n",
    "table = agg_df.unstack(level='config.stop_regret_percentile')\n",
    "\n",
    "# Style with background gradient\n",
    "table.style.background_gradient(cmap='viridis_r', axis=None).format('{:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10615a90",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "umfavi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
