{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "69cdc639",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "from umfavi.experiments.utils import load_experiment_data, compute_aggregated_metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "a841fa0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_grid_cliff = load_experiment_data(Path(\"../experiments/sweep_grid_cliff_25012026/\"))\n",
    "df_grid_sparse = load_experiment_data(Path(\"../experiments/sweep_grid_sparse_25012026/\"))\n",
    "df_grid_trap = load_experiment_data(Path(\"../experiments/sweep_grid_trap_25012026/\"))\n",
    "df_acrobot = load_experiment_data(Path(\"../experiments/sweep_acrobot_23012026/\"))\n",
    "df_cartpole = load_experiment_data(Path(\"../experiments/sweep_cartpole_23012026/\"))\n",
    "df_lander = load_experiment_data(Path(\"../experiments/sweep_lander_23012026/\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "f4cf828c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge all dataframes\n",
    "df_grids = pd.concat([df_grid_cliff, df_grid_sparse, df_grid_trap])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "7ceee837",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_box2d = pd.concat([df_acrobot, df_cartpole, df_lander])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "024a1879",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if use_imitation_learning is true, restrict to 4 demonstrations\n",
    "df_box2d = df_box2d[(df_box2d[\"config.use_imitation_learning\"] == False) | (df_box2d[\"config.n_demo_samples\"] == 2)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "ec1c87eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([df_grids, df_box2d])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "56ebd2c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.loc[df[\"config.use_imitation_learning\"], \"feedback_combo\"] = \"imitation\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb36ddd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_transfer_config_multi(result_df, metric: str):\n",
    "    \"\"\"Generate grid.add_conditional code with multiple paths per feedback combo.\"\"\"\n",
    "    lines = []\n",
    "    for _, row in result_df.iterrows():\n",
    "        paths = row['best_model_path']  # now a list\n",
    "        metric_vals= row[metric]      # now a list\n",
    "        fb_combo = row['feedback_combo']\n",
    "        \n",
    "        lines.append(f'grid.add_conditional(\"fb_model_path\", [')\n",
    "        for path, metric_val in zip(paths, metric_vals):\n",
    "            lines.append(f\"    '{path}',  # {metric}: {metric_val}\")\n",
    "        lines.append(f'], condition=lambda c: c.get(\"feedback_combo\") == \"{fb_combo}\")')\n",
    "        lines.append('')\n",
    "    \n",
    "    return '\\n'.join(lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "497369f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_transfer_config_single(result_df, metric: str):\n",
    "    \"\"\"Generate grid.add_conditional code with only the best path per feedback combo.\"\"\"\n",
    "    lines = []\n",
    "    for _, row in result_df.iterrows():\n",
    "        path = row['best_model_path']\n",
    "        metric_val = row[metric]\n",
    "        fb_combo = row['feedback_combo']\n",
    "        \n",
    "        lines.append(f'grid.add_conditional(\"fb_model_path\", [\"{path}\"],  # {metric}: {metric_val}')\n",
    "        lines.append(f'    condition=lambda c: c.get(\"feedback_combo\") == \"{fb_combo}\")')\n",
    "        lines.append('')\n",
    "    \n",
    "    return '\\n'.join(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8cbe6f8",
   "metadata": {},
   "source": [
    "### LunarLander-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "485137d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lander = load_experiment_data(Path(\"../experiments/sweep_lander_23012026/\"))\n",
    "df_lander.loc[df_lander[\"config.use_imitation_learning\"], \"feedback_combo\"] = \"imitation\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "57a7ddc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lander = df_lander[~(df_lander[\"feedback_combo\"] == \"imitation\") | (df_lander[\"config.n_demo_samples\"] == 4)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "82035dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Calculate mean regret per (env_id, fb_combo, config_hash)\n",
    "config_means_lander = df_lander.groupby([\"config_hash\"]).agg(\n",
    "        {\n",
    "            \"final.mean_rew\": \"mean\",\n",
    "            \"feedback_combo\": \"first\",\n",
    "            \"config.env_id\": \"first\"\n",
    "        }\n",
    ")\n",
    "config_means_lander = config_means_lander[~config_means_lander[\"final.mean_rew\"].isna()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "259204ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Find the best config_hash (lowest mean regret) per (env_id, fb_combo)\n",
    "best_configs_lander = config_means_lander.loc[config_means_lander.groupby(['config.env_id', 'feedback_combo'])['final.mean_rew'].idxmax()].reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "bc1d5975",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 3: Filter original df to only rows from the best configs\n",
    "filtered_lander = df.merge(best_configs_lander[['config.env_id', 'feedback_combo', 'config_hash']], on=['config.env_id', 'feedback_combo', 'config_hash'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf6a92d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 4: Get the path of the best individual run (lowest regret) within each best config\n",
    "result_lander = filtered_lander.loc[filtered_lander.groupby(['config.env_id', 'feedback_combo'])['final.mean_rew'].idxmax()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f5d8edf",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_lander"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "041daecb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generate_transfer_config_single(result_lander, \"final.mean_rew\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0015af59",
   "metadata": {},
   "source": [
    "### grid_cliff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "3f6d842b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_grids = df[~(df[\"feedback_combo\"] == \"imitation\") | (df[\"config.n_demo_samples\"] == 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "3931bc9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Calculate mean regret per (env_id, fb_combo, config_hash)\n",
    "config_means_grid = df_grids.groupby([\"config_hash\"]).agg(\n",
    "        {\n",
    "            \"final.discounted_value\": \"mean\",\n",
    "            \"feedback_combo\": \"first\",\n",
    "            \"config.env_id\": \"first\"\n",
    "        }\n",
    ")\n",
    "config_means_grid = config_means_grid[~config_means_grid[\"final.discounted_value\"].isna()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "223ba688",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Find the best config_hash (lowest mean regret) per (env_id, fb_combo)\n",
    "best_configs = config_means_grid.loc[config_means_grid.groupby(['config.env_id', 'feedback_combo'])['final.discounted_value'].idxmax()].reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "4c30d3de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 3: Filter original df to only rows from the best configs\n",
    "filtered = df.merge(best_configs[['config.env_id', 'feedback_combo', 'config_hash']], on=['config.env_id', 'feedback_combo', 'config_hash'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "c6ca1517",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 4: Get the path of the best individual run (lowest regret) within each best config\n",
    "result_all = filtered.groupby(['config.env_id', 'feedback_combo']).agg({\n",
    "    'best_model_path': list,\n",
    "    'final.discounted_value': list\n",
    "}).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c3071b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generate_transfer_config_multi(result_all[result_all[\"config.env_id\"] == \"grid_trap\"], metric=\"final.discounted_value\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef6baa62",
   "metadata": {},
   "source": [
    "### grid_trap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fdb8fc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generate_transfer_config_multi(result_all[result_all[\"config.env_id\"] == \"grid_trap\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfcd9e07",
   "metadata": {},
   "source": [
    "## Acrobot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "5e988727",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_acrobot.loc[df_acrobot[\"config.use_imitation_learning\"], \"feedback_combo\"] = \"imitation\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d3f565",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Calculate mean regret per (env_id, fb_combo, config_hash)\n",
    "config_means_grid = df_acrobot.groupby([\"config_hash\"]).agg(\n",
    "        {\n",
    "            \"final.mean_rew\": \"mean\",\n",
    "            \"feedback_combo\": \"first\",\n",
    "            \"config.env_id\": \"first\"\n",
    "        }\n",
    ")\n",
    "# config_means_grid = config_means_grid[~config_means_grid[\"final.mean_rew\"].isna()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "8c1d4c35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['rating_only', 'demo+pref+rating+stop', 'pref+rating', 'demo+stop',\n",
       "       'rating+stop', 'demo+pref', 'pref+stop', 'demo+rating',\n",
       "       'demo_only', 'pref_only', 'stop_only', 'imitation'], dtype=object)"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config_means_grid[\"feedback_combo\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "14ead500",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Find the best config_hash (lowest mean regret) per (env_id, fb_combo)\n",
    "best_configs = config_means_grid.loc[config_means_grid.groupby(['config.env_id', 'feedback_combo'])['final.mean_rew'].idxmax()].reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "054a1b46",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "25847317",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 3: Filter original df to only rows from the best configs\n",
    "filtered = df.merge(best_configs[['config.env_id', 'feedback_combo', 'config_hash']], on=['config.env_id', 'feedback_combo', 'config_hash'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6322fb3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered[\"feedback_combo\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "f6e1b11e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 4: Get the path of the best individual run (lowest regret) within each best config\n",
    "result_acrobot = filtered.loc[filtered.groupby(['config.env_id', 'feedback_combo'])['final.mean_rew'].idxmax()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71329bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generate_transfer_config_single(result_acrobot, metric=\"final.mean_rew\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "211c887d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "umfavi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
