{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 565
        },
        "id": "0nviXjkKPZPd",
        "outputId": "bf1cca5c-d49d-4948-a91d-c4abdd8821e3"
      },
      "outputs": [],
      "source": [
        "# -*- coding: utf-8 -*-\n",
        "\"\"\"cifar10lt_with_all_no_wine.ipynb\"\"\"\n",
        "\n",
        "import os\n",
        "import io\n",
        "import zipfile\n",
        "import urllib.request\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from typing import Iterable\n",
        "import math\n",
        "import ssl\n",
        "import itertools\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using device: {device}\")\n",
        "\n",
        "RS_SHARED_TAU = 0.5\n",
        "\n",
        "def download_content(url):\n",
        "    context = ssl.create_default_context()\n",
        "    context.check_hostname = False\n",
        "    context.verify_mode = ssl.CERT_NONE\n",
        "    req = urllib.request.Request(\n",
        "        url,\n",
        "        headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}\n",
        "    )\n",
        "    return urllib.request.urlopen(req, context=context)\n",
        "\n",
        "# ==========================================\n",
        "# === KL-RS Lambda Bisection Helpers     ===\n",
        "# ==========================================\n",
        "@torch.no_grad()\n",
        "def _bisection_for_lambda(losses, tau, lam_min=1e-4, lam_max=100.0, max_iter=15):\n",
        "    \"\"\"Find lambda that minimizes the KL-RS dual objective.\"\"\"\n",
        "    def obj(lam):\n",
        "        scaled = (losses - tau) / lam\n",
        "        lse = torch.logsumexp(scaled, dim=0)\n",
        "        return lam * (lse - math.log(len(losses)))\n",
        "\n",
        "    a, b = lam_min, lam_max\n",
        "    gr = (math.sqrt(5) + 1) / 2\n",
        "    c = b - (b - a) / gr\n",
        "    d = a + (b - a) / gr\n",
        "    while abs(c - d) > 1e-5 and max_iter > 0:\n",
        "        if obj(c) < obj(d):\n",
        "            b = d\n",
        "        else:\n",
        "            a = c\n",
        "        c = b - (b - a) / gr\n",
        "        d = a + (b - a) / gr\n",
        "        max_iter -= 1\n",
        "    return (b + a) / 2\n",
        "\n",
        "# ==========================================\n",
        "# === IRS (KL-RS with Class-wise) Helpers ===\n",
        "# ==========================================\n",
        "@torch.no_grad()\n",
        "def _kl_path_p_of_h(F_vals, P_hat, h, exp_clip=40.0):\n",
        "    log_base = torch.log(P_hat.clamp_min(1e-30))\n",
        "    logits = log_base + (F_vals * float(h))\n",
        "    logits = torch.clamp(logits, min=logits.max()-exp_clip, max=logits.max()+1e-6)\n",
        "    return torch.softmax(logits, dim=0)\n",
        "\n",
        "@torch.no_grad()\n",
        "def _secant_maximize_kappa(F_vals, P_hat, tau_t, h_left, h_right, max_iter=10):\n",
        "    def K(h):\n",
        "        P = _kl_path_p_of_h(F_vals, P_hat, h)\n",
        "        E_F = torch.dot(P, F_vals)\n",
        "        Dkl = (P * (torch.log(P.clamp_min(1e-30)) - torch.log(P_hat.clamp_min(1e-30)))).sum().item()\n",
        "        return (E_F - tau_t).item() / max(1.0 * max(Dkl, 0.0) + 1e-2, 1e-20), P\n",
        "\n",
        "    a, b = float(h_left), float(h_right)\n",
        "    Ka, _ = K(a); Kb, _ = K(b)\n",
        "    for _ in range(max_iter):\n",
        "        if abs(b - a) < 1e-8: break\n",
        "        m = (Kb - Ka) / (b - a)\n",
        "        c = b - Kb / m if abs(m) > 1e-12 else b - np.sign(Kb)*0.1\n",
        "        Kc, P_star = K(c)\n",
        "        if Kc > Kb: a, Ka = b, Kb; b, Kb = c, Kc\n",
        "        else: a, Ka = c, Kc\n",
        "    return b, P_star\n",
        "\n",
        "def classwise_F(loss_vec, g, n_groups, P_glob):\n",
        "    dev = loss_vec.device\n",
        "    counts = torch.bincount(g, minlength=n_groups).to(dev)\n",
        "    mask = counts > 0\n",
        "    lsum = torch.zeros(n_groups, device=dev).scatter_add_(0, g, loss_vec)\n",
        "    F = torch.zeros(n_groups, device=dev); F[mask] = lsum[mask] / counts[mask].float()\n",
        "    P_hatS = torch.zeros_like(F); Pg = P_glob.to(dev)\n",
        "    P_hatS[mask] = Pg[mask] / Pg[mask].sum().clamp_min(1e-12)\n",
        "    return F, P_hatS\n",
        "\n",
        "# ==========================================\n",
        "# === SAM Optimizer                      ===\n",
        "# ==========================================\n",
        "class SAMAdam(torch.optim.Adam):\n",
        "    def __init__(self, params, lr=1e-3, rho=0.05, **kwargs):\n",
        "        super().__init__(params, lr=lr, **kwargs); self.rho=rho\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def step(self, closure=None):\n",
        "        loss = None\n",
        "        if closure is not None:\n",
        "            loss = closure()\n",
        "        super().step()\n",
        "        return loss\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def first_step(self, zero_grad=False):\n",
        "        grad_norm = self._grad_norm()\n",
        "        scale = self.rho / (grad_norm + 1e-12)\n",
        "        self.e_w = []\n",
        "        for group in self.param_groups:\n",
        "            for p in group[\"params\"]:\n",
        "                if p.grad is None: self.e_w.append(None); continue\n",
        "                e_w = p.grad * scale.to(p)\n",
        "                p.add_(e_w); self.e_w.append(e_w)\n",
        "        if zero_grad: self.zero_grad()\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def second_step(self, zero_grad=False):\n",
        "        idx = 0\n",
        "        for group in self.param_groups:\n",
        "            for p in group[\"params\"]:\n",
        "                if p.grad is None: idx+=1; continue\n",
        "                p.sub_(self.e_w[idx]); idx+=1\n",
        "        super().step()\n",
        "        if zero_grad: self.zero_grad()\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _grad_norm(self):\n",
        "        norms = [p.grad.norm(p=2) for group in self.param_groups for p in group[\"params\"] if p.grad is not None]\n",
        "        return torch.norm(torch.stack(norms), p=2) if norms else torch.tensor(0.)\n",
        "\n",
        "# ==========================================\n",
        "# === Data Loaders                       ===\n",
        "# ==========================================\n",
        "class RobustDataset(Dataset):\n",
        "    def __init__(self, x, y, g):\n",
        "        self.x = torch.tensor(x, dtype=torch.float32)\n",
        "        self.y = torch.tensor(y, dtype=torch.float32).view(-1, 1)\n",
        "        self.g = torch.tensor(g, dtype=torch.long)\n",
        "    def __len__(self): return len(self.x)\n",
        "    def __getitem__(self, i): return self.x[i], self.y[i], torch.tensor(i).long(), self.g[i]\n",
        "\n",
        "def disc_g(y, bins=5): return pd.qcut(y.flatten(), q=bins, labels=False, duplicates='drop')\n",
        "\n",
        "def split_train_val(x, y, g, val_ratio=0.2):\n",
        "    n = len(x)\n",
        "    p = np.random.permutation(n)\n",
        "    n_val = int(n * val_ratio)\n",
        "    idx_val = p[:n_val]; idx_train = p[n_val:]\n",
        "    return RobustDataset(x[idx_train], y[idx_train], g[idx_train]), \\\n",
        "           RobustDataset(x[idx_val], y[idx_val], g[idx_val]), \\\n",
        "           RobustDataset(x, y, g)\n",
        "\n",
        "def get_bike():\n",
        "    url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip\"\n",
        "    try:\n",
        "        with zipfile.ZipFile(io.BytesIO(download_content(url).read())) as z:\n",
        "            with z.open(\"hour.csv\") as f: df = pd.read_csv(f)\n",
        "    except Exception as e: print(f\"Bike Error: {e}\"); return None\n",
        "    X = StandardScaler().fit_transform(df[['season','mnth','hr','holiday','weekday','workingday','weathersit','temp','atemp','hum','windspeed']].values)\n",
        "    y = (df['cnt'].values.astype(float) - df['cnt'].mean()) / df['cnt'].std()\n",
        "    idx1 = np.where((df['yr']==0) & (df['season'].isin([1,2])))[0]\n",
        "    idx2 = np.where((df['yr']==0) & (df['season'].isin([3,4])))[0]\n",
        "    idxt = np.where(df['yr']==1)[0]\n",
        "    t1, v1, f1 = split_train_val(X[idx1], y[idx1], disc_g(y[idx1]))\n",
        "    t2, v2, f2 = split_train_val(X[idx2], y[idx2], disc_g(y[idx2]))\n",
        "    dt = RobustDataset(X[idxt], y[idxt], disc_g(y[idxt]))\n",
        "    return (t1, v1, f1), (t2, v2, f2), dt\n",
        "\n",
        "def get_concrete():\n",
        "    url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/concrete/compressive/Concrete_Data.xls\"\n",
        "    try: df = pd.read_excel(io.BytesIO(download_content(url).read()))\n",
        "    except Exception as e: print(f\"Concrete Error: {e}\"); return None\n",
        "    y = (df.iloc[:, -1].values - df.iloc[:, -1].mean()) / df.iloc[:, -1].std()\n",
        "    age = df.iloc[:, 7].values\n",
        "    X = StandardScaler().fit_transform(df.iloc[:, :-2].values)\n",
        "    idx1 = np.where(age<28)[0]; idx2 = np.where(age==28)[0]; idxt = np.where(age>28)[0]\n",
        "    if len(idx1)<10: p=np.random.permutation(len(y)); idx1,idx2,idxt = p[:int(0.4*len(y))], p[int(0.4*len(y)):int(0.8*len(y))], p[int(0.8*len(y)):]\n",
        "    t1, v1, f1 = split_train_val(X[idx1], y[idx1], disc_g(y[idx1]))\n",
        "    t2, v2, f2 = split_train_val(X[idx2], y[idx2], disc_g(y[idx2]))\n",
        "    dt = RobustDataset(X[idxt], y[idxt], disc_g(y[idxt]))\n",
        "    return (t1, v1, f1), (t2, v2, f2), dt\n",
        "\n",
        "def get_loader(name):\n",
        "    loaders = {\"bike\":get_bike, \"concrete\":get_concrete}\n",
        "    if name not in loaders: return None\n",
        "    res = loaders[name]()\n",
        "    if res is None: return None\n",
        "    return res\n",
        "\n",
        "# ==========================================\n",
        "# === Models                             ===\n",
        "# ==========================================\n",
        "class MLP(nn.Module):\n",
        "    def __init__(self, d_in):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(d_in, 256), nn.ReLU(),\n",
        "            nn.Linear(256, 256), nn.ReLU(),\n",
        "            nn.Linear(256, 1)\n",
        "        )\n",
        "    def forward(self, x): return self.net(x)\n",
        "\n",
        "class Poly(nn.Module):\n",
        "    def __init__(self, d_in, degree=2):\n",
        "        super().__init__()\n",
        "        self.l = nn.Linear(d_in*degree, 1); self.d = degree\n",
        "    def forward(self, x): return self.l(torch.cat([x.pow(i) for i in range(1,self.d+1)],1))\n",
        "\n",
        "def get_model(name, d_in):\n",
        "    if name == \"bike\": return Poly(d_in)\n",
        "    return MLP(d_in)\n",
        "\n",
        "# ==========================================\n",
        "# === DRO Loss Functions                 ===\n",
        "# ==========================================\n",
        "def chi2_dro_loss(loss_vec: torch.Tensor, rho: float,\n",
        "                  bisect_tol: float = 1e-5, max_iter: int = 50) -> torch.Tensor:\n",
        "    B = loss_vec.size(0)\n",
        "    if B == 0:\n",
        "        return loss_vec.new_tensor(0.0)\n",
        "\n",
        "    rho = max(float(rho), 0.0)\n",
        "    c = math.sqrt(1.0 + 2.0 * rho)\n",
        "\n",
        "    eta_min = loss_vec.min().item()\n",
        "    eta_max = loss_vec.max().item()\n",
        "\n",
        "    if abs(eta_max - eta_min) < 1e-12:\n",
        "        return loss_vec.mean()\n",
        "\n",
        "    def eval_R(eta):\n",
        "        diff = loss_vec - eta\n",
        "        positive_part = torch.clamp(diff, min=0.0)\n",
        "        var_term = (positive_part ** 2).mean()\n",
        "        return c * torch.sqrt(var_term + 1e-12) + eta\n",
        "\n",
        "    for _ in range(max_iter):\n",
        "        if (eta_max - eta_min) < bisect_tol:\n",
        "            break\n",
        "        eta_mid = 0.5 * (eta_min + eta_max)\n",
        "        diff = loss_vec - eta_mid\n",
        "        positive_part = torch.clamp(diff, min=0.0)\n",
        "        var_term = (positive_part ** 2).mean()\n",
        "        grad = 1.0 - c * (positive_part.mean()) / torch.sqrt(var_term + 1e-12)\n",
        "        if grad.item() < 0:\n",
        "            eta_min = eta_mid\n",
        "        else:\n",
        "            eta_max = eta_mid\n",
        "\n",
        "    eta_star = 0.5 * (eta_min + eta_max)\n",
        "    return eval_R(eta_star)\n",
        "\n",
        "\n",
        "def cvar_loss_from_batch(loss_vec: torch.Tensor, alpha: float) -> torch.Tensor:\n",
        "    alpha = max(min(float(alpha), 1.0), 1e-6)\n",
        "    B = loss_vec.size(0)\n",
        "    k = max(1, int(math.ceil(alpha * B)))\n",
        "    topk_losses, _ = torch.topk(loss_vec, k, largest=True, sorted=False)\n",
        "    return topk_losses.mean()\n",
        "\n",
        "\n",
        "class GroupDROLossComputer:\n",
        "    def __init__(self, n_groups=5, alpha=0.2, gamma=0.1, device='cuda'):\n",
        "        self.alpha = alpha\n",
        "        self.gamma = gamma\n",
        "        self.device = device\n",
        "        self.n_groups = n_groups\n",
        "\n",
        "        self.adv_probs = torch.ones(self.n_groups, device=device) / self.n_groups\n",
        "        self.exp_avg_loss = torch.zeros(self.n_groups, device=device)\n",
        "        self.exp_avg_initialized = torch.zeros(self.n_groups, device=device).bool()\n",
        "\n",
        "    def loss(self, loss_vec, group_idx, is_training=True):\n",
        "        group_losses = []\n",
        "        group_counts_batch = []\n",
        "        for g in range(self.n_groups):\n",
        "            mask = (group_idx == g)\n",
        "            if mask.any():\n",
        "                group_loss = loss_vec[mask].mean()\n",
        "                group_losses.append(group_loss)\n",
        "                group_counts_batch.append(mask.sum().item())\n",
        "            else:\n",
        "                group_losses.append(torch.tensor(0.0, device=self.device))\n",
        "                group_counts_batch.append(0)\n",
        "\n",
        "        group_losses = torch.stack(group_losses)\n",
        "        group_counts_batch = torch.tensor(group_counts_batch, dtype=torch.float32, device=self.device)\n",
        "\n",
        "        if is_training:\n",
        "            for g in range(self.n_groups):\n",
        "                if group_counts_batch[g] > 0:\n",
        "                    if not self.exp_avg_initialized[g]:\n",
        "                        self.exp_avg_loss[g] = group_losses[g].detach()\n",
        "                        self.exp_avg_initialized[g] = True\n",
        "                    else:\n",
        "                        self.exp_avg_loss[g] = (\n",
        "                            self.gamma * group_losses[g].detach() +\n",
        "                            (1 - self.gamma) * self.exp_avg_loss[g]\n",
        "                        )\n",
        "\n",
        "            adv_probs = torch.exp(self.alpha * self.exp_avg_loss.detach())\n",
        "            adv_probs = adv_probs / (adv_probs.sum() + 1e-12)\n",
        "            self.adv_probs = adv_probs\n",
        "\n",
        "        weighted_loss = (self.adv_probs * group_losses).sum()\n",
        "        return weighted_loss\n",
        "\n",
        "# ==========================================\n",
        "# === Train Loop                         ===\n",
        "# ==========================================\n",
        "def train_loop(model, train_loader, env_loaders, algo, param_dict, epochs=100):\n",
        "    lr = param_dict['lr']\n",
        "    penalty = param_dict.get('penalty', 0.0)\n",
        "    rho_sam = param_dict.get('rho', 0.05)\n",
        "    lam_rex = param_dict.get('lam', 1.5)\n",
        "    chi2_rho = param_dict.get('chi2_rho', 0.1)\n",
        "    cvar_alpha = param_dict.get('cvar_alpha', 0.1)\n",
        "    groupdro_alpha = param_dict.get('groupdro_alpha', 0.2)\n",
        "    rex_beta = param_dict.get('rex_beta', 1.0)\n",
        "    alt_iters = param_dict.get('alt_iters', 5)  # For KL-RS alternating optimization\n",
        "\n",
        "    target_tau_factor = 1.01\n",
        "    tau_frozen = False\n",
        "    current_tau = None\n",
        "\n",
        "    if algo == \"SAM\":\n",
        "        opt = SAMAdam(model.parameters(), lr=lr, rho=rho_sam, weight_decay=0.0)\n",
        "    elif algo == \"ERM-SGD\":\n",
        "        opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0, momentum=0.9)\n",
        "    else:\n",
        "        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0)\n",
        "\n",
        "    crit = nn.MSELoss(reduction='none')\n",
        "    \n",
        "    P_g = torch.zeros(5).to(device)\n",
        "    for _,_,_,g in train_loader:\n",
        "        g = g.to(device)\n",
        "        P_g.scatter_add_(0, g, torch.ones_like(g, dtype=torch.float, device=device))\n",
        "    P_g /= P_g.sum()\n",
        "    \n",
        "    if algo == \"GroupDRO\":\n",
        "        groupdro_computer = GroupDROLossComputer(n_groups=5, alpha=groupdro_alpha, device=device)\n",
        "    \n",
        "    l1, l2 = env_loaders\n",
        "    \n",
        "    # For KL-RS: determine epochs per alternation\n",
        "    if algo == \"KL-RS\":\n",
        "        epochs_per_alt = max(1, epochs // alt_iters)\n",
        "\n",
        "    for ep in range(epochs):\n",
        "        model.train()\n",
        "\n",
        "        if algo in [\"V-REx\", \"MM-REx\", \"IRM\"]:\n",
        "            for (x1,y1,_,_), (x2,y2,_,_) in zip(l1, l2):\n",
        "                x1,y1,x2,y2 = x1.to(device), y1.to(device), x2.to(device), y2.to(device)\n",
        "                l1_v, l2_v = crit(model(x1), y1).mean(), crit(model(x2), y2).mean()\n",
        "\n",
        "                if algo == \"V-REx\":\n",
        "                    penalty_weight = min(1.0, float(ep) / 10.0) if ep <= 10 else 1.0\n",
        "                    var_penalty = (l1_v - l2_v)**2\n",
        "                    loss = (l1_v+l2_v)/2 + penalty_weight * rex_beta * var_penalty\n",
        "                elif algo == \"MM-REx\":\n",
        "                    penalty_weight = min(1.0, float(ep) / 10.0) if ep <= 10 else 1.0\n",
        "                    current_lambda = 0.5 + penalty_weight * (lam_rex - 0.5)\n",
        "                    l_val = torch.stack([l1_v, l2_v])\n",
        "                    loss = (1.0 - 2*current_lambda)*l_val.max() + current_lambda*l_val.sum()\n",
        "                elif algo == \"IRM\":\n",
        "                    penalty_weight = min(1.0, float(ep) / 10.0) if ep <= 10 else 1.0\n",
        "                    dum = torch.tensor(1., requires_grad=True, device=device)\n",
        "                    g1 = torch.autograd.grad(crit(model(x1)*dum, y1).mean(), dum, create_graph=True)[0]\n",
        "                    g2 = torch.autograd.grad(crit(model(x2)*dum, y2).mean(), dum, create_graph=True)[0]\n",
        "                    loss = (l1_v+l2_v)/2 + penalty_weight * penalty * (g1**2 + g2**2)\n",
        "                opt.zero_grad(); loss.backward(); opt.step()\n",
        "        else:\n",
        "            for x,y,_,g in train_loader:\n",
        "                x,y,g = x.to(device), y.to(device), g.to(device)\n",
        "                loss_vec = crit(model(x), y).squeeze()\n",
        "\n",
        "                if algo in [\"ERM\", \"ERM-SGD\"]: \n",
        "                    loss = loss_vec.mean()\n",
        "                elif algo == \"SAM\":\n",
        "                    loss = loss_vec.mean(); loss.backward()\n",
        "                    opt.first_step(zero_grad=True); crit(model(x), y).mean().backward(); opt.second_step(zero_grad=True); continue\n",
        "\n",
        "                elif algo == \"IRS\":\n",
        "                    F, P = classwise_F(loss_vec.detach(), g, 5, P_g)\n",
        "                    if tau_frozen:\n",
        "                        target = current_tau\n",
        "                    else:\n",
        "                        adaptive_tau = torch.dot(F, P).item() * target_tau_factor\n",
        "                        if adaptive_tau <= RS_SHARED_TAU:\n",
        "                            tau_frozen = True\n",
        "                            current_tau = RS_SHARED_TAU\n",
        "                            target = RS_SHARED_TAU\n",
        "                        else:\n",
        "                            target = adaptive_tau\n",
        "                    _, P_star = _secant_maximize_kappa(F, P, torch.dot(F, P).new_tensor(target), -2, 2)\n",
        "                    loss = (loss_vec * (P_star/(P+1e-12))[g].detach()).mean()\n",
        "\n",
        "                elif algo == \"KL-RS\":\n",
        "                    if tau_frozen:\n",
        "                        target_tau = current_tau\n",
        "                    else:\n",
        "                        adaptive_tau = loss_vec.mean().detach().item() * target_tau_factor\n",
        "                        if adaptive_tau <= RS_SHARED_TAU:\n",
        "                            tau_frozen = True\n",
        "                            current_tau = RS_SHARED_TAU\n",
        "                            target_tau = RS_SHARED_TAU\n",
        "                        else:\n",
        "                            target_tau = adaptive_tau\n",
        "                    best_lam = _bisection_for_lambda(loss_vec.detach(), target_tau)\n",
        "                    scaled = (loss_vec - target_tau) / best_lam\n",
        "                    loss = best_lam * (torch.logsumexp(scaled, dim=0) - math.log(len(x)))\n",
        "\n",
        "                elif algo == \"GroupDRO\":\n",
        "                    loss = groupdro_computer.loss(loss_vec, g, is_training=True)\n",
        "\n",
        "                elif algo == \"χ²-DRO\":\n",
        "                    loss = chi2_dro_loss(loss_vec, chi2_rho)\n",
        "\n",
        "                elif algo == \"CVaR-DRO\":\n",
        "                    loss = cvar_loss_from_batch(loss_vec, cvar_alpha)\n",
        "\n",
        "                opt.zero_grad(); loss.backward(); opt.step()\n",
        "    return model\n",
        "\n",
        "def evaluate(model, loader):\n",
        "    model.eval()\n",
        "    loss_sum = 0; tot = 0\n",
        "    with torch.no_grad():\n",
        "        for x,y,_,_ in loader:\n",
        "            x,y = x.to(device), y.to(device)\n",
        "            loss_sum += nn.MSELoss(reduction='sum')(model(x), y).item()\n",
        "            tot += x.size(0)\n",
        "    return loss_sum / tot\n",
        "\n",
        "# ==========================================\n",
        "# === Hyperparameter Grids (Waterbirds)  ===\n",
        "# ==========================================\n",
        "DATASETS = [\"bike\", \"concrete\"]\n",
        "ALGOS = [\"ERM\", \"ERM-SGD\", \"SAM\", \"IRS\", \"KL-RS\", \"V-REx\", \"MM-REx\", \"IRM\", \"GroupDRO\", \"χ²-DRO\", \"CVaR-DRO\"]\n",
        "\n",
        "# Hyperparameter grids (matching Waterbirds)\n",
        "LR_GRID = [1e-5, 1e-4, 1e-3, 1e-2]\n",
        "IRM_PENALTY_GRID = [1e-2, 1e-1, 1, 10]\n",
        "VREX_BETA_GRID = [0.1, 0.5, 1.0, 5.0, 10.0]\n",
        "MMREX_LAMBDA_GRID = [1.0, 1.5, 2.0, 3.0]\n",
        "GROUPDRO_ALPHA_GRID = [0.1, 0.2, 0.5]\n",
        "CHI2_RHO_GRID = [0.01, 0.1, 0.5, 1.0, 2.0]\n",
        "CVAR_ALPHA_GRID = [0.05, 0.1, 0.2, 0.3, 0.5]\n",
        "SAM_RHO_GRID = [0.05, 0.1, 0.2]\n",
        "KLRS_ALT_ITERS_GRID = [2, 5, 10]  # For KL-RS alternating optimization\n",
        "\n",
        "def get_grid(algo):\n",
        "    if algo in [\"ERM\", \"ERM-SGD\", \"IRS\"]:\n",
        "        return [{\"lr\": lr} for lr in LR_GRID]\n",
        "    if algo == \"KL-RS\":\n",
        "        # KL-RS: sweep LR × alt_iters (matching Waterbirds)\n",
        "        return [{\"lr\": lr, \"alt_iters\": ai} for lr, ai in itertools.product(LR_GRID, KLRS_ALT_ITERS_GRID)]\n",
        "    if algo == \"V-REx\":\n",
        "        return [{\"lr\": lr, \"rex_beta\": beta} for lr, beta in itertools.product(LR_GRID, VREX_BETA_GRID)]\n",
        "    if algo == \"MM-REx\":\n",
        "        return [{\"lr\": lr, \"lam\": lam} for lr, lam in itertools.product(LR_GRID, MMREX_LAMBDA_GRID)]\n",
        "    if algo == \"IRM\":\n",
        "        return [{\"lr\": lr, \"penalty\": pw} for lr, pw in itertools.product(LR_GRID, IRM_PENALTY_GRID)]\n",
        "    if algo == \"SAM\":\n",
        "        return [{\"lr\": lr, \"rho\": r} for lr, r in itertools.product(LR_GRID, SAM_RHO_GRID)]\n",
        "    if algo == \"GroupDRO\":\n",
        "        return [{\"lr\": lr, \"groupdro_alpha\": a} for lr, a in itertools.product(LR_GRID, GROUPDRO_ALPHA_GRID)]\n",
        "    if algo == \"χ²-DRO\":\n",
        "        return [{\"lr\": lr, \"chi2_rho\": r} for lr, r in itertools.product(LR_GRID, CHI2_RHO_GRID)]\n",
        "    if algo == \"CVaR-DRO\":\n",
        "        return [{\"lr\": lr, \"cvar_alpha\": a} for lr, a in itertools.product(LR_GRID, CVAR_ALPHA_GRID)]\n",
        "    return [{\"lr\": 1e-3}]\n",
        "\n",
        "# ==========================================\n",
        "# === Run Experiments                    ===\n",
        "# ==========================================\n",
        "print(f\"{'='*60}\")\n",
        "print(f\"RUNNING FINAL COMPARISON (with LR grid search)\")\n",
        "print(f\"Algorithms: {ALGOS}\")\n",
        "print(f\"LR Grid: {LR_GRID}\")\n",
        "print(f\"KL-RS alt_iters Grid: {KLRS_ALT_ITERS_GRID}\")\n",
        "print(f\"{'='*60}\")\n",
        "RESULTS = []\n",
        "\n",
        "for dname in DATASETS:\n",
        "    print(f\"\\n>>> Dataset: {dname}\")\n",
        "    try:\n",
        "        res = get_loader(dname)\n",
        "        if res is None: raise ValueError(\"None\")\n",
        "        (t1, v1, f1), (t2, v2, f2), dt = res\n",
        "    except Exception as e: print(f\"Load Fail: {e}\"); continue\n",
        "\n",
        "    train_l1 = DataLoader(t1, 256, True); train_l2 = DataLoader(t2, 256, True)\n",
        "    val_loader = DataLoader(torch.utils.data.ConcatDataset([v1, v2]), 256, False)\n",
        "    full_l1 = DataLoader(f1, 256, True); full_l2 = DataLoader(f2, 256, True)\n",
        "    train_all = DataLoader(torch.utils.data.ConcatDataset([t1, t2]), 256, True)\n",
        "    full_all = DataLoader(torch.utils.data.ConcatDataset([f1, f2]), 256, True)\n",
        "    test_loader = DataLoader(dt, 256, False)\n",
        "\n",
        "    d_in = t1.x.shape[1]\n",
        "    row = {\"Dataset\": dname}\n",
        "\n",
        "    for algo in ALGOS:\n",
        "        grid = get_grid(algo)\n",
        "        best_val = float('inf'); best_p = grid[0]\n",
        "\n",
        "        print(f\"  [{algo}] Sweeping {len(grid)} configs...\", end=\" \")\n",
        "        for p in grid:\n",
        "            torch.manual_seed(42)\n",
        "            m = get_model(dname, d_in).to(device)\n",
        "            m = train_loop(m, train_all, (train_l1, train_l2), algo, p, epochs=100)\n",
        "            v = evaluate(m, val_loader)\n",
        "            if v < best_val: best_val = v; best_p = p\n",
        "\n",
        "        print(f\"best={best_p}\")\n",
        "        torch.manual_seed(42)\n",
        "        m_fin = get_model(dname, d_in).to(device)\n",
        "        m_fin = train_loop(m_fin, full_all, (full_l1, full_l2), algo, best_p, epochs=100)\n",
        "        score = evaluate(m_fin, test_loader)\n",
        "        row[algo] = score\n",
        "        print(f\"    {algo}: {score:.4f}\")\n",
        "    RESULTS.append(row)\n",
        "\n",
        "print(\"\\n\" + \"=\"*60)\n",
        "print(\"FINAL RESULTS (MSE):\")\n",
        "print(\"=\"*60)\n",
        "print(pd.DataFrame(RESULTS))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LaPldIK2PZMz"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TS8R3A-wPZKU"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b04JQtuv356v"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "include_colab_link": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
