{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/EnesAgirman/RS_gradient/blob/main/cifar10lt_updated_final_hopefully_ver2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nAybllql6HQB",
    "outputId": "b46927fe-74c9-49c0-9d2c-a53e37cf2f1c"
   },
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "RCwMj-arFf1B",
    "outputId": "2e36eb03-239b-427d-b868-fad879d4decd"
   },
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "# v6_cifar10lt — cleaned layout\n",
    "\n",
    "# === Imports ===\n",
    "import os, sys, time, math, random, warnings, gc, csv, json, datetime\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "from torchvision import datasets, transforms\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from tqdm import tqdm\n",
    "from contextlib import nullcontext\n",
    "\n",
    "# === Quiet console ===\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*GradScaler is enabled.*\")\n",
    "\n",
    "QUIET_DEVICE_INFO = True  # set False to print CUDA info\n",
    "\n",
    "if not QUIET_DEVICE_INFO:\n",
    "    print(\"CUDA available:\", torch.cuda.is_available())\n",
    "    if torch.cuda.is_available():\n",
    "        print(\"Device:\", torch.cuda.get_device_name(0))\n",
    "        torch.cuda.set_device(0)\n",
    "    else:\n",
    "        print(\"⚠️  No GPU detected (Runtime ▸ Change runtime type ▸ GPU).\")\n",
    "else:\n",
    "    if torch.cuda.is_available():\n",
    "        try:\n",
    "            torch.cuda.set_device(0)\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "# Colab Drive (safe to no-op locally)\n",
    "try:\n",
    "    from google.colab import drive\n",
    "    drive.mount('/content/drive', force_remount=True)\n",
    "except Exception as _e:\n",
    "    pass\n",
    "\n",
    "# === Device & CUDA setup ===\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "# CUDA allocator knobs\n",
    "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:128,expandable_segments:True\"\n",
    "\n",
    "# Global toggles\n",
    "AMP = True      # mixed precision\n",
    "ACCUM_STEPS = 1 # gradient accumulation (unused by default)\n",
    "RS_SHARED_TAU = 0.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F3CiI8QImVH-"
   },
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# === Constants and Early Helper Functions ===\n",
    "# ==========================================\n",
    "# (Defined here so they're available for data loading)\n",
    "\n",
    "CIFAR10_CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']\n",
    "\n",
    "def build_irm_environments_probabilistic(full_dataset, lt_indices, seed=42):\n",
    "    \"\"\"\n",
    "    Split the LT training subset into two probabilistic IRM environments.\n",
    "    Env1: oversamples majority classes (geometric sampling with p=0.7).\n",
    "    Env2: undersamples majority to balance better (geometric with p=0.3).\n",
    "    Returns: (env1_dataset, env2_dataset)\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    targets_full = np.array(full_dataset.targets, dtype=np.int64)\n",
    "    targets_lt = targets_full[lt_indices]\n",
    "\n",
    "    num_classes = int(targets_lt.max()) + 1\n",
    "    class_indices = [lt_indices[targets_lt == c] for c in range(num_classes)]\n",
    "    class_counts  = np.array([len(ci) for ci in class_indices], dtype=np.int64)\n",
    "\n",
    "    def geometric_sample(indices, p, size):\n",
    "        \"\"\"Sample 'size' items from 'indices' using geometric distribution.\"\"\"\n",
    "        if len(indices) == 0:\n",
    "            return np.array([], dtype=np.int64)\n",
    "        sampled = []\n",
    "        for _ in range(size):\n",
    "            idx = rng.integers(0, len(indices))\n",
    "            # geometric: keep re-sampling head with probability p\n",
    "            while rng.random() < p and len(indices) > 1:\n",
    "                idx = rng.integers(0, len(indices))\n",
    "            sampled.append(indices[idx])\n",
    "        return np.array(sampled, dtype=np.int64)\n",
    "\n",
    "    # Env1: majority-heavy (p=0.7)\n",
    "    env1_indices = []\n",
    "    for c in range(num_classes):\n",
    "        n_c = class_counts[c]\n",
    "        env1_indices.extend(geometric_sample(class_indices[c], p=0.7, size=n_c))\n",
    "\n",
    "    # Env2: minority-friendly (p=0.3)\n",
    "    env2_indices = []\n",
    "    for c in range(num_classes):\n",
    "        n_c = class_counts[c]\n",
    "        env2_indices.extend(geometric_sample(class_indices[c], p=0.3, size=n_c))\n",
    "\n",
    "    env1_indices = np.array(env1_indices, dtype=np.int64)\n",
    "    env2_indices = np.array(env2_indices, dtype=np.int64)\n",
    "\n",
    "    # Wrap in IndexedFromBalanced-like dataset\n",
    "    class IndexedFromBalancedIRM(Dataset):\n",
    "        def __init__(self, base_dataset, indices):\n",
    "            self.base = base_dataset\n",
    "            self.indices = indices\n",
    "        def __len__(self):\n",
    "            return len(self.indices)\n",
    "        def __getitem__(self, i):\n",
    "            real_i = self.indices[i]\n",
    "            x, y = self.base[real_i]\n",
    "            return x, y, i\n",
    "\n",
    "    env1_ds = IndexedFromBalancedIRM(full_dataset, env1_indices)\n",
    "    env2_ds = IndexedFromBalancedIRM(full_dataset, env2_indices)\n",
    "\n",
    "    return env1_ds, env2_ds\n",
    "\n",
    "\n",
    "class CIFAR10GroupsForDRO:\n",
    "    \"\"\"\n",
    "    Treat each class as a DRO group for GroupDRO.\n",
    "    Compatible with the LossComputer interface.\n",
    "    \"\"\"\n",
    "    def __init__(self, targets_full, lt_indices, class_names):\n",
    "        self.targets_full = targets_full\n",
    "        self.lt_indices = lt_indices\n",
    "        self.class_names = class_names\n",
    "        self.n_classes = len(class_names)\n",
    "        self.n_groups = self.n_classes  # each class is a group\n",
    "        # compute group counts\n",
    "        targets_lt = targets_full[lt_indices]\n",
    "        self.group_counts = np.bincount(targets_lt, minlength=self.n_classes)\n",
    "\n",
    "    def group_str(self, group_idx):\n",
    "        return self.class_names[group_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1solB9Kf3QC4",
    "outputId": "40bb812b-9b46-4324-ea55-7a0f9773b7f5"
   },
   "outputs": [],
   "source": [
    "# ------------------- knobs -------------------\n",
    "VAL_FRACTION       = 0.1          # LT-train → LT-val fraction (if not using test-as-val)\n",
    "USE_TEST_AS_VAL    = False        # DEBUG: use balanced test as validation\n",
    "BATCH              = 512\n",
    "WORKERS            = 2\n",
    "PIN_MEM            = True\n",
    "PERSISTENT         = True\n",
    "PREFETCH           = 2\n",
    "NUM_CLASSES        = 10\n",
    "DATA_ROOT          = \"/content/data\"\n",
    "SEED               = 42\n",
    "IMB_FACTOR         = 1.0        # long-tail imbalance ratio (max/min)\n",
    "\n",
    "# ------------------- transforms -------------------\n",
    "mean = (0.4914, 0.4822, 0.4465)\n",
    "std  = (0.2470, 0.2435, 0.2616)\n",
    "\n",
    "train_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])\n",
    "eval_tf  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])\n",
    "\n",
    "# ------------------- base datasets -------------------\n",
    "full_train_bal = datasets.CIFAR10(DATA_ROOT, train=True,  download=True, transform=train_tf)\n",
    "full_test      = datasets.CIFAR10(DATA_ROOT, train=False, download=True, transform=eval_tf)\n",
    "targets = np.array(full_train_bal.targets, dtype=np.int64)\n",
    "\n",
    "# ------------------- LT index build (exp schedule) -------------------\n",
    "def img_num_per_cls(total, cls_num, imb_factor_inv):\n",
    "    img_max = total / cls_num\n",
    "    return np.array([\n",
    "        int(img_max * (imb_factor_inv ** (i / (cls_num - 1.0))))\n",
    "        for i in range(cls_num)\n",
    "    ], dtype=np.int64)\n",
    "\n",
    "def build_cifarLT_indices(targets: np.ndarray, num_classes: int, imbalance_ratio: float, seed: int = 0) -> np.ndarray:\n",
    "    rng = np.random.default_rng(seed)\n",
    "    imb_factor_inv = 1.0 / float(imbalance_ratio)\n",
    "    per_cls_count = img_num_per_cls(total=len(targets), cls_num=num_classes, imb_factor_inv=imb_factor_inv)\n",
    "    chosen = []\n",
    "    for c in range(num_classes):\n",
    "        idxs = np.where(targets == c)[0]\n",
    "        idxs_shuffled = rng.permutation(idxs)[:per_cls_count[c]]\n",
    "        chosen.extend(idxs_shuffled)\n",
    "    return np.array(chosen, dtype=np.int64)\n",
    "\n",
    "lt_indices = build_cifarLT_indices(targets, NUM_CLASSES, IMB_FACTOR, seed=SEED)\n",
    "\n",
    "# ------------------- stratified train/val split (on LT subset) -------------------\n",
    "def stratified_split(indices, labels, num_classes, val_fraction=0.1, seed=42):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    train_idx, val_idx = [], []\n",
    "    for c in range(num_classes):\n",
    "        c_idx = indices[labels[indices] == c]\n",
    "        rng.shuffle(c_idx)\n",
    "        split = int((1.0 - val_fraction) * len(c_idx))\n",
    "        train_idx.extend(c_idx[:split]); val_idx.extend(c_idx[split:])\n",
    "    return np.array(train_idx, dtype=np.int64), np.array(val_idx, dtype=np.int64)\n",
    "\n",
    "# ------------------- dataset wrappers -------------------\n",
    "class IndexedFromBalanced(Dataset):\n",
    "    def __init__(self, base_dataset, indices):\n",
    "        self.base = base_dataset\n",
    "        self.indices = indices\n",
    "    def __len__(self):\n",
    "        return len(self.indices)\n",
    "    def __getitem__(self, i):\n",
    "        real_i = self.indices[i]; x, y = self.base[real_i]; return x, y, i\n",
    "\n",
    "class WrapTest(Dataset):\n",
    "    def __init__(self, ds):\n",
    "        self.ds = ds\n",
    "    def __len__(self):\n",
    "        return len(self.ds)\n",
    "    def __getitem__(self, i):\n",
    "        x, y = self.ds[i]; return x, y, i\n",
    "\n",
    "if USE_TEST_AS_VAL:\n",
    "    # debug mode: train on *all* LT, val=test\n",
    "    train_idx, val_idx = lt_indices, np.arange(len(full_test), dtype=np.int64)\n",
    "    train_subset, val_subset, test_subset = IndexedFromBalanced(full_train_bal, train_idx), WrapTest(full_test), WrapTest(full_test)\n",
    "else:\n",
    "    # proper: split LT into train/val\n",
    "    train_idx, val_idx = stratified_split(lt_indices, targets, NUM_CLASSES, VAL_FRACTION, SEED)\n",
    "    train_subset, val_subset, test_subset = IndexedFromBalanced(full_train_bal, train_idx), IndexedFromBalanced(full_train_bal, val_idx), WrapTest(full_test)\n",
    "\n",
    "# ------------------- smoothed global prior from the *training* split -------------------\n",
    "DIRICHLET_ALPHA_RATIO = 1e-6  # tiny, safe; keeps prior faithful while preventing exact zeros\n",
    "num_classes = NUM_CLASSES\n",
    "\n",
    "if USE_TEST_AS_VAL:\n",
    "    # When using test-as-val (debug), the train_idx includes all LT indices\n",
    "    prior_counts = np.bincount(targets[lt_indices], minlength=num_classes).astype(np.float64)\n",
    "else:\n",
    "    # Proper setting: *training* split only (no val, no test)\n",
    "    prior_counts = np.bincount(targets[train_idx], minlength=num_classes).astype(np.float64)\n",
    "\n",
    "N_train = float(prior_counts.sum())\n",
    "alpha = DIRICHLET_ALPHA_RATIO * (N_train / max(1, num_classes))  # tiny smoothing mass per class\n",
    "prior_smoothed = prior_counts + alpha\n",
    "prior_smoothed_sum = prior_smoothed.sum()\n",
    "P_GLOBAL_TRAIN = (prior_smoothed / max(prior_smoothed_sum, 1.0)).astype(np.float64)  # numpy array, length C\n",
    "\n",
    "# Torch version on the right device when needed\n",
    "def get_P_global_tensor(device=device, dtype=torch.float32):\n",
    "    return torch.tensor(P_GLOBAL_TRAIN, device=device, dtype=dtype)\n",
    "\n",
    "\n",
    "def make_loader(ds, bs, shuffle, drop_last=False):\n",
    "    kwargs = dict(dataset=ds,\n",
    "                  batch_size=bs,\n",
    "                  shuffle=shuffle,\n",
    "                  num_workers=WORKERS,\n",
    "                  pin_memory=(PIN_MEM and WORKERS > 0),\n",
    "                  drop_last=drop_last)\n",
    "    if WORKERS:\n",
    "        kwargs.update(persistent_workers=PERSISTENT, prefetch_factor=PREFETCH)\n",
    "    return DataLoader(**kwargs)\n",
    "\n",
    "train_loader = make_loader(train_subset, BATCH, True)\n",
    "val_loader   = make_loader(val_subset,   BATCH, False)\n",
    "test_loader  = make_loader(test_subset,  BATCH, False)\n",
    "\n",
    "print(f\"[CIFAR-10-LT IF={IMB_FACTOR:.0f}] train={len(train_subset)}  \"\n",
    "      f\"val={'TEST' if USE_TEST_AS_VAL else 'LT'}={len(val_subset)}  \"\n",
    "      f\"test(bal)={len(test_subset)}\")\n",
    "\n",
    "# ==========================================\n",
    "# === IRM Environment Setup ===\n",
    "# ==========================================\n",
    "# Build IRM environments (probabilistic split)\n",
    "irm_ds1, irm_ds2 = build_irm_environments_probabilistic(full_train_bal, train_idx, seed=SEED)\n",
    "\n",
    "# Ensure batch sizes are appropriate given the uneven split\n",
    "irm_loader1 = make_loader(irm_ds1, BATCH // 2, True)\n",
    "irm_loader2 = make_loader(irm_ds2, BATCH // 2, True)\n",
    "\n",
    "print(f\"[IRM Envs] env1={len(irm_ds1)}  env2={len(irm_ds2)}\")\n",
    "\n",
    "# ==========================================\n",
    "# === GroupDRO Dataset Setup ===\n",
    "# ==========================================\n",
    "# Use the LT train subset as the \"train_data\" for GroupDRO\n",
    "groupdro_dataset = CIFAR10GroupsForDRO(targets, train_idx, CIFAR10_CLASSES)\n",
    "print(f\"[GroupDRO] dataset ready with {groupdro_dataset.n_groups} groups\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ubYpb6XO7iiC"
   },
   "outputs": [],
   "source": [
    "PRINT_CLASS_STATS = True\n",
    "CLASS_STATS_EVERY = 5\n",
    "\n",
    "def _fmt_hms(s):\n",
    "    h = int(s // 3600); m = int((s % 3600) // 60); sec = s - 3600*h - 60*m\n",
    "    return f\"{h:d}h {m:02d}m {sec:05.2f}s\" if h else (f\"{m:d}m {sec:05.2f}s\" if m else f\"{sec:.2f}s\")\n",
    "\n",
    "def amp_ctx():\n",
    "    return torch.amp.autocast('cuda', enabled=(AMP and device.type==\"cuda\"))\n",
    "\n",
    "@torch.no_grad()\n",
    "def eval_ce_loss(model, loader):\n",
    "    model.eval()\n",
    "    ce = nn.CrossEntropyLoss(reduction='sum')\n",
    "    total, n = 0.0, 0\n",
    "    with amp_ctx():\n",
    "        for x, y, *_ in loader:\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "            total += ce(model(x), y).item(); n += y.size(0)\n",
    "    return total / max(1, n)\n",
    "\n",
    "@torch.no_grad()\n",
    "def eval_accuracy(model, loader, return_cm=False):\n",
    "    model.eval()\n",
    "    all_preds, all_labels = [], []\n",
    "    with amp_ctx():\n",
    "        for x, y, *_ in loader:\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "            preds = model(x).argmax(1)\n",
    "            all_preds.append(preds.cpu().numpy()); all_labels.append(y.cpu().numpy())\n",
    "    y_pred = np.concatenate(all_preds); y_true = np.concatenate(all_labels)\n",
    "    acc = (y_pred == y_true).mean()\n",
    "    if return_cm:\n",
    "        cm = confusion_matrix(y_true, y_pred, labels=list(range(len(CIFAR10_CLASSES))))\n",
    "        return acc, cm\n",
    "    return acc\n",
    "\n",
    "@torch.no_grad()\n",
    "def per_class_metrics(model, loader, num_classes=len(CIFAR10_CLASSES)):\n",
    "    model.eval()\n",
    "    all_preds, all_labels = [], []\n",
    "    for x, y, *_ in loader:\n",
    "        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "        preds = model(x).argmax(1)\n",
    "        all_preds.append(preds.cpu().numpy()); all_labels.append(y.cpu().numpy())\n",
    "    y_pred = np.concatenate(all_preds); y_true = np.concatenate(all_labels)\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))\n",
    "    per_cls_acc = (cm.diagonal() / cm.sum(axis=1).clip(min=1)).astype(float)\n",
    "    k = max(1, num_classes // 3)\n",
    "    head = per_cls_acc[:k].mean()\n",
    "    mid  = per_cls_acc[k:2*k].mean() if num_classes >= 2*k else np.nan\n",
    "    tail = per_cls_acc[2*k:].mean()   if num_classes >  2*k else np.nan\n",
    "    return per_cls_acc, float(head), float(mid), float(tail), cm\n",
    "\n",
    "def maybe_print_class_stats(tag, ep, model, loader):\n",
    "    if not PRINT_CLASS_STATS or CLASS_STATS_EVERY is None or CLASS_STATS_EVERY <= 0: return\n",
    "    if ep % CLASS_STATS_EVERY != 0: return\n",
    "    acc = eval_accuracy(model, loader)\n",
    "    per_cls, _, _, _, cm_full = per_class_metrics(model, loader, num_classes=len(CIFAR10_CLASSES))\n",
    "    print(f\"[{tag}] ep{ep:03d} overallAcc={acc:.4f}\")\n",
    "    print(f\"[{tag}] per-class acc:\")\n",
    "    for i, a in enumerate(per_cls):\n",
    "        cname = CIFAR10_CLASSES[i] if i < len(CIFAR10_CLASSES) else f\"class{i}\"\n",
    "        print(f\"  {i:02d} ({cname:10s}): {float(a):.3f}\")\n",
    "    print(f\"[{tag}] confusion matrix (rows=true, cols=pred):\\n{cm_full}\")\n",
    "\n",
    "def log_epoch(tag, ep, tr, va, acc, t0, extra=\"\"):\n",
    "    t = _fmt_hms(time.perf_counter() - t0)\n",
    "    msg = f\"[{tag}] ep{ep:03d} trainCE={tr:.4f}  valCE={va:.4f}  valAcc={acc:.4f}  t={t}\"\n",
    "    if extra: msg = f\"[{tag}] ep{ep:03d} {extra} trainCE={tr:.4f}  valCE={va:.4f}  valAcc={acc:.4f}  t={t}\"\n",
    "    print(msg)\n",
    "\n",
    "# CUDA helpers (single source of truth)\n",
    "def cuda_mem(fmt=True):\n",
    "    if not torch.cuda.is_available():\n",
    "        return \"\" if fmt else (0.0, 0.0, 0.0)\n",
    "    alloc = torch.cuda.memory_allocated() / 1024**2\n",
    "    reserv = torch.cuda.memory_reserved() / 1024**2\n",
    "    peak  = torch.cuda.max_memory_allocated() / 1024**2\n",
    "    return (f\"alloc={alloc:.1f}MB  reserved={reserv:.1f}MB  peak={peak:.1f}MB\") if fmt else (alloc, reserv, peak)\n",
    "\n",
    "def cuda_clear_cache(sync=True):\n",
    "    gc.collect()\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()\n",
    "        torch.cuda.ipc_collect()\n",
    "        if sync: torch.cuda.synchronize()\n",
    "        torch.cuda.reset_peak_memory_stats()\n",
    "\n",
    "def start_phase(tag: str): cuda_clear_cache(sync=True)\n",
    "def end_phase(tag: str):   cuda_clear_cache(sync=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LrVvrcSKy1xA"
   },
   "outputs": [],
   "source": [
    "RUNS_ROOT = \"/content/drive/My Drive/cifarlt_runs\"\n",
    "\n",
    "def _now_tag():\n",
    "    return datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "\n",
    "def make_run_dir(imb_factor: float, algo: str, extra_tag: str = \"\") -> str:\n",
    "    if_str = f\"IF_{int(imb_factor) if float(imb_factor).is_integer() else imb_factor}\"\n",
    "    base   = os.path.join(RUNS_ROOT, if_str)\n",
    "    os.makedirs(base, exist_ok=True)\n",
    "    name = f\"{_now_tag()}_{algo}\"\n",
    "    if extra_tag: name += f\"_{extra_tag}\"\n",
    "    d = os.path.join(base, name)\n",
    "    os.makedirs(d, exist_ok=True)\n",
    "    return d\n",
    "\n",
    "def save_confusion_matrices(algo_name: str, out_dir: str, model, train_loader, val_loader, test_loader):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    # Train\n",
    "    train_acc, train_cm = eval_accuracy(model, train_loader, return_cm=True)\n",
    "    np.save(os.path.join(out_dir, \"cm_train.npy\"), train_cm)\n",
    "    np.savetxt(os.path.join(out_dir, \"cm_train.csv\"), train_cm, fmt=\"%d\", delimiter=\",\")\n",
    "    # Val\n",
    "    val_acc, val_cm = eval_accuracy(model, val_loader, return_cm=True)\n",
    "    np.save(os.path.join(out_dir, \"cm_val.npy\"), val_cm)\n",
    "    np.savetxt(os.path.join(out_dir, \"cm_val.csv\"), val_cm, fmt=\"%d\", delimiter=\",\")\n",
    "    # Test\n",
    "    test_acc, test_cm = eval_accuracy(model, test_loader, return_cm=True)\n",
    "    np.save(os.path.join(out_dir, \"cm_test.npy\"), test_cm)\n",
    "    np.savetxt(os.path.join(out_dir, \"cm_test.csv\"), test_cm, fmt=\"%d\", delimiter=\",\")\n",
    "    # Summary\n",
    "    with open(os.path.join(out_dir, \"final_summary.json\"), \"w\") as f:\n",
    "        json.dump({\n",
    "            \"algo\": algo_name,\n",
    "            \"train_acc\": float(train_acc),\n",
    "            \"val_acc\": float(val_acc),\n",
    "            \"test_acc\": float(test_acc),\n",
    "        }, f, indent=2)\n",
    "\n",
    "class EpochCSVLogger:\n",
    "    \"\"\"Logs CE + Acc for TRAIN, VAL, TEST each epoch.\"\"\"\n",
    "    def __init__(self, out_dir: str, algo_name: str):\n",
    "        self.path = os.path.join(out_dir, f\"{algo_name}_epoch_metrics.csv\")\n",
    "        self._init = False\n",
    "\n",
    "    def write(self, row: dict):\n",
    "        header = [\"epoch\",\"train_ce\",\"train_acc\",\"val_ce\",\"val_acc\",\"test_ce\",\"test_acc\"]\n",
    "        mode = \"a\" if self._init else \"w\"\n",
    "        with open(self.path, mode, newline=\"\") as f:\n",
    "            w = csv.DictWriter(f, fieldnames=header)\n",
    "            if not self._init:\n",
    "                w.writeheader()\n",
    "                self._init = True\n",
    "            w.writerow({k: row.get(k, \"\") for k in header})\n",
    "\n",
    "@torch.no_grad()\n",
    "def _train_accuracy_epoch(model, loader):\n",
    "    \"\"\"Exact train-set accuracy in one pass (for epoch logging).\"\"\"\n",
    "    model.eval()\n",
    "    correct, total = 0, 0\n",
    "    with amp_ctx():\n",
    "        for x, y, *_ in loader:\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "            pred = model(x).argmax(1)\n",
    "            correct += (pred == y).sum().item()\n",
    "            total   += y.numel()\n",
    "    return (correct / max(1,total)) if total else 0.0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zuJV-njQ4LnO"
   },
   "outputs": [],
   "source": [
    "class BasicBlock(nn.Module):\n",
    "    def __init__(self, in_planes, out_planes, stride, drop_rate=0.0):\n",
    "        super().__init__()\n",
    "        self.equalInOut = (in_planes == out_planes)\n",
    "        self.bn1 = nn.BatchNorm2d(in_planes)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(out_planes)\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.drop_rate = drop_rate\n",
    "        self.shortcut = (None if self.equalInOut else\n",
    "                         nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False))\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.relu1(self.bn1(x))\n",
    "        if self.drop_rate > 0: out = F.dropout(out, p=self.drop_rate, training=self.training)\n",
    "        out = self.conv1(out)\n",
    "        out = self.relu2(self.bn2(out))\n",
    "        if self.drop_rate > 0: out = F.dropout(out, p=self.drop_rate, training=self.training)\n",
    "        out = self.conv2(out)\n",
    "        res = x if self.equalInOut else self.shortcut(x)\n",
    "        return out + res\n",
    "\n",
    "class NetworkBlock(nn.Module):\n",
    "    def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate):\n",
    "        super().__init__()\n",
    "        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, drop_rate)\n",
    "    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate):\n",
    "        layers = []\n",
    "        for i in range(nb_layers):\n",
    "            s = stride if i == 0 else 1\n",
    "            layers.append(block(in_planes, out_planes, s, drop_rate))\n",
    "            in_planes = out_planes\n",
    "        return nn.Sequential(*layers)\n",
    "    def forward(self, x):\n",
    "        return self.layer(x)\n",
    "\n",
    "class WideResNet(nn.Module):\n",
    "    def __init__(self, depth=28, widen_factor=10, num_classes=NUM_CLASSES, drop_rate=0.0):\n",
    "        super().__init__()\n",
    "        assert (depth - 4) % 6 == 0, 'WRN depth should be 6n+4'\n",
    "        n = (depth - 4) // 6\n",
    "        k = widen_factor\n",
    "        nStages = [16, 16*k, 32*k, 64*k]\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, nStages[0], kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.block1 = NetworkBlock(n, nStages[0], nStages[1], BasicBlock, 1, drop_rate)\n",
    "        self.block2 = NetworkBlock(n, nStages[1], nStages[2], BasicBlock, 2, drop_rate)\n",
    "        self.block3 = NetworkBlock(n, nStages[2], nStages[3], BasicBlock, 2, drop_rate)\n",
    "        self.bn = nn.BatchNorm2d(nStages[3])\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.fc = nn.Linear(nStages[3], num_classes)\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                nn.init.constant_(m.weight, 1.0)\n",
    "                nn.init.constant_(m.bias, 0.0)\n",
    "            elif isinstance(m, nn.Linear):\n",
    "                nn.init.kaiming_normal_(m.weight)\n",
    "                nn.init.constant_(m.bias, 0.0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.block1(x)\n",
    "        x = self.block2(x)\n",
    "        x = self.block3(x)\n",
    "        x = self.relu(self.bn(x))\n",
    "        x = self.avgpool(x).flatten(1)\n",
    "        return self.fc(x)\n",
    "\n",
    "class SimpleCNN(nn.Module):\n",
    "    \"\"\"Compatibility wrapper exposing .fc2 like your earlier code expected.\"\"\"\n",
    "    def __init__(self, num_classes=NUM_CLASSES):\n",
    "        super().__init__()\n",
    "        self.backbone = WideResNet(depth=28, widen_factor=10, num_classes=num_classes, drop_rate=0.0)\n",
    "        self.fc2 = self.backbone.fc\n",
    "    def forward(self, x):\n",
    "        return self.backbone(x)\n",
    "\n",
    "def fresh_model():\n",
    "    m = SimpleCNN(num_classes=NUM_CLASSES).to(device)\n",
    "    with torch.no_grad():\n",
    "        if hasattr(m, \"fc2\") and getattr(m.fc2, \"bias\", None) is not None:\n",
    "            m.fc2.bias.zero_()\n",
    "    return m\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qHBrIemB4GyV"
   },
   "outputs": [],
   "source": [
    "# Generic one-epoch vanilla trainer (used for IRS warmup, etc.)\n",
    "def train_epoch(model, loader, optimizer, criterion):\n",
    "    model.train()\n",
    "    total, n = 0.0, 0\n",
    "    for x, y, _ in tqdm(loader, leave=False):\n",
    "        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "        logits = model(x)\n",
    "        loss = criterion(logits, y)\n",
    "        loss.backward(); optimizer.step()\n",
    "        bs = x.size(0); total += loss.item() * bs; n += bs\n",
    "    return total / max(1, n)\n",
    "\n",
    "# === Phi-KL machinery ===\n",
    "class PhiKL:\n",
    "    name = \"kl\"\n",
    "    @staticmethod\n",
    "    def T(p):        return 1.0 + p.log()\n",
    "    @staticmethod\n",
    "    def T_inv(y):    return (y - 1.0).exp()\n",
    "    @staticmethod\n",
    "    def D(p, q, eps=1e-12):\n",
    "        p_ = p.clamp_min(eps); q_ = q.clamp_min(eps)\n",
    "        return (p_ * (p_.log() - q_.log())).sum()\n",
    "\n",
    "# === IRS(KL) helpers — KL path + fast 1D solver ===\n",
    "\n",
    "@torch.no_grad()\n",
    "def _kl_path_p_of_h(F_vals: torch.Tensor, P_hat: torch.Tensor, h: float, exp_clip: float = 40.0):\n",
    "    \"\"\"\n",
    "    KL trajectory: p_i(h) = softmax( log(P_hat_i) + h * F_i ).\n",
    "    Numerically stable with log-sum-exp; clamps h*F to prevent overflow.\n",
    "    \"\"\"\n",
    "    log_base = torch.log(P_hat.clamp_min(1e-30))\n",
    "    logits   = log_base + (F_vals * float(h))\n",
    "    # clamp for numerical safety (symmetric window)\n",
    "    logits   = torch.clamp(logits, min=logits.max()-exp_clip, max=logits.max()+1e-6)\n",
    "    # softmax\n",
    "    P = torch.softmax(logits, dim=0)\n",
    "    return P\n",
    "\n",
    "@torch.no_grad()\n",
    "def _kappa_for_h(F_vals: torch.Tensor, P_hat: torch.Tensor, tau_t: torch.Tensor,\n",
    "                 h: float, distance_scale: float = 1.0, min_div: float = 1e-2):\n",
    "    \"\"\"\n",
    "    Compute kappa(h) along the KL path.\n",
    "    kappa(h) = (E_P[F] - tau) / (distance_scale * KL(P || P_hat) + min_div)\n",
    "    Returns: (kappa, numerator, D_kl, P)\n",
    "    \"\"\"\n",
    "    P   = _kl_path_p_of_h(F_vals, P_hat, h)\n",
    "    E_F = torch.dot(P, F_vals)\n",
    "    num = (E_F - tau_t).item()\n",
    "    # KL(P||P_hat) = sum P * (log P - log P_hat)\n",
    "    Dkl = (P * (torch.log(P.clamp_min(1e-30)) - torch.log(P_hat.clamp_min(1e-30)))).sum().item()\n",
    "    den = distance_scale * max(Dkl, 0.0) + float(min_div)\n",
    "    kappa = num / max(den, 1e-20)\n",
    "    return float(kappa), float(num), float(Dkl), P\n",
    "\n",
    "@torch.no_grad()\n",
    "def _secant_maximize_kappa(F_vals: torch.Tensor, P_hat: torch.Tensor, tau_t: torch.Tensor,\n",
    "                           h_init_left: float, h_init_right: float,\n",
    "                           distance_scale: float = 1.0, min_div: float = 1e-2,\n",
    "                           max_iter: int = 12, expand_steps: int = 4):\n",
    "    \"\"\"\n",
    "    Maximize kappa(h) on the real line by a bracket-expansion + secant steps.\n",
    "    We expand the bracket a few times if needed, then do secant on kappa'(h)≈0 via finite-diff on kappa.\n",
    "    Much cheaper than 50-iter golden-section; good enough per-batch if warm-started.\n",
    "    \"\"\"\n",
    "    # Small helper to evaluate kappa\n",
    "    def K(h): return _kappa_for_h(F_vals, P_hat, tau_t, h, distance_scale, min_div)[0]\n",
    "\n",
    "    # Expand bracket (heuristic)\n",
    "    a, b = float(h_init_left), float(h_init_right)\n",
    "    Ka, Kb = K(a), K(b)\n",
    "    # try to find a local \"up then down\" shape by expansion\n",
    "    step = (b - a)\n",
    "    for _ in range(expand_steps):\n",
    "        if Kb > Ka:\n",
    "            a, Ka = b, Kb\n",
    "            b = a + step\n",
    "            Kb = K(b)\n",
    "        else:\n",
    "            b, Kb = a, Ka\n",
    "            a = b - step\n",
    "            Ka = K(a)\n",
    "        step *= 1.5\n",
    "\n",
    "    # Secant-like search on the argmax via derivative-free update:\n",
    "    # Use three points and move towards the higher side with parabolic interpolation fallback.\n",
    "    x0, x1 = a, b\n",
    "    y0, y1 = Ka, Kb\n",
    "    for _ in range(max_iter):\n",
    "        if abs(x1 - x0) < 1e-8:\n",
    "            break\n",
    "        # finite-diff slope; move towards higher value side\n",
    "        slope = (y1 - y0) / (x1 - x0)\n",
    "        # step towards where slope ≈ 0; simple heuristic: shift by y1/slope\n",
    "        # protect against slope ~ 0\n",
    "        if abs(slope) < 1e-12:\n",
    "            x2 = x1 + (1.0 if y1 >= y0 else -1.0) * max(1.0, abs(x1))\n",
    "        else:\n",
    "            x2 = x1 - y1 / slope\n",
    "        y2 = K(x2)\n",
    "        # keep the best two\n",
    "        xs = [(y0, x0), (y1, x1), (y2, x2)]\n",
    "        xs.sort(key=lambda t: t[0], reverse=True)  # highest kappa first\n",
    "        (y_best, x_best), (y_next, x_next) = xs[0], xs[1]\n",
    "        # update pair for next iter\n",
    "        x0, y0 = x_best, y_best\n",
    "        x1, y1 = x_next, y_next\n",
    "\n",
    "    # Return the best of the last trio\n",
    "    kh, num, Dkl, P = _kappa_for_h(F_vals, P_hat, tau_t, x0, distance_scale, min_div)\n",
    "    return x0, kh, num, Dkl, P\n",
    "\n",
    "# === Class-wise helpers for IRS(KL)-CW with *global* reference prior ===\n",
    "from typing import Tuple\n",
    "\n",
    "def classwise_F_with_global_P(ce_per_sample: torch.Tensor,\n",
    "                              labels: torch.Tensor,\n",
    "                              num_classes: int,\n",
    "                              P_global: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    Inputs:\n",
    "      ce_per_sample : [B] per-sample CE (float tensor, on device)\n",
    "      labels        : [B] int64 labels\n",
    "      num_classes   : int\n",
    "      P_global      : [C] fixed global class prior from the LT *training* split (full support, smoothed)\n",
    "\n",
    "    Returns:\n",
    "      F      : [C] mean CE per class for classes present in the current batch; 0 for absent\n",
    "      P_hatS : [C] reference prior *conditioned on the present classes S* (0 outside S, sums to 1 over S)\n",
    "      mask   : [C] boolean mask of present classes in the batch\n",
    "    \"\"\"\n",
    "    device_ = ce_per_sample.device\n",
    "    dtype   = ce_per_sample.dtype\n",
    "\n",
    "    # Counts and present-class mask on the batch\n",
    "    counts = torch.bincount(labels, minlength=num_classes).to(device_)   # [C]\n",
    "    mask   = counts > 0\n",
    "\n",
    "    # Batch-mean CE per present class\n",
    "    ce_sum = torch.zeros(num_classes, device=device_, dtype=dtype)\n",
    "    ce_sum.scatter_add_(0, labels, ce_per_sample)\n",
    "    F = torch.zeros(num_classes, device=device_, dtype=dtype)\n",
    "    F[mask] = ce_sum[mask] / counts[mask].to(dtype)\n",
    "\n",
    "    # Reference prior: condition the *global* prior onto the present-class support S\n",
    "    P_hatS = torch.zeros_like(F)\n",
    "    P_g = P_global.to(device=device_, dtype=dtype)\n",
    "    mass_S = P_g[mask].sum().clamp_min(1e-12)\n",
    "    P_hatS[mask] = P_g[mask] / mass_S  # sums to 1 over S; zero outside S\n",
    "\n",
    "    return F, P_hatS, mask\n",
    "\n",
    "\n",
    "# === KLRS feasibility estimator (used inside training loop) ===\n",
    "@torch.no_grad()\n",
    "def _estimate_feasibility(model, data_loader, tau, lam, amp_enabled=True, exp_clamp=10.0):\n",
    "    \"\"\"Estimate E[ exp((ℓ-τ)/λ) ] over the (shuffled) training set.\"\"\"\n",
    "    ce_el = nn.CrossEntropyLoss(reduction='none')\n",
    "    s = 0.0\n",
    "    n = 0\n",
    "    for x, y, *_ in data_loader:\n",
    "        x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)\n",
    "        with torch.amp.autocast('cuda', enabled=amp_enabled):\n",
    "            logits = model(x)\n",
    "            ce = ce_el(logits, y)  # [B]\n",
    "            u  = (ce - tau) / lam\n",
    "            u  = torch.clamp(u, max=exp_clamp)\n",
    "            f  = torch.exp(u)\n",
    "        s += f.sum().item()\n",
    "        n += f.numel()\n",
    "    return s / max(1, n)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mMwbqn6xmgEu"
   },
   "outputs": [],
   "source": [
    "from typing import Iterable\n",
    "\n",
    "class SAMAdam(torch.optim.Adam):\n",
    "    def __init__(self, params: Iterable[torch.Tensor], lr=1e-3, rho=0.05, **kwargs):\n",
    "        if rho <= 0:\n",
    "            raise ValueError(f\"rho must be positive, got {rho}\")\n",
    "        self.rho = float(rho)\n",
    "        super().__init__(params, lr=lr, **kwargs)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _grad_norm(self):\n",
    "        device_ = self.param_groups[0][\"params\"][0].device\n",
    "        norms = []\n",
    "        for group in self.param_groups:\n",
    "            for p in group[\"params\"]:\n",
    "                if p.grad is not None:\n",
    "                    norms.append(p.grad.detach().norm(2))\n",
    "        if not norms:\n",
    "            return torch.tensor(0., device=device_)\n",
    "        return torch.norm(torch.stack(norms), p=2)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _epsilon(self, scale, max_scale=1e3):\n",
    "        \"\"\"Apply perturbation with a safety cap on scale to prevent NaN collapse.\"\"\"\n",
    "        # Clamp scale to prevent explosion when gradient norm is tiny\n",
    "        scale = min(scale, max_scale) if isinstance(scale, (int, float)) else scale.clamp(max=max_scale)\n",
    "        epsilons = []\n",
    "        for group in self.param_groups:\n",
    "            eps_group = []\n",
    "            for p in group[\"params\"]:\n",
    "                if p.grad is None:\n",
    "                    eps_group.append(None); continue\n",
    "                e = p.grad * scale\n",
    "                # Check for NaN/Inf and skip if found\n",
    "                if not torch.isfinite(e).all():\n",
    "                    eps_group.append(torch.zeros_like(p))\n",
    "                    continue\n",
    "                p.add_(e)\n",
    "                eps_group.append(e)\n",
    "            epsilons.append(eps_group)\n",
    "        return epsilons\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _restore(self, epsilons):\n",
    "        for group, eps_group in zip(self.param_groups, epsilons):\n",
    "            for p, e in zip(group[\"params\"], eps_group):\n",
    "                if e is not None:\n",
    "                    p.sub_(e)\n",
    "\n",
    "import torch\n",
    "import contextlib\n",
    "\n",
    "def disable_running_stats(model):\n",
    "    # Freeze BN running stats updates, but keep BN in train mode (uses batch stats).\n",
    "    for m in model.modules():\n",
    "        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):\n",
    "            m._sam_backup_momentum = m.momentum\n",
    "            m.momentum = 0.0\n",
    "\n",
    "def enable_running_stats(model):\n",
    "    # Restore BN momentum\n",
    "    for m in model.modules():\n",
    "        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm) and hasattr(m, \"_sam_backup_momentum\"):\n",
    "            m.momentum = m._sam_backup_momentum\n",
    "            delattr(m, \"_sam_backup_momentum\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "k7RLnkyR8RRw"
   },
   "outputs": [],
   "source": [
    "# --- ERM (Empirical Risk Minimization) ---\n",
    "def train_erm(model, train_loader, val_loader, *, epochs=80, lr=1e-3,\n",
    "              weight_decay=0.0, print_every=1, test_loader=None,\n",
    "              log_dir=None, algo_name=\"ERM\"):\n",
    "    model.to(device)\n",
    "    opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "    ce_fn  = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train(); total_ce_sum, total_n = 0.0, 0\n",
    "        for x, y, _ in tqdm(train_loader, leave=False):\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            with amp_ctx():\n",
    "                logits = model(x)\n",
    "                loss_vec = ce_fn(logits, y)\n",
    "                loss = loss_vec.mean()\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(opt); scaler.update()\n",
    "            bs = x.size(0)\n",
    "            total_ce_sum += loss_vec.detach().mean().item() * bs\n",
    "            total_n += bs\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce   = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc  = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va,  \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(\"ERM\", ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(\"ERM\", ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "# --- ERM with Adam optimizer ---\n",
    "def train_erm_adam(model, train_loader, val_loader, *, epochs=80, lr=1e-3,\n",
    "                   weight_decay=0.0, print_every=1, test_loader=None,\n",
    "                   log_dir=None, algo_name=\"ERM-Adam\"):\n",
    "    \"\"\"ERM using Adam optimizer instead of SGD.\"\"\"\n",
    "    model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "    ce_fn  = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train(); total_ce_sum, total_n = 0.0, 0\n",
    "        for x, y, _ in tqdm(train_loader, leave=False):\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            with amp_ctx():\n",
    "                logits = model(x)\n",
    "                loss_vec = ce_fn(logits, y)\n",
    "                loss = loss_vec.mean()\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(opt); scaler.update()\n",
    "            bs = x.size(0)\n",
    "            total_ce_sum += loss_vec.detach().mean().item() * bs\n",
    "            total_n += bs\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce   = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc  = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va,  \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(\"ERM-Adam\", ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(\"ERM-Adam\", ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "# --- SAM (Sharpness-Aware Minimization) ---\n",
    "# --- SAM (Sharpness-Aware Minimization) ---\n",
    "# --- SAM (Sharpness-Aware Minimization) ---\n",
    "def train_sam(model, train_loader, val_loader,\n",
    "              epochs=80, lr=1e-3, rho=0.05, weight_decay=0.0,\n",
    "              momentum=None, use_cosine=False, label_smoothing=0.0,\n",
    "              print_every=1, test_loader=None, log_dir=None, algo_name=\"SAM\",\n",
    "              grad_clip=1.0, max_scale=1e3):\n",
    "    \"\"\"\n",
    "    Paper-faithful SAM: always perform the two forward/backward passes per iteration.\n",
    "    Added stability improvements: gradient clipping, max scale limit, NaN skip.\n",
    "    Ref: Foret et al., ICLR 2021 (\"Sharpness-Aware Minimization\") — two gradient comps/iter.\n",
    "    \"\"\"\n",
    "    model.to(device)\n",
    "    opt = SAMAdam(model.parameters(), lr=lr, rho=rho, weight_decay=weight_decay)\n",
    "    ce  = nn.CrossEntropyLoss(label_smoothing=label_smoothing) if label_smoothing>0 else nn.CrossEntropyLoss()\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "    sched  = (torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs, eta_min=1e-6)\n",
    "              if use_cosine else None)\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train(); total_ce_sum, total_n = 0.0, 0\n",
    "        for x, y, _ in tqdm(train_loader, leave=False):\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "\n",
    "            # 1) forward-backward at w (BN normal)\n",
    "            enable_running_stats(model)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            with amp_ctx():\n",
    "                logits = model(x)\n",
    "                loss = ce(logits, y)\n",
    "\n",
    "            # Skip batch if loss is NaN\n",
    "            if not torch.isfinite(loss):\n",
    "                continue\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "\n",
    "            # Scale epsilon ~ rho * g / ||g||, with safety limits\n",
    "            scaler.unscale_(opt)\n",
    "\n",
    "            # Gradient clipping to prevent explosion\n",
    "            if grad_clip is not None:\n",
    "                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)\n",
    "\n",
    "            gnorm = opt._grad_norm()\n",
    "\n",
    "            # Skip SAM step if gradient is too small or non-finite\n",
    "            if not torch.isfinite(gnorm) or gnorm < 1e-8:\n",
    "                # Fall back to standard gradient step\n",
    "                scaler.step(opt); scaler.update()\n",
    "                bs = x.size(0)\n",
    "                total_ce_sum += loss.detach().item() * bs\n",
    "                total_n      += bs\n",
    "                continue\n",
    "\n",
    "            scale = rho / (gnorm + 1e-12)\n",
    "            epsilons = opt._epsilon(scale, max_scale=max_scale)  # Apply perturbation with safety cap\n",
    "\n",
    "            # 2) forward-backward at w + eps (freeze BN running stats)\n",
    "            disable_running_stats(model)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            with amp_ctx():\n",
    "                logits_pert = model(x)\n",
    "                loss_pert   = ce(logits_pert, y)\n",
    "            \n",
    "            # Check if perturbed loss is NaN - if so, restore and skip SAM step\n",
    "            if not torch.isfinite(loss_pert):\n",
    "                opt._restore(epsilons)\n",
    "                enable_running_stats(model)\n",
    "                # Fall back: just do standard step with original gradients\n",
    "                opt.zero_grad(set_to_none=True)\n",
    "                with amp_ctx():\n",
    "                    logits = model(x)\n",
    "                    loss = ce(logits, y)\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.unscale_(opt)\n",
    "                if grad_clip is not None:\n",
    "                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)\n",
    "                scaler.step(opt); scaler.update()\n",
    "                bs = x.size(0)\n",
    "                total_ce_sum += loss.detach().item() * bs\n",
    "                total_n      += bs\n",
    "                continue\n",
    "            \n",
    "            scaler.scale(loss_pert).backward()\n",
    "\n",
    "            # restore weights and BN, then step\n",
    "            opt._restore(epsilons)\n",
    "            enable_running_stats(model)\n",
    "            scaler.step(opt); scaler.update()\n",
    "\n",
    "\n",
    "            bs = x.size(0)\n",
    "            total_ce_sum += loss.detach().item() * bs\n",
    "            total_n      += bs\n",
    "\n",
    "        if sched is not None:\n",
    "            sched.step()\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        \n",
    "        # Check if val CE is NaN - model has collapsed\n",
    "        if not np.isfinite(va):\n",
    "            print(f\"[SAM] WARNING: valCE became NaN at epoch {ep}. Model collapsed.\")\n",
    "            print(f\"[SAM] Last good epoch was {ep-1}. Stopping training early.\")\n",
    "            break\n",
    "        \n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce   = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc  = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va,  \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(\"SAM\", ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(\"SAM\", ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "\n",
    "# --- IRS(KL) — class-wise variant ---\n",
    "# --- IRS(KL) — class-wise variant (with global reference prior) ---\n",
    "def train_irs_cw(model, train_loader, val_loader, *,\n",
    "                 epochs=80, lr=1e-3, warmup_epochs=1,\n",
    "                 base_tau=RS_SHARED_TAU,        # shared τ\n",
    "                 distance_scale=1.0, min_div=1e-2,\n",
    "                 print_every=1, weight_decay=0.0, grad_clip=None,\n",
    "                 tau_lb_factor=1.01, tau_lb_eps=1e-8, use_gate=True,\n",
    "                 num_classes=NUM_CLASSES, P_global: torch.Tensor = None,\n",
    "                 test_loader=None, log_dir=None, algo_name=\"IRS(KL)-CW\"):\n",
    "    \"\"\"\n",
    "    IRS (our method) — KL version using:\n",
    "      - KL closed-form path p(h) ~ softmax(log P_hat_S + h * F), where P_hat_S is the global prior conditioned on\n",
    "        present classes in the batch (S).\n",
    "      - Same secant-like 1D maximization for kappa(h)\n",
    "      - Keeps gating, τ-scheduling, min_div, warmup, etc.\n",
    "    \"\"\"\n",
    "    model.to(device)\n",
    "    opt   = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    ce_el = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "    # default to the global prior we computed earlier\n",
    "    if P_global is None:\n",
    "        P_global = get_P_global_tensor(device=device, dtype=torch.float32)\n",
    "\n",
    "    kappa_hist, h_hist, tr_hist, va_hist, va_acc_hist = [], [], [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    # --- warmup on CE (unchanged) ---\n",
    "    for ep in range(1, warmup_epochs+1):\n",
    "        tr = train_epoch(model, train_loader, opt, nn.CrossEntropyLoss())\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc = eval_accuracy(model, val_loader)\n",
    "        kappa_hist.append(0.0); h_hist.append(0.0)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce   = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc  = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\"epoch\": ep,\"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                              \"val_ce\": va, \"val_acc\": acc, \"test_ce\": test_ce, \"test_acc\": test_acc})\n",
    "        log_epoch(\"IRS-CW-warm\", ep, tr, va, acc, t0)\n",
    "\n",
    "    # --- main loop with KL path + secant h-search ---\n",
    "    prev_h = 0.0  # warm-start h per batch\n",
    "    for ep in range(warmup_epochs+1, epochs+1):\n",
    "        model.train()\n",
    "        total_ce_sum, n_seen = 0.0, 0\n",
    "        kappa_vals, h_vals = [], []\n",
    "\n",
    "        for x, y, *_ in train_loader:\n",
    "            x = x.to(device, non_blocking=True)\n",
    "            y = y.to(device, non_blocking=True)\n",
    "\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            logits = model(x)\n",
    "            ce_vec = ce_el(logits, y)\n",
    "\n",
    "            # --- class-wise batch stats + global-reference prior conditioned on S ---\n",
    "            F_b, P_b, _mask = classwise_F_with_global_P(ce_vec, y, num_classes, P_global)\n",
    "\n",
    "            # --- tau scheduling (unchanged) ---\n",
    "            m_b   = torch.dot(F_b, P_b)  # expectation under reference over S\n",
    "            tau_b = float(m_b.item()) * float(tau_lb_factor) if base_tau <= float(m_b.item()) + float(tau_lb_eps) else float(base_tau)\n",
    "            tau_b_t = F_b.new_tensor(tau_b)\n",
    "\n",
    "            # --- Maximize kappa(h) along KL path (fast secant), warm-started from prev_h ---\n",
    "            h_left, h_right = prev_h - 2.0, prev_h + 2.0\n",
    "            h_star, k_star, numerator, D_kl = 0.0, 0.0, 0.0, 0.0\n",
    "            try:\n",
    "                h_star, k_star, numerator, D_kl, P_star_b = _secant_maximize_kappa(\n",
    "                    F_b, P_b, tau_b_t,\n",
    "                    h_init_left=h_left, h_init_right=h_right,\n",
    "                    distance_scale=distance_scale, min_div=min_div,\n",
    "                    max_iter=10, expand_steps=3\n",
    "                )\n",
    "            except Exception:\n",
    "                # fallback: evaluate at prev_h\n",
    "                k_star, numerator, D_kl, P_star_b = _kappa_for_h(F_b, P_b, tau_b_t, prev_h, distance_scale, min_div)\n",
    "                h_star = prev_h\n",
    "\n",
    "            # gating (keep)\n",
    "            do_update = (numerator >= 0.0) if use_gate else True\n",
    "            if do_update:\n",
    "                P_const = P_star_b.detach()\n",
    "                numerator_t = torch.dot(P_const, F_b) - tau_b_t\n",
    "                D_eff = F_b.new_tensor(distance_scale * max(D_kl, 0.0) + float(min_div))\n",
    "                kappa = numerator_t / D_eff.clamp_min(1e-20)\n",
    "                kappa.backward()\n",
    "                if grad_clip is not None:\n",
    "                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)\n",
    "                opt.step()\n",
    "\n",
    "                kappa_vals.append(float(kappa.detach().item()))\n",
    "                h_vals.append(float(h_star))\n",
    "                prev_h = float(h_star)\n",
    "\n",
    "            total_ce_sum += ce_vec.mean().item() * x.size(0)\n",
    "            n_seen += x.size(0)\n",
    "\n",
    "        tr = total_ce_sum / max(1, n_seen)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc = eval_accuracy(model, val_loader)\n",
    "\n",
    "        kappa_bar = float(sum(kappa_vals)/max(1, len(kappa_vals))) if kappa_vals else 0.0\n",
    "        h_bar     = float(sum(h_vals)/max(1, len(h_vals)))         if h_vals     else 0.0\n",
    "\n",
    "        kappa_hist.append(kappa_bar); h_hist.append(h_bar)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce   = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc  = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\"epoch\": ep,\"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                              \"val_ce\": va, \"val_acc\": acc, \"test_ce\": test_ce, \"test_acc\": test_acc})\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(\"IRS-CW\", ep, tr, va, acc, t0, extra=f\"κ̄={kappa_bar:+.2e} h̄={h_bar:+.2e}\")\n",
    "        maybe_print_class_stats(\"IRS-CW\", ep, model, val_loader)\n",
    "\n",
    "    return kappa_hist, h_hist, tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# === REx: Loss Machinery ===\n",
    "# ==========================================\n",
    "\n",
    "def get_rex_penalty(losses, mode=\"vrex\", lam=1.0):\n",
    "    \"\"\"\n",
    "    Compute variance-based or Min-Max objective from environment losses.\n",
    "\n",
    "    Reference: Krueger et al. \\\"Out-of-Distribution Generalization via Risk Extrapolation (REx)\\\"\n",
    "    ICML 2021 (https://arxiv.org/abs/2003.00688)\n",
    "\n",
    "    Args:\n",
    "        losses: list/tuple of tensors, each shape [] (per-environment mean losses).\n",
    "        mode: \\\"vrex\\\" → variance penalty (added to ERM)\n",
    "              \\\"mmrex\\\" → extrapolation objective (replaces ERM loss)\n",
    "        lam:\n",
    "          - For vrex: unused here (beta is applied in train_rex)\n",
    "          - For mmrex: λ for extrapolation.\n",
    "            λ > 1: extrapolate beyond worst-case (recommended: 1.5-2.0)\n",
    "            λ = 1: worst-case (equivalent to DRO)\n",
    "            0 < λ < 1: interpolate between min and max\n",
    "\n",
    "    Returns:\n",
    "        For vrex: variance penalty tensor (scalar)\n",
    "        For mmrex: full extrapolated loss (replaces ERM, not added to it)\n",
    "    \"\"\"\n",
    "    if mode == \"vrex\":\n",
    "        # V-REx: Var(R_e) = E[(R_e - E[R_e])^2]\n",
    "        # Paper Eq. 5: L_vrex = ERM + β * Var(R_e)\n",
    "        # The variance is the penalty; β is applied externally in train_rex()\n",
    "        mean_loss = sum(losses) / len(losses)\n",
    "        var_loss = sum((l - mean_loss) ** 2 for l in losses) / len(losses)\n",
    "        return var_loss\n",
    "    elif mode == \"mmrex\":\n",
    "        # MM-REx: Paper Eq. 4\n",
    "        # L_mmrex = λ * max_e(R_e) + (1 - λ) * min_e(R_e)\n",
    "        # For λ > 1: extrapolates beyond the worst environment\n",
    "        # This REPLACES the ERM objective (not a penalty added to ERM)\n",
    "        loss_max = max(losses)\n",
    "        loss_min = min(losses)\n",
    "        return lam * loss_max + (1 - lam) * loss_min\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown REx mode: {mode}\")\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# === IRM: Environment & Loss Machinery ===\n",
    "# ==========================================\n",
    "\n",
    "def irm_penalty(logits, y):\n",
    "    \"\"\"\n",
    "    IRMv1 penalty: ||∇_w [ loss(w ∘ Φ(x); y) ] ||^2 at w=1.0\n",
    "    (gradient of the loss w.r.t a scalar dummy classifier weight).\n",
    "    \"\"\"\n",
    "    scale = torch.ones(1, requires_grad=True, device=logits.device, dtype=logits.dtype)\n",
    "    loss  = F.cross_entropy(logits * scale, y)\n",
    "    grad  = torch.autograd.grad(loss, [scale], create_graph=True)[0]\n",
    "    return (grad ** 2).sum()\n",
    "\n",
    "\n",
    "def build_irm_environments_probabilistic(full_dataset, lt_indices, seed=42):\n",
    "    \"\"\"\n",
    "    Split the LT training subset into two probabilistic IRM environments.\n",
    "    Env1: oversamples majority classes (geometric sampling with p=0.7).\n",
    "    Env2: undersamples majority to balance better (geometric with p=0.3).\n",
    "    Returns: (env1_dataset, env2_dataset)\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    targets_full = np.array(full_dataset.targets, dtype=np.int64)\n",
    "    targets_lt = targets_full[lt_indices]\n",
    "\n",
    "    num_classes = int(targets_lt.max()) + 1\n",
    "    class_indices = [lt_indices[targets_lt == c] for c in range(num_classes)]\n",
    "    class_counts  = np.array([len(ci) for ci in class_indices], dtype=np.int64)\n",
    "\n",
    "    def geometric_sample(indices, p, size):\n",
    "        \"\"\"Sample 'size' items from 'indices' using geometric distribution.\"\"\"\n",
    "        if len(indices) == 0:\n",
    "            return np.array([], dtype=np.int64)\n",
    "        sampled = []\n",
    "        for _ in range(size):\n",
    "            idx = rng.integers(0, len(indices))\n",
    "            # geometric: keep re-sampling head with probability p\n",
    "            while rng.random() < p and len(indices) > 1:\n",
    "                idx = rng.integers(0, len(indices))\n",
    "            sampled.append(indices[idx])\n",
    "        return np.array(sampled, dtype=np.int64)\n",
    "\n",
    "    # Env1: majority-heavy (p=0.7)\n",
    "    env1_indices = []\n",
    "    for c in range(num_classes):\n",
    "        n_c = class_counts[c]\n",
    "        env1_indices.extend(geometric_sample(class_indices[c], p=0.7, size=n_c))\n",
    "\n",
    "    # Env2: minority-friendly (p=0.3)\n",
    "    env2_indices = []\n",
    "    for c in range(num_classes):\n",
    "        n_c = class_counts[c]\n",
    "        env2_indices.extend(geometric_sample(class_indices[c], p=0.3, size=n_c))\n",
    "\n",
    "    env1_indices = np.array(env1_indices, dtype=np.int64)\n",
    "    env2_indices = np.array(env2_indices, dtype=np.int64)\n",
    "\n",
    "    # Wrap in IndexedFromBalanced\n",
    "    class IndexedFromBalanced(Dataset):\n",
    "        def __init__(self, base_dataset, indices):\n",
    "            self.base = base_dataset\n",
    "            self.indices = indices\n",
    "        def __len__(self):\n",
    "            return len(self.indices)\n",
    "        def __getitem__(self, i):\n",
    "            real_i = self.indices[i]\n",
    "            x, y = self.base[real_i]\n",
    "            return x, y, i\n",
    "\n",
    "    env1_ds = IndexedFromBalanced(full_dataset, env1_indices)\n",
    "    env2_ds = IndexedFromBalanced(full_dataset, env2_indices)\n",
    "\n",
    "    return env1_ds, env2_ds\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# === GroupDRO: Dataset and LossComputer ===\n",
    "# ==========================================\n",
    "\n",
    "class CIFAR10GroupsForDRO:\n",
    "    \"\"\"\n",
    "    Treat each class as a DRO group for GroupDRO.\n",
    "    Compatible with the LossComputer interface.\n",
    "    \"\"\"\n",
    "    def __init__(self, targets_full, lt_indices, class_names):\n",
    "        self.targets_full = targets_full\n",
    "        self.lt_indices = lt_indices\n",
    "        self.class_names = class_names\n",
    "        self.n_classes = len(class_names)\n",
    "        self.n_groups = self.n_classes  # each class is a group\n",
    "        # compute group counts\n",
    "        targets_lt = targets_full[lt_indices]\n",
    "        self.group_counts = np.bincount(targets_lt, minlength=self.n_classes)\n",
    "\n",
    "    def group_str(self, group_idx):\n",
    "        return self.class_names[group_idx]\n",
    "\n",
    "\n",
    "class LossComputer:\n",
    "    \"\"\"\n",
    "    Paper-faithful GroupDRO loss computer (Sagawa et al., 2020).\n",
    "    Uses robust DRO weighting: q_g ∝ exp(α * R_g).\n",
    "    \"\"\"\n",
    "    def __init__(self, dataset, alpha=0.2, gamma=0.1, adj=None,\n",
    "                 min_var_weight=0, step_size=0.01, normalize_loss=False,\n",
    "                 btl=False, device='cuda'):\n",
    "        self.dataset = dataset\n",
    "        self.alpha = alpha\n",
    "        self.gamma = gamma\n",
    "        self.step_size = step_size\n",
    "        self.device = device\n",
    "        self.n_groups = dataset.n_groups\n",
    "\n",
    "        # Initialize group weights uniformly\n",
    "        self.adv_probs = torch.ones(self.n_groups, device=device) / self.n_groups\n",
    "\n",
    "        # EMA accumulators for group losses\n",
    "        self.exp_avg_loss = torch.zeros(self.n_groups, device=device)\n",
    "        self.exp_avg_initialized = torch.zeros(self.n_groups, device=device).bool()\n",
    "\n",
    "        # Group counts\n",
    "        self.group_counts = torch.tensor(dataset.group_counts, dtype=torch.float32, device=device)\n",
    "        self.group_frac = self.group_counts / self.group_counts.sum()\n",
    "\n",
    "        # Processed counts (for logging)\n",
    "        self.processed_data_counts = torch.zeros(self.n_groups, device=device)\n",
    "        self.update_data_counts = torch.zeros(self.n_groups, device=device)\n",
    "        self.update_batch_counts = torch.zeros(self.n_groups, device=device)\n",
    "\n",
    "        # Adjustment (unused, for compatibility)\n",
    "        self.adj = adj\n",
    "        self.normalize_loss = normalize_loss\n",
    "        self.btl = btl\n",
    "        self.min_var_weight = min_var_weight\n",
    "\n",
    "    def loss(self, yhat, y, group_idx=None, is_training=True):\n",
    "        \"\"\"\n",
    "        Compute per-sample losses and DRO-weighted batch loss.\n",
    "        Args:\n",
    "            yhat: [B, C] logits\n",
    "            y: [B] labels\n",
    "            group_idx: [B] group indices (if None, use y as group indices)\n",
    "            is_training: bool\n",
    "        Returns:\n",
    "            loss (scalar), loss_dict (for logging)\n",
    "        \"\"\"\n",
    "        # Per-sample CE\n",
    "        per_sample_losses = F.cross_entropy(yhat, y, reduction='none')\n",
    "\n",
    "        # Group indices (each class is a group)\n",
    "        if group_idx is None:\n",
    "            group_idx = y\n",
    "\n",
    "        # Compute group loss contributions\n",
    "        unique_groups = group_idx.unique()\n",
    "        group_losses = []\n",
    "        group_counts_batch = []\n",
    "        for g in range(self.n_groups):\n",
    "            mask = (group_idx == g)\n",
    "            if mask.any():\n",
    "                group_loss = per_sample_losses[mask].mean()\n",
    "                group_losses.append(group_loss)\n",
    "                group_counts_batch.append(mask.sum().item())\n",
    "            else:\n",
    "                group_losses.append(torch.tensor(0.0, device=self.device))\n",
    "                group_counts_batch.append(0)\n",
    "\n",
    "        group_losses = torch.stack(group_losses)\n",
    "        group_counts_batch = torch.tensor(group_counts_batch, dtype=torch.float32, device=self.device)\n",
    "\n",
    "        # Update EMA\n",
    "        if is_training:\n",
    "            for g in range(self.n_groups):\n",
    "                if group_counts_batch[g] > 0:\n",
    "                    if not self.exp_avg_initialized[g]:\n",
    "                        self.exp_avg_loss[g] = group_losses[g].detach()\n",
    "                        self.exp_avg_initialized[g] = True\n",
    "                    else:\n",
    "                        self.exp_avg_loss[g] = (\n",
    "                            self.gamma * group_losses[g].detach() +\n",
    "                            (1 - self.gamma) * self.exp_avg_loss[g]\n",
    "                        )\n",
    "                    self.update_data_counts[g] += group_counts_batch[g]\n",
    "                    self.update_batch_counts[g] += 1\n",
    "\n",
    "            self.processed_data_counts += group_counts_batch\n",
    "\n",
    "            # Update adversarial weights: q_g ∝ exp(α * R_g)\n",
    "            adv_probs = torch.exp(self.alpha * self.exp_avg_loss.detach())\n",
    "            adv_probs = adv_probs / (adv_probs.sum() + 1e-12)\n",
    "            self.adv_probs = adv_probs\n",
    "\n",
    "        # DRO-weighted loss\n",
    "        group_weights = self.adv_probs\n",
    "        weighted_loss = (group_weights * group_losses).sum()\n",
    "\n",
    "        # Loss dict for logging\n",
    "        loss_dict = {\n",
    "            \"avg_loss\": per_sample_losses.mean().item(),\n",
    "            \"weighted_loss\": weighted_loss.item(),\n",
    "        }\n",
    "        for g in range(self.n_groups):\n",
    "            loss_dict[f\"loss_group_{g}\"] = group_losses[g].item()\n",
    "            loss_dict[f\"weight_group_{g}\"] = group_weights[g].item()\n",
    "\n",
    "        return weighted_loss, loss_dict\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# === χ²-DRO (Duchi & Namkoong 2021) ===\n",
    "# ==========================================\n",
    "\n",
    "def chi2_dro_loss(loss_vec: torch.Tensor, rho: float,\n",
    "                  bisect_tol: float = 1e-5, max_iter: int = 50) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Chi-squared DRO loss:\n",
    "        R_ρ(θ) = inf_η { √(1+2ρ) · √(E_P[(ℓ(θ;X) - η)_+²]) + η }\n",
    "    Solve via bisection on η.\n",
    "    Args:\n",
    "        loss_vec: [B] per-sample losses\n",
    "        rho: Chi-squared constraint radius (≥0). Larger rho → more robust.\n",
    "        bisect_tol: tolerance for bisection\n",
    "        max_iter: max bisection iterations\n",
    "    Returns:\n",
    "        scalar DRO loss\n",
    "    \"\"\"\n",
    "    B = loss_vec.size(0)\n",
    "    if B == 0:\n",
    "        return loss_vec.new_tensor(0.0)\n",
    "\n",
    "    # Clamp rho to non-negative\n",
    "    rho = max(float(rho), 0.0)\n",
    "    c = math.sqrt(1.0 + 2.0 * rho)\n",
    "\n",
    "    # Bisection bounds\n",
    "    eta_min = loss_vec.min().item()\n",
    "    eta_max = loss_vec.max().item()\n",
    "\n",
    "    # If all losses are equal, return the mean\n",
    "    if abs(eta_max - eta_min) < 1e-12:\n",
    "        return loss_vec.mean()\n",
    "\n",
    "    # Bisection to find η* that minimizes R_ρ\n",
    "    def eval_R(eta):\n",
    "        diff = loss_vec - eta\n",
    "        positive_part = torch.clamp(diff, min=0.0)\n",
    "        var_term = (positive_part ** 2).mean()\n",
    "        return c * torch.sqrt(var_term + 1e-12) + eta\n",
    "\n",
    "    for _ in range(max_iter):\n",
    "        if (eta_max - eta_min) < bisect_tol:\n",
    "            break\n",
    "        eta_mid = 0.5 * (eta_min + eta_max)\n",
    "        # derivative test\n",
    "        diff = loss_vec - eta_mid\n",
    "        positive_part = torch.clamp(diff, min=0.0)\n",
    "        var_term = (positive_part ** 2).mean()\n",
    "        grad = 1.0 - c * (positive_part.mean()) / torch.sqrt(var_term + 1e-12)\n",
    "        if grad.item() < 0:\n",
    "            eta_min = eta_mid\n",
    "        else:\n",
    "            eta_max = eta_mid\n",
    "\n",
    "    eta_star = 0.5 * (eta_min + eta_max)\n",
    "    return eval_R(eta_star)\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# === CVaR-DRO (Levy et al. 2020) ===\n",
    "# ==========================================\n",
    "\n",
    "def cvar_loss_from_batch(loss_vec: torch.Tensor, alpha: float) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    CVaR-DRO loss (Levy et al., 2020):\n",
    "        CVaR_α(ℓ) = inf_t { t + (1/α) E[(ℓ - t)_+] }\n",
    "    Solved in closed form: CVaR = mean of top-α quantile losses.\n",
    "    Args:\n",
    "        loss_vec: [B] per-sample losses\n",
    "        alpha: CVaR level in (0,1]. Smaller α → more robust (focus on worst).\n",
    "    Returns:\n",
    "        scalar CVaR loss\n",
    "    \"\"\"\n",
    "    alpha = max(min(float(alpha), 1.0), 1e-6)\n",
    "    B = loss_vec.size(0)\n",
    "    k = max(1, int(math.ceil(alpha * B)))\n",
    "    # Top-k losses (worst)\n",
    "    topk_losses, _ = torch.topk(loss_vec, k, largest=True, sorted=False)\n",
    "    return topk_losses.mean()\n",
    "\n",
    "\n",
    "# --- KLRS ---\n",
    "def train_klrs(model, train_loader, val_loader, *,\n",
    "               epochs: int = 80,\n",
    "               lr: float = 1e-3,\n",
    "               weight_decay: float = 0.0,\n",
    "               base_tau: float = None,             # <-- if None, we estimate once (toy behavior)\n",
    "               warmup_epochs: int = 1,            # ERM warmup to bring losses down before KL-RS constraint\n",
    "               lam_lo: float = 1e-3,\n",
    "               lam_hi_init: float = 1.0,\n",
    "               bisect_tol: float = 1e-3,\n",
    "               expand_factor: float = 2.0,\n",
    "               max_expand: int = 12,\n",
    "               inner_epochs_probe: int = 1,        # epochs per λ probe (as in toy)\n",
    "               alt_iters: int = 5,                 # #alternations between θ-update & λ-search\n",
    "               inner_epochs_theta: int = 1,        # #epochs to improve θ between bisections\n",
    "               exp_clamp: float = 50.0,            # clamp for numerical stability (toy: [-50, 50])\n",
    "               grad_clip_max_norm: float = None,\n",
    "               optimizer: str = \"adam\",            # kept for signature compatibility; we use Adam (toy)\n",
    "               print_every: int = 1,\n",
    "               test_loader=None, log_dir=None, algo_name: str = \"KLRS\"):\n",
    "    \"\"\"\n",
    "    KL-RS trainer matching the toy implementation of τ and the alternating λ search:\n",
    "      • τ is FIXED: provided via base_tau or estimated once from a few batches (1.05× mean CE).\n",
    "      • λ* is found by doubling + bisection; feasibility probes RUN SHORT θ-UPDATES (inner_epochs_probe).\n",
    "      • Objective: minimize E[exp((ℓ - τ)/λ)].\n",
    "\n",
    "    CRITICAL: warmup_epochs runs ERM first to bring losses down before applying the KL-RS constraint.\n",
    "    Without warmup, if loss >> tau, the constraint becomes infeasible and lambda explodes.\n",
    "\n",
    "    Returns: (train_ce_hist, val_ce_hist, val_acc_hist)\n",
    "    \"\"\"\n",
    "\n",
    "    import math\n",
    "    import time\n",
    "    from typing import Tuple\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model = model.to(device)\n",
    "    ce_nored = nn.CrossEntropyLoss(reduction='none')\n",
    "    ce_mean  = nn.CrossEntropyLoss(reduction='mean')\n",
    "\n",
    "    # --- helpers (toy-style) ---\n",
    "    @torch.no_grad()\n",
    "    def _accuracy(m: nn.Module, loader) -> float:\n",
    "        if loader is None: return float('nan')\n",
    "        m.eval()\n",
    "        correct = total = 0\n",
    "        for xb, yb, *rest in loader:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            preds = m(xb).argmax(dim=1)\n",
    "            correct += (preds == yb).sum().item()\n",
    "            total += yb.numel()\n",
    "        return correct / max(total, 1)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _estimate_mean_f(m: nn.Module, loader, tau: float, lam: float) -> float:\n",
    "        m.eval()\n",
    "        tot = 0.0\n",
    "        n_batches = 0\n",
    "        for xb, yb, *rest in loader:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            logits = m(xb)\n",
    "            a = (ce_nored(logits, yb) - tau) / lam\n",
    "            a = torch.clamp(a, min=-exp_clamp, max=exp_clamp)\n",
    "            f = torch.exp(a).mean()\n",
    "            tot += f.item(); n_batches += 1\n",
    "        return tot / max(n_batches, 1)\n",
    "\n",
    "    def _feasibility_train_step(m: nn.Module, loader, opt, tau: float, lam: float):\n",
    "        m.train()\n",
    "        for xb, yb, *rest in loader:\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            logits = m(xb)\n",
    "            a = (ce_nored(logits, yb) - tau) / lam\n",
    "            a = torch.clamp(a, min=-exp_clamp, max=exp_clamp)\n",
    "            f = torch.exp(a).mean()\n",
    "            f.backward()\n",
    "            if grad_clip_max_norm is not None:\n",
    "                torch.nn.utils.clip_grad_norm_(m.parameters(), grad_clip_max_norm)\n",
    "            opt.step()\n",
    "\n",
    "    def _feasibility_oracle(m: nn.Module,\n",
    "                            loader,\n",
    "                            tau: float,\n",
    "                            lam: float,\n",
    "                            inner_epochs: int,\n",
    "                            lr: float,\n",
    "                            wd: float) -> Tuple[nn.Module, float]:\n",
    "        \"\"\"\n",
    "        For a given λ: run a few epochs to minimize E[exp((ℓ-τ)/λ)] (updates θ), then return mean f.\n",
    "        \"\"\"\n",
    "        m = m.to(device)\n",
    "        opt = torch.optim.Adam(m.parameters(), lr=lr, weight_decay=wd)\n",
    "        for _ in range(inner_epochs):\n",
    "            _feasibility_train_step(m, loader, opt, tau, lam)\n",
    "        mean_f = _estimate_mean_f(m, loader, tau, lam)\n",
    "        return m, mean_f\n",
    "\n",
    "    def _find_min_feasible_lambda(m: nn.Module,\n",
    "                                  loader,\n",
    "                                  tau: float,\n",
    "                                  lam_lo_init: float,\n",
    "                                  lam_hi_init: float,\n",
    "                                  lr: float,\n",
    "                                  wd: float,\n",
    "                                  inner_epochs_probe: int,\n",
    "                                  bisect_tol: float,\n",
    "                                  expand_factor: float,\n",
    "                                  max_expand: int) -> Tuple[nn.Module, float]:\n",
    "        \"\"\"\n",
    "        Doubling (if needed) + bisection to find the smallest feasible λ.\n",
    "        θ is UPDATED during probes (alternating minimization, like the toy).\n",
    "        \"\"\"\n",
    "        lam_lo_ = lam_lo_init\n",
    "        lam_hi_ = lam_hi_init\n",
    "\n",
    "        # Ensure lam_hi is feasible; expand if needed.\n",
    "        m, mean_f = _feasibility_oracle(m, loader, tau, lam_hi_, inner_epochs_probe, lr, wd)\n",
    "        expands = 0\n",
    "        while mean_f > 1.0 and expands < max_expand:\n",
    "            lam_lo_ = lam_hi_\n",
    "            lam_hi_ *= expand_factor\n",
    "            m, mean_f = _feasibility_oracle(m, loader, tau, lam_hi_, inner_epochs_probe, lr, wd)\n",
    "            expands += 1\n",
    "\n",
    "        # If still infeasible, return current state.\n",
    "        if mean_f > 1.0:\n",
    "            print(f\"[KLRS] Failed to find feasible λ after {max_expand} expansions (λ={lam_hi_:.2e}, E[f]={mean_f:.3f})\")\n",
    "            print(f\"[KLRS] This usually means tau ({tau:.4f}) is too low relative to current loss. Consider increasing warmup_epochs.\")\n",
    "            return m, lam_hi_\n",
    "\n",
    "        # Try to make lam_lo infeasible to tighten bracket (optional refinement).\n",
    "        for _ in range(max_expand):\n",
    "            m, mean_f_lo = _feasibility_oracle(m, loader, tau, lam_lo_, inner_epochs_probe, lr, wd)\n",
    "            if mean_f_lo > 1.0:\n",
    "                break\n",
    "            lam_hi_ = lam_lo_\n",
    "            lam_lo_ = max(lam_lo_ / expand_factor, 1e-6)  # Prevent collapse to near-zero\n",
    "            if lam_lo_ <= 1e-6:\n",
    "                break\n",
    "\n",
    "        # Bisection.\n",
    "        while (lam_hi_ - lam_lo_) > bisect_tol * max(1.0, lam_hi_):\n",
    "            lam_mid = 0.5 * (lam_lo_ + lam_hi_)\n",
    "            m, mean_f_mid = _feasibility_oracle(m, loader, tau, lam_mid, inner_epochs_probe, lr, wd)\n",
    "            if mean_f_mid <= 1.0:\n",
    "                lam_hi_ = lam_mid\n",
    "            else:\n",
    "                lam_lo_ = lam_mid\n",
    "\n",
    "        return m, lam_hi_\n",
    "\n",
    "    # --- τ: fixed (toy behavior) ---\n",
    "    if base_tau is None:\n",
    "        # one-time estimate from a few batches\n",
    "        with torch.no_grad():\n",
    "            est, n = 0.0, 0\n",
    "            for xb, yb, *rest in train_loader:\n",
    "                xb, yb = xb.to(device), yb.to(device)\n",
    "                est += ce_mean(model(xb), yb).item()\n",
    "                n += 1\n",
    "                if n >= 5:\n",
    "                    break\n",
    "            tau = 1.05 * (est / max(n, 1))\n",
    "    else:\n",
    "        tau = float(base_tau)\n",
    "\n",
    "    # --- histories & logging ---\n",
    "    tr_hist, va_hist, acc_hist = [], [], []\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "    t0 = time.perf_counter()\n",
    "\n",
    "    # Optimizer used during θ-improvement steps outside the feasibility probes\n",
    "    opt_theta = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "\n",
    "    # === WARMUP PHASE: ERM training to bring losses down before KL-RS constraint ===\n",
    "    # This is CRITICAL: if loss >> tau, the KL-RS constraint is infeasible and lambda explodes\n",
    "    if warmup_epochs > 0:\n",
    "        print(f\"[KLRS] Starting {warmup_epochs} epochs of ERM warmup (tau={tau:.4f})...\")\n",
    "        for ep in range(1, warmup_epochs + 1):\n",
    "            model.train()\n",
    "            for xb, yb, *rest in train_loader:\n",
    "                xb, yb = xb.to(device), yb.to(device)\n",
    "                opt_theta.zero_grad(set_to_none=True)\n",
    "                loss = ce_mean(model(xb), yb)\n",
    "                loss.backward()\n",
    "                if grad_clip_max_norm is not None:\n",
    "                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max_norm)\n",
    "                opt_theta.step()\n",
    "\n",
    "            # Check loss after warmup epoch\n",
    "            if ep == warmup_epochs or ep % print_every == 0:\n",
    "                model.eval()\n",
    "                with torch.no_grad():\n",
    "                    tr_sum, tr_n = 0.0, 0\n",
    "                    for xb, yb, *rest in train_loader:\n",
    "                        xb, yb = xb.to(device), yb.to(device)\n",
    "                        tr_sum += ce_mean(model(xb), yb).item() * xb.size(0)\n",
    "                        tr_n += xb.size(0)\n",
    "                    warmup_loss = tr_sum / max(1, tr_n)\n",
    "                print(f\"[KLRS-warmup] ep{ep:02d} trainCE={warmup_loss:.4f} (target tau={tau:.4f})\")\n",
    "\n",
    "        print(f\"[KLRS] Warmup complete. Final loss={warmup_loss:.4f}, tau={tau:.4f}\")\n",
    "        if warmup_loss > tau * 5:\n",
    "            print(f\"[KLRS WARNING] Loss ({warmup_loss:.4f}) still >> tau ({tau:.4f}). Lambda may explode.\")\n",
    "\n",
    "    lam = float(lam_hi_init)\n",
    "\n",
    "    # Alternations: (1) improve θ at current λ, (2) find smallest feasible λ\n",
    "    for it in range(alt_iters):\n",
    "        # (1) θ-improvement epochs at current λ\n",
    "        for _ in range(inner_epochs_theta):\n",
    "            _feasibility_train_step(model, train_loader, opt_theta, tau, lam)\n",
    "\n",
    "        # (2) find λ*\n",
    "        model, lam = _find_min_feasible_lambda(model, train_loader, tau,\n",
    "                                               lam_lo_init=lam_lo,\n",
    "                                               lam_hi_init=lam,\n",
    "                                               lr=lr, wd=weight_decay,\n",
    "                                               inner_epochs_probe=inner_epochs_probe,\n",
    "                                               bisect_tol=bisect_tol,\n",
    "                                               expand_factor=expand_factor,\n",
    "                                               max_expand=max_expand)\n",
    "\n",
    "        # evaluate after each alternation\n",
    "        model.eval()\n",
    "        # Log CE for comparability (on train/val), plus accuracy\n",
    "        with torch.no_grad():\n",
    "            # mean CE on train\n",
    "            tr_sum, tr_n = 0.0, 0\n",
    "            for xb, yb, *rest in train_loader:\n",
    "                xb, yb = xb.to(device), yb.to(device)\n",
    "                tr_sum += ce_mean(model(xb), yb).item() * xb.size(0)\n",
    "                tr_n += xb.size(0)\n",
    "            tr = tr_sum / max(1, tr_n)\n",
    "\n",
    "        va = eval_ce_loss(model, val_loader) if 'eval_ce_loss' in globals() else float('nan')\n",
    "        acc = eval_accuracy(model, val_loader) if 'eval_accuracy' in globals() else _accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); acc_hist.append(acc)\n",
    "\n",
    "        test_ce  = eval_ce_loss(model, test_loader) if (test_loader is not None and 'eval_ce_loss' in globals()) else float('nan')\n",
    "        test_acc = eval_accuracy(model, test_loader) if (test_loader is not None and 'eval_accuracy' in globals()) else (_accuracy(model, test_loader) if test_loader is not None else float('nan'))\n",
    "\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\"epoch\": it+1,\n",
    "                              \"train_ce\": tr, \"val_ce\": va, \"val_acc\": acc,\n",
    "                              \"test_ce\": test_ce, \"test_acc\": test_acc,\n",
    "                              \"lambda\": lam, \"tau\": tau})\n",
    "\n",
    "        if (it+1) % print_every == 0:\n",
    "            log_epoch(\"KLRS-alt\", it+1, tr, va, acc, t0, extra=f\"λ*={lam:.3e} τ={tau:.4f}\") if 'log_epoch' in globals() \\\n",
    "                else print(f\"[KLRS-alt][it {it+1:03d}] trainCE={tr:.4f} valCE={va:.4f} valAcc={acc:.4f}  λ*={lam:.3e} τ={tau:.4f}\")\n",
    "\n",
    "        if 'maybe_print_class_stats' in globals():\n",
    "            maybe_print_class_stats(\"KLRS\", it+1, model, val_loader)\n",
    "\n",
    "    # Final polishing epochs at fixed λ*\n",
    "    remaining = max(0, epochs - alt_iters*inner_epochs_theta)\n",
    "    for ep in range(remaining):\n",
    "        _feasibility_train_step(model, train_loader, opt_theta, tau, lam)\n",
    "\n",
    "        # periodic logging\n",
    "        with torch.no_grad():\n",
    "            tr_sum, tr_n = 0.0, 0\n",
    "            for xb, yb, *rest in train_loader:\n",
    "                xb, yb = xb.to(device), yb.to(device)\n",
    "                tr_sum += ce_mean(model(xb), yb).item() * xb.size(0)\n",
    "                tr_n += xb.size(0)\n",
    "            tr = tr_sum / max(1, tr_n)\n",
    "\n",
    "        va = eval_ce_loss(model, val_loader) if 'eval_ce_loss' in globals() else float('nan')\n",
    "        acc = eval_accuracy(model, val_loader) if 'eval_accuracy' in globals() else _accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); acc_hist.append(acc)\n",
    "\n",
    "        if (ep+1) % print_every == 0:\n",
    "            log_epoch(\"KLRS\", ep+1, tr, va, acc, t0, extra=f\"λ={lam:.3e} τ={tau:.4f}\") if 'log_epoch' in globals() \\\n",
    "                else print(f\"[KLRS][ep {ep+1:03d}] trainCE={tr:.4f} valCE={va:.4f} valAcc={acc:.4f}  λ={lam:.3e} τ={tau:.4f}\")\n",
    "\n",
    "    return tr_hist, va_hist, acc_hist\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7GB2_56FmVIF"
   },
   "outputs": [],
   "source": [
    "# === REx Training Functions ===\n",
    "\n",
    "def train_rex(model, train_loader_env1, train_loader_env2, val_loader, *,\n",
    "              epochs=80, lr=1e-3, weight_decay=0.0,\n",
    "              rex_mode=\"vrex\",       # \"vrex\" or \"mmrex\"\n",
    "              rex_beta=10.0,         # variance penalty weight (V-REx)\n",
    "              rex_lambda=1.5,        # extrapolation coefficient (MM-REx), λ > 1 for extrapolation\n",
    "              penalty_anneal_epochs=10,\n",
    "              print_every=1, test_loader=None, log_dir=None, algo_name=\"V-REx\"):\n",
    "    \"\"\"\n",
    "    Train with REx (Risk Extrapolation).\n",
    "\n",
    "    Reference: Krueger et al. \\\"Out-of-Distribution Generalization via Risk Extrapolation (REx)\\\"\n",
    "    ICML 2021 (https://arxiv.org/abs/2003.00688)\n",
    "\n",
    "    V-REx (mode=\\\"vrex\\\"):\n",
    "        L = ERM + β * Var(R_e)\n",
    "        Adds variance penalty to encourage equal risk across environments.\n",
    "\n",
    "    MM-REx (mode=\\\"mmrex\\\"):\n",
    "        L = λ * max(R_e) + (1-λ) * min(R_e)\n",
    "        Replaces ERM with extrapolated objective.\n",
    "        λ > 1: Extrapolate beyond worst-case (emphasize difficult environment)\n",
    "        λ = 1: Worst-case only (equivalent to GroupDRO)\n",
    "    \"\"\"\n",
    "    model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "    ce_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train()\n",
    "        total_ce_sum, total_n = 0.0, 0\n",
    "\n",
    "        # Anneal penalty weight (for V-REx) or lambda (for MM-REx)\n",
    "        if ep <= penalty_anneal_epochs:\n",
    "            anneal_factor = float(ep) / float(penalty_anneal_epochs)\n",
    "        else:\n",
    "            anneal_factor = 1.0\n",
    "\n",
    "        # Iterate over both environments\n",
    "        iter_env1 = iter(train_loader_env1)\n",
    "        iter_env2 = iter(train_loader_env2)\n",
    "\n",
    "        for _ in range(min(len(train_loader_env1), len(train_loader_env2))):\n",
    "            try:\n",
    "                x1, y1, _ = next(iter_env1)\n",
    "                x2, y2, _ = next(iter_env2)\n",
    "            except StopIteration:\n",
    "                break\n",
    "\n",
    "            x1, y1 = x1.to(device, non_blocking=True), y1.to(device, non_blocking=True)\n",
    "            x2, y2 = x2.to(device, non_blocking=True), y2.to(device, non_blocking=True)\n",
    "\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "\n",
    "            with amp_ctx():\n",
    "                logits1 = model(x1)\n",
    "                logits2 = model(x2)\n",
    "                loss1 = ce_fn(logits1, y1)\n",
    "                loss2 = ce_fn(logits2, y2)\n",
    "\n",
    "                # Mean loss (for logging)\n",
    "                mean_loss = (loss1 + loss2) / 2.0\n",
    "\n",
    "                # Compute loss based on mode\n",
    "                if rex_mode == \"vrex\":\n",
    "                    # V-REx: L = ERM + β * Var(R_e)\n",
    "                    var_penalty = get_rex_penalty([loss1, loss2], mode=\"vrex\")\n",
    "                    loss = mean_loss + anneal_factor * rex_beta * var_penalty\n",
    "                elif rex_mode == \"mmrex\":\n",
    "                    # MM-REx: L = λ*max(R_e) + (1-λ)*min(R_e)\n",
    "                    # Anneal from ERM (λ=0.5) to target λ during penalty_anneal_epochs\n",
    "                    # λ=0.5 means: 0.5*max + 0.5*min = mean (ERM)\n",
    "                    current_lambda = 0.5 + anneal_factor * (rex_lambda - 0.5)\n",
    "                    loss = get_rex_penalty([loss1, loss2], mode=\"mmrex\", lam=current_lambda)\n",
    "                else:\n",
    "                    loss = mean_loss\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(opt); scaler.update()\n",
    "\n",
    "            bs = x1.size(0) + x2.size(0)\n",
    "            total_ce_sum += mean_loss.detach().item() * bs\n",
    "            total_n += bs\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader_env1)\n",
    "        test_ce = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va, \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(algo_name, ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(algo_name, ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "\n",
    "# === IRM Training Function ===\n",
    "\n",
    "def train_irm(model, train_loader_env1, train_loader_env2, val_loader, *,\n",
    "              epochs=80, lr=1e-3, weight_decay=0.0,\n",
    "              penalty_weight=10000.0, penalty_anneal_epochs=20,\n",
    "              print_every=1, test_loader=None, log_dir=None, algo_name=\"IRMv1\"):\n",
    "\n",
    "    model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train()\n",
    "        total_ce_sum, total_n = 0.0, 0\n",
    "\n",
    "        # Anneal penalty\n",
    "        if ep <= penalty_anneal_epochs:\n",
    "            current_penalty_weight = penalty_weight * (float(ep) / float(penalty_anneal_epochs))\n",
    "        else:\n",
    "            current_penalty_weight = penalty_weight\n",
    "\n",
    "        iter_env1 = iter(train_loader_env1)\n",
    "        iter_env2 = iter(train_loader_env2)\n",
    "\n",
    "        for _ in range(min(len(train_loader_env1), len(train_loader_env2))):\n",
    "            try:\n",
    "                x1, y1, _ = next(iter_env1)\n",
    "                x2, y2, _ = next(iter_env2)\n",
    "            except StopIteration:\n",
    "                break\n",
    "\n",
    "            x1, y1 = x1.to(device, non_blocking=True), y1.to(device, non_blocking=True)\n",
    "            x2, y2 = x2.to(device, non_blocking=True), y2.to(device, non_blocking=True)\n",
    "\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "\n",
    "            with amp_ctx():\n",
    "                logits1 = model(x1)\n",
    "                logits2 = model(x2)\n",
    "\n",
    "                # ERM loss\n",
    "                loss1 = F.cross_entropy(logits1, y1)\n",
    "                loss2 = F.cross_entropy(logits2, y2)\n",
    "                mean_loss = (loss1 + loss2) / 2.0\n",
    "\n",
    "                # IRM penalty\n",
    "                penalty1 = irm_penalty(logits1, y1)\n",
    "                penalty2 = irm_penalty(logits2, y2)\n",
    "                penalty = (penalty1 + penalty2) / 2.0\n",
    "\n",
    "                # Total loss\n",
    "                loss = mean_loss + current_penalty_weight * penalty\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(opt); scaler.update()\n",
    "\n",
    "            bs = x1.size(0) + x2.size(0)\n",
    "            total_ce_sum += mean_loss.detach().item() * bs\n",
    "            total_n += bs\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader_env1)\n",
    "        test_ce = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va, \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(algo_name, ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(algo_name, ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "\n",
    "# === GroupDRO Training Function ===\n",
    "\n",
    "def train_groupdro(model, train_loader, val_loader, *,\n",
    "                   epochs=80,\n",
    "                   lr=1e-3,\n",
    "                   weight_decay=5e-5,      # paper default\n",
    "                   alpha=0.2,              # group weight step size\n",
    "                   gamma=0.1,              # EMA decay for group losses\n",
    "                   robust_step_size=0.01,  # not used in this implementation\n",
    "                   print_every=1,\n",
    "                   test_loader=None,\n",
    "                   log_dir=None,\n",
    "                   algo_name=\"GroupDRO\"):\n",
    "\n",
    "    model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "\n",
    "    # Initialize LossComputer\n",
    "    loss_computer = LossComputer(groupdro_dataset, alpha=alpha, gamma=gamma, device=device)\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train()\n",
    "        total_ce_sum, total_n = 0.0, 0\n",
    "\n",
    "        for x, y, _ in tqdm(train_loader, leave=False):\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "\n",
    "            with amp_ctx():\n",
    "                logits = model(x)\n",
    "                # LossComputer handles DRO weighting\n",
    "                loss, loss_dict = loss_computer.loss(logits, y, group_idx=y, is_training=True)\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(opt); scaler.update()\n",
    "\n",
    "            bs = x.size(0)\n",
    "            total_ce_sum += loss_dict[\"avg_loss\"] * bs\n",
    "            total_n += bs\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va, \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(algo_name, ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(algo_name, ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "\n",
    "# === χ²-DRO Training Function ===\n",
    "\n",
    "def train_chi2_dro(model, train_loader, val_loader, *,\n",
    "                   epochs=80, lr=1e-3, rho=0.1,\n",
    "                   weight_decay=0.0, print_every=1,\n",
    "                   test_loader=None, log_dir=None,\n",
    "                   algo_name=\"χ²-DRO\"):\n",
    "\n",
    "    model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "    ce_fn = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train()\n",
    "        total_ce_sum, total_n = 0.0, 0\n",
    "\n",
    "        for x, y, _ in tqdm(train_loader, leave=False):\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "\n",
    "            with amp_ctx():\n",
    "                logits = model(x)\n",
    "                loss_vec = ce_fn(logits, y)\n",
    "                # Chi-squared DRO loss\n",
    "                loss = chi2_dro_loss(loss_vec, rho=rho)\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(opt); scaler.update()\n",
    "\n",
    "            bs = x.size(0)\n",
    "            total_ce_sum += loss_vec.mean().detach().item() * bs\n",
    "            total_n += bs\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va, \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(algo_name, ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(algo_name, ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist\n",
    "\n",
    "\n",
    "# === CVaR-DRO Training Function ===\n",
    "\n",
    "def train_cvar_dro(model, train_loader, val_loader, *,\n",
    "                   epochs=80, lr=1e-3, alpha=0.1,\n",
    "                   weight_decay=0.0, print_every=1,\n",
    "                   test_loader=None, log_dir=None,\n",
    "                   algo_name=\"CVaR-DRO\"):\n",
    "\n",
    "    model.to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    scaler = torch.cuda.amp.GradScaler(enabled=(AMP and device.type==\"cuda\"))\n",
    "    ce_fn = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "    tr_hist, va_hist, va_acc_hist = [], [], []\n",
    "    t0 = time.perf_counter()\n",
    "    csv_logger = EpochCSVLogger(log_dir, algo_name) if log_dir else None\n",
    "\n",
    "    for ep in range(1, epochs+1):\n",
    "        model.train()\n",
    "        total_ce_sum, total_n = 0.0, 0\n",
    "\n",
    "        for x, y, _ in tqdm(train_loader, leave=False):\n",
    "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "\n",
    "            with amp_ctx():\n",
    "                logits = model(x)\n",
    "                loss_vec = ce_fn(logits, y)\n",
    "                # CVaR-DRO loss\n",
    "                loss = cvar_loss_from_batch(loss_vec, alpha=alpha)\n",
    "\n",
    "            scaler.scale(loss).backward()\n",
    "            scaler.step(opt); scaler.update()\n",
    "\n",
    "            bs = x.size(0)\n",
    "            total_ce_sum += loss_vec.mean().detach().item() * bs\n",
    "            total_n += bs\n",
    "\n",
    "        tr = total_ce_sum / max(1, total_n)\n",
    "        va = eval_ce_loss(model, val_loader)\n",
    "        acc_val = eval_accuracy(model, val_loader)\n",
    "        tr_hist.append(tr); va_hist.append(va); va_acc_hist.append(acc_val)\n",
    "\n",
    "        train_acc = _train_accuracy_epoch(model, train_loader)\n",
    "        test_ce = eval_ce_loss(model, test_loader) if test_loader is not None else float('nan')\n",
    "        test_acc = eval_accuracy(model, test_loader) if test_loader is not None else float('nan')\n",
    "        if csv_logger:\n",
    "            csv_logger.write({\n",
    "                \"epoch\": ep,\n",
    "                \"train_ce\": tr, \"train_acc\": train_acc,\n",
    "                \"val_ce\": va, \"val_acc\": acc_val,\n",
    "                \"test_ce\": test_ce, \"test_acc\": test_acc\n",
    "            })\n",
    "\n",
    "        if ep % print_every == 0:\n",
    "            log_epoch(algo_name, ep, tr, va, acc_val, t0)\n",
    "        maybe_print_class_stats(algo_name, ep, model, val_loader)\n",
    "\n",
    "    return tr_hist, va_hist, va_acc_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FRSWTqSr3uKE"
   },
   "outputs": [],
   "source": [
    "# ---------- small helpers ----------\n",
    "def reset_seeds(seed: int = 42):\n",
    "    import random\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = False\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "\n",
    "def select_lr(best_lr, grid, default=None):\n",
    "    if isinstance(best_lr, (float, int)) and np.isfinite(best_lr) and best_lr > 0:\n",
    "        return float(best_lr)\n",
    "    if grid and len(grid) > 0 and np.isfinite(grid[0]) and grid[0] > 0:\n",
    "        print(f\"[LR] Using grid[0] fallback for {grid}: {grid[0]}\")\n",
    "        return float(grid[0])\n",
    "    if default is not None:\n",
    "        print(f\"[LR] Using explicit default fallback: {default}\")\n",
    "        return float(default)\n",
    "    print(\"[LR] Fallback to 1e-3\")\n",
    "    return 1e-3\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_on_loader(model, loader):\n",
    "    \"\"\"\n",
    "    Returns (acc, per_class_acc, confusion_matrix).\n",
    "    Robust to both versions of per_class_metrics:\n",
    "      - (per_cls, cm)\n",
    "      - (per_cls, head, mid, tail, cm)\n",
    "    \"\"\"\n",
    "    acc, _ = eval_accuracy(model, loader, return_cm=True)\n",
    "    out = per_class_metrics(model, loader, num_classes=NUM_CLASSES)\n",
    "    if isinstance(out, (tuple, list)):\n",
    "        if len(out) == 2:\n",
    "            per_cls, cm_full = out\n",
    "        elif len(out) == 5:\n",
    "            per_cls, _, _, _, cm_full = out\n",
    "        else:\n",
    "            raise ValueError(\"Unexpected per_class_metrics return shape.\")\n",
    "    else:\n",
    "        raise TypeError(\"per_class_metrics must return a tuple/list.\")\n",
    "    return float(acc), np.asarray(per_cls, dtype=float), cm_full\n",
    "\n",
    "# ---------- generic LR sweep (robust to different trainer return shapes) ----------\n",
    "def lr_sweep(algo_name: str, lr_list, build_model_fn, train_fn,\n",
    "             train_kwargs: dict, train_epochs: int = 50, seed: int = 42,\n",
    "             print_every_sweep: int = 10):\n",
    "    results = []\n",
    "    for lr in lr_list:\n",
    "        reset_seeds(seed)\n",
    "        model = build_model_fn()\n",
    "        local_kwargs = dict(train_kwargs)\n",
    "        if (\"print_every\" in getattr(train_fn, \"__code__\", type(\"\", (), {\"co_varnames\": ()})).co_varnames\n",
    "            and \"print_every\" not in local_kwargs):\n",
    "            local_kwargs[\"print_every\"] = print_every_sweep\n",
    "        try:\n",
    "            out = train_fn(model, train_loader, val_loader, epochs=train_epochs, lr=lr, **local_kwargs)\n",
    "        except Exception as e:\n",
    "            print(f\"[SWEEP:{algo_name}] lr={lr:g}  ❌ {e}\")\n",
    "            continue\n",
    "\n",
    "        if not isinstance(out, (list, tuple)):\n",
    "            print(f\"[SWEEP:{algo_name}] lr={lr:g}  ❌ trainer returned non-tuple\")\n",
    "            continue\n",
    "\n",
    "        # Accept common shapes:\n",
    "        # 3: (tr, va, acc)\n",
    "        # 4: (..., tr, va, acc)\n",
    "        # 5: (..., ..., tr, va, acc)\n",
    "        if len(out) == 3:\n",
    "            tr, va, acc = out\n",
    "        elif len(out) in (4, 5):\n",
    "            tr, va, acc = out[-3], out[-2], out[-1]\n",
    "        else:\n",
    "            print(f\"[SWEEP:{algo_name}] lr={lr:g}  ❌ unexpected return len={len(out)}\")\n",
    "            continue\n",
    "\n",
    "        try:\n",
    "            val_acc_final = float(acc[-1]) if len(acc) else float(\"nan\")\n",
    "        except Exception:\n",
    "            print(f\"[SWEEP:{algo_name}] lr={lr:g}  ❌ acc not indexable\")\n",
    "            continue\n",
    "\n",
    "        if not np.isfinite(val_acc_final):\n",
    "            print(f\"[SWEEP:{algo_name}] lr={lr:g}  ❌ valAcc is NaN/Inf\")\n",
    "            continue\n",
    "\n",
    "        results.append({\"lr\": lr, \"val_acc\": val_acc_final})\n",
    "        print(f\"[SWEEP:{algo_name}] lr={lr:g}  valAcc(final)={val_acc_final:.4f}\")\n",
    "\n",
    "    best = max(results, key=lambda r: r[\"val_acc\"]) if results else {\"lr\": None, \"val_acc\": float(\"nan\")}\n",
    "    print(f\"[SWEEP:{algo_name}] best_lr={best['lr']}  best_valAcc={best['val_acc']:.4f}\")\n",
    "    return best[\"lr\"], results\n",
    "\n",
    "# ---------- hyperparam sweep (for algorithms with multiple hyperparameters) ----------\n",
    "def hyperparam_sweep(algo_name: str, config_grid, build_model_fn, train_fn,\n",
    "                     base_train_kwargs: dict, train_epochs: int = 50, seed: int = 42,\n",
    "                     print_every_sweep: int = 10):\n",
    "    \"\"\"\n",
    "    Sweep over a grid of hyperparameter configurations.\n",
    "    Args:\n",
    "        config_grid: list of dicts, each containing hyperparameters to test\n",
    "        base_train_kwargs: dict of base arguments passed to train_fn\n",
    "    Returns:\n",
    "        (best_config, all_results)\n",
    "    \"\"\"\n",
    "    results = []\n",
    "    for config in config_grid:\n",
    "        reset_seeds(seed)\n",
    "        model = build_model_fn()\n",
    "\n",
    "        # Merge config into base_train_kwargs\n",
    "        local_kwargs = dict(base_train_kwargs)\n",
    "        local_kwargs.update(config)\n",
    "\n",
    "        # Add print_every if not present\n",
    "        if (\"print_every\" in getattr(train_fn, \"__code__\", type(\"\", (), {\"co_varnames\": ()})).co_varnames\n",
    "            and \"print_every\" not in local_kwargs):\n",
    "            local_kwargs[\"print_every\"] = print_every_sweep\n",
    "\n",
    "        # Extract epochs and lr (assume they're in config or base_train_kwargs)\n",
    "        epochs = local_kwargs.pop(\"epochs\", train_epochs)\n",
    "\n",
    "        config_str = \", \".join([f\"{k}={v}\" for k, v in config.items()])\n",
    "\n",
    "        try:\n",
    "            out = train_fn(model, train_loader, val_loader, epochs=epochs, **local_kwargs)\n",
    "        except Exception as e:\n",
    "            print(f\"[SWEEP:{algo_name}] {config_str}  ❌ {e}\")\n",
    "            continue\n",
    "\n",
    "        if not isinstance(out, (list, tuple)):\n",
    "            print(f\"[SWEEP:{algo_name}] {config_str}  ❌ trainer returned non-tuple\")\n",
    "            continue\n",
    "\n",
    "        # Accept common shapes\n",
    "        if len(out) == 3:\n",
    "            tr, va, acc = out\n",
    "        elif len(out) in (4, 5):\n",
    "            tr, va, acc = out[-3], out[-2], out[-1]\n",
    "        else:\n",
    "            print(f\"[SWEEP:{algo_name}] {config_str}  ❌ unexpected return len={len(out)}\")\n",
    "            continue\n",
    "\n",
    "        try:\n",
    "            val_acc_final = float(acc[-1]) if len(acc) else float(\"nan\")\n",
    "        except Exception:\n",
    "            print(f\"[SWEEP:{algo_name}] {config_str}  ❌ acc not indexable\")\n",
    "            continue\n",
    "\n",
    "        if not np.isfinite(val_acc_final):\n",
    "            print(f\"[SWEEP:{algo_name}] {config_str}  ❌ valAcc is NaN/Inf\")\n",
    "            continue\n",
    "\n",
    "        result_entry = dict(config)\n",
    "        result_entry[\"val_acc\"] = val_acc_final\n",
    "        results.append(result_entry)\n",
    "        print(f\"[SWEEP:{algo_name}] {config_str}  valAcc(final)={val_acc_final:.4f}\")\n",
    "\n",
    "    if not results:\n",
    "        print(f\"[SWEEP:{algo_name}] No valid results!\")\n",
    "        return None, []\n",
    "\n",
    "    best = max(results, key=lambda r: r[\"val_acc\"])\n",
    "    best_config = {k: v for k, v in best.items() if k != \"val_acc\"}\n",
    "    best_str = \", \".join([f\"{k}={v}\" for k, v in best_config.items()])\n",
    "    print(f\"[SWEEP:{algo_name}] BEST: {best_str}  valAcc={best['val_acc']:.4f}\")\n",
    "    return best_config, results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "VRVykaY-3uHZ",
    "outputId": "edeea2ee-cbd8-48c0-f3e4-160da1b2d579"
   },
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# === Hyperparameter Grids ===\n",
    "# ==========================================\n",
    "\n",
    "# Base LR grid for all algorithms\n",
    "LR_GRID = [1e-5, 1e-4, 1e-3, 1e-2]\n",
    "\n",
    "# Algorithm-specific LR grids\n",
    "ERM_LR_GRID = LR_GRID\n",
    "SAM_LR_GRID = LR_GRID\n",
    "IRS_CW_LR_GRID = LR_GRID\n",
    "KLRS_LR_GRID = LR_GRID\n",
    "\n",
    "# IRM penalty grid\n",
    "IRM_PENALTY_GRID = [100, 1000, 10000, 100000]\n",
    "\n",
    "# REx grids\n",
    "VREX_LR_GRID = LR_GRID\n",
    "VREX_BETA_GRID = [0.1, 0.5, 1.0, 5.0, 10.0]  # Variance penalty weights\n",
    "\n",
    "MMREX_LR_GRID = LR_GRID\n",
    "# MM-REx λ values (Krueger et al. 2021):\n",
    "# λ = 0.5: ERM (interpolate equally between min/max)\n",
    "# λ = 1.0: Worst-case (equivalent to GroupDRO)\n",
    "# λ > 1.0: Extrapolation beyond worst-case (the key insight of MM-REx)\n",
    "MMREX_LAMBDA_GRID = [1.0, 1.5, 2.0, 3.0]  # λ > 1 for extrapolation\n",
    "\n",
    "# GroupDRO grids\n",
    "GROUPDRO_LR_GRID = LR_GRID\n",
    "GROUPDRO_ALPHA_GRID = [0.1, 0.2, 0.5]  # Group weight step size\n",
    "\n",
    "# χ²-DRO grids\n",
    "# rho = Chi-squared constraint radius. Larger rho = more robust (mean + larger std penalty)\n",
    "CHI2_RHO_GRID = [0.01, 0.1, 0.5, 1.0, 2.0]\n",
    "\n",
    "# CVaR-DRO grids (Levy et al. 2020)\n",
    "# alpha = CVaR level. Smaller alpha = more robust (focus on worst cases)\n",
    "CVAR_ALPHA_GRID = [0.05, 0.1, 0.2, 0.3, 0.5]\n",
    "\n",
    "# Training settings\n",
    "VAL_SWEEP_EPOCHS = 50 # 50\n",
    "EPOCHS = 100 # 100\n",
    "WEIGHT_DECAY = 0.0\n",
    "SAM_RHO = 0.05\n",
    "\n",
    "# ==========================================\n",
    "# === SWEEP CONTROL FLAGS ===\n",
    "# ==========================================\n",
    "# Set to False to skip sweep and use pre-computed hyperparameters\n",
    "\n",
    "RUN_SWEEP = {\n",
    "    \"ERM\": True,\n",
    "    \"ERM-Adam\": True,\n",
    "    \"SAM\": True,\n",
    "    \"IRS(KL)-CW\": True,\n",
    "    \"KL-RS\": True,\n",
    "    \"V-REx\": True,\n",
    "    \"MM-REx\": True,\n",
    "    \"IRMv1\": True,\n",
    "    \"GroupDRO\": True,\n",
    "    \"χ²-DRO\": True,\n",
    "    \"CVaR-DRO\": True,\n",
    "}\n",
    "\n",
    "# ==========================================\n",
    "# === FULL TRAINING CONTROL FLAGS ===\n",
    "# ==========================================\n",
    "\n",
    "RUN_FULL_TRAINING = {\n",
    "    \"ERM\": True,\n",
    "    \"ERM-Adam\": True,\n",
    "    \"SAM\": True,\n",
    "    \"IRS(KL)-CW\": True,\n",
    "    \"KL-RS\": True,\n",
    "    \"V-REx\": True,\n",
    "    \"MM-REx\": True,\n",
    "    \"IRMv1\": True,\n",
    "    \"GroupDRO\": True,\n",
    "    \"χ²-DRO\": True,\n",
    "    \"CVaR-DRO\": True,\n",
    "}\n",
    "\n",
    "PRECOMPUTED_HYPERPARAMS = {\n",
    "    \"ERM\": {\"lr\": 0.01},\n",
    "    \"ERM-Adam\": {\"lr\": 0.001},\n",
    "    \"SAM\": {\"lr\": 0.001},\n",
    "    \"IRS(KL)-CW\": {\"lr\": 1e-3},  # Default, will be swept\n",
    "    \"KL-RS\": {\"lr\": 1e-3, \"alt_iters\": 5},  # Default, will be swept\n",
    "    \"V-REx\": {\"lr\": 1e-3, \"rex_beta\": 0.5},  # Default, will be swept\n",
    "    \"MM-REx\": {\"lr\": 0.001, \"rex_lambda\": 1.5},\n",
    "    \"IRMv1\": {\"lr\": 0.001, \"penalty_weight\": 10},\n",
    "    \"GroupDRO\": {\"lr\": 0.001, \"alpha\": 0.1},\n",
    "    \"χ²-DRO\": {\"lr\": 0.001, \"rho\": 0.01},\n",
    "    \"CVaR-DRO\": {\"lr\": 1e-3, \"alpha\": 0.5},  # Default, will be swept\n",
    "}\n",
    "\n",
    "# ==========================================\n",
    "# === Grid Search Execution ===\n",
    "# ==========================================\n",
    "\n",
    "import itertools\n",
    "\n",
    "# --- 1. ERM LR Sweep ---\n",
    "if RUN_SWEEP[\"ERM\"]:\n",
    "    print(\"\\n▶ LR sweep: ERM (SGD)\")\n",
    "    best_erm_lr, erm_sweep = lr_sweep(\n",
    "        algo_name=\"ERM\",\n",
    "        lr_list=ERM_LR_GRID,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=train_erm,\n",
    "        train_kwargs=dict(weight_decay=WEIGHT_DECAY),\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "        seed=SEED,\n",
    "        print_every_sweep=max(VAL_SWEEP_EPOCHS//5, 10),\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] ERM sweep — using precomputed lr=\", PRECOMPUTED_HYPERPARAMS[\"ERM\"][\"lr\"])\n",
    "    best_erm_lr = PRECOMPUTED_HYPERPARAMS[\"ERM\"][\"lr\"]\n",
    "    erm_sweep = []\n",
    "\n",
    "# --- 1b. ERM-Adam LR Sweep ---\n",
    "if RUN_SWEEP[\"ERM-Adam\"]:\n",
    "    print(\"\\n▶ LR sweep: ERM-Adam\")\n",
    "    best_erm_adam_lr, erm_adam_sweep = lr_sweep(\n",
    "        algo_name=\"ERM-Adam\",\n",
    "        lr_list=ERM_LR_GRID,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=train_erm_adam,\n",
    "        train_kwargs=dict(weight_decay=WEIGHT_DECAY),\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "        seed=SEED,\n",
    "        print_every_sweep=max(VAL_SWEEP_EPOCHS//5, 10),\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] ERM-Adam sweep — using precomputed lr=\", PRECOMPUTED_HYPERPARAMS[\"ERM-Adam\"][\"lr\"])\n",
    "    best_erm_adam_lr = PRECOMPUTED_HYPERPARAMS[\"ERM-Adam\"][\"lr\"]\n",
    "    erm_adam_sweep = []\n",
    "\n",
    "# --- 2. SAM LR Sweep ---\n",
    "if RUN_SWEEP[\"SAM\"]:\n",
    "    print(\"\\n▶ LR sweep: SAM (rho fixed at 0.05)\")\n",
    "    best_sam_lr, sam_sweep = lr_sweep(\n",
    "        algo_name=\"SAM\",\n",
    "        lr_list=SAM_LR_GRID,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=train_sam,\n",
    "        train_kwargs=dict(rho=SAM_RHO, weight_decay=WEIGHT_DECAY),\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "        seed=SEED,\n",
    "        print_every_sweep=max(VAL_SWEEP_EPOCHS//5, 10),\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] SAM sweep — using precomputed lr=\", PRECOMPUTED_HYPERPARAMS[\"SAM\"][\"lr\"])\n",
    "    best_sam_lr = PRECOMPUTED_HYPERPARAMS[\"SAM\"][\"lr\"]\n",
    "    sam_sweep = []\n",
    "\n",
    "# --- 3. IRS(KL)-CW LR Sweep ---\n",
    "if RUN_SWEEP[\"IRS(KL)-CW\"]:\n",
    "    print(\"\\n▶ LR sweep: IRS(KL)-CW (class-wise)\")\n",
    "    best_irs_cw_lr, irs_cw_sweep = lr_sweep(\n",
    "        algo_name=\"IRS(KL)-CW\",\n",
    "        lr_list=IRS_CW_LR_GRID,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=train_irs_cw,\n",
    "        train_kwargs=dict(\n",
    "            warmup_epochs=3,\n",
    "            base_tau=RS_SHARED_TAU,\n",
    "            distance_scale=1.0,\n",
    "            min_div=1e-2,\n",
    "            weight_decay=WEIGHT_DECAY,\n",
    "            tau_lb_factor=1.01,\n",
    "            tau_lb_eps=1e-8,\n",
    "            use_gate=True,\n",
    "            num_classes=NUM_CLASSES,\n",
    "            P_global=get_P_global_tensor(),\n",
    "        ),\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "        seed=SEED,\n",
    "        print_every_sweep=max(VAL_SWEEP_EPOCHS//5, 10),\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] IRS(KL)-CW sweep — using precomputed lr=\", PRECOMPUTED_HYPERPARAMS[\"IRS(KL)-CW\"][\"lr\"])\n",
    "    best_irs_cw_lr = PRECOMPUTED_HYPERPARAMS[\"IRS(KL)-CW\"][\"lr\"]\n",
    "    irs_cw_sweep = []\n",
    "\n",
    "# --- 4. KL-RS Grid Sweep (LR × alt_iters) ---\n",
    "alt_iters_grid = [5, 10, 15]  # Different alternation counts to try\n",
    "\n",
    "def klrs_train_wrapper(model, train_loader_, val_loader_, epochs, lr, alt_iters, **kwargs):\n",
    "    \"\"\"Wrapper that adjusts inner_epochs_theta to keep total epochs constant.\"\"\"\n",
    "    # Total training iterations = alt_iters * inner_epochs_theta\n",
    "    # We want this to equal VAL_SWEEP_EPOCHS regardless of alt_iters\n",
    "    inner_epochs_theta = max(1, epochs // alt_iters)\n",
    "    return train_klrs(\n",
    "        model, train_loader_, val_loader_,\n",
    "        epochs=epochs,\n",
    "        lr=lr,\n",
    "        alt_iters=alt_iters,\n",
    "        inner_epochs_theta=inner_epochs_theta,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=None,\n",
    "        algo_name=\"KL-RS\",\n",
    "        print_every=999,  # Suppress per-alternation prints during sweep\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "if RUN_SWEEP[\"KL-RS\"]:\n",
    "    print(\"\\n▶ Hyperparam sweep: KL-RS (LR × alt_iters)\")\n",
    "    print(\"NOTE: Total epochs kept constant across all alt_iters values\")\n",
    "\n",
    "    # Create grid: LR × alt_iters\n",
    "    klrs_config_grid = [\n",
    "        {\"lr\": lr, \"alt_iters\": alt_iters}\n",
    "        for lr, alt_iters in itertools.product(KLRS_LR_GRID, alt_iters_grid)\n",
    "    ]\n",
    "\n",
    "    best_klrs_config, klrs_sweep_results = hyperparam_sweep(\n",
    "        algo_name=\"KL-RS\",\n",
    "        config_grid=klrs_config_grid,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=klrs_train_wrapper,\n",
    "        base_train_kwargs={\n",
    "            \"weight_decay\": WEIGHT_DECAY,\n",
    "            \"base_tau\": RS_SHARED_TAU,  # Use same tau as IRS\n",
    "            \"warmup_epochs\": 3  # ERM warmup before KL-RS constraint\n",
    "        },\n",
    "        train_epochs=VAL_SWEEP_EPOCHS\n",
    "    )\n",
    "\n",
    "    best_klrs_lr = best_klrs_config[\"lr\"]\n",
    "    best_klrs_alt_iters = best_klrs_config[\"alt_iters\"]\n",
    "    print(f\"Best KL-RS config: lr={best_klrs_lr}, alt_iters={best_klrs_alt_iters}\")\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] KL-RS sweep — using precomputed lr=\", PRECOMPUTED_HYPERPARAMS[\"KL-RS\"][\"lr\"], \", alt_iters=\", PRECOMPUTED_HYPERPARAMS[\"KL-RS\"][\"alt_iters\"])\n",
    "    best_klrs_config = PRECOMPUTED_HYPERPARAMS[\"KL-RS\"]\n",
    "    best_klrs_lr = best_klrs_config[\"lr\"]\n",
    "    best_klrs_alt_iters = best_klrs_config[\"alt_iters\"]\n",
    "    klrs_sweep_results = []\n",
    "\n",
    "# --- 5. V-REx Hyperparam Sweep ---\n",
    "def vrex_train_wrapper(model, train_loader_, val_loader_, epochs, **kwargs):\n",
    "    return train_rex(\n",
    "        model, irm_loader1, irm_loader2, val_loader,\n",
    "        epochs=epochs,\n",
    "        rex_mode=\"vrex\",\n",
    "        test_loader=test_loader,\n",
    "        log_dir=None,\n",
    "        algo_name=\"V-REx\",\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "if RUN_SWEEP[\"V-REx\"]:\n",
    "    print(\"\\n▶ Hyperparam sweep: V-REx\")\n",
    "    vrex_config_grid = [\n",
    "        {\"lr\": lr, \"rex_beta\": beta}\n",
    "        for lr, beta in itertools.product(VREX_LR_GRID, VREX_BETA_GRID)\n",
    "    ]\n",
    "\n",
    "    best_vrex_config, vrex_sweep_results = hyperparam_sweep(\n",
    "        algo_name=\"V-REx\",\n",
    "        config_grid=vrex_config_grid,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=vrex_train_wrapper,\n",
    "        base_train_kwargs={},\n",
    "        train_epochs=VAL_SWEEP_EPOCHS\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] V-REx sweep — using precomputed config\")\n",
    "    best_vrex_config = PRECOMPUTED_HYPERPARAMS[\"V-REx\"]\n",
    "    vrex_sweep_results = []\n",
    "\n",
    "# --- 6. MM-REx Hyperparam Sweep ---\n",
    "def mmrex_train_wrapper(model, train_loader_, val_loader_, epochs, **kwargs):\n",
    "    return train_rex(\n",
    "        model, irm_loader1, irm_loader2, val_loader,\n",
    "        epochs=epochs,\n",
    "        rex_mode=\"mmrex\",\n",
    "        test_loader=test_loader,\n",
    "        log_dir=None,\n",
    "        algo_name=\"MM-REx\",\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "if RUN_SWEEP[\"MM-REx\"]:\n",
    "    print(\"\\n▶ Hyperparam sweep: MM-REx\")\n",
    "    mmrex_config_grid = [\n",
    "        {\"lr\": lr, \"rex_lambda\": lam}\n",
    "        for lr, lam in itertools.product(MMREX_LR_GRID, MMREX_LAMBDA_GRID)\n",
    "    ]\n",
    "\n",
    "    best_mmrex_config, mmrex_sweep_results = hyperparam_sweep(\n",
    "        algo_name=\"MM-REx\",\n",
    "        config_grid=mmrex_config_grid,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=mmrex_train_wrapper,\n",
    "        base_train_kwargs={},\n",
    "        train_epochs=VAL_SWEEP_EPOCHS\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] MM-REx sweep — using precomputed config\")\n",
    "    best_mmrex_config = PRECOMPUTED_HYPERPARAMS[\"MM-REx\"]\n",
    "    mmrex_sweep_results = []\n",
    "\n",
    "# --- 7. IRMv1 Hyperparam Sweep ---\n",
    "def irm_train_wrapper(model, train_loader_, val_loader_, epochs, **kwargs):\n",
    "    return train_irm(\n",
    "        model, irm_loader1, irm_loader2, val_loader,\n",
    "        epochs=epochs,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=None,\n",
    "        algo_name=\"IRMv1\",\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "if RUN_SWEEP[\"IRMv1\"]:\n",
    "    print(\"\\n▶ Hyperparam sweep: IRMv1 (LR x Penalty)\")\n",
    "    irm_config_grid = [\n",
    "        {\"lr\": lr, \"penalty_weight\": pw}\n",
    "        for lr, pw in itertools.product(LR_GRID, IRM_PENALTY_GRID)\n",
    "    ]\n",
    "\n",
    "    best_irm_config, irm_sweep_results = hyperparam_sweep(\n",
    "        algo_name=\"IRMv1\",\n",
    "        config_grid=irm_config_grid,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=irm_train_wrapper,\n",
    "        base_train_kwargs={\n",
    "            \"weight_decay\": WEIGHT_DECAY,\n",
    "            \"penalty_anneal_epochs\": 5\n",
    "        },\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] IRMv1 sweep — using precomputed config\")\n",
    "    best_irm_config = PRECOMPUTED_HYPERPARAMS[\"IRMv1\"]\n",
    "    irm_sweep_results = []\n",
    "\n",
    "# --- 8. GroupDRO Hyperparam Sweep ---\n",
    "def groupdro_train_wrapper(model, train_loader_, val_loader_, epochs, **kwargs):\n",
    "    return train_groupdro(\n",
    "        model, train_loader, val_loader,\n",
    "        epochs=epochs,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=None,\n",
    "        algo_name=\"GroupDRO\",\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "if RUN_SWEEP[\"GroupDRO\"]:\n",
    "    print(\"\\n▶ Hyperparam sweep: GroupDRO (LR × Alpha)\")\n",
    "    groupdro_config_grid = [\n",
    "        {\"lr\": lr, \"alpha\": alpha}\n",
    "        for lr, alpha in itertools.product(GROUPDRO_LR_GRID, GROUPDRO_ALPHA_GRID)\n",
    "    ]\n",
    "\n",
    "    best_groupdro_config, groupdro_sweep_results = hyperparam_sweep(\n",
    "        algo_name=\"GroupDRO\",\n",
    "        config_grid=groupdro_config_grid,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=groupdro_train_wrapper,\n",
    "        base_train_kwargs={\n",
    "            \"weight_decay\": WEIGHT_DECAY,\n",
    "            \"gamma\": 0.1,\n",
    "            \"robust_step_size\": 0.01,\n",
    "        },\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] GroupDRO sweep — using precomputed config\")\n",
    "    best_groupdro_config = PRECOMPUTED_HYPERPARAMS[\"GroupDRO\"]\n",
    "    groupdro_sweep_results = []\n",
    "\n",
    "# --- 9. χ²-DRO Hyperparam Sweep ---\n",
    "def chi2_train_wrapper(model, train_loader_, val_loader_, epochs, **kwargs):\n",
    "    return train_chi2_dro(\n",
    "        model, train_loader, val_loader,\n",
    "        epochs=epochs,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=None,\n",
    "        algo_name=\"χ²-DRO\",\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "if RUN_SWEEP[\"χ²-DRO\"]:\n",
    "    print(\"\\n▶ Hyperparam sweep: χ²-DRO (LR × Rho)\")\n",
    "    chi2_config_grid = [\n",
    "        {\"lr\": lr, \"rho\": rho}\n",
    "        for lr, rho in itertools.product(LR_GRID, CHI2_RHO_GRID)\n",
    "    ]\n",
    "\n",
    "    best_chi2_config, chi2_sweep_results = hyperparam_sweep(\n",
    "        algo_name=\"χ²-DRO\",\n",
    "        config_grid=chi2_config_grid,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=chi2_train_wrapper,\n",
    "        base_train_kwargs={\n",
    "            \"weight_decay\": WEIGHT_DECAY,\n",
    "        },\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] χ²-DRO sweep — using precomputed config\")\n",
    "    best_chi2_config = PRECOMPUTED_HYPERPARAMS[\"χ²-DRO\"]\n",
    "    chi2_sweep_results = []\n",
    "\n",
    "# --- 10. CVaR-DRO Hyperparam Sweep ---\n",
    "def cvar_train_wrapper(model, train_loader_, val_loader_, epochs, **kwargs):\n",
    "    return train_cvar_dro(\n",
    "        model, train_loader, val_loader,\n",
    "        epochs=epochs,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=None,\n",
    "        algo_name=\"CVaR-DRO\",\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "if RUN_SWEEP[\"CVaR-DRO\"]:\n",
    "    print(\"\\n▶ Hyperparam sweep: CVaR-DRO (LR × Alpha) [Levy et al. 2020]\")\n",
    "    cvar_config_grid = [\n",
    "        {\"lr\": lr, \"alpha\": alpha}\n",
    "        for lr, alpha in itertools.product(LR_GRID, CVAR_ALPHA_GRID)\n",
    "    ]\n",
    "\n",
    "    best_cvar_config, cvar_sweep_results = hyperparam_sweep(\n",
    "        algo_name=\"CVaR-DRO\",\n",
    "        config_grid=cvar_config_grid,\n",
    "        build_model_fn=fresh_model,\n",
    "        train_fn=cvar_train_wrapper,\n",
    "        base_train_kwargs={\n",
    "            \"weight_decay\": WEIGHT_DECAY,\n",
    "        },\n",
    "        train_epochs=VAL_SWEEP_EPOCHS,\n",
    "    )\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] CVaR-DRO sweep — using precomputed config\")\n",
    "    best_cvar_config = PRECOMPUTED_HYPERPARAMS[\"CVaR-DRO\"]\n",
    "    cvar_sweep_results = []\n",
    "\n",
    "# ==========================================\n",
    "# === Best Hyperparameter Selection ===\n",
    "# ==========================================\n",
    "\n",
    "# Extract best hyperparameters (with fallbacks)\n",
    "best_erm_lr = select_lr(best_erm_lr, ERM_LR_GRID, default=1e-3)\n",
    "best_erm_adam_lr = select_lr(best_erm_adam_lr, ERM_LR_GRID, default=1e-3)\n",
    "best_sam_lr = select_lr(best_sam_lr, SAM_LR_GRID, default=1e-3)\n",
    "best_irs_cw_lr = select_lr(best_irs_cw_lr, IRS_CW_LR_GRID, default=1e-3)\n",
    "\n",
    "best_klrs_lr = best_klrs_config.get(\"lr\", 1e-3) if best_klrs_config else 1e-3\n",
    "\n",
    "best_vrex_lr = best_vrex_config.get(\"lr\", 1e-3) if best_vrex_config else 1e-3\n",
    "best_vrex_beta = best_vrex_config.get(\"rex_beta\", 1.0) if best_vrex_config else 1.0\n",
    "\n",
    "best_mmrex_lr = best_mmrex_config.get(\"lr\", 1e-3) if best_mmrex_config else 1e-3\n",
    "best_mmrex_lambda = best_mmrex_config.get(\"rex_lambda\", 1.5) if best_mmrex_config else 1.5\n",
    "\n",
    "best_irm_lr = best_irm_config.get(\"lr\", 1e-3) if best_irm_config else 1e-3\n",
    "best_irm_penalty = best_irm_config.get(\"penalty_weight\", 1.0) if best_irm_config else 1.0\n",
    "\n",
    "best_groupdro_lr = best_groupdro_config.get(\"lr\", 1e-3) if best_groupdro_config else 1e-3\n",
    "best_groupdro_alpha = best_groupdro_config.get(\"alpha\", 0.2) if best_groupdro_config else 0.2\n",
    "\n",
    "best_chi2_lr = best_chi2_config.get(\"lr\", 1e-3) if best_chi2_config else 1e-3\n",
    "best_chi2_rho = best_chi2_config.get(\"rho\", 0.1) if best_chi2_config else 0.1\n",
    "\n",
    "best_cvar_lr = best_cvar_config.get(\"lr\", 1e-3) if best_cvar_config else 1e-3\n",
    "best_cvar_alpha = best_cvar_config.get(\"alpha\", 0.1) if best_cvar_config else 0.1\n",
    "\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"CHOSEN HYPERPARAMETERS (after fallback):\")\n",
    "print(\"=\"*60)\n",
    "print(f\"ERM (SGD):  lr={best_erm_lr}\")\n",
    "print(f\"ERM-Adam:   lr={best_erm_adam_lr}\")\n",
    "print(f\"SAM:        lr={best_sam_lr}, rho={SAM_RHO}\")\n",
    "print(f\"IRS(KL)-CW: lr={best_irs_cw_lr}\")\n",
    "print(f\"KL-RS:      lr={best_klrs_lr}, alt_iters={best_klrs_alt_iters}\")\n",
    "print(f\"V-REx:      lr={best_vrex_lr}, beta={best_vrex_beta}\")\n",
    "print(f\"MM-REx:     lr={best_mmrex_lr}, lambda={best_mmrex_lambda}\")\n",
    "print(f\"IRMv1:      lr={best_irm_lr}, penalty_weight={best_irm_penalty}\")\n",
    "print(f\"GroupDRO:   lr={best_groupdro_lr}, alpha={best_groupdro_alpha}\")\n",
    "print(f\"χ²-DRO:     lr={best_chi2_lr}, rho={best_chi2_rho}\")\n",
    "print(f\"CVaR-DRO:   lr={best_cvar_lr}, alpha={best_cvar_alpha}\")\n",
    "print(\"=\"*60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NvV8Fjw57-kP"
   },
   "outputs": [],
   "source": [
    "# Where to save sweep CSVs\n",
    "sweeps_dir = make_run_dir(IMB_FACTOR, \"SWEEPS\")\n",
    "print(\"Sweep results will be saved under:\", sweeps_dir)\n",
    "\n",
    "def _write_csv(path, rows, header):\n",
    "    import csv, os\n",
    "    os.makedirs(os.path.dirname(path), exist_ok=True)\n",
    "    with open(path, \"w\", newline=\"\") as f:\n",
    "        w = csv.DictWriter(f, fieldnames=header)\n",
    "        w.writeheader()\n",
    "        for r in rows:\n",
    "            w.writerow({k: r.get(k, \"\") for k in header})\n",
    "\n",
    "def save_lr_sweep_table(algo_name, sweep_list, out_dir):\n",
    "    \"\"\"\n",
    "    sweep_list: list of dicts like {\"lr\": <float>, \"val_acc\": <float>}\n",
    "    \"\"\"\n",
    "    if not sweep_list:\n",
    "        print(f\"[SWEEP:{algo_name}] nothing to save\")\n",
    "        return None\n",
    "    rows = [{\"lr\": float(r[\"lr\"]), \"val_acc\": float(r[\"val_acc\"])} for r in sweep_list]\n",
    "    path = os.path.join(out_dir, f\"{algo_name}_lr_sweep.csv\")\n",
    "    _write_csv(path, rows, header=[\"lr\", \"val_acc\"])\n",
    "    best = max(rows, key=lambda r: r[\"val_acc\"])\n",
    "    print(f\"[SWEEP:{algo_name}] saved → {path} (best lr={best['lr']}, val_acc={best['val_acc']:.4f})\")\n",
    "    return path\n",
    "\n",
    "# Save per-algo CSVs\n",
    "# Save per-algo CSVs\n",
    "erm_csv = save_lr_sweep_table(\"ERM\",         erm_sweep, sweeps_dir)\n",
    "sam_csv  = save_lr_sweep_table(\"SAM\",        sam_sweep,  sweeps_dir)\n",
    "irs_csv  = save_lr_sweep_table(\"IRS(KL)-CW\",  irs_cw_sweep, sweeps_dir)\n",
    "\n",
    "# Combined sweep table\n",
    "combined = []\n",
    "for algo, sweep in [\n",
    "    (\"ERM\", erm_sweep),\n",
    "    (\"ERM-Adam\", erm_adam_sweep),\n",
    "    (\"SAM\", sam_sweep),\n",
    "    (\"IRS(KL)-CW\", irs_cw_sweep),\n",
    "]:\n",
    "    for r in (sweep or []):\n",
    "        combined.append({\"algo\": algo, \"lr\": float(r[\"lr\"]), \"val_acc\": float(r[\"val_acc\"])})\n",
    "\n",
    "combined_path = os.path.join(sweeps_dir, \"ALL_algos_lr_sweep.csv\")\n",
    "_write_csv(combined_path, combined, header=[\"algo\", \"lr\", \"val_acc\"])\n",
    "print(\"Combined sweep table saved →\", combined_path)\n",
    "\n",
    "# (Optional) chosen-LR JSON\n",
    "chosen = {\n",
    "    \"ERM\":        float(best_erm_lr),\n",
    "    \"ERM-Adam\":   float(best_erm_adam_lr),\n",
    "    \"SAM\":        float(best_sam_lr),\n",
    "    \"IRS(KL)-CW\": float(best_irs_cw_lr),\n",
    "}\n",
    "with open(os.path.join(sweeps_dir, \"chosen_lrs.json\"), \"w\") as f:\n",
    "    json.dump(chosen, f, indent=2)\n",
    "print(\"Chosen LRs saved →\", os.path.join(sweeps_dir, \"chosen_lrs.json\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AEOCzJd53uFC"
   },
   "outputs": [],
   "source": [
    "results = {}\n",
    "wallclock_times = {}  # Track runtime for each algorithm\n",
    "\n",
    "# Create run dirs per algorithm inside Drive with IF_<IMB_FACTOR>/...\n",
    "irs_dir = make_run_dir(IMB_FACTOR, \"IRS(KL)-CW\")\n",
    "sam_dir = make_run_dir(IMB_FACTOR, \"SAM\")\n",
    "erm_dir = make_run_dir(IMB_FACTOR, \"ERM\")\n",
    "erm_adam_dir = make_run_dir(IMB_FACTOR, \"ERM-Adam\")\n",
    "klrs_dir = make_run_dir(IMB_FACTOR, \"KL-RS\")\n",
    "vrex_dir = make_run_dir(IMB_FACTOR, \"V-REx\")\n",
    "mmrex_dir = make_run_dir(IMB_FACTOR, \"MM-REx\")\n",
    "irm_dir = make_run_dir(IMB_FACTOR, \"IRMv1\")\n",
    "groupdro_dir = make_run_dir(IMB_FACTOR, \"GroupDRO\")\n",
    "chi2_dir = make_run_dir(IMB_FACTOR, \"χ²-DRO\")\n",
    "cvar_dir = make_run_dir(IMB_FACTOR, \"CVaR-DRO\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"IRS(KL)-CW\"]:\n",
    "    print(\"\\n▶ Full training — IRS (KL) Class-wise\")\n",
    "    model_irs_cw = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    kappa_hist, h_hist, tr_irs, val_irs, acc_irs = train_irs_cw(\n",
    "        model_irs_cw, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_irs_cw_lr,\n",
    "        warmup_epochs=3,\n",
    "        base_tau=RS_SHARED_TAU,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        tau_lb_factor=1.01, tau_lb_eps=1e-8,\n",
    "        distance_scale=1.0, min_div=1e-2,\n",
    "        use_gate=True,\n",
    "        print_every=5,\n",
    "        num_classes=NUM_CLASSES,\n",
    "        P_global=get_P_global_tensor(),\n",
    "        test_loader=test_loader,\n",
    "        log_dir=irs_dir,\n",
    "        algo_name=\"IRS(KL)-CW\"\n",
    "    )\n",
    "    wallclock_times[\"IRS(KL)-CW\"] = time.perf_counter() - t_start\n",
    "    print(f\"[IRS(KL)-CW] Total runtime: {_fmt_hms(wallclock_times['IRS(KL)-CW'])}\")\n",
    "    results[\"IRS(KL)-CW\"] = {\"kappa\": kappa_hist, \"h\": h_hist, \"tr\": tr_irs, \"val\": val_irs, \"acc\": acc_irs}\n",
    "    save_confusion_matrices(\"IRS(KL)-CW\", irs_dir, model_irs_cw, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — IRS(KL)-CW\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"SAM\"]:\n",
    "    print(\"\\n▶ Full training — SAM\")\n",
    "    model_sam = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_sam, val_sam, acc_sam = train_sam(\n",
    "        model_sam, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_sam_lr,\n",
    "        rho=SAM_RHO, weight_decay=WEIGHT_DECAY, print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=sam_dir,\n",
    "        algo_name=\"SAM\"\n",
    "    )\n",
    "    wallclock_times[\"SAM\"] = time.perf_counter() - t_start\n",
    "    print(f\"[SAM] Total runtime: {_fmt_hms(wallclock_times['SAM'])}\")\n",
    "    results[\"SAM\"] = {\"tr\": tr_sam, \"val\": val_sam, \"acc\": acc_sam}\n",
    "    save_confusion_matrices(\"SAM\", sam_dir, model_sam, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — SAM\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"ERM\"]:\n",
    "    print(\"\\n▶ Full training — ERM (SGD)\")\n",
    "    model_erm = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_erm, val_erm, acc_erm = train_erm(\n",
    "        model_erm, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_erm_lr,\n",
    "        weight_decay=WEIGHT_DECAY, print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=erm_dir,\n",
    "        algo_name=\"ERM\"\n",
    "    )\n",
    "    wallclock_times[\"ERM\"] = time.perf_counter() - t_start\n",
    "    print(f\"[ERM] Total runtime: {_fmt_hms(wallclock_times['ERM'])}\")\n",
    "    results[\"ERM\"] = {\"tr\": tr_erm, \"val\": val_erm, \"acc\": acc_erm}\n",
    "    save_confusion_matrices(\"ERM\", erm_dir, model_erm, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — ERM\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"ERM-Adam\"]:\n",
    "    print(\"\\n▶ Full training — ERM-Adam\")\n",
    "    model_erm_adam = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_erm_adam, val_erm_adam, acc_erm_adam = train_erm_adam(\n",
    "        model_erm_adam, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_erm_adam_lr,\n",
    "        weight_decay=WEIGHT_DECAY, print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=erm_adam_dir,\n",
    "        algo_name=\"ERM-Adam\"\n",
    "    )\n",
    "    wallclock_times[\"ERM-Adam\"] = time.perf_counter() - t_start\n",
    "    print(f\"[ERM-Adam] Total runtime: {_fmt_hms(wallclock_times['ERM-Adam'])}\")\n",
    "    results[\"ERM-Adam\"] = {\"tr\": tr_erm_adam, \"val\": val_erm_adam, \"acc\": acc_erm_adam}\n",
    "    save_confusion_matrices(\"ERM-Adam\", erm_adam_dir, model_erm_adam, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — ERM-Adam\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"KL-RS\"]:\n",
    "    print(\"\\n▶ Full training — KL-RS\")\n",
    "    model_klrs = fresh_model()\n",
    "    inner_epochs_theta_klrs = max(1, EPOCHS // best_klrs_alt_iters)\n",
    "    print(f\"Using alt_iters={best_klrs_alt_iters}, inner_epochs_theta={inner_epochs_theta_klrs}\")\n",
    "    print(f\"Total training: {best_klrs_alt_iters} × {inner_epochs_theta_klrs} = {best_klrs_alt_iters * inner_epochs_theta_klrs} epochs\")\n",
    "    t_start = time.perf_counter()\n",
    "    tr_klrs, val_klrs, acc_klrs = train_klrs(\n",
    "        model_klrs, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_klrs_lr,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        base_tau=RS_SHARED_TAU,\n",
    "        warmup_epochs=3,\n",
    "        alt_iters=best_klrs_alt_iters,\n",
    "        inner_epochs_theta=inner_epochs_theta_klrs,\n",
    "        print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=klrs_dir,\n",
    "        algo_name=\"KL-RS\"\n",
    "    )\n",
    "    wallclock_times[\"KL-RS\"] = time.perf_counter() - t_start\n",
    "    print(f\"[KL-RS] Total runtime: {_fmt_hms(wallclock_times['KL-RS'])}\")\n",
    "    results[\"KL-RS\"] = {\"tr\": tr_klrs, \"val\": val_klrs, \"acc\": acc_klrs}\n",
    "    save_confusion_matrices(\"KL-RS\", klrs_dir, model_klrs, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — KL-RS\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"V-REx\"]:\n",
    "    print(\"\\n▶ Full training — V-REx\")\n",
    "    model_vrex = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_vrex, val_vrex, acc_vrex = train_rex(\n",
    "        model_vrex, irm_loader1, irm_loader2, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_vrex_lr,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        rex_mode=\"vrex\",\n",
    "        rex_beta=best_vrex_beta,\n",
    "        penalty_anneal_epochs=10,\n",
    "        print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=vrex_dir,\n",
    "        algo_name=\"V-REx\"\n",
    "    )\n",
    "    wallclock_times[\"V-REx\"] = time.perf_counter() - t_start\n",
    "    print(f\"[V-REx] Total runtime: {_fmt_hms(wallclock_times['V-REx'])}\")\n",
    "    results[\"V-REx\"] = {\"tr\": tr_vrex, \"val\": val_vrex, \"acc\": acc_vrex}\n",
    "    save_confusion_matrices(\"V-REx\", vrex_dir, model_vrex, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — V-REx\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"MM-REx\"]:\n",
    "    print(\"\\n▶ Full training — MM-REx\")\n",
    "    model_mmrex = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_mmrex, val_mmrex, acc_mmrex = train_rex(\n",
    "        model_mmrex, irm_loader1, irm_loader2, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_mmrex_lr,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        rex_mode=\"mmrex\",\n",
    "        rex_lambda=best_mmrex_lambda,\n",
    "        penalty_anneal_epochs=10,\n",
    "        print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=mmrex_dir,\n",
    "        algo_name=\"MM-REx\"\n",
    "    )\n",
    "    wallclock_times[\"MM-REx\"] = time.perf_counter() - t_start\n",
    "    print(f\"[MM-REx] Total runtime: {_fmt_hms(wallclock_times['MM-REx'])}\")\n",
    "    results[\"MM-REx\"] = {\"tr\": tr_mmrex, \"val\": val_mmrex, \"acc\": acc_mmrex}\n",
    "    save_confusion_matrices(\"MM-REx\", mmrex_dir, model_mmrex, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — MM-REx\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"IRMv1\"]:\n",
    "    print(\"\\n▶ Full training — IRMv1\")\n",
    "    model_irm = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_irm, val_irm, acc_irm = train_irm(\n",
    "        model_irm, irm_loader1, irm_loader2, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_irm_lr,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        penalty_weight=best_irm_penalty,\n",
    "        penalty_anneal_epochs=20,\n",
    "        print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=irm_dir,\n",
    "        algo_name=\"IRMv1\"\n",
    "    )\n",
    "    wallclock_times[\"IRMv1\"] = time.perf_counter() - t_start\n",
    "    print(f\"[IRMv1] Total runtime: {_fmt_hms(wallclock_times['IRMv1'])}\")\n",
    "    results[\"IRMv1\"] = {\"tr\": tr_irm, \"val\": val_irm, \"acc\": acc_irm}\n",
    "    save_confusion_matrices(\"IRMv1\", irm_dir, model_irm, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — IRMv1\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"GroupDRO\"]:\n",
    "    print(\"\\n▶ Full training — GroupDRO\")\n",
    "    model_groupdro = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_groupdro, val_groupdro, acc_groupdro = train_groupdro(\n",
    "        model_groupdro, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_groupdro_lr,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        alpha=best_groupdro_alpha,\n",
    "        gamma=0.1,\n",
    "        robust_step_size=0.01,\n",
    "        print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=groupdro_dir,\n",
    "        algo_name=\"GroupDRO\"\n",
    "    )\n",
    "    wallclock_times[\"GroupDRO\"] = time.perf_counter() - t_start\n",
    "    print(f\"[GroupDRO] Total runtime: {_fmt_hms(wallclock_times['GroupDRO'])}\")\n",
    "    results[\"GroupDRO\"] = {\"tr\": tr_groupdro, \"val\": val_groupdro, \"acc\": acc_groupdro}\n",
    "    save_confusion_matrices(\"GroupDRO\", groupdro_dir, model_groupdro, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — GroupDRO\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"χ²-DRO\"]:\n",
    "    print(\"\\n▶ Full training — χ²-DRO\")\n",
    "    model_chi2 = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_chi2, val_chi2, acc_chi2 = train_chi2_dro(\n",
    "        model_chi2, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_chi2_lr,\n",
    "        rho=best_chi2_rho,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=chi2_dir,\n",
    "        algo_name=\"χ²-DRO\"\n",
    "    )\n",
    "    wallclock_times[\"χ²-DRO\"] = time.perf_counter() - t_start\n",
    "    print(f\"[χ²-DRO] Total runtime: {_fmt_hms(wallclock_times['χ²-DRO'])}\")\n",
    "    results[\"χ²-DRO\"] = {\"tr\": tr_chi2, \"val\": val_chi2, \"acc\": acc_chi2}\n",
    "    save_confusion_matrices(\"χ²-DRO\", chi2_dir, model_chi2, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — χ²-DRO\")\n",
    "\n",
    "if RUN_FULL_TRAINING[\"CVaR-DRO\"]:\n",
    "    print(\"\\n▶ Full training — CVaR-DRO (Levy et al. 2020)\")\n",
    "    model_cvar = fresh_model()\n",
    "    t_start = time.perf_counter()\n",
    "    tr_cvar, val_cvar, acc_cvar = train_cvar_dro(\n",
    "        model_cvar, train_loader, val_loader,\n",
    "        epochs=EPOCHS,\n",
    "        lr=best_cvar_lr,\n",
    "        alpha=best_cvar_alpha,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        print_every=5,\n",
    "        test_loader=test_loader,\n",
    "        log_dir=cvar_dir,\n",
    "        algo_name=\"CVaR-DRO\"\n",
    "    )\n",
    "    wallclock_times[\"CVaR-DRO\"] = time.perf_counter() - t_start\n",
    "    print(f\"[CVaR-DRO] Total runtime: {_fmt_hms(wallclock_times['CVaR-DRO'])}\")\n",
    "    results[\"CVaR-DRO\"] = {\"tr\": tr_cvar, \"val\": val_cvar, \"acc\": acc_cvar}\n",
    "    save_confusion_matrices(\"CVaR-DRO\", cvar_dir, model_cvar, train_loader, val_loader, test_loader)\n",
    "else:\n",
    "    print(\"\\n▶ [SKIP] Full training — CVaR-DRO\")\n",
    "\n",
    "# ==========================================\n",
    "# === Runtime Summary ===\n",
    "# ==========================================\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(f\"WALLCLOCK RUNTIME SUMMARY ({EPOCHS} epochs)\")\n",
    "print(\"=\"*60)\n",
    "for algo, runtime in wallclock_times.items():\n",
    "    print(f\"  {algo:15s}: {_fmt_hms(runtime)}\")\n",
    "print(\"=\"*60)\n",
    "total_time = sum(wallclock_times.values())\n",
    "print(f\"  {'TOTAL':15s}: {_fmt_hms(total_time)}\")\n",
    "print(\"=\"*60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5LwEOw-4c_cJ"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "74LRA7Gzc-7g"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zX10n5b73uCo"
   },
   "outputs": [],
   "source": [
    "print(\"\\n▶ Final test evaluation (balanced test set)\")\n",
    "all_models = []\n",
    "if RUN_FULL_TRAINING[\"ERM\"]: all_models.append((\"ERM\", model_erm))\n",
    "if RUN_FULL_TRAINING[\"ERM-Adam\"]: all_models.append((\"ERM-Adam\", model_erm_adam))\n",
    "if RUN_FULL_TRAINING[\"SAM\"]: all_models.append((\"SAM\", model_sam))\n",
    "if RUN_FULL_TRAINING[\"IRS(KL)-CW\"]: all_models.append((\"IRS(KL)-CW\", model_irs_cw))\n",
    "if RUN_FULL_TRAINING[\"KL-RS\"]: all_models.append((\"KL-RS\", model_klrs))\n",
    "if RUN_FULL_TRAINING[\"V-REx\"]: all_models.append((\"V-REx\", model_vrex))\n",
    "if RUN_FULL_TRAINING[\"MM-REx\"]: all_models.append((\"MM-REx\", model_mmrex))\n",
    "if RUN_FULL_TRAINING[\"IRMv1\"]: all_models.append((\"IRMv1\", model_irm))\n",
    "if RUN_FULL_TRAINING[\"GroupDRO\"]: all_models.append((\"GroupDRO\", model_groupdro))\n",
    "if RUN_FULL_TRAINING[\"χ²-DRO\"]: all_models.append((\"χ²-DRO\", model_chi2))\n",
    "if RUN_FULL_TRAINING[\"CVaR-DRO\"]: all_models.append((\"CVaR-DRO\", model_cvar))\n",
    "\n",
    "for name, model in all_models:\n",
    "    test_acc, test_percls, test_cm = evaluate_on_loader(model, test_loader)\n",
    "    print(f\"\\n[TEST] {name}: acc={test_acc:.4f}\")\n",
    "    print(f\"[TEST] {name}: per-class acc:\")\n",
    "    for i, a in enumerate(test_percls):\n",
    "        cname = CIFAR10_CLASSES[i] if i < len(CIFAR10_CLASSES) else f\"class{i}\"\n",
    "        print(f\"  {i:02d} ({cname:10s}): {a:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VH1M5zR83uAA"
   },
   "outputs": [],
   "source": [
    "def _ep(n): return list(range(1, n+1))\n",
    "\n",
    "def _plot_curve_if(key, field, label=None, style=None):\n",
    "    if key in results and field in results[key]:\n",
    "        y = results[key][field]\n",
    "        if not y: return\n",
    "        lbl = label if label is not None else f\"{key} {field}\"\n",
    "        if style is None:\n",
    "            plt.plot(_ep(len(y)), y, label=lbl)\n",
    "        else:\n",
    "            plt.plot(_ep(len(y)), y, label=lbl, **style)\n",
    "\n",
    "display_name = {\n",
    "    \"ERM\": \"ERM (SGD)\",\n",
    "    \"ERM-Adam\": \"ERM-Adam\",\n",
    "    \"SAM\": \"SAM\",\n",
    "    \"IRS(KL)-CW\": \"IRS(KL)-CW\",\n",
    "    \"KL-RS\": \"KL-RS\",\n",
    "    \"V-REx\": \"V-REx\",\n",
    "    \"MM-REx\": \"MM-REx\",\n",
    "    \"IRMv1\": \"IRMv1\",\n",
    "    \"GroupDRO\": \"GroupDRO\",\n",
    "    \"χ²-DRO\": \"χ²-DRO\",\n",
    "    \"CVaR-DRO\": \"CVaR-DRO\",\n",
    "}\n",
    "available = set(results.keys())\n",
    "\n",
    "# All algorithms to plot\n",
    "all_algos = [\"ERM\", \"ERM-Adam\", \"SAM\", \"IRS(KL)-CW\", \"KL-RS\", \"V-REx\", \"MM-REx\",\n",
    "             \"IRMv1\", \"GroupDRO\", \"χ²-DRO\", \"CVaR-DRO\"]\n",
    "\n",
    "# Train CE\n",
    "plt.figure(figsize=(10,6))\n",
    "for k in all_algos:\n",
    "    if k in available: _plot_curve_if(k, \"tr\", f\"{display_name.get(k,k)}\")\n",
    "plt.title(\"Cross-Entropy vs Epoch (Train) — CIFAR-10-LT\")\n",
    "plt.xlabel(\"Epoch\"); plt.ylabel(\"Cross-Entropy\")\n",
    "plt.ylim(0.0, 6.0)\n",
    "plt.legend(ncol=3, fontsize=7, loc='best'); plt.grid(True, alpha=0.25); plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Val CE\n",
    "plt.figure(figsize=(10,6))\n",
    "for k in all_algos:\n",
    "    if k in available: _plot_curve_if(k, \"val\", f\"{display_name.get(k,k)}\")\n",
    "plt.title(\"Cross-Entropy vs Epoch (Validation) — CIFAR-10-LT\")\n",
    "plt.xlabel(\"Epoch\"); plt.ylabel(\"Cross-Entropy\")\n",
    "plt.ylim(0.0, 6.0)\n",
    "plt.legend(ncol=3, fontsize=7, loc='best'); plt.grid(True, alpha=0.25); plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Val Acc (most important)\n",
    "plt.figure(figsize=(10,6))\n",
    "for k in all_algos:\n",
    "    if k in available: _plot_curve_if(k, \"acc\", f\"{display_name.get(k,k)}\")\n",
    "plt.title(\"Validation Accuracy vs Epoch — CIFAR-10-LT\")\n",
    "plt.xlabel(\"Epoch\"); plt.ylabel(\"Accuracy\")\n",
    "plt.legend(ncol=3, fontsize=7, loc='best'); plt.grid(True, alpha=0.25); plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# IRS-CW extras\n",
    "if \"IRS(KL)-CW\" in available:\n",
    "    if \"kappa\" in results[\"IRS(KL)-CW\"]:\n",
    "        plt.figure(figsize=(8,4.2))\n",
    "        _plot_curve_if(\"IRS(KL)-CW\", \"kappa\", \"IRS(KL)-CW κ̄\")\n",
    "        plt.title(\"IRS(KL)-CW — κ̄ per epoch\")\n",
    "        plt.xlabel(\"Epoch\"); plt.ylabel(\"κ̄\")\n",
    "        plt.grid(True, alpha=0.25); plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "    if \"h\" in results[\"IRS(KL)-CW\"]:\n",
    "        plt.figure(figsize=(8,4.2))\n",
    "        _plot_curve_if(\"IRS(KL)-CW\", \"h\", \"IRS(KL)-CW h̄\")\n",
    "        plt.title(\"IRS(KL)-CW — h̄ per epoch\")\n",
    "        plt.xlabel(\"Epoch\"); plt.ylabel(\"h̄\")\n",
    "        plt.grid(True, alpha=0.25); plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "# Print final accuracies summary\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"FINAL VALIDATION ACCURACIES:\")\n",
    "print(\"=\"*60)\n",
    "for k in all_algos:\n",
    "    if k in available and \"acc\" in results[k]:\n",
    "        acc_final = results[k][\"acc\"][-1] if results[k][\"acc\"] else float('nan')\n",
    "        print(f\"{k:15s}: {acc_final:.4f}\")\n",
    "print(\"=\"*60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7ZlpKEKL342R"
   },
   "outputs": [],
   "source": [
    "save_dir = \"/content/drive/My Drive/cifarlt_checkpoints\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "final_models = {}\n",
    "if 'model_erm' in globals(): final_models['erm'] = model_erm\n",
    "if 'model_erm_adam' in globals(): final_models['erm_adam'] = model_erm_adam\n",
    "if 'model_sam' in globals(): final_models['sam'] = model_sam\n",
    "if 'model_irs_cw' in globals(): final_models['irs_kl'] = model_irs_cw\n",
    "if 'model_klrs' in globals(): final_models['klrs'] = model_klrs\n",
    "if 'model_vrex' in globals(): final_models['vrex'] = model_vrex\n",
    "if 'model_mmrex' in globals(): final_models['mmrex'] = model_mmrex\n",
    "if 'model_irm' in globals(): final_models['irm'] = model_irm\n",
    "if 'model_groupdro' in globals(): final_models['groupdro'] = model_groupdro\n",
    "if 'model_chi2' in globals(): final_models['chi2_dro'] = model_chi2\n",
    "if 'model_cvar' in globals(): final_models['cvar_dro'] = model_cvar\n",
    "\n",
    "for name, mdl in final_models.items():\n",
    "    path = os.path.join(save_dir, f\"{name}.pth\")\n",
    "    try:\n",
    "        torch.save(mdl.state_dict(), path)\n",
    "        print(f\"✔️  Saved {name} → {path}\")\n",
    "    except Exception as e:\n",
    "        print(f\"⚠️  Could not save {name}: {e}\")\n",
    "\n",
    "print(\"Done saving.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b04JQtuv356v"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
