{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "09cd42c1",
   "metadata": {},
   "source": [
    "\n",
    "# ML-JET — Loss-Weight Sweep Aggregator (Notebook)\n",
    "\n",
    "This notebook scans your sweep outputs under:\n",
    "\n",
    "```\n",
    "experiments/exp_loss_weight_sweep/training_output/\n",
    "```\n",
    "\n",
    "It collects per-run metrics from `training_summary.json`, builds a table with the columns needed for your paper, and then reports the **best run per model family** (ConvNeXt, EfficientNet, ViT, Swin, Mamba).\n",
    "\n",
    "**Columns generated:**\n",
    "\n",
    "- **Model (Initialization)** — inferred from `model_tag` / folder name\n",
    "- **Batch** — `batch_size`\n",
    "- **LR** — `learning_rate`\n",
    "- **Total_Acc** — `best_accuracy`\n",
    "- **Energy (Acc, Prec, Rec, F1)**\n",
    "- **αs (Acc, Prec, Rec, F1)**\n",
    "- **Q0 (Acc, Prec, Rec, F1)**\n",
    "\n",
    "You can customize the ranking logic if needed.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d52c85",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# --- CONFIG ---\n",
    "ROOT = \"training_output/\"   # path to run folders\n",
    "OUT_DIR = \"\"                # where to write CSVs\n",
    "\n",
    "# If running from a different CWD, you can set absolute paths:\n",
    "# ROOT = \"/home/.../experiments/exp_loss_weight_sweep/training_output\"\n",
    "# OUT_DIR = \"/home/.../experiments/exp_loss_weight_sweep\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbb1e155",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os, json, re\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "\n",
    "def load_json(p):\n",
    "    try:\n",
    "        with open(p, \"r\") as f:\n",
    "            return json.load(f)\n",
    "    except Exception as e:\n",
    "        print(f\"[WARN] Failed to read {p}: {e}\")\n",
    "        return None\n",
    "\n",
    "FAMILY_PATTERNS = [\n",
    "    (re.compile(r'efficientnet', re.I), 'EfficientNet'),\n",
    "    (re.compile(r'convnext', re.I),     'ConvNeXt'),\n",
    "    (re.compile(r'\\bvit\\b', re.I),      'ViT'),\n",
    "    (re.compile(r'comer', re.I),        'ViT'),      # ViT-CoMer\n",
    "    (re.compile(r'swin', re.I),         'Swin'),\n",
    "    (re.compile(r'mamba', re.I),        'Mamba'),\n",
    "]\n",
    "\n",
    "def infer_family(model_tag: str, fallback: str) -> str:\n",
    "    key = (model_tag or fallback or \"\").lower()\n",
    "    for pat, fam in FAMILY_PATTERNS:\n",
    "        if pat.search(key):\n",
    "            return fam\n",
    "    return \"Other\"\n",
    "\n",
    "def get_nested(d, path, default=None):\n",
    "    cur = d\n",
    "    for k in path.split(\".\"):\n",
    "        if not isinstance(cur, dict) or k not in cur:\n",
    "            return default\n",
    "        cur = cur[k]\n",
    "    return cur\n",
    "\n",
    "def tuple4(d, base):\n",
    "    \"\"\"Return (acc, prec, rec, f1) for a head or (None,..) if missing.\"\"\"\n",
    "    return (\n",
    "        get_nested(d, f\"{base}.accuracy\"),\n",
    "        get_nested(d, f\"{base}.precision\"),\n",
    "        get_nested(d, f\"{base}.recall\"),\n",
    "        get_nested(d, f\"{base}.f1\"),\n",
    "    )\n",
    "\n",
    "ROOT = Path(ROOT)\n",
    "OUT_DIR = Path(OUT_DIR)\n",
    "OUT_DIR.mkdir(parents=True, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3a2856a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rows = []\n",
    "for run_dir in sorted(ROOT.iterdir()):\n",
    "    if not run_dir.is_dir():\n",
    "        continue\n",
    "    summary_path = run_dir / \"training_summary.json\"\n",
    "    if not summary_path.exists():\n",
    "        continue\n",
    "\n",
    "    js = load_json(summary_path)\n",
    "    if js is None:\n",
    "        continue\n",
    "\n",
    "    model_tag = js.get(\"model_tag\") or run_dir.name\n",
    "    family = infer_family(model_tag, run_dir.name)\n",
    "\n",
    "    # Scalars\n",
    "    batch = js.get(\"batch_size\")\n",
    "    lr    = js.get(\"learning_rate\")\n",
    "    total_acc = js.get(\"best_accuracy\")\n",
    "\n",
    "    # Per-head tuples\n",
    "    e_acc, e_prec, e_rec, e_f1 = tuple4(js, \"best_model_metrics.energy\")\n",
    "    a_acc, a_prec, a_rec, a_f1 = tuple4(js, \"best_model_metrics.alpha\")\n",
    "    q_acc, q_prec, q_rec, q_f1 = tuple4(js, \"best_model_metrics.q0\")\n",
    "\n",
    "    rows.append({\n",
    "        \"family\": family,\n",
    "        \"model_tag\": model_tag,\n",
    "        \"run_dir\": str(run_dir),\n",
    "        \"batch\": batch,\n",
    "        \"lr\": lr,\n",
    "        \"Total_Acc\": total_acc,\n",
    "        \"Energy_Acc\": e_acc, \"Energy_Prec\": e_prec, \"Energy_Rec\": e_rec, \"Energy_F1\": e_f1,\n",
    "        \"Alpha_Acc\": a_acc,  \"Alpha_Prec\": a_prec,  \"Alpha_Rec\": a_rec,  \"Alpha_F1\": a_f1,\n",
    "        \"Q0_Acc\": q_acc,     \"Q0_Prec\": q_prec,     \"Q0_Rec\": q_rec,     \"Q0_F1\": q_f1,\n",
    "    })\n",
    "\n",
    "import pandas as pd\n",
    "df = pd.DataFrame(rows)\n",
    "print(f\"[OK] Collected {len(df)} runs from {ROOT}\")\n",
    "df.head()\n",
    "display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b5fb55",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from os import path\n",
    "\n",
    "\n",
    "full_csv = path.join(OUT_DIR, \"aggregate_results_detailed.csv\")\n",
    "df.to_csv(full_csv, index=False)\n",
    "print(f\"[OK] wrote {full_csv}\")\n",
    "\n",
    "# Display (if running in a notebook UI that supports DataFrames)\n",
    "try:\n",
    "    from caas_jupyter_tools import display_dataframe_to_user\n",
    "    display_dataframe_to_user(\"Loss-Weight Sweep — Full Results\", df)\n",
    "except Exception as e:\n",
    "    print(\"[INFO] display_dataframe_to_user not available in this environment.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880b33cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Rank by Total_Acc, tie-break on Q0_F1 (most challenging head)\n",
    "df_numeric = df.copy()\n",
    "for c in [\"Total_Acc\",\"Q0_F1\"]:\n",
    "    df_numeric[c] = pd.to_numeric(df_numeric[c], errors=\"coerce\")\n",
    "\n",
    "best_rows = []\n",
    "for fam, sub in df_numeric.groupby(\"family\"):\n",
    "    sub = sub.copy()\n",
    "    sub[\"_rank_key\"] = list(zip(sub[\"Total_Acc\"].fillna(-1), sub[\"Q0_F1\"].fillna(-1)))\n",
    "    best_idx = sub[\"_rank_key\"].idxmax()\n",
    "    best_rows.append(df.loc[best_idx])\n",
    "\n",
    "best_df = pd.DataFrame(best_rows).sort_values(\"family\")\n",
    "best_csv = path.join(OUT_DIR, \"best_per_family_detailed.csv\")\n",
    "best_df.to_csv(best_csv, index=False)\n",
    "print(f\"[OK] wrote {best_csv}\")\n",
    "best_df[[\"family\",\"model_tag\",\"batch\",\"lr\",\"Total_Acc\",\n",
    "         \"Energy_Acc\",\"Energy_Prec\",\"Energy_Rec\",\"Energy_F1\",\n",
    "         \"Alpha_Acc\",\"Alpha_Prec\",\"Alpha_Rec\",\"Alpha_F1\",\n",
    "         \"Q0_Acc\",\"Q0_Prec\",\"Q0_Rec\",\"Q0_F1\"]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fbedf4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Add loss weights from run_dir name ---\n",
    "import re\n",
    "\n",
    "def parse_weights_from_dir(name: str):\n",
    "    m = re.search(\n",
    "        r'energy[_]?loss[_]?output[_-]?([0-9.]+).*?alpha[_]?output[_-]?([0-9.]+).*?q0[_]?output[_-]?([0-9.]+)',\n",
    "        name,\n",
    "        flags=re.I,\n",
    "    )\n",
    "    if m:\n",
    "        return float(m.group(1)), float(m.group(2)), float(m.group(3))\n",
    "    return None, None, None\n",
    "\n",
    "best_df = best_df.copy()\n",
    "best_df[\"w_energy\"], best_df[\"w_alpha\"], best_df[\"w_q0\"] = zip(\n",
    "    *best_df[\"run_dir\"].map(parse_weights_from_dir)\n",
    ")\n",
    "\n",
    "# Save again with weights included\n",
    "best_csv = path.join(OUT_DIR, \"best_per_family_detailed_with_weights.csv\")\n",
    "best_df.to_csv(best_csv, index=False)\n",
    "print(f\"[OK] wrote {best_csv}\")\n",
    "\n",
    "# best_df[\n",
    "#     [\"family\",\"model_tag\",\"batch\",\"lr\",\"Total_Acc\",\n",
    "#      \"Energy_Acc\",\"Energy_Prec\",\"Energy_Rec\",\"Energy_F1\",\n",
    "#      \"Alpha_Acc\",\"Alpha_Prec\",\"Alpha_Rec\",\"Alpha_F1\",\n",
    "#      \"Q0_Acc\",\"Q0_Prec\",\"Q0_Rec\",\"Q0_F1\",\n",
    "#      \"w_energy\",\"w_alpha\",\"w_q0\"]\n",
    "# ]\n",
    "best_df[\n",
    "    [\"family\",\n",
    "     \"Energy_Prec\",\"Energy_Rec\",\n",
    "     \"Alpha_Prec\",\"Alpha_Rec\",\n",
    "     \"Q0_Prec\",\"Q0_Rec\",\n",
    "     \"w_energy\",\"w_alpha\",\"w_q0\"]\n",
    "]\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
