{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b8857800",
   "metadata": {},
   "source": [
    "# N Models Queries (BALROG K-query) - Figures\n",
    "\n",
    "Generates plots for the n-models queries experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cba017af",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# ---------------- Dataset parameters ----------------\n",
    "dataset = \"ms-coco\"\n",
    "params = {\"T\": 5000, \"num_runs\": 5}\n",
    "pretty_name = \"MS-COCO\"\n",
    "\n",
    "# ---------------- Style global ----------------\n",
    "plt.rcParams.update({\n",
    "    'font.size': 18,\n",
    "    'axes.titlesize': 20,\n",
    "    'axes.labelsize': 18,\n",
    "    'xtick.labelsize': 16,\n",
    "    'ytick.labelsize': 16,\n",
    "    'legend.fontsize': 14\n",
    "})\n",
    "\n",
    "# palette par K\n",
    "color_map = {2:\"tab:orange\", 3:\"tab:green\", 4:\"tab:red\", 5:\"tab:purple\", \"full\":\"tab:blue\"}\n",
    "linew = 2.2\n",
    "\n",
    "def sliding_avg(x, w):\n",
    "    if len(x) < w: \n",
    "        return np.array([])\n",
    "    return np.convolve(x, np.ones(w)/w, mode=\"valid\")\n",
    "\n",
    "# ---------------- Loading + preparation ----------------\n",
    "T        = params[\"T\"]\n",
    "num_runs = params[\"num_runs\"]\n",
    "in_path  = f\"../experiments/results/n_models_queries/data/n_models_queries_compute_equiv_{dataset}_{T}_{num_runs}runs.pkl\"\n",
    "\n",
    "with open(in_path, \"rb\") as f:\n",
    "    saved = pickle.load(f)\n",
    "\n",
    "configs      = saved[\"configs\"]\n",
    "all_results  = saved[\"all_results\"]\n",
    "optimal_runs = saved.get(\"optimal_runs\", [])\n",
    "\n",
    "label_map = {c[\"K\"]: c[\"label\"] for c in configs}\n",
    "\n",
    "# moyennes OtB (sur runs)\n",
    "avg_o2b = {}\n",
    "for c in configs:\n",
    "    o2b_runs = all_results[c[\"K\"]][\"o2b_runs\"]\n",
    "    avg_o2b[label_map[c[\"K\"]]] = np.mean(np.stack(o2b_runs), axis=0)\n",
    "\n",
    "# regret cumulé moyen (final) vs optimal\n",
    "regret_means = {}\n",
    "if len(optimal_runs) > 0:\n",
    "    opt_stack = np.stack(optimal_runs)  # (R, T)\n",
    "    for c in configs:\n",
    "        o2b_stack = np.stack(all_results[c[\"K\"]][\"o2b_runs\"])\n",
    "        L = min(opt_stack.shape[1], o2b_stack.shape[1])\n",
    "        reg = np.cumsum(opt_stack[:, :L] - o2b_stack[:, :L], axis=1)\n",
    "        final_reg = reg[:, -1]\n",
    "        regret_means[label_map[c[\"K\"]]] = float(np.mean(final_reg))\n",
    "else:\n",
    "    for c in configs:\n",
    "        regret_means[label_map[c[\"K\"]]] = float('nan')\n",
    "\n",
    "# ---------------- Plot : OtB pour MS-COCO ----------------\n",
    "fig, ax = plt.subplots(figsize=(8, 6))\n",
    "\n",
    "w = T // 10\n",
    "idx = None\n",
    "order_K = [2, 3, 4, 5, \"full\"]\n",
    "\n",
    "for K in order_K:\n",
    "    if K not in label_map:\n",
    "        continue\n",
    "    lab = label_map[K]\n",
    "    series = avg_o2b.get(lab, None)\n",
    "    if series is None or len(series) < w:\n",
    "        continue\n",
    "    mov = sliding_avg(series, w)\n",
    "    if mov.size == 0:\n",
    "        continue\n",
    "    if idx is None:\n",
    "        idx = np.linspace(0, len(mov) - 1, 120, dtype=int)\n",
    "    x = np.arange(w, w + len(mov))[idx]\n",
    "    ax.plot(x, mov[idx], label=lab,\n",
    "            color=color_map.get(K, \"black\"), linewidth=linew)\n",
    "\n",
    "ax.set_title(pretty_name)\n",
    "ax.set_xlabel(\"Iteration\")\n",
    "ax.set_ylabel(\"Avg OtB\")\n",
    "ax.legend(loc=\"lower right\", frameon=False)\n",
    "ax.grid(True)\n",
    "\n",
    "os.makedirs(\"plots/n_models_queries\", exist_ok=True)\n",
    "out_plot = \"plots/n_models_queries/otb_ms_coco_classic.pdf\"\n",
    "plt.savefig(out_plot, dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "print(f\"[OK] Figure enregistrée : {out_plot}\")\n",
    "\n",
    "# ---------------- Tableau : regret cumulé moyen ----------------\n",
    "print(f\"\\n=== Regret cumulé moyen — {pretty_name} ===\")\n",
    "def extract_K(label):\n",
    "    try:\n",
    "        s = label.split(\"K=\")[1]\n",
    "        k_str = s.split(\",\")[0]\n",
    "        return int(k_str) if k_str.isdigit() else (k_str if k_str != \"\" else None)\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "items = list(regret_means.items())\n",
    "items.sort(key=lambda kv: (extract_K(kv[0]) is None, extract_K(kv[0])))\n",
    "print(\"{:<28} | {:>12}\".format(\"K_query (label)\", \"Mean Regret\"))\n",
    "print(\"-\"*45)\n",
    "for lab, val in items:\n",
    "    val_str = f\"{val:.2f}\" if np.isfinite(val) else \"NA\"\n",
    "    print(f\"{lab:<28} | {val_str:>12}\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
