{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "aDtB-5S1SVeL",
        "outputId": "13daedc9-db28-4159-8d9f-9b79965b53b0"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# Download Waterbirds to Google Drive\n",
        "# ==========================================\n",
        "import os\n",
        "from google.colab import drive\n",
        "\n",
        "# Mount Google Drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# Set paths\n",
        "DRIVE_ROOT = \"/content/drive/MyDrive/datasets\"\n",
        "TGZ_PATH = f\"{DRIVE_ROOT}/waterbird.tar.gz\"\n",
        "LOCAL_ROOT = \"/content/waterbirds_local\"\n",
        "\n",
        "# Final DATA_ROOT will be set after extraction\n",
        "DATA_ROOT = f\"{DRIVE_ROOT}/waterbirds\"  # Default - will be updated\n",
        "\n",
        "# Create directories\n",
        "os.makedirs(DRIVE_ROOT, exist_ok=True)\n",
        "\n",
        "# Download using wget (more reliable than urllib for problematic URLs)\n",
        "if not os.path.exists(TGZ_PATH) and not os.path.exists(DATA_ROOT):\n",
        "    print(\"Attempting to download Waterbirds dataset to Google Drive...\")\n",
        "\n",
        "    # Try multiple sources\n",
        "    sources = [\n",
        "        \"https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz\",\n",
        "        \"https://worksheets.codalab.org/rest/bundles/0xb922b6149f9043789ae932c02fe39043/contents/blob/\",\n",
        "    ]\n",
        "\n",
        "    success = False\n",
        "    for i, url in enumerate(sources):\n",
        "        print(f\"\\nTrying source {i+1}...\")\n",
        "        result = os.system(f'wget -O \"{TGZ_PATH}\" \"{url}\" 2>&1')\n",
        "        if result == 0 and os.path.exists(TGZ_PATH) and os.path.getsize(TGZ_PATH) > 1000:\n",
        "            print(\"Download successful!\")\n",
        "            success = True\n",
        "            break\n",
        "        else:\n",
        "            print(f\"Source {i+1} failed\")\n",
        "            if os.path.exists(TGZ_PATH):\n",
        "                os.remove(TGZ_PATH)\n",
        "\n",
        "    if not success:\n",
        "        print(\"\\n❌ Automatic download failed.\")\n",
        "        print(\"Please download manually from: https://github.com/kohpangwei/group_DRO\")\n",
        "\n",
        "# Extract if we have the tar.gz and haven't extracted yet\n",
        "if os.path.exists(TGZ_PATH) and not os.path.exists(DATA_ROOT):\n",
        "    print(\"\\nExtracting dataset to Google Drive...\")\n",
        "    import tarfile\n",
        "    with tarfile.open(TGZ_PATH, \"r:gz\") as tf:\n",
        "        tf.extractall(DRIVE_ROOT)\n",
        "\n",
        "    # The tar extracts to waterbird_complete95_forest2water2, rename to waterbirds\n",
        "    extracted_path = f\"{DRIVE_ROOT}/waterbird_complete95_forest2water2\"\n",
        "    if os.path.exists(extracted_path):\n",
        "        os.rename(extracted_path, DATA_ROOT)\n",
        "        print(f\"✓ Renamed to {DATA_ROOT}\")\n",
        "\n",
        "# Copy to local disk for faster I/O\n",
        "if os.path.exists(DATA_ROOT) and not os.path.exists(LOCAL_ROOT):\n",
        "    print(f\"\\n🚀 Copying dataset to local disk ({LOCAL_ROOT}) for faster I/O...\")\n",
        "    os.makedirs(LOCAL_ROOT, exist_ok=True)\n",
        "\n",
        "    if os.path.exists(TGZ_PATH):\n",
        "        # Extract from tar directly to local (faster)\n",
        "        os.system(f\"cp '{TGZ_PATH}' /content/temp_waterbirds.tar.gz\")\n",
        "        os.system(f\"tar -xf /content/temp_waterbirds.tar.gz -C '{LOCAL_ROOT}'\")\n",
        "        if os.path.exists(\"/content/temp_waterbirds.tar.gz\"):\n",
        "            os.remove(\"/content/temp_waterbirds.tar.gz\")\n",
        "    else:\n",
        "        # Copy from Drive\n",
        "        import distutils.dir_util\n",
        "        distutils.dir_util.copy_tree(DATA_ROOT, LOCAL_ROOT)\n",
        "    print(\"✅ Data ready on local disk!\")\n",
        "\n",
        "# Set final DATA_ROOT - prefer local for speed\n",
        "if os.path.exists(LOCAL_ROOT):\n",
        "    if \"waterbird_complete95_forest2water2\" in os.listdir(LOCAL_ROOT):\n",
        "        DATA_ROOT = os.path.join(LOCAL_ROOT, \"waterbird_complete95_forest2water2\")\n",
        "    elif os.path.exists(os.path.join(LOCAL_ROOT, \"metadata.csv\")):\n",
        "        DATA_ROOT = LOCAL_ROOT\n",
        "    else:\n",
        "        # Check subdirectories\n",
        "        for item in os.listdir(LOCAL_ROOT):\n",
        "            subpath = os.path.join(LOCAL_ROOT, item)\n",
        "            if os.path.isdir(subpath) and os.path.exists(os.path.join(subpath, \"metadata.csv\")):\n",
        "                DATA_ROOT = subpath\n",
        "                break\n",
        "\n",
        "# Check final status\n",
        "if os.path.exists(DATA_ROOT) and os.path.exists(os.path.join(DATA_ROOT, \"metadata.csv\")):\n",
        "    print(f\"\\n✅ Waterbirds dataset ready at: {DATA_ROOT}\")\n",
        "    print(f\"Contents: {os.listdir(DATA_ROOT)[:5]}...\")\n",
        "else:\n",
        "    print(f\"\\n❌ Dataset not properly set up at {DATA_ROOT}\")\n",
        "    print(\"Please check paths and try again.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iFQQPRWDW_dw",
        "outputId": "4c9377c5-10e4-4984-e2ba-878d5465027b"
      },
      "outputs": [],
      "source": [
        "# Diagnostic: Check what's actually in your dataset\n",
        "import os\n",
        "import pandas as pd\n",
        "\n",
        "# DATA_ROOT should already be set from previous cell\n",
        "print(f\"DATA_ROOT = {DATA_ROOT}\")\n",
        "\n",
        "# Check if basic structure exists\n",
        "print(\"\\nChecking dataset structure...\")\n",
        "print(f\"DATA_ROOT exists: {os.path.exists(DATA_ROOT)}\")\n",
        "print(f\"metadata.csv exists: {os.path.exists(os.path.join(DATA_ROOT, 'metadata.csv'))}\")\n",
        "\n",
        "# List contents\n",
        "if os.path.exists(DATA_ROOT):\n",
        "    print(f\"\\nContents of {DATA_ROOT}:\")\n",
        "    contents = os.listdir(DATA_ROOT)\n",
        "    for item in contents[:10]:\n",
        "        print(f\"  - {item}\")\n",
        "\n",
        "# Check metadata structure\n",
        "meta_path = os.path.join(DATA_ROOT, \"metadata.csv\")\n",
        "if os.path.exists(meta_path):\n",
        "    meta = pd.read_csv(meta_path)\n",
        "    print(f\"\\nMetadata columns: {list(meta.columns)}\")\n",
        "    print(f\"Number of rows: {len(meta)}\")\n",
        "    print(\"\\nFirst few img_filename entries:\")\n",
        "    print(meta['img_filename'].head(5))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "otz1JiE84ZUQ",
        "outputId": "8ac041d6-c1fb-40f3-cfe4-146cafaa8f3b"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# Verify DATA_ROOT is set correctly\n",
        "# ==========================================\n",
        "# DATA_ROOT should already be set from the first cell\n",
        "# This cell just verifies everything is ready\n",
        "\n",
        "print(f\"Final DATA_ROOT: {DATA_ROOT}\")\n",
        "print(f\"Exists: {os.path.exists(DATA_ROOT)}\")\n",
        "\n",
        "if os.path.exists(DATA_ROOT):\n",
        "    contents = os.listdir(DATA_ROOT)\n",
        "    print(f\"Contents ({len(contents)} items): {contents[:5]}...\")\n",
        "\n",
        "    # Verify key files exist\n",
        "    meta_exists = os.path.exists(os.path.join(DATA_ROOT, \"metadata.csv\"))\n",
        "    print(f\"metadata.csv exists: {meta_exists}\")\n",
        "\n",
        "    if meta_exists:\n",
        "        print(\"✅ Dataset is ready!\")\n",
        "    else:\n",
        "        print(\"❌ metadata.csv not found - check DATA_ROOT path\")\n",
        "else:\n",
        "    print(\"❌ DATA_ROOT does not exist\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 88
        },
        "id": "5ud532WabR-6",
        "outputId": "8139fbe9-476a-4d7e-9a5a-de333f8b14e4"
      },
      "outputs": [],
      "source": [
        "# -*- coding: utf-8 -*-\n",
        "\"\"\"\n",
        "IRS(KL)-Group vs CVaR DRO on Waterbirds (ResNet50)\n",
        "- Protocol: Strict Train/Val split. Test set only evaluated at the END.\n",
        "- Monitoring: Prints Train Loss, Validation Worst-Group Acc, and h_bar.\n",
        "- Optimization: Fast Vectorized Search, Fixed Seeds.\n",
        "- FIXED: Added missing PIL Image import.\n",
        "\"\"\"\n",
        "\n",
        "import os\n",
        "import time\n",
        "import math\n",
        "import warnings\n",
        "import random\n",
        "import copy\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from tqdm import tqdm\n",
        "from PIL import Image  # <--- FIXED: ADDED THIS IMPORT\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "from torchvision import transforms\n",
        "from torchvision.models import resnet50, ResNet50_Weights\n",
        "\n",
        "# ==========================================\n",
        "# 1. CONFIGURATION\n",
        "# ==========================================\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "torch.backends.cudnn.benchmark = True\n",
        "\n",
        "# >>> UPDATE PATHS <<<\n",
        "DATA_ROOT = \"/content/drive/MyDrive/datasets/waterbirds\"\n",
        "\n",
        "# Hyperparameters\n",
        "SEED = 42\n",
        "NUM_CLASSES = 2\n",
        "NUM_GROUPS = 4\n",
        "BATCH_SIZE = 256\n",
        "IMG_SIZE = 224\n",
        "\n",
        "# Training Settings\n",
        "EPOCHS = 50\n",
        "WARMUP_EPOCHS = 1\n",
        "LR_IRS = 1e-5\n",
        "LR_DRO = 1e-3\n",
        "WEIGHT_DECAY = 1e-4\n",
        "\n",
        "# IRS Parameters\n",
        "IRS_TAU = 0.1\n",
        "IRS_DISTANCE_SCALE = 1.0\n",
        "IRS_MIN_DIV = 1e-2\n",
        "IRS_H_MAX = 100.0       # prevents insane h\n",
        "IRS_TAU_EPS = 1e-4      # feasibility margin\n",
        "IRS_TAU_MULT = 1.01   # your “loss * 1.01 until final tau” rule\n",
        "\n",
        "\n",
        "\n",
        "# ==========================================\n",
        "# 2. UTILS\n",
        "# ==========================================\n",
        "def set_seed(seed):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "# ==========================================\n",
        "# 3. DATA & MODEL\n",
        "# ==========================================\n",
        "mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)\n",
        "\n",
        "train_tf = transforms.Compose([\n",
        "    transforms.Resize((IMG_SIZE, IMG_SIZE)),\n",
        "    transforms.RandomHorizontalFlip(),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(mean, std),\n",
        "])\n",
        "eval_tf = transforms.Compose([\n",
        "    transforms.Resize((IMG_SIZE, IMG_SIZE)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(mean, std),\n",
        "])\n",
        "\n",
        "class WaterbirdsDataset(Dataset):\n",
        "    def __init__(self, root, split, transform=None):\n",
        "        self.root = root\n",
        "        self.transform = transform\n",
        "        meta_path = os.path.join(root, \"metadata.csv\")\n",
        "        if not os.path.exists(meta_path): raise FileNotFoundError(f\"Metadata not found at {meta_path}\")\n",
        "        meta = pd.read_csv(meta_path)\n",
        "        split_map = {\"train\": 0, \"val\": 1, \"test\": 2}\n",
        "        self.meta = meta[meta[\"split\"] == split_map[split]].reset_index(drop=True)\n",
        "        self.meta[\"group\"] = self.meta[\"y\"] * 2 + self.meta[\"place\"]\n",
        "\n",
        "    def __len__(self): return len(self.meta)\n",
        "    def __getitem__(self, idx):\n",
        "        row = self.meta.iloc[idx]\n",
        "        img_path = os.path.join(self.root, \"images\", row[\"img_filename\"])\n",
        "        if not os.path.exists(img_path): img_path = os.path.join(self.root, row[\"img_filename\"])\n",
        "        img = Image.open(img_path).convert(\"RGB\")\n",
        "        if self.transform: img = self.transform(img)\n",
        "        return img, int(row[\"y\"]), int(idx), int(row[\"group\"])\n",
        "\n",
        "def get_loaders_and_priors():\n",
        "    train_ds = WaterbirdsDataset(DATA_ROOT, \"train\", train_tf)\n",
        "    val_ds   = WaterbirdsDataset(DATA_ROOT, \"val\", eval_tf)\n",
        "    test_ds  = WaterbirdsDataset(DATA_ROOT, \"test\", eval_tf)\n",
        "\n",
        "    group_counts = train_ds.meta[\"group\"].value_counts().sort_index().values\n",
        "    P_group = group_counts / group_counts.sum()\n",
        "    P_group_tensor = torch.tensor(P_group, dtype=torch.float32, device=device)\n",
        "\n",
        "    loaders = {\n",
        "        \"train\": DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True),\n",
        "        \"val\":   DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True),\n",
        "        \"test\":  DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)\n",
        "    }\n",
        "    return loaders, P_group_tensor\n",
        "\n",
        "def fresh_model():\n",
        "    weights = ResNet50_Weights.IMAGENET1K_V1\n",
        "    model = resnet50(weights=weights)\n",
        "    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)\n",
        "    return model.to(device)\n",
        "\n",
        "# ==========================================\n",
        "# 4. IRS MATH\n",
        "# ==========================================\n",
        "@torch.no_grad()\n",
        "def _kappa_for_h(F_vals, P_hat, tau_t, h, distance_scale=1.0, min_div=1e-2):\n",
        "    log_base = torch.log(P_hat.clamp_min(1e-30))\n",
        "    logits = log_base + (F_vals * h)\n",
        "    logits = torch.clamp(logits, min=logits.max()-40.0, max=logits.max()+1e-6)\n",
        "    P = torch.softmax(logits, dim=0)\n",
        "    E_F = torch.dot(P, F_vals)\n",
        "    num = E_F - tau_t\n",
        "    Dkl = (P * (torch.log(P.clamp_min(1e-30)) - torch.log(P_hat.clamp_min(1e-30)))).sum()\n",
        "    den = distance_scale * torch.max(Dkl, torch.tensor(0.0, device=device)) + min_div\n",
        "    kappa = num / torch.max(den, torch.tensor(1e-20, device=device))\n",
        "    return kappa, num, Dkl, P\n",
        "\n",
        "@torch.no_grad()\n",
        "def _fast_secant_search(F_vals, P_hat, tau_t, h_prev,\n",
        "                        dist_scale=1.0, min_div=1e-2, max_iter=2):\n",
        "    h0 = torch.tensor(max(0.0, h_prev - 2.0), device=device)\n",
        "    h1 = torch.tensor(max(0.1, h_prev + 2.0), device=device)\n",
        "    h0 = torch.clamp(h0, 0.0, IRS_H_MAX)\n",
        "    h1 = torch.clamp(h1, 0.0, IRS_H_MAX)\n",
        "\n",
        "    k0, _, _, _ = _kappa_for_h(F_vals, P_hat, tau_t, h0, dist_scale, min_div)\n",
        "    k1, _, _, _ = _kappa_for_h(F_vals, P_hat, tau_t, h1, dist_scale, min_div)\n",
        "\n",
        "    for _ in range(max_iter):\n",
        "        diff = h1 - h0\n",
        "        slope = (k1 - k0) / (diff + 1e-8)\n",
        "        h_new = torch.where(torch.abs(slope) > 1e-12, h1 - k1 / slope, h1 + 1.0)\n",
        "        h_new = torch.max(h_new, torch.tensor(0.0, device=device))\n",
        "        h_new = torch.clamp(h_new, 0.0, IRS_H_MAX)\n",
        "        k_new, _, _, _ = _kappa_for_h(F_vals, P_hat, tau_t, h_new, dist_scale, min_div)\n",
        "        h0 = h1; k0 = k1\n",
        "        h1 = h_new; k1 = k_new\n",
        "\n",
        "    _, num, Dkl, P = _kappa_for_h(F_vals, P_hat, tau_t, h1, dist_scale, min_div)\n",
        "    return h1.item(), num, Dkl, P\n",
        "\n",
        "def compute_group_stats(ce_vec, groups, num_groups, P_global):\n",
        "    counts = torch.bincount(groups, minlength=num_groups).float()\n",
        "    mask = counts > 0\n",
        "\n",
        "    ce_sum = torch.zeros(num_groups, device=device).scatter_add_(0, groups, ce_vec)\n",
        "    F = torch.zeros(num_groups, device=device)\n",
        "    F[mask] = ce_sum[mask] / counts[mask]\n",
        "\n",
        "    # reference distribution restricted to present groups, renormalized\n",
        "    P_hatS = torch.zeros_like(F)\n",
        "    mass_S = P_global[mask].sum().clamp_min(1e-12)\n",
        "    P_hatS[mask] = P_global[mask] / mass_S\n",
        "\n",
        "    return F, P_hatS, mask\n",
        "\n",
        "\n",
        "# ==========================================\n",
        "# 5. TRAINING\n",
        "# ==========================================\n",
        "@torch.no_grad()\n",
        "def eval_metrics(model, loader):\n",
        "    model.eval()\n",
        "    total_loss, total_n = 0., 0\n",
        "    group_correct = torch.zeros(NUM_GROUPS)\n",
        "    group_total = torch.zeros(NUM_GROUPS)\n",
        "    ce_loss = nn.CrossEntropyLoss(reduction='sum')\n",
        "\n",
        "    for x, y, _, g in loader:\n",
        "        x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "        logits = model(x)\n",
        "        total_loss += ce_loss(logits, y).item()\n",
        "        preds = logits.argmax(1)\n",
        "        correct = (preds == y)\n",
        "        total_n += y.size(0)\n",
        "        for k in range(NUM_GROUPS):\n",
        "            mask = (g == k)\n",
        "            if mask.any():\n",
        "                group_correct[k] += correct[mask].sum().item()\n",
        "                group_total[k] += mask.sum().item()\n",
        "    group_accs = group_correct / group_total.clamp(min=1)\n",
        "    return total_loss/total_n, group_accs.min().item()\n",
        "\n",
        "@torch.no_grad()\n",
        "def estimate_reference_risk_ce(model, loader, P_global_group):\n",
        "    \"\"\"\n",
        "    Estimates R(q_ref) = sum_g P_global[g] * r_g\n",
        "    where r_g is the mean CE loss for group g over the given loader.\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    ce_none = nn.CrossEntropyLoss(reduction=\"none\")\n",
        "\n",
        "    sum_loss = torch.zeros(NUM_GROUPS, device=device)\n",
        "    count = torch.zeros(NUM_GROUPS, device=device)\n",
        "\n",
        "    for x, y, _, g in loader:\n",
        "        x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "        logits = model(x)\n",
        "        ce = ce_none(logits, y)\n",
        "\n",
        "        sum_loss.scatter_add_(0, g, ce)\n",
        "        count.scatter_add_(0, g, torch.ones_like(ce))\n",
        "\n",
        "    F = sum_loss / count.clamp_min(1.0)  # group mean CE\n",
        "    return torch.dot(F, P_global_group).item()\n",
        "\n",
        "\n",
        "def train_algorithm(algo_name, model, loaders, P_global_group, epochs, lr, warmup_epochs=0, alpha=0.2):\n",
        "    print(f\"\\n[{algo_name}] Training {epochs} eps...\")\n",
        "\n",
        "    if algo_name == \"IRS-Group\":\n",
        "        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)\n",
        "    else:\n",
        "        opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY, momentum=0.9)\n",
        "\n",
        "    ce_none = nn.CrossEntropyLoss(reduction='none')\n",
        "\n",
        "    # Track Best Model\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    # 1. Warmup (IRS Only)\n",
        "    if algo_name == \"IRS-Group\":\n",
        "        for ep in range(1, warmup_epochs + 1):\n",
        "            model.train()\n",
        "            running_loss = 0.0\n",
        "            n_batches = 0\n",
        "            for x, y, _, _ in loaders['train']:\n",
        "                x, y = x.to(device), y.to(device)\n",
        "                opt.zero_grad()\n",
        "                loss = nn.CrossEntropyLoss()(model(x), y)\n",
        "                loss.backward(); opt.step()\n",
        "                running_loss += loss.item()\n",
        "                n_batches += 1\n",
        "\n",
        "            # Eval\n",
        "            v_loss, v_worst = eval_metrics(model, loaders['val'])\n",
        "            print(f\"[Warmup] Ep {ep}: Tr Loss {running_loss/n_batches:.4f} | Val Worst {v_worst:.4f}\")\n",
        "\n",
        "\n",
        "    # 2. Main Loop\n",
        "    prev_h = 0.0\n",
        "    for ep in range(warmup_epochs + 1, epochs + 1):\n",
        "        model.train()\n",
        "        sum_h, cnt = 0.0, 0\n",
        "        running_loss = 0.0\n",
        "        n_batches = 0\n",
        "\n",
        "        for x, y, idx, g in tqdm(loaders['train'], desc=f\"{algo_name} Ep {ep}\", leave=False):\n",
        "            x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "            opt.zero_grad()\n",
        "            logits = model(x)\n",
        "            ce_vec = ce_none(logits, y)\n",
        "\n",
        "            loss = None\n",
        "            if algo_name == \"IRS-Group\":\n",
        "                # IRS Logic\n",
        "                F_b, P_b, mask = compute_group_stats(ce_vec, g, NUM_GROUPS, P_global_group)\n",
        "\n",
        "                m_ref_b = torch.dot(F_b, P_b)              # reference risk on this batch support\n",
        "                maxF_b  = F_b[mask].max()\n",
        "\n",
        "                # final target tau (what you ultimately want to satisfice)\n",
        "                tau_final_t = F_b.new_tensor(float(IRS_TAU))\n",
        "\n",
        "                # tau used ONLY for root existence / worst-case construction until we reach final tau\n",
        "                tau_root = max(float(IRS_TAU), float(m_ref_b.item()) * IRS_TAU_MULT)\n",
        "\n",
        "                # keep tau_root strictly below maxF to avoid degenerate root issues\n",
        "                tau_root = min(tau_root, float(maxF_b.item()) - IRS_TAU_EPS)\n",
        "\n",
        "                # if batch is degenerate (all present groups identical so maxF≈m_ref), skip\n",
        "                if tau_root <= float(m_ref_b.item()) + IRS_TAU_EPS:\n",
        "                    continue\n",
        "\n",
        "                tau_root_t = F_b.new_tensor(tau_root)\n",
        "\n",
        "                try:\n",
        "                    h_star, num, D_kl, P_star_b = _fast_secant_search(\n",
        "                        F_b, P_b, tau_root_t, prev_h, IRS_DISTANCE_SCALE, IRS_MIN_DIV, max_iter=2\n",
        "                    )\n",
        "\n",
        "                except:\n",
        "                    h_star = prev_h\n",
        "                    _, _, D_kl, P_star_b = _kappa_for_h(F_b, P_b, tau_root_t, h_star, IRS_DISTANCE_SCALE, IRS_MIN_DIV)\n",
        "\n",
        "                r_wc = torch.dot(P_star_b, F_b)  # worst-case (under the constructed distribution)\n",
        "\n",
        "                # ---- YOUR gating: stop updating once we beat final tau on worst-case ----\n",
        "                if r_wc <= tau_final_t + IRS_TAU_EPS:\n",
        "                    continue  # objective is 0, no update\n",
        "\n",
        "                # otherwise do the robust-weighted CE update\n",
        "                P_const = P_star_b.detach()\n",
        "                loss = torch.dot(P_const, F_b)\n",
        "\n",
        "                sum_h += min(h_star, IRS_H_MAX)\n",
        "                prev_h = min(h_star, IRS_H_MAX)\n",
        "                cnt += 1\n",
        "\n",
        "\n",
        "            else:\n",
        "                # DRO Logic\n",
        "                k = max(1, int(alpha * ce_vec.size(0)))\n",
        "                top_losses, _ = torch.topk(ce_vec, k)\n",
        "                loss = top_losses.mean()\n",
        "\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            running_loss += loss.item()\n",
        "            n_batches += 1\n",
        "\n",
        "        # Eval\n",
        "        val_loss, val_worst = eval_metrics(model, loaders['val'])\n",
        "\n",
        "        # Checkpoint Best Model based on VALIDATION only\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "            is_best = \"*NEW BEST*\"\n",
        "        else:\n",
        "            is_best = \"\"\n",
        "\n",
        "        h_info = f\"| h_bar: {sum_h/max(1, cnt):.1f}\" if algo_name == \"IRS-Group\" else \"\"\n",
        "        avg_tr = running_loss / max(1, n_batches)\n",
        "        print(f\"[{algo_name} Ep {ep:02d}] Tr Loss: {avg_tr:.4f} | Val Worst: {val_worst:.4f} {h_info} {is_best}\")\n",
        "\n",
        "\n",
        "    # Load Best Model for Final Test Evaluation\n",
        "    model.load_state_dict(best_model_state)\n",
        "    return model\n",
        "\n",
        "# ==========================================\n",
        "# 6. EXECUTION\n",
        "# ==========================================\n",
        "'''\n",
        "if __name__ == \"__main__\":\n",
        "    set_seed(SEED)\n",
        "    print(f\"Seed set to {SEED}\")\n",
        "    print(\"Preparing Data...\")\n",
        "    loaders, P_global_group = get_loaders_and_priors()\n",
        "\n",
        "    # --------------------------------------\n",
        "    # RUN 1: IRS-Group\n",
        "    # --------------------------------------\n",
        "    print(\"\\n\" + \"=\"*40)\n",
        "    print(\" STARTING: IRS-Group\")\n",
        "    print(\"=\"*40)\n",
        "    t0 = time.time()\n",
        "    set_seed(SEED)\n",
        "    model_irs = fresh_model()\n",
        "    # Train and get best model\n",
        "    best_model_irs = train_algorithm(\"IRS-Group\", model_irs, loaders, P_global_group,\n",
        "                                     epochs=EPOCHS, lr=LR_IRS, warmup_epochs=WARMUP_EPOCHS)\n",
        "    time_irs = (time.time() - t0) / 60\n",
        "\n",
        "    # --------------------------------------\n",
        "    # FINAL EVALUATION (Rigorous)\n",
        "    # --------------------------------------\n",
        "    print(\"\\n\" + \"=\"*40)\n",
        "    print(\" FINAL EVALUATION ON TEST SET\")\n",
        "    print(\" (Evaluating Best Validation Models)\")\n",
        "    print(\"=\"*40)\n",
        "\n",
        "    _, test_worst_irs = eval_metrics(best_model_irs, loaders['test'])\n",
        "\n",
        "    print(f\"{'Method':<15} | {'Worst-Group Acc':<15} | {'Time (min)':<10}\")\n",
        "    print(\"-\" * 45)\n",
        "    print(f\"{'IRS-Group':<15} | {test_worst_irs:.4f}          | {time_irs:.2f}\")\n",
        "\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 106
        },
        "id": "87maa4uYCPpt",
        "outputId": "a6e48ac1-1c5b-445e-e4d0-d4b3378c7c36"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# DETAILED METRICS & CONFUSION MATRIX\n",
        "# ==========================================\n",
        "@torch.no_grad()\n",
        "def get_detailed_metrics(model, loader, model_name):\n",
        "    model.eval()\n",
        "    group_correct = torch.zeros(NUM_GROUPS)\n",
        "    group_total = torch.zeros(NUM_GROUPS)\n",
        "    total_correct = 0\n",
        "    total_samples = 0\n",
        "\n",
        "    # Confusion Matrix: Rows = True Label, Cols = Pred Label\n",
        "    # Class 0: Landbird, Class 1: Waterbird\n",
        "    cm = torch.zeros(NUM_CLASSES, NUM_CLASSES, dtype=torch.long)\n",
        "\n",
        "    for x, y, _, g in loader:\n",
        "        x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "        logits = model(x)\n",
        "        preds = logits.argmax(1)\n",
        "\n",
        "        correct_mask = (preds == y)\n",
        "        total_correct += correct_mask.sum().item()\n",
        "        total_samples += y.size(0)\n",
        "\n",
        "        # Update Group Stats\n",
        "        for k in range(NUM_GROUPS):\n",
        "            mask = (g == k)\n",
        "            if mask.any():\n",
        "                group_correct[k] += correct_mask[mask].sum().item()\n",
        "                group_total[k] += mask.sum().item()\n",
        "\n",
        "        # Update Confusion Matrix\n",
        "        for t, p in zip(y.view(-1), preds.view(-1)):\n",
        "            cm[t.long(), p.long()] += 1\n",
        "\n",
        "    # Calculate Accuracies\n",
        "    group_accs = (group_correct / group_total.clamp(min=1)).cpu().numpy()\n",
        "    overall_acc = total_correct / total_samples\n",
        "    worst_group_acc = group_accs.min()\n",
        "\n",
        "    # Print Group Results\n",
        "    print(f\"\\n>>> Detailed Results for {model_name} <<<\")\n",
        "    print(f\"{'Group Name':<20} | {'Size':<6} | {'Accuracy':<10}\")\n",
        "    print(\"-\" * 42)\n",
        "    group_names = [\"Land (Land bg)\", \"Land (Water bg)\", \"Water (Land bg)\", \"Water (Water bg)\"]\n",
        "    for i, name in enumerate(group_names):\n",
        "        print(f\"{name:<20} | {int(group_total[i]):<6} | {group_accs[i]*100:.2f}%\")\n",
        "    print(\"-\" * 42)\n",
        "    print(f\"{'Overall Accuracy':<20} | {overall_acc*100:.2f}%\")\n",
        "    print(f\"{'Worst-Group Acc':<20} | {worst_group_acc*100:.2f}%\")\n",
        "\n",
        "    # Print Confusion Matrix\n",
        "    print(f\"\\n[Confusion Matrix] (Rows=True, Cols=Pred)\")\n",
        "    print(f\"{'':<12} | {'Pred Land':<10} | {'Pred Water':<10}\")\n",
        "    print(\"-\" * 38)\n",
        "    print(f\"{'True Land':<12} | {cm[0,0]:<10} | {cm[0,1]:<10}\")\n",
        "    print(f\"{'True Water':<12} | {cm[1,0]:<10} | {cm[1,1]:<10}\")\n",
        "\n",
        "    return overall_acc, worst_group_acc\n",
        "\n",
        "# Run Evaluation\n",
        "print(\"\\n========================================\")\n",
        "print(\" FINAL DETAILED PAPER TABLE (With CM)\")\n",
        "print(\"========================================\")\n",
        "\n",
        "# 1. Evaluate IRS\n",
        "'''\n",
        "get_detailed_metrics(best_model_irs, loaders['test'], \"IRS-Group\")\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ln14Irlrj1uE",
        "outputId": "ce199293-7c04-459e-e2e9-9327f63ea6e0"
      },
      "outputs": [],
      "source": [
        "# =========================\n",
        "# SETUP: Data Loading\n",
        "# =========================\n",
        "# This cell MUST be run before any training!\n",
        "# It sets up: loaders, P_global_group, group_counts\n",
        "\n",
        "set_seed(SEED)\n",
        "\n",
        "print(\"Using DATA_ROOT:\", DATA_ROOT)\n",
        "print(\"Preparing data...\")\n",
        "loaders, P_global_group = get_loaders_and_priors()\n",
        "\n",
        "# Extract group counts for GroupDRO\n",
        "train_ds = loaders[\"train\"].dataset\n",
        "group_counts = train_ds.meta[\"group\"].value_counts().sort_index().values.tolist()\n",
        "print(f\"Group counts: {group_counts}\")\n",
        "print(f\"P_global_group: {P_global_group}\")\n",
        "print(f\"\\n✅ Data ready! loaders, P_global_group, and group_counts are now available.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jV2Zu53jLpu5",
        "outputId": "d56eb43d-3626-43b6-d05f-90d8b54519db"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === ALGORITHM CONTROL FLAGS ===\n",
        "# ==========================================\n",
        "# Control which algorithms to run for each phase.\n",
        "# Set to True to run, False to skip.\n",
        "# This gives you full control - checkpoints are NOT used for gating.\n",
        "# ==========================================\n",
        "\n",
        "# --- HYPERPARAMETER SWEEP FLAGS ---\n",
        "# Set True to run LR/hyperparam sweep for that algorithm\n",
        "RUN_SWEEP = {\n",
        "    \"ERM-SGD\":       False,  # ✅ COMPLETED: best_lr=0.01, valWorst=0.6391\n",
        "    \"ERM\":           False,  # ✅ COMPLETED: best_lr=1e-05, valWorst=0.5789 (Adam)\n",
        "    \"SAM\":           False,  # ✅ COMPLETED: best_lr=1e-05, valWorst=0.5489\n",
        "    \"IRS-Group\":     False,  # ✅ COMPLETED: best_lr=0.0001, valWorst=0.8421\n",
        "    \"KL-RS\":         False,  # ✅ COMPLETED: best_lr=0.0001, alt_iters=2, valWorst=0.6502\n",
        "    \"V-REx\":         False,  # ✅ COMPLETED: best_lr=0.0001, rex_beta=10.0, valWorst=0.6316\n",
        "    \"MM-REx\":        False,  # ✅ COMPLETED: best_lr=0.0001, rex_lambda=1.0, valWorst=0.6015\n",
        "    \"IRMv1\":         False,  # ✅ COMPLETED: best_lr=0.0001, penalty_weight=100, valWorst=0.6015\n",
        "    \"GroupDRO\":      False,  # ✅ COMPLETED: best_lr=0.0001, alpha=0.2, valWorst=0.7218\n",
        "    \"χ²-DRO\":        False,  # ✅ COMPLETED: best_lr=0.0001, rho=1.0, valWorst=0.6692\n",
        "    \"CVaR-DRO\":      True,   # Run sweep - NOT YET COMPLETED\n",
        "}\n",
        "\n",
        "# --- FULL TRAINING FLAGS ---\n",
        "# Set True to run full training (using best hyperparams from sweep or defaults)\n",
        "RUN_FULL_TRAINING = {\n",
        "    \"ERM-SGD\":       True,\n",
        "    \"ERM\":           True,\n",
        "    \"SAM\":           True,\n",
        "    \"IRS-Group\":     True,    # Keep IRS enabled by default\n",
        "    \"KL-RS\":         True,\n",
        "    \"V-REx\":         True,\n",
        "    \"MM-REx\":        True,\n",
        "    \"IRMv1\":         True,\n",
        "    \"GroupDRO\":      True,\n",
        "    \"χ²-DRO\":        True,\n",
        "    \"CVaR-DRO\":      True,\n",
        "}\n",
        "\n",
        "# --- TEST EVALUATION FLAGS ---\n",
        "# Set True to evaluate on test set (only meaningful if model was trained)\n",
        "RUN_TEST_EVAL = {\n",
        "    \"ERM-SGD\":       True,\n",
        "    \"ERM\":           True,\n",
        "    \"SAM\":           True,\n",
        "    \"IRS-Group\":     True,    # Keep IRS enabled by default\n",
        "    \"KL-RS\":         True,\n",
        "    \"V-REx\":         True,\n",
        "    \"MM-REx\":        True,\n",
        "    \"IRMv1\":         True,\n",
        "    \"GroupDRO\":      True,\n",
        "    \"χ²-DRO\":        True,\n",
        "    \"CVaR-DRO\":      True,\n",
        "}\n",
        "\n",
        "# --- ANALYSIS FLAGS ---\n",
        "# Set True to include in final analysis/visualization\n",
        "RUN_ANALYSIS = {\n",
        "    \"ERM-SGD\":       True,\n",
        "    \"ERM\":           True,\n",
        "    \"SAM\":           True,\n",
        "    \"IRS-Group\":     True,    # Keep IRS enabled by default\n",
        "    \"KL-RS\":         True,\n",
        "    \"V-REx\":         True,\n",
        "    \"MM-REx\":        True,\n",
        "    \"IRMv1\":         True,\n",
        "    \"GroupDRO\":      True,\n",
        "    \"χ²-DRO\":        True,\n",
        "    \"CVaR-DRO\":      True,\n",
        "}\n",
        "\n",
        "# === HELPER FUNCTIONS ===\n",
        "def should_sweep(algo: str) -> bool:\n",
        "    \"\"\"Check if we should run hyperparameter sweep for this algorithm.\"\"\"\n",
        "    return RUN_SWEEP.get(algo, False)\n",
        "\n",
        "def should_train(algo: str) -> bool:\n",
        "    \"\"\"Check if we should run full training for this algorithm.\"\"\"\n",
        "    return RUN_FULL_TRAINING.get(algo, False)\n",
        "\n",
        "def should_test(algo: str) -> bool:\n",
        "    \"\"\"Check if we should evaluate on test set for this algorithm.\"\"\"\n",
        "    return RUN_TEST_EVAL.get(algo, False)\n",
        "\n",
        "def should_analyze(algo: str) -> bool:\n",
        "    \"\"\"Check if we should include in analysis for this algorithm.\"\"\"\n",
        "    return RUN_ANALYSIS.get(algo, False)\n",
        "\n",
        "def print_run_plan():\n",
        "    \"\"\"Print which algorithms will be run in each phase.\"\"\"\n",
        "    algos = list(RUN_SWEEP.keys())\n",
        "\n",
        "    print(\"=\"*70)\n",
        "    print(\" ALGORITHM RUN PLAN\")\n",
        "    print(\"=\"*70)\n",
        "    print(f\"{'Algorithm':<15} | {'Sweep':<8} | {'Train':<8} | {'Test':<8} | {'Analysis':<8}\")\n",
        "    print(\"-\"*70)\n",
        "    for algo in algos:\n",
        "        sweep = \"✅\" if should_sweep(algo) else \"❌\"\n",
        "        train = \"✅\" if should_train(algo) else \"❌\"\n",
        "        test = \"✅\" if should_test(algo) else \"❌\"\n",
        "        analyze = \"✅\" if should_analyze(algo) else \"❌\"\n",
        "        print(f\"{algo:<15} | {sweep:<8} | {train:<8} | {test:<8} | {analyze:<8}\")\n",
        "    print(\"=\"*70)\n",
        "\n",
        "# Print the plan\n",
        "print_run_plan()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "83agbbIsLpu5"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === Utility Functions ===\n",
        "# ==========================================\n",
        "\n",
        "import gc\n",
        "\n",
        "def cuda_clear_cache():\n",
        "    \"\"\"Clear CUDA cache.\"\"\"\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "        torch.cuda.ipc_collect()\n",
        "\n",
        "def reset_seeds(seed: int = SEED):\n",
        "    \"\"\"Reset random seeds for reproducibility.\"\"\"\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = True\n",
        "\n",
        "def log_epoch(tag, ep, tr_loss, val_worst, val_overall, t0, extra=\"\"):\n",
        "    \"\"\"Log epoch progress.\"\"\"\n",
        "    elapsed = time.time() - t0\n",
        "    hrs, rem = divmod(elapsed, 3600)\n",
        "    mins, secs = divmod(rem, 60)\n",
        "    t_str = f\"{int(hrs):02d}:{int(mins):02d}:{int(secs):02d}\"\n",
        "    msg = f\"[{tag}] ep{ep:03d} trainLoss={tr_loss:.4f} valWorst={val_worst:.4f} valOverall={val_overall:.4f} t={t_str}\"\n",
        "    if extra:\n",
        "        msg += f\" | {extra}\"\n",
        "    print(msg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FU1SdrfDLpu6"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === SAM Optimizer (Foret et al., ICLR 2021) ===\n",
        "# ==========================================\n",
        "\n",
        "class SAMAdam(torch.optim.Adam):\n",
        "    \"\"\"\n",
        "    Sharpness-Aware Minimization with Adam optimizer.\n",
        "    Paper: \"Sharpness-Aware Minimization for Efficiently Improving Generalization\" (ICLR 2021)\n",
        "    \"\"\"\n",
        "    def __init__(self, params, lr=1e-3, rho=0.05, **kwargs):\n",
        "        if rho <= 0:\n",
        "            raise ValueError(f\"rho must be positive, got {rho}\")\n",
        "        self.rho = float(rho)\n",
        "        super().__init__(params, lr=lr, **kwargs)\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _grad_norm(self):\n",
        "        device_ = self.param_groups[0][\"params\"][0].device\n",
        "        norms = []\n",
        "        for group in self.param_groups:\n",
        "            for p in group[\"params\"]:\n",
        "                if p.grad is not None:\n",
        "                    norms.append(p.grad.detach().norm(2))\n",
        "        if not norms:\n",
        "            return torch.tensor(0., device=device_)\n",
        "        return torch.norm(torch.stack(norms), p=2)\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _epsilon(self, scale):\n",
        "        epsilons = []\n",
        "        for group in self.param_groups:\n",
        "            eps_group = []\n",
        "            for p in group[\"params\"]:\n",
        "                if p.grad is None:\n",
        "                    eps_group.append(None)\n",
        "                    continue\n",
        "                e = p.grad * scale\n",
        "                p.add_(e)\n",
        "                eps_group.append(e)\n",
        "            epsilons.append(eps_group)\n",
        "        return epsilons\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _restore(self, epsilons):\n",
        "        for group, eps_group in zip(self.param_groups, epsilons):\n",
        "            for p, e in zip(group[\"params\"], eps_group):\n",
        "                if e is not None:\n",
        "                    p.sub_(e)\n",
        "\n",
        "def disable_running_stats(model):\n",
        "    \"\"\"Freeze BN running stats for SAM's second forward pass.\"\"\"\n",
        "    for m in model.modules():\n",
        "        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):\n",
        "            m._sam_backup_momentum = m.momentum\n",
        "            m.momentum = 0.0\n",
        "\n",
        "def enable_running_stats(model):\n",
        "    \"\"\"Restore BN momentum after SAM.\"\"\"\n",
        "    for m in model.modules():\n",
        "        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm) and hasattr(m, \"_sam_backup_momentum\"):\n",
        "            m.momentum = m._sam_backup_momentum\n",
        "            delattr(m, \"_sam_backup_momentum\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4iDqpD8NLpu6"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === DRO Loss Functions ===\n",
        "# ==========================================\n",
        "\n",
        "def irm_penalty(logits, y):\n",
        "    \"\"\"\n",
        "    IRMv1 penalty: ||∇_w [ loss(w ∘ Φ(x); y) ] ||^2 at w=1.0\n",
        "    Paper: Arjovsky et al., \"Invariant Risk Minimization\" (2019)\n",
        "    \"\"\"\n",
        "    scale = torch.ones(1, requires_grad=True, device=logits.device, dtype=logits.dtype)\n",
        "    loss = nn.CrossEntropyLoss()(logits * scale, y)\n",
        "    grad = torch.autograd.grad(loss, [scale], create_graph=True)[0]\n",
        "    return (grad ** 2).sum()\n",
        "\n",
        "\n",
        "def get_rex_penalty(losses, mode=\"vrex\", lam=1.0):\n",
        "    \"\"\"\n",
        "    Compute variance-based or Min-Max objective from environment losses.\n",
        "    Reference: Krueger et al. \"Out-of-Distribution Generalization via Risk Extrapolation (REx)\" ICML 2021\n",
        "    \"\"\"\n",
        "    if mode == \"vrex\":\n",
        "        mean_loss = sum(losses) / len(losses)\n",
        "        var_loss = sum((l - mean_loss) ** 2 for l in losses) / len(losses)\n",
        "        return var_loss\n",
        "    elif mode == \"mmrex\":\n",
        "        loss_max = max(losses)\n",
        "        loss_min = min(losses)\n",
        "        return lam * loss_max + (1 - lam) * loss_min\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown REx mode: {mode}\")\n",
        "\n",
        "\n",
        "def chi2_dro_loss(loss_vec: torch.Tensor, rho: float,\n",
        "                  bisect_tol: float = 1e-5, max_iter: int = 50) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Chi-squared DRO loss.\n",
        "    R_ρ(θ) = inf_η { √(1+2ρ) · √(E_P[(ℓ(θ;X) - η)_+²]) + η }\n",
        "    \"\"\"\n",
        "    B = loss_vec.size(0)\n",
        "    if B == 0:\n",
        "        return loss_vec.new_tensor(0.0)\n",
        "\n",
        "    rho = max(float(rho), 0.0)\n",
        "    c = math.sqrt(1.0 + 2.0 * rho)\n",
        "\n",
        "    eta_min = loss_vec.min().item()\n",
        "    eta_max = loss_vec.max().item()\n",
        "\n",
        "    if abs(eta_max - eta_min) < 1e-12:\n",
        "        return loss_vec.mean()\n",
        "\n",
        "    def eval_R(eta):\n",
        "        diff = loss_vec - eta\n",
        "        positive_part = torch.clamp(diff, min=0.0)\n",
        "        var_term = (positive_part ** 2).mean()\n",
        "        return c * torch.sqrt(var_term + 1e-12) + eta\n",
        "\n",
        "    for _ in range(max_iter):\n",
        "        if (eta_max - eta_min) < bisect_tol:\n",
        "            break\n",
        "        eta_mid = 0.5 * (eta_min + eta_max)\n",
        "        diff = loss_vec - eta_mid\n",
        "        positive_part = torch.clamp(diff, min=0.0)\n",
        "        var_term = (positive_part ** 2).mean()\n",
        "        grad = 1.0 - c * (positive_part.mean()) / torch.sqrt(var_term + 1e-12)\n",
        "        if grad.item() < 0:\n",
        "            eta_min = eta_mid\n",
        "        else:\n",
        "            eta_max = eta_mid\n",
        "\n",
        "    eta_star = 0.5 * (eta_min + eta_max)\n",
        "    return eval_R(eta_star)\n",
        "\n",
        "\n",
        "def cvar_loss_from_batch(loss_vec: torch.Tensor, alpha: float) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    CVaR-DRO loss from Levy et al., \"Large-Scale Methods for Distributionally Robust Optimization\" (NeurIPS 2020).\n",
        "    CVaR_α(ℓ) = mean of top-α quantile losses.\n",
        "    \"\"\"\n",
        "    alpha = max(min(float(alpha), 1.0), 1e-6)\n",
        "    B = loss_vec.size(0)\n",
        "    k = max(1, int(math.ceil(alpha * B)))\n",
        "    topk_losses, _ = torch.topk(loss_vec, k, largest=True, sorted=False)\n",
        "    return topk_losses.mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h052KPdKLpu7"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === GroupDRO LossComputer (Sagawa et al., 2020) ===\n",
        "# ==========================================\n",
        "\n",
        "class LossComputer:\n",
        "    \"\"\"\n",
        "    Paper-faithful GroupDRO loss computer from Sagawa et al., \"Distributionally Robust Neural Networks\n",
        "    for Group Shifts: On the Importance of Regularization for Worst-Case Generalization\" (ICLR 2020).\n",
        "    \"\"\"\n",
        "    def __init__(self, group_counts, alpha=0.2, gamma=0.1, device='cuda'):\n",
        "        self.alpha = alpha\n",
        "        self.gamma = gamma\n",
        "        self.device = device\n",
        "        self.n_groups = len(group_counts)\n",
        "\n",
        "        self.adv_probs = torch.ones(self.n_groups, device=device) / self.n_groups\n",
        "        self.exp_avg_loss = torch.zeros(self.n_groups, device=device)\n",
        "        self.exp_avg_initialized = torch.zeros(self.n_groups, device=device).bool()\n",
        "        self.group_counts = torch.tensor(group_counts, dtype=torch.float32, device=device)\n",
        "        self.group_frac = self.group_counts / self.group_counts.sum()\n",
        "\n",
        "    def loss(self, yhat, y, group_idx, is_training=True):\n",
        "        per_sample_losses = nn.CrossEntropyLoss(reduction='none')(yhat, y)\n",
        "\n",
        "        group_losses = []\n",
        "        group_counts_batch = []\n",
        "        for g in range(self.n_groups):\n",
        "            mask = (group_idx == g)\n",
        "            if mask.any():\n",
        "                group_loss = per_sample_losses[mask].mean()\n",
        "                group_losses.append(group_loss)\n",
        "                group_counts_batch.append(mask.sum().item())\n",
        "            else:\n",
        "                group_losses.append(torch.tensor(0.0, device=self.device))\n",
        "                group_counts_batch.append(0)\n",
        "\n",
        "        group_losses = torch.stack(group_losses)\n",
        "        group_counts_batch = torch.tensor(group_counts_batch, dtype=torch.float32, device=self.device)\n",
        "\n",
        "        if is_training:\n",
        "            for g in range(self.n_groups):\n",
        "                if group_counts_batch[g] > 0:\n",
        "                    if not self.exp_avg_initialized[g]:\n",
        "                        self.exp_avg_loss[g] = group_losses[g].detach()\n",
        "                        self.exp_avg_initialized[g] = True\n",
        "                    else:\n",
        "                        self.exp_avg_loss[g] = (\n",
        "                            self.gamma * group_losses[g].detach() +\n",
        "                            (1 - self.gamma) * self.exp_avg_loss[g]\n",
        "                        )\n",
        "\n",
        "            adv_probs = torch.exp(self.alpha * self.exp_avg_loss.detach())\n",
        "            adv_probs = adv_probs / (adv_probs.sum() + 1e-12)\n",
        "            self.adv_probs = adv_probs\n",
        "\n",
        "        weighted_loss = (self.adv_probs * group_losses).sum()\n",
        "        return weighted_loss, {\"avg_loss\": per_sample_losses.mean().item()}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dnizothyLpu7"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === Training Functions (ERM, SAM) ===\n",
        "# ==========================================\n",
        "\n",
        "def train_erm(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "              weight_decay=0.0, print_every=5):\n",
        "    \"\"\"Standard ERM training with Adam optimizer.\"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "    ce_fn = nn.CrossEntropyLoss()\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"ERM Ep {ep}\"):\n",
        "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "            loss = ce_fn(logits, y)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += loss.item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "            is_best = \" *BEST*\"\n",
        "        else:\n",
        "            is_best = \"\"\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"ERM\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "            print(f\"  {is_best}\")\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "\n",
        "def train_erm_sgd(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "                  weight_decay=0.0, momentum=0.9, print_every=5):\n",
        "    \"\"\"Standard ERM training with SGD optimizer.\"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)\n",
        "    ce_fn = nn.CrossEntropyLoss()\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"ERM-SGD Ep {ep}\"):\n",
        "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "            loss = ce_fn(logits, y)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += loss.item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "            is_best = \" *BEST*\"\n",
        "        else:\n",
        "            is_best = \"\"\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"ERM-SGD\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "            print(f\"  {is_best}\")\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "\n",
        "def train_sam(model, train_loader, val_loader, *, epochs=50, lr=1e-4, rho=0.05,\n",
        "              weight_decay=0.0, print_every=5):\n",
        "    \"\"\"SAM training with Adam base optimizer.\"\"\"\n",
        "    model.to(device)\n",
        "    opt = SAMAdam(model.parameters(), lr=lr, rho=rho, weight_decay=weight_decay)\n",
        "    ce_fn = nn.CrossEntropyLoss()\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"SAM Ep {ep}\"):\n",
        "            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
        "\n",
        "            # First forward-backward\n",
        "            enable_running_stats(model)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "            logits = model(x)\n",
        "            loss = ce_fn(logits, y)\n",
        "            loss.backward()\n",
        "\n",
        "            # Compute perturbation\n",
        "            gnorm = opt._grad_norm()\n",
        "            scale = rho / (gnorm + 1e-12)\n",
        "            epsilons = opt._epsilon(scale)\n",
        "\n",
        "            # Second forward-backward\n",
        "            disable_running_stats(model)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "            logits_pert = model(x)\n",
        "            loss_pert = ce_fn(logits_pert, y)\n",
        "            loss_pert.backward()\n",
        "\n",
        "            # Restore and step\n",
        "            opt._restore(epsilons)\n",
        "            enable_running_stats(model)\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += loss.item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "            is_best = \" *BEST*\"\n",
        "        else:\n",
        "            is_best = \"\"\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"SAM\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "            print(f\"  {is_best}\")\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def eval_overall_accuracy(model, loader):\n",
        "    \"\"\"Compute overall accuracy.\"\"\"\n",
        "    model.eval()\n",
        "    correct, total = 0, 0\n",
        "    for x, y, _, g in loader:\n",
        "        x, y = x.to(device), y.to(device)\n",
        "        preds = model(x).argmax(1)\n",
        "        correct += (preds == y).sum().item()\n",
        "        total += y.size(0)\n",
        "    return correct / max(1, total)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QDXKSDodLpu7",
        "outputId": "759582f0-3a86-47ef-ae2b-bf41e3f400ec"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === Training Functions (V-REx, MM-REx, IRMv1) ===\n",
        "# ==========================================\n",
        "# Paper-faithful implementations with proper environment handling and penalty annealing\n",
        "\n",
        "def train_vrex(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "               rex_beta=1.0, penalty_anneal_epochs=10, weight_decay=0.0, print_every=5):\n",
        "    \"\"\"\n",
        "    V-REx training (Krueger et al., ICML 2021).\n",
        "\n",
        "    L = ERM + β * Var(R_e)\n",
        "\n",
        "    NOTE: Uses environment (background) labels, NOT all 4 groups.\n",
        "    Environment is extracted as: env = g % 2 (0=land, 1=water background)\n",
        "\n",
        "    Includes penalty annealing as recommended in paper.\n",
        "    \"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "    ce_none = nn.CrossEntropyLoss(reduction='none')\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        # Anneal penalty weight (paper recommendation)\n",
        "        penalty_weight = min(1.0, float(ep) / float(penalty_anneal_epochs)) if ep <= penalty_anneal_epochs else 1.0\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"V-REx Ep {ep}\"):\n",
        "            x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "\n",
        "            # Extract environment from group (place = g % 2)\n",
        "            env = g % 2  # 0 = land background, 1 = water background\n",
        "\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "            ce_vec = ce_none(logits, y)\n",
        "\n",
        "            # Compute per-ENVIRONMENT losses (2 environments, not 4 groups)\n",
        "            env_losses = []\n",
        "            for e in [0, 1]:\n",
        "                mask = (env == e)\n",
        "                if mask.any():\n",
        "                    env_losses.append(ce_vec[mask].mean())\n",
        "\n",
        "            if len(env_losses) == 2:\n",
        "                # V-REx: L = ERM + β * Var(R_e)\n",
        "                mean_loss = sum(env_losses) / len(env_losses)\n",
        "                var_penalty = get_rex_penalty(env_losses, mode=\"vrex\")\n",
        "                loss = mean_loss + penalty_weight * rex_beta * var_penalty\n",
        "            else:\n",
        "                # If only one environment in batch, use mean loss\n",
        "                loss = ce_vec.mean()\n",
        "\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += ce_vec.mean().item() * x.size(0)  # Track ERM loss\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"V-REx\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "\n",
        "def train_mmrex(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "                rex_lambda=1.5, penalty_anneal_epochs=10, weight_decay=0.0, print_every=5):\n",
        "    \"\"\"\n",
        "    MM-REx training (Krueger et al., ICML 2021).\n",
        "\n",
        "    L = λ * max(R_e) + (1-λ) * min(R_e)\n",
        "\n",
        "    This REPLACES the ERM objective (not a penalty).\n",
        "    - λ = 0.5: ERM (equal weight to min/max)\n",
        "    - λ = 1.0: Worst-case only (like GroupDRO)\n",
        "    - λ > 1.0: Extrapolation beyond worst-case (key insight of MM-REx)\n",
        "\n",
        "    NOTE: Uses environment labels, anneals λ from 0.5 to target.\n",
        "    \"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "    ce_none = nn.CrossEntropyLoss(reduction='none')\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        # Anneal λ from 0.5 (ERM) to target λ (paper recommendation)\n",
        "        anneal_factor = min(1.0, float(ep) / float(penalty_anneal_epochs)) if ep <= penalty_anneal_epochs else 1.0\n",
        "        current_lambda = 0.5 + anneal_factor * (rex_lambda - 0.5)\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"MM-REx Ep {ep}\"):\n",
        "            x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "\n",
        "            # Extract environment from group\n",
        "            env = g % 2\n",
        "\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "            ce_vec = ce_none(logits, y)\n",
        "\n",
        "            # Compute per-ENVIRONMENT losses\n",
        "            env_losses = []\n",
        "            for e in [0, 1]:\n",
        "                mask = (env == e)\n",
        "                if mask.any():\n",
        "                    env_losses.append(ce_vec[mask].mean())\n",
        "\n",
        "            if len(env_losses) == 2:\n",
        "                # MM-REx: L = λ*max(R_e) + (1-λ)*min(R_e)\n",
        "                loss = get_rex_penalty(env_losses, mode=\"mmrex\", lam=current_lambda)\n",
        "            else:\n",
        "                loss = ce_vec.mean()\n",
        "\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += ce_vec.mean().item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"MM-REx\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "\n",
        "def train_irm(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "              penalty_weight=10000.0, penalty_anneal_epochs=20, weight_decay=0.0, print_every=5):\n",
        "    \"\"\"\n",
        "    IRMv1 training (Arjovsky et al., 2019).\n",
        "\n",
        "    L = ERM + λ * Σ_e ||∇_w [R_e(w·Φ)]|w=1||²\n",
        "\n",
        "    NOTE: Uses environment (background) labels for invariance.\n",
        "    Includes penalty annealing as strongly recommended in original paper.\n",
        "\n",
        "    Default penalty_weight=10000 and penalty_anneal_epochs=20 from paper.\n",
        "    \"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        # Anneal penalty (critical for IRM convergence)\n",
        "        current_penalty_weight = penalty_weight * min(1.0, float(ep) / float(penalty_anneal_epochs))\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"IRMv1 Ep {ep}\"):\n",
        "            x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "\n",
        "            # Extract environment from group\n",
        "            env = g % 2\n",
        "\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "\n",
        "            # Compute per-ENVIRONMENT losses and IRM penalties\n",
        "            env_losses = []\n",
        "            env_penalties = []\n",
        "            for e in [0, 1]:\n",
        "                mask = (env == e)\n",
        "                if mask.sum() > 1:  # Need at least 2 samples for gradient\n",
        "                    env_logits = logits[mask]\n",
        "                    env_y = y[mask]\n",
        "                    env_loss = nn.CrossEntropyLoss()(env_logits, env_y)\n",
        "                    env_pen = irm_penalty(env_logits, env_y)\n",
        "                    env_losses.append(env_loss)\n",
        "                    env_penalties.append(env_pen)\n",
        "\n",
        "            if len(env_losses) == 2:\n",
        "                mean_loss = sum(env_losses) / len(env_losses)\n",
        "                mean_penalty = sum(env_penalties) / len(env_penalties)\n",
        "                loss = mean_loss + current_penalty_weight * mean_penalty\n",
        "            else:\n",
        "                loss = nn.CrossEntropyLoss()(logits, y)\n",
        "\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += nn.CrossEntropyLoss()(logits.detach(), y).item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"IRMv1\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "print(\"✅ V-REx, MM-REx, IRMv1 defined (paper-faithful with env labels & penalty annealing)\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ueEFNRbzLpu7"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === Training Functions (GroupDRO, χ²-DRO, CVaR-DRO) ===\n",
        "# ==========================================\n",
        "\n",
        "def train_groupdro(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "                   alpha=0.2, weight_decay=0.0, print_every=5, group_counts=None):\n",
        "    \"\"\"GroupDRO training (Sagawa et al., ICLR 2020).\"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "\n",
        "    if group_counts is None:\n",
        "        group_counts = [1] * NUM_GROUPS\n",
        "    loss_computer = LossComputer(group_counts, alpha=alpha, device=device)\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"GroupDRO Ep {ep}\"):\n",
        "            x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "            loss, _ = loss_computer.loss(logits, y, g, is_training=True)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += loss.item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"GroupDRO\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "\n",
        "def train_chi2_dro(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "                   rho=0.1, weight_decay=0.0, print_every=5):\n",
        "    \"\"\"χ²-DRO training.\"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "    ce_none = nn.CrossEntropyLoss(reduction='none')\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"χ²-DRO Ep {ep}\"):\n",
        "            x, y = x.to(device), y.to(device)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "            ce_vec = ce_none(logits, y)\n",
        "            loss = chi2_dro_loss(ce_vec, rho)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += loss.item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"χ²-DRO\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "\n",
        "def train_cvar_dro(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "                   alpha=0.1, weight_decay=0.0, print_every=5):\n",
        "    \"\"\"CVaR-DRO training (Levy et al., NeurIPS 2020).\"\"\"\n",
        "    model.to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "    ce_none = nn.CrossEntropyLoss(reduction='none')\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        total_loss, total_n = 0.0, 0\n",
        "\n",
        "        for x, y, _, g in tqdm(train_loader, leave=False, desc=f\"CVaR-DRO Ep {ep}\"):\n",
        "            x, y = x.to(device), y.to(device)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "\n",
        "            logits = model(x)\n",
        "            ce_vec = ce_none(logits, y)\n",
        "            loss = cvar_loss_from_batch(ce_vec, alpha)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            total_loss += loss.item() * x.size(0)\n",
        "            total_n += x.size(0)\n",
        "\n",
        "        tr_loss = total_loss / max(1, total_n)\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if ep % print_every == 0:\n",
        "            log_epoch(\"CVaR-DRO\", ep, tr_loss, val_worst, val_overall, t0)\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UChbuLE-Lpu8",
        "outputId": "4a02b1df-07c9-4643-e8d0-ec72272cdf37"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === KL-RS Training Function ===\n",
        "# ==========================================\n",
        "\n",
        "def train_klrs(model, train_loader, val_loader, *, epochs=50, lr=1e-4,\n",
        "               weight_decay=0.0, base_tau=None, warmup_epochs=3, lam_lo=1e-3, lam_hi_init=1.0,\n",
        "               bisect_tol=1e-3, expand_factor=2.0, max_expand=12,\n",
        "               inner_epochs_probe=1, alt_iters=10, inner_epochs_theta=2,\n",
        "               exp_clamp=50.0, grad_clip_max_norm=None, print_every=5):\n",
        "    \"\"\"\n",
        "    KL-RS (Dual KL Risk-Sensitive) trainer:\n",
        "      • τ is FIXED: provided via base_tau or estimated once from a few batches (1.05× mean CE).\n",
        "      • λ* is found by doubling + bisection; feasibility probes RUN SHORT θ-UPDATES.\n",
        "      • Objective: minimize E[exp((ℓ - τ)/λ)].\n",
        "      • Includes ERM warmup phase before KL-RS constraint.\n",
        "      • Tracks worst-group accuracy for hyperparameter selection.\n",
        "    \"\"\"\n",
        "    model.to(device)\n",
        "    ce_nored = nn.CrossEntropyLoss(reduction='none')\n",
        "    ce_mean = nn.CrossEntropyLoss(reduction='mean')\n",
        "\n",
        "    best_val_worst = -1.0\n",
        "    best_model_state = None\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _estimate_mean_f(m, loader, tau, lam):\n",
        "        m.eval()\n",
        "        tot, n_batches = 0.0, 0\n",
        "        for xb, yb, *_ in loader:\n",
        "            xb, yb = xb.to(device), yb.to(device)\n",
        "            logits = m(xb)\n",
        "            a = (ce_nored(logits, yb) - tau) / lam\n",
        "            a = torch.clamp(a, min=-exp_clamp, max=exp_clamp)\n",
        "            f = torch.exp(a).mean()\n",
        "            tot += f.item()\n",
        "            n_batches += 1\n",
        "        return tot / max(n_batches, 1)\n",
        "\n",
        "    def _feasibility_train_step(m, loader, opt, tau, lam):\n",
        "        m.train()\n",
        "        for xb, yb, *_ in loader:\n",
        "            xb, yb = xb.to(device), yb.to(device)\n",
        "            opt.zero_grad(set_to_none=True)\n",
        "            logits = m(xb)\n",
        "            a = (ce_nored(logits, yb) - tau) / lam\n",
        "            a = torch.clamp(a, min=-exp_clamp, max=exp_clamp)\n",
        "            f = torch.exp(a).mean()\n",
        "            f.backward()\n",
        "            if grad_clip_max_norm is not None:\n",
        "                torch.nn.utils.clip_grad_norm_(m.parameters(), grad_clip_max_norm)\n",
        "            opt.step()\n",
        "\n",
        "    def _feasibility_oracle(m, loader, tau, lam, inner_epochs, lr_, wd):\n",
        "        m = m.to(device)\n",
        "        opt = torch.optim.Adam(m.parameters(), lr=lr_, weight_decay=wd)\n",
        "        for _ in range(inner_epochs):\n",
        "            _feasibility_train_step(m, loader, opt, tau, lam)\n",
        "        mean_f = _estimate_mean_f(m, loader, tau, lam)\n",
        "        return m, mean_f\n",
        "\n",
        "    def _find_min_feasible_lambda(m, loader, tau, lam_lo_init, lam_hi_init, lr_, wd,\n",
        "                                   inner_epochs_probe_, bisect_tol_, expand_factor_, max_expand_):\n",
        "        lam_lo_, lam_hi_ = lam_lo_init, lam_hi_init\n",
        "        m, mean_f = _feasibility_oracle(m, loader, tau, lam_hi_, inner_epochs_probe_, lr_, wd)\n",
        "        expands = 0\n",
        "        while mean_f > 1.0 and expands < max_expand_:\n",
        "            lam_lo_ = lam_hi_\n",
        "            lam_hi_ *= expand_factor_\n",
        "            m, mean_f = _feasibility_oracle(m, loader, tau, lam_hi_, inner_epochs_probe_, lr_, wd)\n",
        "            expands += 1\n",
        "        if mean_f > 1.0:\n",
        "            print(f\"[KLRS] Failed to find feasible λ after {max_expand_} expansions (λ={lam_hi_:.2e}, E[f]={mean_f:.3f})\")\n",
        "            return m, lam_hi_\n",
        "\n",
        "        # Contract phase\n",
        "        for _ in range(max_expand_):\n",
        "            m, mean_f_lo = _feasibility_oracle(m, loader, tau, lam_lo_, inner_epochs_probe_, lr_, wd)\n",
        "            if mean_f_lo > 1.0:\n",
        "                break\n",
        "            lam_hi_ = lam_lo_\n",
        "            lam_lo_ = max(lam_lo_ / expand_factor_, 1e-6)\n",
        "            if lam_lo_ <= 1e-6:\n",
        "                break\n",
        "\n",
        "        # Bisection\n",
        "        while (lam_hi_ - lam_lo_) > bisect_tol_ * max(1.0, lam_hi_):\n",
        "            lam_mid = 0.5 * (lam_lo_ + lam_hi_)\n",
        "            m, mean_f_mid = _feasibility_oracle(m, loader, tau, lam_mid, inner_epochs_probe_, lr_, wd)\n",
        "            if mean_f_mid <= 1.0:\n",
        "                lam_hi_ = lam_mid\n",
        "            else:\n",
        "                lam_lo_ = lam_mid\n",
        "        return m, lam_hi_\n",
        "\n",
        "    # --- τ: fixed ---\n",
        "    if base_tau is None:\n",
        "        with torch.no_grad():\n",
        "            est, n = 0.0, 0\n",
        "            model.eval()\n",
        "            for xb, yb, *_ in train_loader:\n",
        "                xb, yb = xb.to(device), yb.to(device)\n",
        "                est += ce_mean(model(xb), yb).item()\n",
        "                n += 1\n",
        "                if n >= 5:\n",
        "                    break\n",
        "            tau = 1.05 * (est / max(n, 1))\n",
        "    else:\n",
        "        tau = float(base_tau)\n",
        "\n",
        "    # === WARMUP PHASE: ERM training to bring losses down before KL-RS constraint ===\n",
        "    if warmup_epochs > 0:\n",
        "        print(f\"[KLRS] Starting {warmup_epochs} epochs of ERM warmup (tau={tau:.4f})...\")\n",
        "        opt_theta = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "        for ep in range(1, warmup_epochs + 1):\n",
        "            model.train()\n",
        "            for xb, yb, *_ in train_loader:\n",
        "                xb, yb = xb.to(device), yb.to(device)\n",
        "                opt_theta.zero_grad(set_to_none=True)\n",
        "                loss = ce_mean(model(xb), yb)\n",
        "                loss.backward()\n",
        "                if grad_clip_max_norm is not None:\n",
        "                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max_norm)\n",
        "                opt_theta.step()\n",
        "\n",
        "    tr_hist, va_worst_hist, va_overall_hist = [], [], []\n",
        "    t0 = time.time()\n",
        "    opt_theta = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "    lam = float(lam_hi_init)\n",
        "\n",
        "    # Alternations: (1) improve θ at current λ, (2) find smallest feasible λ\n",
        "    for it in range(alt_iters):\n",
        "        for _ in range(inner_epochs_theta):\n",
        "            _feasibility_train_step(model, train_loader, opt_theta, tau, lam)\n",
        "\n",
        "        model, lam = _find_min_feasible_lambda(\n",
        "            model, train_loader, tau, lam_lo, lam, lr, weight_decay,\n",
        "            inner_epochs_probe, bisect_tol, expand_factor, max_expand)\n",
        "\n",
        "        # Evaluate\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            tr_sum, tr_n = 0.0, 0\n",
        "            for xb, yb, *_ in train_loader:\n",
        "                xb, yb = xb.to(device), yb.to(device)\n",
        "                tr_sum += ce_mean(model(xb), yb).item() * xb.size(0)\n",
        "                tr_n += xb.size(0)\n",
        "            tr_loss = tr_sum / max(1, tr_n)\n",
        "\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if (it + 1) % print_every == 0:\n",
        "            elapsed = time.time() - t0\n",
        "            print(f\"[KLRS-alt][it {it+1:03d}] trainCE={tr_loss:.4f} valWorst={val_worst:.4f} valOvr={val_overall:.4f}  λ*={lam:.3e} τ={tau:.4f}  [{elapsed:.1f}s]\")\n",
        "\n",
        "    # Final polishing epochs at fixed λ*\n",
        "    remaining = max(0, epochs - alt_iters * inner_epochs_theta - warmup_epochs)\n",
        "    for ep in range(remaining):\n",
        "        _feasibility_train_step(model, train_loader, opt_theta, tau, lam)\n",
        "\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            tr_sum, tr_n = 0.0, 0\n",
        "            for xb, yb, *_ in train_loader:\n",
        "                xb, yb = xb.to(device), yb.to(device)\n",
        "                tr_sum += ce_mean(model(xb), yb).item() * xb.size(0)\n",
        "                tr_n += xb.size(0)\n",
        "            tr_loss = tr_sum / max(1, tr_n)\n",
        "\n",
        "        _, val_worst = eval_metrics(model, val_loader)\n",
        "        val_overall = eval_overall_accuracy(model, val_loader)\n",
        "\n",
        "        tr_hist.append(tr_loss)\n",
        "        va_worst_hist.append(val_worst)\n",
        "        va_overall_hist.append(val_overall)\n",
        "\n",
        "        if val_worst > best_val_worst:\n",
        "            best_val_worst = val_worst\n",
        "            best_model_state = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        if (ep + 1) % print_every == 0:\n",
        "            elapsed = time.time() - t0\n",
        "            print(f\"[KLRS][ep {ep+1:03d}] trainCE={tr_loss:.4f} valWorst={val_worst:.4f} valOvr={val_overall:.4f}  λ={lam:.3e} τ={tau:.4f}  [{elapsed:.1f}s]\")\n",
        "\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "\n",
        "    return tr_hist, va_worst_hist, va_overall_hist\n",
        "\n",
        "print(\"✅ train_klrs defined\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h4VkNJyNLpu8",
        "outputId": "80a1623e-bed0-4716-8810-d6275621592c"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === Hyperparameter Grids ===\n",
        "# ==========================================\n",
        "\n",
        "import itertools\n",
        "\n",
        "# Base LR grid for all algorithms (4 values)\n",
        "LR_GRID = [1e-5, 1e-4, 1e-3, 1e-2]\n",
        "\n",
        "# SAM rho\n",
        "SAM_RHO = 0.05\n",
        "\n",
        "# KL-RS alt_iters grid (3 values)\n",
        "KLRS_ALT_ITERS_GRID = [2, 5, 10]\n",
        "\n",
        "# IRM penalty weight grid (paper uses 10000 as default, sweep around it)\n",
        "IRM_PENALTY_GRID = [100, 1000, 10000, 100000]\n",
        "IRM_PENALTY_ANNEAL_EPOCHS = 20  # Paper default\n",
        "\n",
        "# REx grids (paper-faithful)\n",
        "VREX_BETA_GRID = [0.1, 0.5, 1.0, 5.0, 10.0]  # 5 values\n",
        "MMREX_LAMBDA_GRID = [1.0, 1.5, 2.0, 3.0]      # 4 values (λ > 1 extrapolates beyond worst-case)\n",
        "REX_PENALTY_ANNEAL_EPOCHS = 10  # Paper recommendation\n",
        "\n",
        "# GroupDRO grids (3 values)\n",
        "GROUPDRO_ALPHA_GRID = [0.1, 0.2, 0.5]\n",
        "\n",
        "# χ²-DRO grids\n",
        "CHI2_RHO_GRID = [0.01, 0.1, 0.5, 1.0, 2.0]  # 5 values\n",
        "\n",
        "# CVaR-DRO grids - Levy et al. (NeurIPS 2020)\n",
        "CVAR_ALPHA_GRID = [0.05, 0.1, 0.2, 0.3, 0.5]  # 5 values\n",
        "\n",
        "# === EPOCH SETTINGS ===\n",
        "VAL_SWEEP_EPOCHS = 30   # Epochs for hyperparameter sweeps\n",
        "FULL_TRAIN_EPOCHS = 50  # Epochs for full training\n",
        "\n",
        "# Shared tau for IRS and KL-RS\n",
        "RS_SHARED_TAU = 0.1\n",
        "\n",
        "print(\"=\"*60)\n",
        "print(\"Hyperparameter grids defined (paper-faithful).\")\n",
        "print(\"=\"*60)\n",
        "print(f\"LR Grid: {LR_GRID}\")\n",
        "print(f\"KL-RS alt_iters Grid: {KLRS_ALT_ITERS_GRID}\")\n",
        "print(f\"IRM Penalty Grid: {IRM_PENALTY_GRID} (anneal_epochs={IRM_PENALTY_ANNEAL_EPOCHS})\")\n",
        "print(f\"V-REx Beta Grid: {VREX_BETA_GRID} (anneal_epochs={REX_PENALTY_ANNEAL_EPOCHS})\")\n",
        "print(f\"MM-REx Lambda Grid: {MMREX_LAMBDA_GRID} (anneal_epochs={REX_PENALTY_ANNEAL_EPOCHS})\")\n",
        "print(f\"GroupDRO Alpha Grid: {GROUPDRO_ALPHA_GRID}\")\n",
        "print(f\"χ²-DRO Rho Grid: {CHI2_RHO_GRID}\")\n",
        "print(f\"CVaR-DRO Alpha Grid: {CVAR_ALPHA_GRID}\")\n",
        "print(\"=\"*60)\n",
        "print(f\"Sweep Epochs: {VAL_SWEEP_EPOCHS}\")\n",
        "print(f\"Full Training Epochs: {FULL_TRAIN_EPOCHS}\")\n",
        "print(\"=\"*60)\n",
        "print(\"\\nGrid sizes (LR × Hyperparam):\")\n",
        "print(f\"  ERM:      {len(LR_GRID)} runs (LR only)\")\n",
        "print(f\"  SAM:      {len(LR_GRID)} runs (LR only)\")\n",
        "print(f\"  IRS:      {len(LR_GRID)} runs (LR only)\")\n",
        "print(f\"  KL-RS:    {len(LR_GRID)} × {len(KLRS_ALT_ITERS_GRID)} = {len(LR_GRID) * len(KLRS_ALT_ITERS_GRID)} runs\")\n",
        "print(f\"  V-REx:    {len(LR_GRID)} × {len(VREX_BETA_GRID)} = {len(LR_GRID) * len(VREX_BETA_GRID)} runs\")\n",
        "print(f\"  MM-REx:   {len(LR_GRID)} × {len(MMREX_LAMBDA_GRID)} = {len(LR_GRID) * len(MMREX_LAMBDA_GRID)} runs\")\n",
        "print(f\"  IRMv1:    {len(LR_GRID)} × {len(IRM_PENALTY_GRID)} = {len(LR_GRID) * len(IRM_PENALTY_GRID)} runs\")\n",
        "print(f\"  GroupDRO: {len(LR_GRID)} × {len(GROUPDRO_ALPHA_GRID)} = {len(LR_GRID) * len(GROUPDRO_ALPHA_GRID)} runs\")\n",
        "print(f\"  χ²-DRO:   {len(LR_GRID)} × {len(CHI2_RHO_GRID)} = {len(LR_GRID) * len(CHI2_RHO_GRID)} runs\")\n",
        "print(f\"  CVaR-DRO: {len(LR_GRID)} × {len(CVAR_ALPHA_GRID)} = {len(LR_GRID) * len(CVAR_ALPHA_GRID)} runs\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W1vzeysXLpu9",
        "outputId": "b835bd5c-7629-4852-f487-d36d76440f61"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === Hyperparameter Sweep Utilities ===\n",
        "# ==========================================\n",
        "\n",
        "def lr_sweep(algo_name: str, lr_list, train_fn, train_kwargs: dict,\n",
        "             train_epochs: int = None, seed: int = SEED, print_every_sweep: int = 10):\n",
        "    \"\"\"\n",
        "    Sweep over learning rates and select the best one by worst-group validation accuracy.\n",
        "    Uses VAL_SWEEP_EPOCHS by default.\n",
        "    \"\"\"\n",
        "    if train_epochs is None:\n",
        "        train_epochs = VAL_SWEEP_EPOCHS  # Use global default (30)\n",
        "\n",
        "    results = []\n",
        "    print(f\"[{algo_name}] Starting LR sweep: {len(lr_list)} configurations, {train_epochs} epochs each\")\n",
        "\n",
        "    for lr in lr_list:\n",
        "        reset_seeds(seed)\n",
        "        cuda_clear_cache()\n",
        "        model = fresh_model()\n",
        "        local_kwargs = dict(train_kwargs)\n",
        "        local_kwargs[\"print_every\"] = print_every_sweep\n",
        "\n",
        "        try:\n",
        "            out = train_fn(model, loaders[\"train\"], loaders[\"val\"], epochs=train_epochs, lr=lr, **local_kwargs)\n",
        "        except Exception as e:\n",
        "            print(f\"[SWEEP:{algo_name}] lr={lr:g}  ❌ {e}\")\n",
        "            continue\n",
        "\n",
        "        # Extract worst-group accuracy (second element in returns)\n",
        "        va_worst = out[1] if len(out) >= 2 else []\n",
        "        val_worst_final = float(va_worst[-1]) if va_worst else float(\"nan\")\n",
        "\n",
        "        if not np.isfinite(val_worst_final):\n",
        "            print(f\"[SWEEP:{algo_name}] lr={lr:g}  ❌ valWorst is NaN/Inf\")\n",
        "            continue\n",
        "\n",
        "        results.append({\"lr\": lr, \"val_worst_acc\": val_worst_final})\n",
        "        print(f\"[SWEEP:{algo_name}] lr={lr:g}  valWorst(final)={val_worst_final:.4f}\")\n",
        "\n",
        "        del model\n",
        "        cuda_clear_cache()\n",
        "\n",
        "    best = max(results, key=lambda r: r[\"val_worst_acc\"]) if results else {\"lr\": None, \"val_worst_acc\": float(\"nan\")}\n",
        "    print(f\"[SWEEP:{algo_name}] ★ best_lr={best['lr']}  best_valWorst={best['val_worst_acc']:.4f}\")\n",
        "    return best[\"lr\"], results\n",
        "\n",
        "\n",
        "def hyperparam_sweep(algo_name: str, config_grid, train_fn, base_train_kwargs: dict,\n",
        "                     train_epochs: int = None, seed: int = SEED, print_every_sweep: int = 10):\n",
        "    \"\"\"\n",
        "    Sweep over a grid of hyperparameter configurations (LR × other params).\n",
        "    Uses VAL_SWEEP_EPOCHS by default.\n",
        "    \"\"\"\n",
        "    if train_epochs is None:\n",
        "        train_epochs = VAL_SWEEP_EPOCHS  # Use global default (30)\n",
        "\n",
        "    results = []\n",
        "    print(f\"[{algo_name}] Starting hyperparam sweep: {len(config_grid)} configurations, {train_epochs} epochs each\")\n",
        "\n",
        "    for i, config in enumerate(config_grid):\n",
        "        reset_seeds(seed)\n",
        "        cuda_clear_cache()\n",
        "        model = fresh_model()\n",
        "\n",
        "        local_kwargs = dict(base_train_kwargs)\n",
        "        local_kwargs.update(config)\n",
        "        local_kwargs[\"print_every\"] = print_every_sweep\n",
        "\n",
        "        epochs = local_kwargs.pop(\"epochs\", train_epochs)\n",
        "        config_str = \", \".join([f\"{k}={v}\" for k, v in config.items()])\n",
        "\n",
        "        try:\n",
        "            out = train_fn(model, loaders[\"train\"], loaders[\"val\"], epochs=epochs, **local_kwargs)\n",
        "        except Exception as e:\n",
        "            print(f\"[SWEEP:{algo_name}] [{i+1}/{len(config_grid)}] {config_str}  ❌ {e}\")\n",
        "            del model\n",
        "            cuda_clear_cache()\n",
        "            continue\n",
        "\n",
        "        va_worst = out[1] if len(out) >= 2 else []\n",
        "        val_worst_final = float(va_worst[-1]) if va_worst else float(\"nan\")\n",
        "\n",
        "        if not np.isfinite(val_worst_final):\n",
        "            print(f\"[SWEEP:{algo_name}] [{i+1}/{len(config_grid)}] {config_str}  ❌ valWorst is NaN/Inf\")\n",
        "            del model\n",
        "            cuda_clear_cache()\n",
        "            continue\n",
        "\n",
        "        result_entry = dict(config)\n",
        "        result_entry[\"val_worst_acc\"] = val_worst_final\n",
        "        results.append(result_entry)\n",
        "        print(f\"[SWEEP:{algo_name}] [{i+1}/{len(config_grid)}] {config_str}  valWorst={val_worst_final:.4f}\")\n",
        "\n",
        "        del model\n",
        "        cuda_clear_cache()\n",
        "\n",
        "    if not results:\n",
        "        print(f\"[SWEEP:{algo_name}] No valid results!\")\n",
        "        return None, []\n",
        "\n",
        "    best = max(results, key=lambda r: r[\"val_worst_acc\"])\n",
        "    best_config = {k: v for k, v in best.items() if k != \"val_worst_acc\"}\n",
        "    best_str = \", \".join([f\"{k}={v}\" for k, v in best_config.items()])\n",
        "    print(f\"[SWEEP:{algo_name}] ★ BEST: {best_str}  valWorst={best['val_worst_acc']:.4f}\")\n",
        "    return best_config, results\n",
        "\n",
        "print(f\"✅ Sweep utilities defined (default sweep epochs: {VAL_SWEEP_EPOCHS})\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BCccM9JILpu9",
        "outputId": "8f109764-7ff8-4bfc-d4ea-48f8aaddd53c"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === RUN HYPERPARAMETER SWEEPS ===\n",
        "# ==========================================\n",
        "# Only runs algorithms where RUN_SWEEP[algo] = True\n",
        "# Grid search uses Cartesian product: LR × other_params (e.g., 4×3 = 12 runs)\n",
        "\n",
        "# Storage for best hyperparameters (with paper-faithful defaults)\n",
        "# Updated from completed sweeps\n",
        "best_hyperparams = {\n",
        "    \"ERM-SGD\": {\"lr\": 0.01},   # ✅ From sweep: valWorst=0.6391\n",
        "    \"ERM\": {\"lr\": 1e-5},       # ✅ From sweep: valWorst=0.5789 (Adam)\n",
        "    \"SAM\": {\"lr\": 1e-5, \"rho\": SAM_RHO},  # ✅ From sweep: valWorst=0.5489\n",
        "    \"IRS-Group\": {\"lr\": 1e-4},  # ✅ From sweep: valWorst=0.8421\n",
        "    \"KL-RS\": {\"lr\": 1e-4, \"alt_iters\": 2, \"base_tau\": RS_SHARED_TAU},  # ✅ From sweep: valWorst=0.6502\n",
        "    \"V-REx\": {\"lr\": 1e-4, \"rex_beta\": 10.0},     # ✅ From sweep: valWorst=0.6316\n",
        "    \"MM-REx\": {\"lr\": 1e-4, \"rex_lambda\": 1.0},   # ✅ From sweep: valWorst=0.6015\n",
        "    \"IRMv1\": {\"lr\": 1e-4, \"penalty_weight\": 100.0},  # ✅ From sweep: valWorst=0.6015\n",
        "    \"GroupDRO\": {\"lr\": 1e-4, \"alpha\": 0.2},      # ✅ From sweep: valWorst=0.7218\n",
        "    \"χ²-DRO\": {\"lr\": 1e-4, \"rho\": 1.0},          # ✅ From sweep: valWorst=0.6692\n",
        "    \"CVaR-DRO\": {\"lr\": 1e-4, \"alpha\": 0.1},      # Default - will be swept\n",
        "}\n",
        "\n",
        "print(\"=\"*60)\n",
        "print(\"SWEEP CONFIGURATION\")\n",
        "print(\"=\"*60)\n",
        "print(f\"Sweep epochs: {VAL_SWEEP_EPOCHS}\")\n",
        "print(f\"Full training epochs: {FULL_TRAIN_EPOCHS}\")\n",
        "print(f\"LR grid: {LR_GRID}\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "# --- ERM-SGD LR Sweep ---\n",
        "if should_sweep(\"ERM-SGD\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ LR sweep: ERM-SGD ({len(LR_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    best_erm_sgd_lr, _ = lr_sweep(\"ERM-SGD\", LR_GRID, train_erm_sgd, dict(weight_decay=WEIGHT_DECAY))\n",
        "    best_hyperparams[\"ERM-SGD\"][\"lr\"] = best_erm_sgd_lr or 1e-4\n",
        "\n",
        "# --- ERM LR Sweep ---\n",
        "if should_sweep(\"ERM\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ LR sweep: ERM ({len(LR_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    best_erm_lr, _ = lr_sweep(\"ERM\", LR_GRID, train_erm, dict(weight_decay=WEIGHT_DECAY))\n",
        "    best_hyperparams[\"ERM\"][\"lr\"] = best_erm_lr or 1e-4\n",
        "\n",
        "# --- SAM LR Sweep ---\n",
        "if should_sweep(\"SAM\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ LR sweep: SAM ({len(LR_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    best_sam_lr, _ = lr_sweep(\"SAM\", LR_GRID, train_sam, dict(rho=SAM_RHO, weight_decay=WEIGHT_DECAY))\n",
        "    best_hyperparams[\"SAM\"][\"lr\"] = best_sam_lr or 1e-4\n",
        "\n",
        "# --- IRS-Group LR Sweep ---\n",
        "if should_sweep(\"IRS-Group\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ LR sweep: IRS-Group ({len(LR_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "\n",
        "    # Custom wrapper to adapt train_algorithm to lr_sweep interface\n",
        "    def train_irs_for_sweep(model, train_loader, val_loader, *, epochs, lr, weight_decay, print_every=10):\n",
        "        \"\"\"\n",
        "        Wrapper to use train_algorithm for sweeps.\n",
        "        train_algorithm: (algo_name, model, loaders, P_global_group, epochs, lr, warmup_epochs=0, alpha=0.2) -> model\n",
        "        lr_sweep expects: (model, train_loader, val_loader, epochs, lr, ...) -> (tr_hist, val_worst_hist, val_overall_hist)\n",
        "        \"\"\"\n",
        "        # Reconstruct loaders dict for train_algorithm\n",
        "        loaders_dict = {\"train\": train_loader, \"val\": val_loader}\n",
        "\n",
        "        # Call train_algorithm with correct signature\n",
        "        trained_model = train_algorithm(\n",
        "            algo_name=\"IRS-Group\",\n",
        "            model=model,\n",
        "            loaders=loaders_dict,\n",
        "            P_global_group=P_global_group,  # Use global variable\n",
        "            epochs=epochs,\n",
        "            lr=lr,\n",
        "            warmup_epochs=0,  # No warmup for sweep\n",
        "            alpha=0.2  # Not used for IRS-Group\n",
        "        )\n",
        "\n",
        "        # train_algorithm returns model and tracks best internally\n",
        "        # We need to evaluate to get final worst-group accuracy for lr_sweep\n",
        "        val_loss, val_worst = eval_metrics(trained_model, val_loader)\n",
        "\n",
        "        # Return format expected by lr_sweep: (tr_hist, val_worst_hist, val_overall_hist)\n",
        "        # Since train_algorithm doesn't track history, return lists with final values\n",
        "        return [], [val_worst], []\n",
        "\n",
        "    best_irs_lr, _ = lr_sweep(\"IRS-Group\", LR_GRID, train_irs_for_sweep, dict(weight_decay=WEIGHT_DECAY))\n",
        "    best_hyperparams[\"IRS-Group\"][\"lr\"] = best_irs_lr or 1e-5\n",
        "\n",
        "# --- KL-RS Sweep (LR × alt_iters) ---\n",
        "if should_sweep(\"KL-RS\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Hyperparam sweep: KL-RS (LR × alt_iters = {len(LR_GRID)} × {len(KLRS_ALT_ITERS_GRID)} = {len(LR_GRID) * len(KLRS_ALT_ITERS_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "\n",
        "    # For KL-RS, we adjust inner_epochs_theta based on total epochs\n",
        "    def train_klrs_for_sweep(model, train_loader, val_loader, *, epochs, lr, alt_iters, weight_decay, print_every=10):\n",
        "        # Calculate inner_epochs_theta so that total training time ≈ epochs\n",
        "        # Total work ≈ warmup + alt_iters * inner_epochs_theta + remaining\n",
        "        warmup = 3\n",
        "        inner_epochs_theta = max(1, (epochs - warmup) // alt_iters)\n",
        "        return train_klrs(\n",
        "            model, train_loader, val_loader,\n",
        "            epochs=epochs, lr=lr, alt_iters=alt_iters,\n",
        "            inner_epochs_theta=inner_epochs_theta,\n",
        "            warmup_epochs=warmup,\n",
        "            base_tau=RS_SHARED_TAU,\n",
        "            weight_decay=weight_decay,\n",
        "            print_every=print_every\n",
        "        )\n",
        "\n",
        "    klrs_config_grid = [\n",
        "        {\"lr\": lr, \"alt_iters\": ai}\n",
        "        for lr, ai in itertools.product(LR_GRID, KLRS_ALT_ITERS_GRID)\n",
        "    ]\n",
        "    best_klrs_config, _ = hyperparam_sweep(\"KL-RS\", klrs_config_grid, train_klrs_for_sweep, dict(weight_decay=WEIGHT_DECAY))\n",
        "    if best_klrs_config:\n",
        "        best_hyperparams[\"KL-RS\"].update(best_klrs_config)\n",
        "\n",
        "# --- V-REx Sweep (LR × beta) ---\n",
        "if should_sweep(\"V-REx\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Hyperparam sweep: V-REx (LR × Beta = {len(LR_GRID)} × {len(VREX_BETA_GRID)} = {len(LR_GRID) * len(VREX_BETA_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    vrex_config_grid = [\n",
        "        {\"lr\": lr, \"rex_beta\": beta}\n",
        "        for lr, beta in itertools.product(LR_GRID, VREX_BETA_GRID)\n",
        "    ]\n",
        "    best_vrex_config, _ = hyperparam_sweep(\"V-REx\", vrex_config_grid, train_vrex, dict(weight_decay=WEIGHT_DECAY))\n",
        "    if best_vrex_config:\n",
        "        best_hyperparams[\"V-REx\"].update(best_vrex_config)\n",
        "\n",
        "# --- MM-REx Sweep (LR × lambda) ---\n",
        "if should_sweep(\"MM-REx\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Hyperparam sweep: MM-REx (LR × Lambda = {len(LR_GRID)} × {len(MMREX_LAMBDA_GRID)} = {len(LR_GRID) * len(MMREX_LAMBDA_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    mmrex_config_grid = [\n",
        "        {\"lr\": lr, \"rex_lambda\": lam}\n",
        "        for lr, lam in itertools.product(LR_GRID, MMREX_LAMBDA_GRID)\n",
        "    ]\n",
        "    best_mmrex_config, _ = hyperparam_sweep(\"MM-REx\", mmrex_config_grid, train_mmrex, dict(weight_decay=WEIGHT_DECAY))\n",
        "    if best_mmrex_config:\n",
        "        best_hyperparams[\"MM-REx\"].update(best_mmrex_config)\n",
        "\n",
        "# --- IRMv1 Sweep (LR × penalty) ---\n",
        "if should_sweep(\"IRMv1\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Hyperparam sweep: IRMv1 (LR × Penalty = {len(LR_GRID)} × {len(IRM_PENALTY_GRID)} = {len(LR_GRID) * len(IRM_PENALTY_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    irm_config_grid = [\n",
        "        {\"lr\": lr, \"penalty_weight\": pw}\n",
        "        for lr, pw in itertools.product(LR_GRID, IRM_PENALTY_GRID)\n",
        "    ]\n",
        "    best_irm_config, _ = hyperparam_sweep(\"IRMv1\", irm_config_grid, train_irm, dict(weight_decay=WEIGHT_DECAY))\n",
        "    if best_irm_config:\n",
        "        best_hyperparams[\"IRMv1\"].update(best_irm_config)\n",
        "\n",
        "# --- GroupDRO Sweep (LR × alpha) ---\n",
        "if should_sweep(\"GroupDRO\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Hyperparam sweep: GroupDRO (LR × Alpha = {len(LR_GRID)} × {len(GROUPDRO_ALPHA_GRID)} = {len(LR_GRID) * len(GROUPDRO_ALPHA_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    groupdro_config_grid = [\n",
        "        {\"lr\": lr, \"alpha\": alpha}\n",
        "        for lr, alpha in itertools.product(LR_GRID, GROUPDRO_ALPHA_GRID)\n",
        "    ]\n",
        "    best_groupdro_config, _ = hyperparam_sweep(\"GroupDRO\", groupdro_config_grid, train_groupdro,\n",
        "                                                dict(weight_decay=WEIGHT_DECAY, group_counts=group_counts))\n",
        "    if best_groupdro_config:\n",
        "        best_hyperparams[\"GroupDRO\"].update(best_groupdro_config)\n",
        "\n",
        "# --- χ²-DRO Sweep (LR × rho) ---\n",
        "if should_sweep(\"χ²-DRO\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Hyperparam sweep: χ²-DRO (LR × Rho = {len(LR_GRID)} × {len(CHI2_RHO_GRID)} = {len(LR_GRID) * len(CHI2_RHO_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    chi2_config_grid = [\n",
        "        {\"lr\": lr, \"rho\": rho}\n",
        "        for lr, rho in itertools.product(LR_GRID, CHI2_RHO_GRID)\n",
        "    ]\n",
        "    best_chi2_config, _ = hyperparam_sweep(\"χ²-DRO\", chi2_config_grid, train_chi2_dro, dict(weight_decay=WEIGHT_DECAY))\n",
        "    if best_chi2_config:\n",
        "        best_hyperparams[\"χ²-DRO\"].update(best_chi2_config)\n",
        "\n",
        "# --- CVaR-DRO Sweep (LR × alpha) ---\n",
        "if should_sweep(\"CVaR-DRO\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Hyperparam sweep: CVaR-DRO (LR × Alpha = {len(LR_GRID)} × {len(CVAR_ALPHA_GRID)} = {len(LR_GRID) * len(CVAR_ALPHA_GRID)} runs)\")\n",
        "    print(\"=\"*60)\n",
        "    cvar_config_grid = [\n",
        "        {\"lr\": lr, \"alpha\": alpha}\n",
        "        for lr, alpha in itertools.product(LR_GRID, CVAR_ALPHA_GRID)\n",
        "    ]\n",
        "    best_cvar_config, _ = hyperparam_sweep(\"CVaR-DRO\", cvar_config_grid, train_cvar_dro, dict(weight_decay=WEIGHT_DECAY))\n",
        "    if best_cvar_config:\n",
        "        best_hyperparams[\"CVaR-DRO\"].update(best_cvar_config)\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"✅ Hyperparameter sweeps completed!\")\n",
        "print(\"=\"*60)\n",
        "print(\"\\nBest hyperparameters:\")\n",
        "for algo, params in best_hyperparams.items():\n",
        "    print(f\"  {algo}: {params}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HDetXkQBLpu9",
        "outputId": "b2dc951d-f2b0-4b36-875c-bdcbbd10b43e"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === RUN FULL TRAINING ===\n",
        "# ==========================================\n",
        "# Only runs algorithms where RUN_FULL_TRAINING[algo] = True\n",
        "# IRS uses the ORIGINAL working code (train_algorithm function)\n",
        "# Full training uses FULL_TRAIN_EPOCHS (50 epochs)\n",
        "\n",
        "# Storage for trained models and results\n",
        "trained_models = {}\n",
        "results = {}\n",
        "wallclock_times = {}  # Track wallclock time for each algorithm\n",
        "\n",
        "print(\"=\"*60)\n",
        "print(\"FULL TRAINING CONFIGURATION\")\n",
        "print(\"=\"*60)\n",
        "print(f\"Full training epochs: {FULL_TRAIN_EPOCHS}\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "# --- ERM-SGD Full Training ---\n",
        "if should_train(\"ERM-SGD\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: ERM-SGD ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_erm_sgd = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_erm_sgd, va_worst_erm_sgd, va_overall_erm_sgd = train_erm_sgd(\n",
        "        model_erm_sgd, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"ERM-SGD\"][\"lr\"], weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"ERM-SGD\"] = time.time() - t_start\n",
        "    trained_models[\"ERM-SGD\"] = model_erm_sgd\n",
        "    results[\"ERM-SGD\"] = {\"tr\": tr_erm_sgd, \"va_worst\": va_worst_erm_sgd, \"va_overall\": va_overall_erm_sgd}\n",
        "\n",
        "# --- ERM Full Training ---\n",
        "if should_train(\"ERM\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: ERM ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_erm = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_erm, va_worst_erm, va_overall_erm = train_erm(\n",
        "        model_erm, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"ERM\"][\"lr\"], weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"ERM\"] = time.time() - t_start\n",
        "    trained_models[\"ERM\"] = model_erm\n",
        "    results[\"ERM\"] = {\"tr\": tr_erm, \"va_worst\": va_worst_erm, \"va_overall\": va_overall_erm}\n",
        "\n",
        "# --- SAM Full Training ---\n",
        "if should_train(\"SAM\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: SAM ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_sam = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_sam, va_worst_sam, va_overall_sam = train_sam(\n",
        "        model_sam, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"SAM\"][\"lr\"], rho=SAM_RHO, weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"SAM\"] = time.time() - t_start\n",
        "    trained_models[\"SAM\"] = model_sam\n",
        "    results[\"SAM\"] = {\"tr\": tr_sam, \"va_worst\": va_worst_sam, \"va_overall\": va_overall_sam}\n",
        "\n",
        "# --- IRS-Group Full Training (USES ORIGINAL WORKING CODE!) ---\n",
        "if should_train(\"IRS-Group\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: IRS-Group ({FULL_TRAIN_EPOCHS} epochs) [ORIGINAL WORKING CODE]\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_irs = fresh_model()\n",
        "    t_start = time.time()\n",
        "    # Uses the ORIGINAL train_algorithm function from the working IRS code!\n",
        "    best_model_irs = train_algorithm(\n",
        "        \"IRS-Group\", model_irs, loaders, P_global_group,\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"IRS-Group\"][\"lr\"], warmup_epochs=WARMUP_EPOCHS\n",
        "    )\n",
        "    wallclock_times[\"IRS-Group\"] = time.time() - t_start\n",
        "    trained_models[\"IRS-Group\"] = best_model_irs\n",
        "    results[\"IRS-Group\"] = {\"tr\": [], \"va_worst\": [], \"va_overall\": []}  # Logged during training\n",
        "\n",
        "# --- KL-RS Full Training ---\n",
        "if should_train(\"KL-RS\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: KL-RS ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_klrs = fresh_model()\n",
        "    t_start = time.time()\n",
        "    best_alt_iters = best_hyperparams[\"KL-RS\"].get(\"alt_iters\", 10)\n",
        "    warmup = 3\n",
        "    inner_epochs_theta = max(1, (FULL_TRAIN_EPOCHS - warmup) // best_alt_iters)\n",
        "    tr_klrs, va_worst_klrs, va_overall_klrs = train_klrs(\n",
        "        model_klrs, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"KL-RS\"][\"lr\"],\n",
        "        alt_iters=best_alt_iters,\n",
        "        inner_epochs_theta=inner_epochs_theta,\n",
        "        warmup_epochs=warmup,\n",
        "        base_tau=RS_SHARED_TAU,\n",
        "        weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"KL-RS\"] = time.time() - t_start\n",
        "    trained_models[\"KL-RS\"] = model_klrs\n",
        "    results[\"KL-RS\"] = {\"tr\": tr_klrs, \"va_worst\": va_worst_klrs, \"va_overall\": va_overall_klrs}\n",
        "\n",
        "# --- V-REx Full Training ---\n",
        "if should_train(\"V-REx\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: V-REx ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_vrex = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_vrex, va_worst_vrex, va_overall_vrex = train_vrex(\n",
        "        model_vrex, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"V-REx\"][\"lr\"],\n",
        "        rex_beta=best_hyperparams[\"V-REx\"][\"rex_beta\"], weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"V-REx\"] = time.time() - t_start\n",
        "    trained_models[\"V-REx\"] = model_vrex\n",
        "    results[\"V-REx\"] = {\"tr\": tr_vrex, \"va_worst\": va_worst_vrex, \"va_overall\": va_overall_vrex}\n",
        "\n",
        "# --- MM-REx Full Training ---\n",
        "if should_train(\"MM-REx\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: MM-REx ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_mmrex = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_mmrex, va_worst_mmrex, va_overall_mmrex = train_mmrex(\n",
        "        model_mmrex, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"MM-REx\"][\"lr\"],\n",
        "        rex_lambda=best_hyperparams[\"MM-REx\"][\"rex_lambda\"], weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"MM-REx\"] = time.time() - t_start\n",
        "    trained_models[\"MM-REx\"] = model_mmrex\n",
        "    results[\"MM-REx\"] = {\"tr\": tr_mmrex, \"va_worst\": va_worst_mmrex, \"va_overall\": va_overall_mmrex}\n",
        "\n",
        "# --- IRMv1 Full Training ---\n",
        "if should_train(\"IRMv1\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: IRMv1 ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_irm = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_irm, va_worst_irm, va_overall_irm = train_irm(\n",
        "        model_irm, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"IRMv1\"][\"lr\"],\n",
        "        penalty_weight=best_hyperparams[\"IRMv1\"][\"penalty_weight\"], weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"IRMv1\"] = time.time() - t_start\n",
        "    trained_models[\"IRMv1\"] = model_irm\n",
        "    results[\"IRMv1\"] = {\"tr\": tr_irm, \"va_worst\": va_worst_irm, \"va_overall\": va_overall_irm}\n",
        "\n",
        "# --- GroupDRO Full Training ---\n",
        "if should_train(\"GroupDRO\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: GroupDRO ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_groupdro = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_groupdro, va_worst_groupdro, va_overall_groupdro = train_groupdro(\n",
        "        model_groupdro, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"GroupDRO\"][\"lr\"],\n",
        "        alpha=best_hyperparams[\"GroupDRO\"][\"alpha\"], weight_decay=WEIGHT_DECAY,\n",
        "        group_counts=group_counts, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"GroupDRO\"] = time.time() - t_start\n",
        "    trained_models[\"GroupDRO\"] = model_groupdro\n",
        "    results[\"GroupDRO\"] = {\"tr\": tr_groupdro, \"va_worst\": va_worst_groupdro, \"va_overall\": va_overall_groupdro}\n",
        "\n",
        "# --- χ²-DRO Full Training ---\n",
        "if should_train(\"χ²-DRO\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: χ²-DRO ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_chi2 = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_chi2, va_worst_chi2, va_overall_chi2 = train_chi2_dro(\n",
        "        model_chi2, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"χ²-DRO\"][\"lr\"],\n",
        "        rho=best_hyperparams[\"χ²-DRO\"][\"rho\"], weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"χ²-DRO\"] = time.time() - t_start\n",
        "    trained_models[\"χ²-DRO\"] = model_chi2\n",
        "    results[\"χ²-DRO\"] = {\"tr\": tr_chi2, \"va_worst\": va_worst_chi2, \"va_overall\": va_overall_chi2}\n",
        "\n",
        "# --- CVaR-DRO Full Training ---\n",
        "if should_train(\"CVaR-DRO\"):\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(f\"▶ Full training: CVaR-DRO ({FULL_TRAIN_EPOCHS} epochs)\")\n",
        "    print(\"=\"*60)\n",
        "    reset_seeds(SEED)\n",
        "    model_cvar = fresh_model()\n",
        "    t_start = time.time()\n",
        "    tr_cvar, va_worst_cvar, va_overall_cvar = train_cvar_dro(\n",
        "        model_cvar, loaders[\"train\"], loaders[\"val\"],\n",
        "        epochs=FULL_TRAIN_EPOCHS, lr=best_hyperparams[\"CVaR-DRO\"][\"lr\"],\n",
        "        alpha=best_hyperparams[\"CVaR-DRO\"][\"alpha\"], weight_decay=WEIGHT_DECAY, print_every=5\n",
        "    )\n",
        "    wallclock_times[\"CVaR-DRO\"] = time.time() - t_start\n",
        "    trained_models[\"CVaR-DRO\"] = model_cvar\n",
        "    results[\"CVaR-DRO\"] = {\"tr\": tr_cvar, \"va_worst\": va_worst_cvar, \"va_overall\": va_overall_cvar}\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"✅ Full training completed!\")\n",
        "print(f\"Trained models: {list(trained_models.keys())}\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "# Print wallclock times\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"WALLCLOCK TIMES (Full Training)\")\n",
        "print(\"=\"*60)\n",
        "for algo, elapsed in wallclock_times.items():\n",
        "    mins, secs = divmod(elapsed, 60)\n",
        "    hrs, mins = divmod(mins, 60)\n",
        "    print(f\"  {algo:<15}: {int(hrs):02d}:{int(mins):02d}:{secs:05.2f}\")\n",
        "print(\"=\"*60)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DkS3c4Y-Lpu-",
        "outputId": "fa018c45-de27-48e7-af57-25b5e84dc8a8"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === TEST SET EVALUATION ===\n",
        "# ==========================================\n",
        "# Only evaluates algorithms where RUN_TEST_EVAL[algo] = True\n",
        "\n",
        "# Define helper function first\n",
        "@torch.no_grad()\n",
        "def get_per_group_accuracy(model, loader):\n",
        "    \"\"\"Get per-group accuracy.\"\"\"\n",
        "    model.eval()\n",
        "    group_correct = torch.zeros(NUM_GROUPS)\n",
        "    group_total = torch.zeros(NUM_GROUPS)\n",
        "\n",
        "    for x, y, _, g in loader:\n",
        "        x, y, g = x.to(device), y.to(device), g.to(device)\n",
        "        preds = model(x).argmax(1)\n",
        "        correct = (preds == y)\n",
        "        for k in range(NUM_GROUPS):\n",
        "            mask = (g == k)\n",
        "            if mask.any():\n",
        "                group_correct[k] += correct[mask].sum().item()\n",
        "                group_total[k] += mask.sum().item()\n",
        "\n",
        "    return (group_correct / group_total.clamp(min=1)).numpy()\n",
        "\n",
        "# Run test evaluation\n",
        "test_results = {}\n",
        "GROUP_NAMES = [\"Land (Land bg)\", \"Land (Water bg)\", \"Water (Land bg)\", \"Water (Water bg)\"]\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\" FINAL TEST EVALUATION\")\n",
        "print(\" (Best validation models evaluated on test set)\")\n",
        "print(\"=\"*60)\n",
        "\n",
        "for algo_name in trained_models.keys():\n",
        "    if not should_test(algo_name):\n",
        "        continue\n",
        "\n",
        "    model = trained_models[algo_name]\n",
        "    test_loss, test_worst = eval_metrics(model, loaders[\"test\"])\n",
        "    test_overall = eval_overall_accuracy(model, loaders[\"test\"])\n",
        "\n",
        "    # Get per-group accuracies\n",
        "    group_accs = get_per_group_accuracy(model, loaders[\"test\"])\n",
        "\n",
        "    test_results[algo_name] = {\n",
        "        \"worst_group_acc\": test_worst,\n",
        "        \"overall_acc\": test_overall,\n",
        "        \"per_group_acc\": group_accs,\n",
        "    }\n",
        "\n",
        "    print(f\"\\n>>> {algo_name} <<<\")\n",
        "    print(f\"Overall Accuracy: {test_overall*100:.2f}%\")\n",
        "    print(f\"Worst-Group Accuracy: {test_worst*100:.2f}%\")\n",
        "    print(\"Per-Group Accuracy:\")\n",
        "    for i, gname in enumerate(GROUP_NAMES):\n",
        "        print(f\"  {gname}: {group_accs[i]*100:.2f}%\")\n",
        "\n",
        "# Summary table\n",
        "if test_results:\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(\" SUMMARY TABLE (Test Set)\")\n",
        "    print(\"=\"*60)\n",
        "    print(f\"{'Method':<18} | {'Worst-Group':<12} | {'Overall':<10}\")\n",
        "    print(\"-\" * 48)\n",
        "    for algo_name, r in test_results.items():\n",
        "        print(f\"{algo_name:<18} | {r['worst_group_acc']*100:>10.2f}% | {r['overall_acc']*100:>8.2f}%\")\n",
        "    print(\"=\"*60)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "79vJAA7eLpu-",
        "outputId": "e39a26dd-bf7f-4c66-d578-3f3bf0e020fc"
      },
      "outputs": [],
      "source": [
        "# ==========================================\n",
        "# === VISUALIZATION & ANALYSIS ===\n",
        "# ==========================================\n",
        "# Only includes algorithms where RUN_ANALYSIS[algo] = True\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def _ep(n):\n",
        "    return list(range(1, n + 1))\n",
        "\n",
        "# Get algorithms to analyze\n",
        "algos_to_analyze = [algo for algo in results.keys() if should_analyze(algo)]\n",
        "\n",
        "if algos_to_analyze:\n",
        "    print(\"\\n\" + \"=\"*60)\n",
        "    print(\" VISUALIZATION\")\n",
        "    print(f\" Algorithms: {algos_to_analyze}\")\n",
        "    print(\"=\"*60)\n",
        "\n",
        "    # Plot 1: Training Loss\n",
        "    fig, ax = plt.subplots(figsize=(10, 6))\n",
        "    for algo in algos_to_analyze:\n",
        "        if \"tr\" in results[algo] and results[algo][\"tr\"]:\n",
        "            y = results[algo][\"tr\"]\n",
        "            ax.plot(_ep(len(y)), y, label=algo)\n",
        "    ax.set_title(\"Training Loss vs Epoch — Waterbirds\", fontsize=14)\n",
        "    ax.set_xlabel(\"Epoch\")\n",
        "    ax.set_ylabel(\"Cross-Entropy Loss\")\n",
        "    ax.legend(loc='best')\n",
        "    ax.grid(True, alpha=0.25)\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "    # Plot 2: Validation Worst-Group Accuracy (KEY METRIC)\n",
        "    fig, ax = plt.subplots(figsize=(10, 6))\n",
        "    for algo in algos_to_analyze:\n",
        "        if \"va_worst\" in results[algo] and results[algo][\"va_worst\"]:\n",
        "            y = results[algo][\"va_worst\"]\n",
        "            ax.plot(_ep(len(y)), y, label=algo)\n",
        "    ax.set_title(\"Validation Worst-Group Accuracy vs Epoch — Waterbirds\", fontsize=14)\n",
        "    ax.set_xlabel(\"Epoch\")\n",
        "    ax.set_ylabel(\"Worst-Group Accuracy\")\n",
        "    ax.legend(loc='best')\n",
        "    ax.grid(True, alpha=0.25)\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "    # Plot 3: Validation Overall Accuracy\n",
        "    fig, ax = plt.subplots(figsize=(10, 6))\n",
        "    for algo in algos_to_analyze:\n",
        "        if \"va_overall\" in results[algo] and results[algo][\"va_overall\"]:\n",
        "            y = results[algo][\"va_overall\"]\n",
        "            ax.plot(_ep(len(y)), y, label=algo)\n",
        "    ax.set_title(\"Validation Overall Accuracy vs Epoch — Waterbirds\", fontsize=14)\n",
        "    ax.set_xlabel(\"Epoch\")\n",
        "    ax.set_ylabel(\"Overall Accuracy\")\n",
        "    ax.legend(loc='best')\n",
        "    ax.grid(True, alpha=0.25)\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "else:\n",
        "    print(\"\\n[INFO] No algorithms selected for analysis. Set RUN_ANALYSIS[algo] = True to include.\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "include_colab_link": true,
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
