{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f343e7c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save as: plot_learning_curves.py\n",
    "import json, os\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# ------------- config -------------\n",
    "pattern = \"epoch_metrics*.json\"   # e.g., epoch_metrics_run1.json, epoch_metrics_run2.json\n",
    "out_pdf = \"figures/learning_curves_placeholder.pdf\"\n",
    "out_png = \"figures/learning_curves_placeholder.png\"\n",
    "model_label = \"ViT (tiny, gaussian, RLRP)\"  # shown in title\n",
    "# ----------------------------------\n",
    "\n",
    "# 1) collect files (add more later; rerun to update)\n",
    "files = sorted(glob(pattern))\n",
    "if not files:\n",
    "    raise FileNotFoundError(f\"No files match {pattern} in current directory.\")\n",
    "\n",
    "# 2) helper to parse one file into lists\n",
    "def parse_epoch_file(path):\n",
    "    \"\"\"\n",
    "    Accepts either:\n",
    "      - a JSON list of epoch dicts, or\n",
    "      - a dict with key 'epochs' that is the list.\n",
    "    Expected keys per-epoch (best effort): 'epoch', 'train_acc', 'val_acc'\n",
    "    Fallbacks: 'train_accuracy', 'val_accuracy' if present.\n",
    "    Acc values can be [0,1] or [%].\n",
    "    \"\"\"\n",
    "    with open(path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    if isinstance(data, dict) and \"epochs\" in data:\n",
    "        records = data[\"epochs\"]\n",
    "    elif isinstance(data, list):\n",
    "        records = data\n",
    "    else:\n",
    "        # Some logs store step-wise dict; try common key\n",
    "        records = data.get(\"history\", [])\n",
    "\n",
    "    ep, tr, va = [], [], []\n",
    "    for r in records:\n",
    "        e  = r.get(\"epoch\", None)\n",
    "        ta = r.get(\"train_acc\", r.get(\"train_accuracy\"))\n",
    "        va_ = r.get(\"val_acc\", r.get(\"val_accuracy\"))\n",
    "        if e is None or ta is None or va_ is None:\n",
    "            # try alt names used earlier\n",
    "            ta = ta if ta is not None else r.get(\"train\", {}).get(\"acc\")\n",
    "            va_ = va_ if va_ is not None else r.get(\"val\", {}).get(\"acc\")\n",
    "        if (ta is None) or (va_ is None):\n",
    "            continue\n",
    "        # normalize to %\n",
    "        if 0 <= ta <= 1: ta *= 100.0\n",
    "        if 0 <= va_ <= 1: va_ *= 100.0\n",
    "        ep.append(int(e) if e is not None else len(ep))\n",
    "        tr.append(float(ta))\n",
    "        va.append(float(va_))\n",
    "    return ep, tr, va\n",
    "\n",
    "# 3) concatenate runs with continuous epoch indexing\n",
    "all_epochs, all_train, all_val = [], [], []\n",
    "epoch_offset = 0\n",
    "for path in files:\n",
    "    ep, tr, va = parse_epoch_file(path)\n",
    "    if not ep:\n",
    "        continue\n",
    "    # rebase epochs to continue after previous\n",
    "    # ensure monotonic 0..N for this chunk\n",
    "    # use the order they appear, regardless of their internal epoch numbering\n",
    "    for i in range(len(ep)):\n",
    "        all_epochs.append(epoch_offset + i)\n",
    "        all_train.append(tr[i])\n",
    "        all_val.append(va[i])\n",
    "    epoch_offset += len(ep)\n",
    "\n",
    "# 4) plot (single chart, two lines; no explicit colors/styles)\n",
    "import numpy as np\n",
    "os.makedirs(os.path.dirname(out_pdf), exist_ok=True)\n",
    "fig, ax = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "ax.plot(all_epochs, all_train, marker='o', linewidth=1.2, markersize=3, label=\"Train Accuracy\")\n",
    "ax.plot(all_epochs, all_val,   marker='s', linewidth=1.2, markersize=3, label=\"Val Accuracy\")\n",
    "\n",
    "ax.set_xlabel(\"Epoch\")\n",
    "ax.set_ylabel(\"Accuracy (%)\")\n",
    "ax.set_title(f\"Learning Curves — {model_label}\")\n",
    "ax.set_ylim(0, 105)\n",
    "ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "ax.legend(frameon=True, framealpha=0.5)\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "fig.savefig(out_png, dpi=300, bbox_inches=\"tight\")\n",
    "print(f\"Saved:\\n  {out_pdf}\\n  {out_png}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f880f37",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save as: plot_learning_curves.py\n",
    "import json, os\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# -------- configuration --------\n",
    "pattern = \"epoch_metrics*.json\"   # e.g., epoch_metrics_run1.json, epoch_metrics_run2.json\n",
    "out_pdf = \"figures/learning_curves_placeholder.pdf\"\n",
    "out_png = \"figures/learning_curves_placeholder.png\"\n",
    "model_label = r\"ViT (tiny, gaussian, RLRP)\"  # shown in title\n",
    "# --------------------------------\n",
    "\n",
    "def parse_epoch_file(path):\n",
    "    \"\"\"\n",
    "    Accepts either:\n",
    "      - a JSON list of epoch dicts, or\n",
    "      - a dict with key 'epochs' that is the list, or\n",
    "      - a dict with key 'history' that is the list.\n",
    "    Per-epoch keys (best effort): \n",
    "      'epoch', 'train_acc' or 'train_accuracy', 'val_acc' or 'val_accuracy',\n",
    "      'train_loss', 'val_loss'.\n",
    "    Acc values can be [0,1] or already in %.\n",
    "    \"\"\"\n",
    "    with open(path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    if isinstance(data, list):\n",
    "        records = data\n",
    "    elif isinstance(data, dict):\n",
    "        if \"epochs\" in data:\n",
    "            records = data[\"epochs\"]\n",
    "        elif \"history\" in data:\n",
    "            records = data[\"history\"]\n",
    "        else:\n",
    "            # assume it's already a single-epoch dict list in some nested key; fallback to flatten none\n",
    "            records = []\n",
    "    else:\n",
    "        records = []\n",
    "\n",
    "    ep, tr_acc, va_acc, tr_loss, va_loss = [], [], [], [], []\n",
    "    for i, r in enumerate(records):\n",
    "        e = r.get(\"epoch\", i)\n",
    "        # accuracy fields\n",
    "        ta = r.get(\"train_acc\", r.get(\"train_accuracy\"))\n",
    "        va = r.get(\"val_acc\",   r.get(\"val_accuracy\"))\n",
    "        # loss fields\n",
    "        tl = r.get(\"train_loss\", r.get(\"loss\", r.get(\"train\", {}).get(\"loss\")))\n",
    "        vl = r.get(\"val_loss\",   r.get(\"val\", {}).get(\"loss\"))\n",
    "\n",
    "        # skip if both accs and losses are missing\n",
    "        if ta is None and va is None and tl is None and vl is None:\n",
    "            continue\n",
    "\n",
    "        # normalize accuracy to %\n",
    "        if ta is not None and 0 <= ta <= 1:\n",
    "            ta *= 100.0\n",
    "        if va is not None and 0 <= va <= 1:\n",
    "            va *= 100.0\n",
    "\n",
    "        ep.append(int(e))\n",
    "        tr_acc.append(float(ta) if ta is not None else None)\n",
    "        va_acc.append(float(va) if va is not None else None)\n",
    "        tr_loss.append(float(tl) if tl is not None else None)\n",
    "        va_loss.append(float(vl) if vl is not None else None)\n",
    "\n",
    "    return ep, tr_acc, va_acc, tr_loss, va_loss\n",
    "\n",
    "# 1) collect and sort files\n",
    "files = sorted(glob(pattern))\n",
    "if not files:\n",
    "    raise FileNotFoundError(f\"No files match {pattern} in current directory.\")\n",
    "\n",
    "# 2) concatenate runs with continuous epoch indexing\n",
    "all_ep, all_tr_acc, all_va_acc, all_tr_loss, all_va_loss = [], [], [], [], []\n",
    "offset = 0\n",
    "for p in files:\n",
    "    ep, tra, vaa, trl, val = parse_epoch_file(p)\n",
    "    # rebase epochs so the second run continues after the first\n",
    "    for i in range(len(ep)):\n",
    "        all_ep.append(offset + i)\n",
    "        all_tr_acc.append(tra[i] if i < len(tra) else None)\n",
    "        all_va_acc.append(vaa[i] if i < len(vaa) else None)\n",
    "        all_tr_loss.append(trl[i] if i < len(trl) else None)\n",
    "        all_va_loss.append(val[i] if i < len(val) else None)\n",
    "    offset += len(ep)\n",
    "\n",
    "# 3) plot: single figure, accuracy on left y-axis, loss on right y-axis\n",
    "os.makedirs(os.path.dirname(out_pdf), exist_ok=True)\n",
    "fig, ax1 = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "# Accuracy (left axis)\n",
    "acc_train_line, = ax1.plot(all_ep, all_tr_acc, marker='o', linewidth=1.2, markersize=3, label=\"Train Accuracy\")\n",
    "acc_val_line,   = ax1.plot(all_ep, all_va_acc, marker='s', linewidth=1.2, markersize=3, label=\"Val Accuracy\")\n",
    "ax1.set_xlabel(\"Epoch\")\n",
    "ax1.set_ylabel(\"Accuracy (%)\")\n",
    "ax1.set_ylim(0, 105)\n",
    "ax1.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Loss (right axis)\n",
    "ax2 = ax1.twinx()\n",
    "loss_train_line, = ax2.plot(all_ep, all_tr_loss, marker='^', linewidth=1.2, markersize=3, label=\"Train Loss\")\n",
    "loss_val_line,   = ax2.plot(all_ep, all_va_loss, marker='v', linewidth=1.2, markersize=3, label=\"Val Loss\")\n",
    "ax2.set_ylabel(\"Loss\")\n",
    "\n",
    "# Title and legend (combine both axes)\n",
    "fig.suptitle(f\"Learning Curves — {model_label}\", y=0.98)\n",
    "\n",
    "lines = [acc_train_line, acc_val_line, loss_train_line, loss_val_line]\n",
    "labels = [l.get_label() for l in lines]\n",
    "leg = ax1.legend(lines, labels, ncols=2, fontsize=9, frameon=True, framealpha=0.5, loc=\"upper right\")\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "fig.savefig(out_png, dpi=300, bbox_inches=\"tight\")\n",
    "print(f\"Saved:\\n  {out_pdf}\\n  {out_png}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86607f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_learning_curves.py\n",
    "import json, os\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "pattern = \"epoch_metrics*.json\"\n",
    "out_pdf = \"figures/learning_curves_placeholder.pdf\"\n",
    "out_png = \"figures/learning_curves_placeholder.png\"\n",
    "\n",
    "def parse_epoch_file(path):\n",
    "    with open(path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    records = data if isinstance(data, list) else data.get(\"epochs\", data.get(\"history\", []))\n",
    "\n",
    "    ep, tr_acc, va_acc, tr_loss, va_loss = [], [], [], [], []\n",
    "    for i, r in enumerate(records):\n",
    "        e = r.get(\"epoch\", i)\n",
    "        ta = r.get(\"train_acc\", r.get(\"train_accuracy\"))\n",
    "        va = r.get(\"val_acc\", r.get(\"val_accuracy\"))\n",
    "        tl = r.get(\"train_loss\", r.get(\"loss\", r.get(\"train\", {}).get(\"loss\")))\n",
    "        vl = r.get(\"val_loss\", r.get(\"val\", {}).get(\"loss\"))\n",
    "        if ta is None and va is None and tl is None and vl is None:\n",
    "            continue\n",
    "        if ta is not None and 0 <= ta <= 1: ta *= 100\n",
    "        if va is not None and 0 <= va <= 1: va *= 100\n",
    "        ep.append(int(e))\n",
    "        tr_acc.append(ta if ta is not None else None)\n",
    "        va_acc.append(va if va is not None else None)\n",
    "        tr_loss.append(tl if tl is not None else None)\n",
    "        va_loss.append(vl if vl is not None else None)\n",
    "    return ep, tr_acc, va_acc, tr_loss, va_loss\n",
    "\n",
    "files = sorted(glob(pattern))\n",
    "if not files:\n",
    "    raise FileNotFoundError(\"No epoch_metrics*.json files found.\")\n",
    "\n",
    "all_ep, all_tr_acc, all_va_acc, all_tr_loss, all_va_loss = [], [], [], [], []\n",
    "offset = 0\n",
    "for p in files:\n",
    "    ep, tra, vaa, trl, val = parse_epoch_file(p)\n",
    "    for i in range(len(ep)):\n",
    "        all_ep.append(offset + i)\n",
    "        all_tr_acc.append(tra[i])\n",
    "        all_va_acc.append(vaa[i])\n",
    "        all_tr_loss.append(trl[i])\n",
    "        all_va_loss.append(val[i])\n",
    "    offset += len(ep)\n",
    "\n",
    "os.makedirs(os.path.dirname(out_pdf), exist_ok=True)\n",
    "fig, ax1 = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "# Accuracy curves\n",
    "acc_train_line, = ax1.plot(all_ep, all_tr_acc, marker='o', linewidth=1.2, markersize=3, label=\"Train Accuracy\")\n",
    "acc_val_line,   = ax1.plot(all_ep, all_va_acc, marker='s', linewidth=1.2, markersize=3, label=\"Val Accuracy\")\n",
    "ax1.set_xlabel(\"Epoch\")\n",
    "ax1.set_ylabel(\"Accuracy (%)\")\n",
    "ax1.set_ylim(0, 105)\n",
    "ax1.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Loss curves\n",
    "ax2 = ax1.twinx()\n",
    "loss_train_line, = ax2.plot(all_ep, all_tr_loss, marker='^', linewidth=1.2, markersize=3, label=\"Train Loss\")\n",
    "loss_val_line,   = ax2.plot(all_ep, all_va_loss, marker='v', linewidth=1.2, markersize=3, label=\"Val Loss\")\n",
    "ax2.set_ylabel(\"Loss\")\n",
    "\n",
    "# Combined legend\n",
    "lines = [acc_train_line, acc_val_line, loss_train_line, loss_val_line]\n",
    "labels = [l.get_label() for l in lines]\n",
    "ax1.legend(lines, labels, ncols=2, fontsize=9, frameon=True, framealpha=0.5, loc=\"upper right\")\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "fig.savefig(out_png, dpi=300, bbox_inches=\"tight\")\n",
    "print(f\"Saved: {out_pdf}, {out_png}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b7bba45",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_learning_curves.py\n",
    "import json, os\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "pattern = \"epoch_metrics*.json\"\n",
    "out_pdf = \"figures/learning_curves_placeholder.pdf\"\n",
    "out_png = \"figures/learning_curves_placeholder.png\"\n",
    "\n",
    "def parse_epoch_file(path):\n",
    "    with open(path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    records = data if isinstance(data, list) else data.get(\"epochs\", data.get(\"history\", []))\n",
    "\n",
    "    ep, tr_acc, va_acc, tr_loss, va_loss = [], [], [], [], []\n",
    "    for i, r in enumerate(records):\n",
    "        e = r.get(\"epoch\", i)\n",
    "        ta = r.get(\"train_acc\", r.get(\"train_accuracy\"))\n",
    "        va = r.get(\"val_acc\", r.get(\"val_accuracy\"))\n",
    "        tl = r.get(\"train_loss\", r.get(\"loss\", r.get(\"train\", {}).get(\"loss\")))\n",
    "        vl = r.get(\"val_loss\", r.get(\"val\", {}).get(\"loss\"))\n",
    "        if ta is None and va is None and tl is None and vl is None:\n",
    "            continue\n",
    "        if ta is not None and 0 <= ta <= 1: ta *= 100\n",
    "        if va is not None and 0 <= va <= 1: va *= 100\n",
    "        ep.append(int(e))\n",
    "        tr_acc.append(ta if ta is not None else None)\n",
    "        va_acc.append(va if va is not None else None)\n",
    "        tr_loss.append(tl if tl is not None else None)\n",
    "        va_loss.append(vl if vl is not None else None)\n",
    "    return ep, tr_acc, va_acc, tr_loss, va_loss\n",
    "\n",
    "files = sorted(glob(pattern))\n",
    "if not files:\n",
    "    raise FileNotFoundError(\"No epoch_metrics*.json files found.\")\n",
    "\n",
    "all_ep, all_tr_acc, all_va_acc, all_tr_loss, all_va_loss = [], [], [], [], []\n",
    "offset = 0\n",
    "for p in files:\n",
    "    ep, tra, vaa, trl, val = parse_epoch_file(p)\n",
    "    for i in range(len(ep)):\n",
    "        all_ep.append(offset + i)\n",
    "        all_tr_acc.append(tra[i])\n",
    "        all_va_acc.append(vaa[i])\n",
    "        all_tr_loss.append(trl[i])\n",
    "        all_va_loss.append(val[i])\n",
    "    offset += len(ep)\n",
    "\n",
    "os.makedirs(os.path.dirname(out_pdf), exist_ok=True)\n",
    "fig, ax1 = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "# Accuracy curves\n",
    "acc_train_line, = ax1.plot(all_ep, all_tr_acc, linestyle='-', marker='o',\n",
    "                           linewidth=1.2, markersize=3, label=\"Train Accuracy\")\n",
    "acc_val_line,   = ax1.plot(all_ep, all_va_acc, linestyle='--', marker='s',\n",
    "                           linewidth=1.2, markersize=3, label=\"Val Accuracy\")\n",
    "ax1.set_xlabel(\"Epoch\")\n",
    "ax1.set_ylabel(\"Accuracy (%)\")\n",
    "ax1.set_ylim(0, 105)\n",
    "ax1.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Loss curves\n",
    "ax2 = ax1.twinx()\n",
    "loss_train_line, = ax2.plot(all_ep, all_tr_loss, linestyle='-', marker='^',\n",
    "                            linewidth=1.2, markersize=3, label=\"Train Loss\")\n",
    "loss_val_line,   = ax2.plot(all_ep, all_va_loss, linestyle='--', marker='v',\n",
    "                            linewidth=1.2, markersize=3, label=\"Val Loss\")\n",
    "ax2.set_ylabel(\"Loss\")\n",
    "\n",
    "# Combined legend (with transparent background)\n",
    "lines = [acc_train_line, acc_val_line, loss_train_line, loss_val_line]\n",
    "labels = [l.get_label() for l in lines]\n",
    "ax1.legend(lines, labels, ncols=2, fontsize=9,\n",
    "           frameon=True, framealpha=0.5, loc=\"upper right\")\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "fig.savefig(out_png, dpi=300, bbox_inches=\"tight\")\n",
    "print(f\"Saved: {out_pdf}, {out_png}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3b3ed4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_learning_curves.py\n",
    "import json, os\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "pattern = \"epoch_metrics*.json\"\n",
    "out_pdf = \"figures/learning_curves_efficientnet.pdf\"\n",
    "out_png = \"figures/learning_curves_efficientnet.png\"\n",
    "\n",
    "def parse_epoch_file(path):\n",
    "    with open(path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    records = data if isinstance(data, list) else data.get(\"epochs\", data.get(\"history\", []))\n",
    "\n",
    "    ep, tr_acc, va_acc, tr_loss, va_loss = [], [], [], [], []\n",
    "    for i, r in enumerate(records):\n",
    "        e = r.get(\"epoch\", i)\n",
    "        ta = r.get(\"train_acc\", r.get(\"train_accuracy\"))\n",
    "        va = r.get(\"val_acc\", r.get(\"val_accuracy\"))\n",
    "        tl = r.get(\"train_loss\", r.get(\"loss\", r.get(\"train\", {}).get(\"loss\")))\n",
    "        vl = r.get(\"val_loss\", r.get(\"val\", {}).get(\"loss\"))\n",
    "        if ta is None and va is None and tl is None and vl is None:\n",
    "            continue\n",
    "        if ta is not None and 0 <= ta <= 1: ta *= 100\n",
    "        if va is not None and 0 <= va <= 1: va *= 100\n",
    "        ep.append(int(e))\n",
    "        tr_acc.append(ta if ta is not None else None)\n",
    "        va_acc.append(va if va is not None else None)\n",
    "        tr_loss.append(tl if tl is not None else None)\n",
    "        va_loss.append(vl if vl is not None else None)\n",
    "    return ep, tr_acc, va_acc, tr_loss, va_loss\n",
    "\n",
    "files = sorted(glob(pattern))\n",
    "if not files:\n",
    "    raise FileNotFoundError(\"No epoch_metrics*.json files found.\")\n",
    "\n",
    "all_ep, all_tr_acc, all_va_acc, all_tr_loss, all_va_loss = [], [], [], [], []\n",
    "offset = 0\n",
    "for p in files:\n",
    "    ep, tra, vaa, trl, val = parse_epoch_file(p)\n",
    "    for i in range(len(ep)):\n",
    "        all_ep.append(offset + i)\n",
    "        all_tr_acc.append(tra[i])\n",
    "        all_va_acc.append(vaa[i])\n",
    "        all_tr_loss.append(trl[i])\n",
    "        all_va_loss.append(val[i])\n",
    "    offset += len(ep)\n",
    "\n",
    "os.makedirs(os.path.dirname(out_pdf), exist_ok=True)\n",
    "fig, ax1 = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "# ---- ensure 4 distinct colors across both axes ----\n",
    "cycle_colors = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3'])\n",
    "c0 = cycle_colors[0 % len(cycle_colors)]\n",
    "c1 = cycle_colors[1 % len(cycle_colors)]\n",
    "c2 = cycle_colors[2 % len(cycle_colors)]\n",
    "c3 = cycle_colors[3 % len(cycle_colors)]\n",
    "\n",
    "# Accuracy (left axis)\n",
    "acc_train_line, = ax1.plot(all_ep, all_tr_acc, linestyle='-', marker='o',\n",
    "                           linewidth=1.2, markersize=3, label=\"Train Accuracy\", color=c0)\n",
    "acc_val_line,   = ax1.plot(all_ep, all_va_acc, linestyle='--', marker='s',\n",
    "                           linewidth=1.2, markersize=3, label=\"Val Accuracy\",   color=c1)\n",
    "ax1.set_xlabel(\"Epoch\")\n",
    "ax1.set_ylabel(\"Accuracy (%)\")\n",
    "ax1.set_ylim(0, 105)\n",
    "ax1.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Loss (right axis)\n",
    "ax2 = ax1.twinx()\n",
    "loss_train_line, = ax2.plot(all_ep, all_tr_loss, linestyle='-', marker='^',\n",
    "                            linewidth=1.2, markersize=3, label=\"Train Loss\", color=c2)\n",
    "loss_val_line,   = ax2.plot(all_ep, all_va_loss, linestyle='--', marker='v',\n",
    "                            linewidth=1.2, markersize=3, label=\"Val Loss\",   color=c3)\n",
    "ax2.set_ylabel(\"Loss\")\n",
    "\n",
    "# Combined legend (semi-transparent frame)\n",
    "lines = [acc_train_line, acc_val_line, loss_train_line, loss_val_line]\n",
    "labels = [l.get_label() for l in lines]\n",
    "ax1.legend(lines, labels, ncols=2, fontsize=9, frameon=True, framealpha=0.5, loc=\"upper right\")\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "fig.savefig(out_png, dpi=300, bbox_inches=\"tight\")\n",
    "print(f\"Saved: {out_pdf}, {out_png}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfd02c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_learning_curves.py\n",
    "import json, os\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "pattern = \"epoch_metrics*.json\"\n",
    "out_pdf = \"figures/learning_curves_trainloss_valacc_swapped.pdf\"\n",
    "out_png = \"figures/learning_curves_trainloss_valacc_swapped.png\"\n",
    "\n",
    "def parse_epoch_file(path):\n",
    "    with open(path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    records = data if isinstance(data, list) else data.get(\"epochs\", data.get(\"history\", []))\n",
    "\n",
    "    ep, tr_loss, va_acc = [], [], []\n",
    "    for i, r in enumerate(records):\n",
    "        e = r.get(\"epoch\", i)\n",
    "        tl = r.get(\"train_loss\", r.get(\"loss\", r.get(\"train\", {}).get(\"loss\")))\n",
    "        va = r.get(\"val_acc\", r.get(\"val_accuracy\"))\n",
    "        if va is not None and 0 <= va <= 1:  # normalize to %\n",
    "            va *= 100\n",
    "        ep.append(int(e))\n",
    "        tr_loss.append(tl if tl is not None else None)\n",
    "        va_acc.append(va if va is not None else None)\n",
    "    return ep, tr_loss, va_acc\n",
    "\n",
    "files = sorted(glob(pattern))\n",
    "if not files:\n",
    "    raise FileNotFoundError(\"No epoch_metrics*.json files found.\")\n",
    "\n",
    "all_ep, all_tr_loss, all_va_acc = [], [], []\n",
    "offset = 0\n",
    "for p in files:\n",
    "    ep, trl, vaa = parse_epoch_file(p)\n",
    "    for i in range(len(ep)):\n",
    "        all_ep.append(offset + i)\n",
    "        all_tr_loss.append(trl[i])\n",
    "        all_va_acc.append(vaa[i])\n",
    "    offset += len(ep)\n",
    "\n",
    "os.makedirs(os.path.dirname(out_pdf), exist_ok=True)\n",
    "fig, ax1 = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "# ---- consistent color cycle ----\n",
    "cycle_colors = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3'])\n",
    "c_trloss = cycle_colors[2 % len(cycle_colors)]  # train loss color\n",
    "c_valacc = cycle_colors[1 % len(cycle_colors)]  # val accuracy color\n",
    "\n",
    "# Left axis: training loss\n",
    "train_loss_line, = ax1.plot(all_ep, all_tr_loss, linestyle='-', marker='^',\n",
    "                            linewidth=1.2, markersize=3, label=\"Train Loss\", color=c_trloss)\n",
    "ax1.set_xlabel(\"Epoch\")\n",
    "ax1.set_ylabel(\"Training Loss\")\n",
    "ax1.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Right axis: validation accuracy\n",
    "ax2 = ax1.twinx()\n",
    "val_acc_line, = ax2.plot(all_ep, all_va_acc, linestyle='--', marker='s',\n",
    "                         linewidth=1.2, markersize=3, label=\"Val Accuracy\", color=c_valacc)\n",
    "ax2.set_ylabel(\"Validation Accuracy (%)\")\n",
    "ax2.set_ylim(0, 105)\n",
    "\n",
    "# Legend\n",
    "lines = [train_loss_line, val_acc_line]\n",
    "labels = [l.get_label() for l in lines]\n",
    "ax1.legend(lines, labels, ncols=2, fontsize=9, frameon=True, framealpha=0.5, loc=\"upper right\")\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "fig.savefig(out_png, dpi=300, bbox_inches=\"tight\")\n",
    "print(f\"Saved: {out_pdf}, {out_png}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
