{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "077cc867",
   "metadata": {},
   "source": [
    "# Model Addition - Figures\n",
    "\n",
    "Experiment showing BALROG's ability to handle dynamic model addition during execution."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a0c8dd7",
   "metadata": {},
   "source": [
    "## OtB Comparison (MS-COCO and Carrot-bowl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d790b3ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# --- Datasets ciblés ---\n",
    "datasets = {\n",
    "    \"ms-coco\":     {\"T\": 5000, \"num_runs\": 10},\n",
    "    \"carrot-bowl\": {\"T\": 2000, \"num_runs\": 10},\n",
    "}\n",
    "pretty_names = {\"ms-coco\": \"MS-COCO\", \"carrot-bowl\": \"Carrot-bowl\"}\n",
    "\n",
    "# --- Styles globaux ---\n",
    "plt.rcParams.update({\n",
    "    'font.size': 26,\n",
    "    'axes.titlesize': 26,\n",
    "    'axes.labelsize': 24,\n",
    "    'xtick.labelsize': 16,\n",
    "    'ytick.labelsize': 22,\n",
    "    'legend.fontsize': 22\n",
    "})\n",
    "\n",
    "styles = {\n",
    "    \"Optimal\":    {\"linestyle\": \"-.\",  \"color\": \"green\",      \"linewidth\": 2.0},\n",
    "    \"Always\":     {\"linestyle\": \"--\",  \"color\": \"blue\",       \"linewidth\": 1.8},\n",
    "    \"Random\":     {\"linestyle\": \"--\",  \"color\": \"orange\",     \"linewidth\": 1.8},\n",
    "    \"PAK-UCB\":    {\"linestyle\": \"--\",  \"color\": \"red\",        \"linewidth\": 1.8},\n",
    "    \"BALROG\":     {\"linestyle\": \"-\",   \"color\": \"indigo\",     \"linewidth\": 2.4},\n",
    "    \"KNN-UCB\":    {\"linestyle\": \"--\",  \"color\": \"magenta\",    \"linewidth\": 1.8},\n",
    "    \"LinUCB\":     {\"linestyle\": \"--\",  \"color\": \"gray\",       \"linewidth\": 1.8},\n",
    "    \"neuronal-s\": {\"linestyle\": \"--\",  \"color\": \"cyan\",       \"linewidth\": 1.8},\n",
    "}\n",
    "\n",
    "def annotate_model_addition(ax, T, x_pos, label):\n",
    "    y_min, y_max = ax.get_ylim()\n",
    "    y_pos = y_min + 0.75 * (y_max - y_min)\n",
    "    ax.axvline(x=x_pos, color='black', linestyle='--', linewidth=2.0, alpha=0.8)\n",
    "    ax.text(\n",
    "        x_pos + 0.02 * T, y_pos, label,\n",
    "        rotation=90, color='crimson', fontsize=18, fontweight='bold',\n",
    "        ha='left', va='center', clip_on=False\n",
    "    )\n",
    "\n",
    "# --- Figure ---\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharey=True)\n",
    "legend_handles, legend_labels = [], []\n",
    "\n",
    "for i, (dataset, params) in enumerate(datasets.items()):\n",
    "    T, num_runs = params[\"T\"], params[\"num_runs\"]\n",
    "    w = T // 10\n",
    "    file_path = f\"../experiments/results/model_addition/data/raw_data_model_addition_{dataset}_{T}_{num_runs}runs.pkl\"\n",
    "\n",
    "    with open(file_path, \"rb\") as f:\n",
    "        data = pickle.load(f)\n",
    "    all_OtB = data[\"all_o2b\"]\n",
    "\n",
    "    avg_OtB = {a: np.mean(np.stack(all_OtB[a]), axis=0) for a in all_OtB if len(all_OtB[a]) > 0}\n",
    "\n",
    "    ax = axes[i]\n",
    "    idx = None\n",
    "    for a, series in avg_OtB.items():\n",
    "        if len(series) >= w:\n",
    "            mov = np.convolve(series, np.ones(w)/w, mode=\"valid\")\n",
    "            if idx is None:\n",
    "                idx = np.linspace(0, len(mov) - 1, 100, dtype=int)\n",
    "            x_vals = np.arange(w, w + len(mov))[idx]\n",
    "            ln, = ax.plot(x_vals, mov[idx], label=a, **styles.get(a, {}))\n",
    "            if i == 0:\n",
    "                legend_handles.append(ln)\n",
    "                legend_labels.append(a)\n",
    "\n",
    "    ax.set_title(f\"{w}-Sliding Avg OtB\")\n",
    "    ax.annotate(pretty_names.get(dataset, dataset),\n",
    "                xy=(0.5, -0.18), xycoords='axes fraction',\n",
    "                ha='center', va='top', fontsize=34)\n",
    "    ax.grid(True)\n",
    "\n",
    "    annotate_model_addition(ax, T, T // 3,  \"+ model 4\")\n",
    "    annotate_model_addition(ax, T, 2 * T // 3, \"+ model 5\")\n",
    "\n",
    "# Label Y\n",
    "fig.supylabel(\"OtB\", x=0.02, y=0.53, fontsize=24)\n",
    "\n",
    "# Ajustement des marges\n",
    "fig.subplots_adjust(left=0.09, right=0.99, bottom=0.22, top=0.88, wspace=0.08)\n",
    "\n",
    "# Légende\n",
    "fig.legend(\n",
    "    legend_handles, legend_labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.5, 0.19),\n",
    "    ncol=len(legend_labels),\n",
    "    frameon=False,\n",
    "    handlelength=2.4, columnspacing=1.0, handletextpad=0.6\n",
    ")\n",
    "\n",
    "os.makedirs(\"plots/model_addition\", exist_ok=True)\n",
    "out_path = \"plots/model_addition/OtB_model_addition.pdf\"\n",
    "plt.savefig(out_path, dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "\n",
    "print(\"Saved to:\", out_path)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
