{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6aeb2d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %load_ext autoreload\n",
    "# %autoreload 2\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "from umfavi.experiments.utils import load_experiment_data, compute_aggregated_metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d200e052",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_acrobot = load_experiment_data(Path(\"../../acrobot-metrics\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "df9ceeec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge all dataframes\n",
    "df_full = pd.concat([df_acrobot])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f449b97",
   "metadata": {},
   "outputs": [],
   "source": [
    "is_imitation = df_full[\"config.use_imitation_learning\"] == True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f28699",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_full[\"feedback_combo\"][is_imitation] = \"imitation\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "622041cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['imitation', 'demo_only', 'rating_only', 'demo+rating',\n",
       "       'pref_only', 'demo+pref', 'pref+rating', 'demo+pref+rating'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_full[\"feedback_combo\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "391816e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "is_grid = df_full[\"config.env_id\"].str.contains(\"grid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06c0b289",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_full"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c5f72b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply filters\n",
    "# df = df_full[\n",
    "#     is_imitation | (\n",
    "#         # Grid environments\n",
    "#         (\n",
    "#             (is_grid & df_full[\"config.n_pref_samples\"].isin([0, 256])) &\n",
    "#             (is_grid & df_full[\"config.n_rating_samples\"].isin([0, 256])) &\n",
    "#             (is_grid & df_full[\"config.n_demo_samples\"].isin([0, 2]))\n",
    "#         ) | \n",
    "#         # Box2D environments  \n",
    "#         (\n",
    "#             (~is_grid & df_full[\"config.n_pref_samples\"].isin([0, 512])) &\n",
    "#             (~is_grid & df_full[\"config.n_rating_samples\"].isin([0, 512])) &\n",
    "#             (~is_grid & df_full[\"config.n_demo_samples\"].isin([0, 4]))\n",
    "#         )\n",
    "#     )\n",
    "# ].copy()\n",
    "#df = df_full.copy()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82413c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df_full.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f373d3a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([32,  4,  8, 16,  0])"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"config.n_demo_samples\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2df9b564",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"min_regret\"] = compute_aggregated_metric(df, \"regret\", \"min\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a00d6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"max_return\"] = compute_aggregated_metric(df, \"mean_rew\", \"max\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b5987a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"config.env_id\"].unique()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea7e4f8b",
   "metadata": {},
   "source": [
    "## Get best model paths for Acrobot-v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "632bdbc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Calculate mean return per (env_id, fb_combo, config_hash)\n",
    "config_means = df.groupby(['config.env_id', 'feedback_combo', 'config_hash'])['max_return'].mean().reset_index(name='mean_return')\n",
    "\n",
    "# Step 2: Find the best config_hash (highest mean return) per (env_id, fb_combo)\n",
    "best_configs = config_means.loc[config_means.groupby(['config.env_id', 'feedback_combo'])['mean_return'].idxmax()]\n",
    "\n",
    "# Step 3: Filter original df to only rows from the best configs\n",
    "filtered = df.merge(best_configs[['config.env_id', 'feedback_combo', 'config_hash']], on=['config.env_id', 'feedback_combo', 'config_hash'])\n",
    "\n",
    "# Step 4: Get the path of the best individual run (lowest regret) within each best config\n",
    "result_all = filtered.groupby(['config.env_id', 'feedback_combo']).agg({\n",
    "    'best_model_path': list,\n",
    "    'max_return': list\n",
    "}).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb36ddd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_transfer_config_multi(result_df):\n",
    "    \"\"\"Generate grid.add_conditional code with multiple paths per feedback combo.\"\"\"\n",
    "    lines = []\n",
    "    for _, row in result_df.iterrows():\n",
    "        paths = row['best_model_path']  # now a list\n",
    "        returns = row['max_return']      # now a list\n",
    "        fb_combo = row['feedback_combo']\n",
    "\n",
    "        # sort paths and returns by returns descending\n",
    "        sorted_pairs = sorted(zip(paths, returns), key=lambda x: x[1], reverse=True)\n",
    "        paths, returns = zip(*sorted_pairs)\n",
    "        \n",
    "        lines.append(f'grid.add_conditional(\"fb_model_path\", [')\n",
    "        for path, return_val in zip(paths, returns):\n",
    "            lines.append(f\"    '{path}',  # return: {return_val}\")\n",
    "        lines.append(f'], condition=lambda c: c.get(\"feedback_combo\") == \"{fb_combo}\")')\n",
    "        lines.append('')\n",
    "\n",
    "        print(fb_combo + \":\")\n",
    "        for path, return_val in zip(paths, returns):\n",
    "            print(f\"  Path: {path}, Return: {return_val}\")\n",
    "    \n",
    "    return '\\n'.join(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8cbe6f8",
   "metadata": {},
   "source": [
    "### Acrobot-v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "041daecb",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_transfer_config_multi(result_all[result_all[\"config.env_id\"] == \"Acrobot-v1\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5904a940",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
