"""
toy_gmm_labelshift.py

Proof-of-concept toy experiment showing robustness under distribution shift.
Implements ERM (Adam), SAM, and IRS (fragility with KL path) on a K-class 2D Gaussian mixture.
Designed to be compatible with typical PyTorch training code.

It will print end-of-training test accuracies and save a PDF plot and a JSON of the results.

Author: Anonymous (for review)
"""

import argparse
import json
import math
import random
from dataclasses import dataclass
from typing import Tuple, List, Dict

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from copy import deepcopy
import os
import sys
# Add the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'utils')))

from utils import kl_cat, priors_from_strength, strength_for_target_kl, gaussian_kl, mixture_kl


# ------------------------------
# Reproducibility
# ------------------------------
def set_seed(seed: int = 1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ------------------------------
# Synthetic Gaussian Mixture (2D) with K classes
# ------------------------------
@dataclass
class GMMConfig:
    K: int = 3
    d: int = 2
    train_N: int = 6000
    test_N: int = 6000
    cov_scale: float = 0.9
    means: np.ndarray = None
    train_priors: np.ndarray = None


def default_means(K: int, d: int, radius: float = 3.2) -> np.ndarray:
    """Place K means on a circle for nice separation/visualization."""
    angles = np.linspace(0, 2 * np.pi, K, endpoint=False)
    means = np.stack([radius * np.cos(angles), radius * np.sin(angles)], axis=1)
    if d > 2:
        pad = np.zeros((K, d - 2))
        means = np.concatenate([means, pad], axis=1)
    return means.astype(np.float32)


def sample_gmm(N: int, priors: np.ndarray, means: np.ndarray, cov_scale: float) -> Tuple[np.ndarray, np.ndarray]:
    """Sample N points from a K-component isotropic GMM with given priors and means."""
    K, d = means.shape
    z = np.random.choice(K, size=N, p=priors)
    X = np.zeros((N, d), dtype=np.float32)
    for k in range(K):
        idx = np.where(z == k)[0]
        if len(idx) == 0:
            continue
        X[idx] = np.random.randn(len(idx), d).astype(np.float32) * cov_scale + means[k]
    y = z.astype(np.int64)
    return X, y


class GMMDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        assert X.shape[0] == y.shape[0]
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


def make_label_shift_priors(base_priors: np.ndarray, shift_vector: np.ndarray, strength: float) -> np.ndarray:
    """
    Create shifted priors by 'tilting' base_priors in the direction of shift_vector.
    - base_priors: shape (K,)
    - shift_vector: shape (K,) (can be any real numbers; we'll softmax it)
    - strength: scalar >= 0 that controls how strong the tilt is.
    Returns priors that sum to 1.
    """
    logits = np.log(base_priors + 1e-12) + strength * shift_vector
    logits = logits - np.max(logits)  # stabilize
    p = np.exp(logits)
    p = p / p.sum()
    return p


# ------------------------------
# Simple MLP model for 2D -> K classification
# ------------------------------
class MLP(nn.Module):
    def __init__(self, d: int, K: int, width: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, width),
            nn.ReLU(inplace=True),
            nn.Linear(width, width),
            nn.ReLU(inplace=True),
            nn.Linear(width, K),
        )

    def forward(self, x):
        return self.net(x)


# ------------------------------
# Utilities: accuracy, per-class losses
# ------------------------------
@torch.no_grad()
def accuracy(model: nn.Module, loader: DataLoader, device) -> float:
    model.eval()
    correct = 0
    total = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.numel()
    return correct / total if total > 0 else 0.0

def build_loaders_hard(
    K=4,                   # more classes → harder
    radius_train=2.0,      # smaller separation → more overlap
    radius_test=1.8,       # extra overlap at test
    cov_scale_train=1,   # fatter train Gaussians
    cov_scale_test=1,    # even fatter test Gaussians (covariate shift)
    rotate_test_deg=25.0,  # rotate class means at test (covariate shift)
    label_strength=6.0,    # stronger label shift
    asymmetric_shift= np.array([-0.8, 0.6], dtype=np.float32),  # asymmetric shift of one class
    batch_size=256,        # smaller batches → noisier gradients
    train_N=8000, test_N=8000,
):
    # Train means (on a smaller circle)
    means_tr = default_means(K=K, d=2, radius=radius_train)
    # Test means (smaller radius + rotation + a tiny translation on one class)
    means_te = default_means(K=K, d=2, radius=radius_test)
    means_te = rotate_2d(means_te, rotate_test_deg)
    means_te[0] += asymmetric_shift  # shift class 0 a bit

    # Train priors: mildly imbalanced (makes learning harder)
    train_priors = np.array([0.45, 0.25, 0.2, 0.10], dtype=np.float32) if K == 4 else \
                   np.ones(K, dtype=np.float32) / K
    train_priors = train_priors / train_priors.sum()

    # Test priors: strong label shift (tilt toward higher class index)
    shift_vec = np.linspace(0, 1, K).astype(np.float32)
    test_priors = make_label_shift_priors(train_priors, shift_vec, strength=label_strength)

    # Assemble loaders using your existing sampling utilities
    cfg_tr = GMMConfig(K=K, d=2, train_N=train_N, test_N=test_N, cov_scale=cov_scale_train, means=means_tr, train_priors=train_priors)
    Xtr, ytr = sample_gmm(cfg_tr.train_N, train_priors, means_tr, cov_scale_train)

    # Note: for test we purposely use different means+cov scale
    Xte, yte = sample_gmm(test_N, test_priors, means_te, cov_scale_test)

    tr_loader = DataLoader(GMMDataset(Xtr, ytr), batch_size=batch_size, shuffle=True)
    te_loader = DataLoader(GMMDataset(Xte, yte), batch_size=batch_size, shuffle=True)
    return tr_loader, te_loader, train_priors, test_priors, means_tr, means_te


def per_class_losses(logits: torch.Tensor, y: torch.Tensor, K: int) -> torch.Tensor:
    """
    Compute average loss per class in the *current batch*.
    Returns tensor of shape (K,) where entries for unseen classes are zeros.
    """
    ce = nn.CrossEntropyLoss(reduction="none")
    losses = ce(logits, y)  # (B,)
    out = torch.zeros(K, device=logits.device, dtype=losses.dtype)
    for k in range(K):
        mask = (y == k)
        if mask.any():
            out[k] = losses[mask].mean()
        else:
            out[k] = 0.0
    return out


# ------------------------------
# ERM Trainer (SGD or Adam)
# ------------------------------
def train_erm(model, train_loader, test_loader, epochs=20, lr=1e-3, weight_decay=0.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
    # opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ce = nn.CrossEntropyLoss()
    tr_hist, te_hist = [], []
    for ep in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = ce(logits, yb)
            loss.backward()
            opt.step()
        tr_acc = accuracy(model, train_loader, device)
        te_acc = accuracy(model, test_loader, device)
        tr_hist.append(tr_acc)
        te_hist.append(te_acc)
    return model, tr_hist, te_hist


# ------------------------------
# SAM Trainer (Sharpness-Aware Minimization)
# ------------------------------
class SAM:
    """
    A light SAM implementation that perturbs weights along the gradient direction.
    It matches the algorithmic structure used in the SAM paper (Foret et al., 2021).
    """
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        self.params = list(params)
        self.base_optimizer = base_optimizer(self.params, **kwargs)
        self.rho = rho
        self._backup = {}

    @torch.no_grad()
    def _grad_norm(self):
        norms = []
        for p in self.params:
            if p.grad is not None:
                norms.append(p.grad.norm(p=2))
        if not norms:
            return torch.tensor(0.0)
        return torch.norm(torch.stack(norms), p=2)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm() + 1e-12
        scale = self.rho / grad_norm
        self._backup = {}
        for p in self.params:
            if p.grad is None:
                continue
            self._backup[p] = p.data.clone()
            p.add_(p.grad, alpha=scale)
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for p in self.params:
            if p in self._backup:
                p.data.copy_(self._backup[p])
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    def zero_grad(self):
        self.base_optimizer.zero_grad()

def gaussian_kl(mu1: np.ndarray, mu2: np.ndarray, sigma1: float, sigma2: float, d: int) -> float:
    """KL( N(mu1, sigma1^2 I) || N(mu2, sigma2^2 I) )."""
    term1 = d * (sigma1**2 / sigma2**2)
    term2 = float(np.sum((mu2 - mu1)**2)) / (sigma2**2)
    term3 = -d + d * np.log((sigma2**2) / (sigma1**2))
    return 0.5 * (term1 + term2 + term3)

def mixture_kl_joint(q_test: np.ndarray, q_train: np.ndarray,
                     means_te: np.ndarray, means_tr: np.ndarray,
                     sigma_te: float, sigma_tr: float) -> float:
    """KL( p_te(x,y) || p_tr(x,y) ) assuming class alignment between components."""
    d = means_tr.shape[1]
    total = 0.0
    for k in range(len(q_train)):
        kl_g = gaussian_kl(means_te[k], means_tr[k], sigma_te, sigma_tr, d)
        kl_label = np.log(np.clip(q_test[k],1e-12,1.0)) - np.log(np.clip(q_train[k],1e-12,1.0))
        total += q_test[k] * (kl_label + kl_g)
    return float(total)

def rotate_2d(points: np.ndarray, degrees: float) -> np.ndarray:
    rad = np.deg2rad(degrees)
    R = np.array([[np.cos(rad), -np.sin(rad)],[np.sin(rad), np.cos(rad)]], dtype=np.float32)
    return (points @ R.T).astype(np.float32)

def priors_from_strength(q_train: np.ndarray, shift_vec: np.ndarray, strength: float) -> np.ndarray:
    """Label tilt: p ∝ q * exp(strength * v)."""
    logits = np.log(np.clip(q_train,1e-12,1.0)) + strength * shift_vec
    logits -= logits.max()
    p = np.exp(logits); p /= p.sum()
    return p.astype(np.float32)

# ============================================================
# (B) Train once on train distribution, then sweep *general* shift on test
# ============================================================
def run_general_shift_sweep(
    K: int,
    tr_loader, te_loader,          # from your base/hard builder (training distro)
    q_train: np.ndarray,            # training priors (numpy)
    means_tr: np.ndarray,           # training means (K x d)
    cov_scale_train: float,         # sigma_train
    *,
    # --- model/train hyperparams ---
    epochs_erm: int = 20, lr_erm: float = 1e-3,
    epochs_sam: int = 20, lr_sam: float = 5e-3, rho_sam: float = 0.05,
    epochs_irs: int = 20, lr_irs: float = 1e-4, warmup_irs: int = 3,
    tau0: float = 0.1, tau_eps: float = 0.05,
    IRS_BATCH_REFERENCE: bool = True,   # True = instance-level IRS (data-shift robust)
    # --- shift schedule (covariate + label) ---
    strengths: List[float] = None,      # abstract shift parameter s
    rot_deg_per_s: float = 8.0,         # rotation per unit s (degrees)
    trans_per_s: Tuple[float,float] = (-0.15, 0.10),   # translation added to class 0 per s
    sigma_scale_per_s: float = 0.06,    # sigma_test = sigma_train * (1 + sigma_scale_per_s * s)
    label_tilt_per_s: float = 1.0,      # label-strength = s * label_tilt_per_s
    # --- eval/saving ---
    N_test_each: int = 6000, batch_size_eval: int = 512,
    out_dir: str = "runs",
    fig_name_acc: str = "general_shift_acc_vs_KL.pdf",
    fig_name_ce: str = "general_shift_ce_vs_KL.pdf",
    json_name: str = "general_shift_results.json",
):
    """
    Trains ERM / SAM / IRS on the TRAIN distribution; then, for s in strengths,
    constructs a TEST distribution with:
      - rotation by rot_deg_per_s * s,
      - class-0 translation by s * trans_per_s,
      - covariance scale sigma_test = sigma_train * (1 + sigma_scale_per_s * s),
      - label tilt with strength = s * label_tilt_per_s along v = linspace(0,1,K).
    Evaluates accuracy and CE, and plots them vs joint KL(p_te || p_tr).
    """
    os.makedirs(out_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Build models and train on TRAIN distribution
    arch, arch_args = "MLP", {"d": 2, "K": K, "width": 64}
    model_erm = MLP(**arch_args).to(device)
    _, _, _ = train_erm(model_erm, tr_loader, te_loader, epochs=epochs_erm, lr=lr_erm)

    model_sam = MLP(**arch_args).to(device)
    _, _, _ = train_sam(model_sam, tr_loader, te_loader, epochs=epochs_sam, lr=lr_sam, rho=rho_sam)

    model_irs = MLP(**arch_args).to(device)
    # q_train torch (for label-space IRS; harmless if batch-reference mode)
    q_train_t = torch.tensor(q_train, dtype=torch.float32, device=device)
    _, _, _, _, _ = train_irs(
        model_irs, tr_loader, te_loader,
        epochs=epochs_irs, lr=lr_irs, warmup_epochs=warmup_irs,
        tau=tau0, tau_eps=tau_eps,
        base_priors=q_train_t, K=K,
        use_batch_reference=IRS_BATCH_REFERENCE
    )

    # Sweep over general shift strengths
    if strengths is None:
        strengths = np.linspace(0.0, 1.5, 16)  # 0 → 1.5 units, tweak as you like

    # Fixed shift direction for labels
    shift_vec = np.linspace(0, 1, K).astype(np.float32)

    # Eval helpers
    ce_mean = nn.CrossEntropyLoss(reduction="mean")

    def compute_metrics(model, loader):
        model.eval()
        correct = total = 0
        loss_sum = 0.0; n_batches = 0
        with torch.no_grad():
            for xb, yb in loader:
                xb, yb = xb.to(device), yb.to(device)
                logits = model(xb)
                loss_sum += ce_mean(logits, yb).item()
                n_batches += 1
                pred = logits.argmax(dim=1)
                correct += (pred == yb).sum().item()
                total += yb.numel()
        return (correct / max(total,1)), (loss_sum / max(n_batches,1))

    # Storage
    KLs, accs_erm, accs_sam, accs_irs = [], [], [], []
    ces_erm, ces_sam, ces_irs = [], [], []

    for s in strengths:
        # --- define TEST distribution parameters for this s ---
        rot_deg = rot_deg_per_s * s
        tx, ty = trans_per_s
        trans = np.array([tx * s, ty * s], dtype=np.float32)
        sigma_te = cov_scale_train * (1.0 + sigma_scale_per_s * s)
        means_te = rotate_2d(means_tr, rot_deg).copy()
        means_te[0] = means_te[0] + trans  # asymm translation on class 0
        q_test = priors_from_strength(q_train, shift_vec, s * label_tilt_per_s)

        # --- compute analytic joint KL between TEST and TRAIN ---
        KL = mixture_kl_joint(q_test, q_train, means_te, means_tr, sigma_te, cov_scale_train)
        KLs.append(KL)

        # --- (re)build a TEST loader with these params (geometry + priors) ---
        # keep TRAIN geometry fixed for tr_loader; build only test loader here
        cfg = GMMConfig(K=K, d=2, train_N=0, test_N=N_test_each, cov_scale=cov_scale_train, means=means_tr)
        _, te_loader_s = build_loaders(cfg, q_train, q_test, batch_size=batch_size_eval,
                                       means_test=means_te, cov_scale_test=sigma_te)

        # --- evaluate models ---
        a_e, ce_e = compute_metrics(model_erm, te_loader_s)
        a_s, ce_s = compute_metrics(model_sam, te_loader_s)
        a_i, ce_i = compute_metrics(model_irs, te_loader_s)
        accs_erm.append(a_e); accs_sam.append(a_s); accs_irs.append(a_i)
        ces_erm.append(ce_e);  ces_sam.append(ce_s);  ces_irs.append(ce_i)

        print(f"s={s:.2f} | KL={KL:.3f} | ACC  ERM:{a_e:.3f} SAM:{a_s:.3f} IRS:{a_i:.3f} | "
              f"CE  ERM:{ce_e:.3f} SAM:{ce_s:.3f} IRS:{ce_i:.3f}")

    # Sort by KL on x-axis (in case monotonicity not perfect)
    order = np.argsort(np.array(KLs))
    KLs = list(np.array(KLs)[order])
    accs_erm = list(np.array(accs_erm)[order])
    accs_sam = list(np.array(accs_sam)[order])
    accs_irs = list(np.array(accs_irs)[order])
    ces_erm = list(np.array(ces_erm)[order])
    ces_sam = list(np.array(ces_sam)[order])
    ces_irs = list(np.array(ces_irs)[order])

    # --- Save JSON ---
    out_json = {
        "KLs": KLs,
        "strengths": list(np.array(strengths)[order]),
        "acc_ERM": accs_erm, "acc_SAM": accs_sam, "acc_IRS": accs_irs,
        "ce_ERM": ces_erm,   "ce_SAM": ces_sam,   "ce_IRS": ces_irs,
        "config": {
            "rot_deg_per_s": rot_deg_per_s,
            "trans_per_s": trans_per_s,
            "sigma_scale_per_s": sigma_scale_per_s,
            "label_tilt_per_s": label_tilt_per_s,
            "IRS_BATCH_REFERENCE": IRS_BATCH_REFERENCE
        }
    }
    with open(os.path.join(out_dir, json_name), "w") as f:
        json.dump(out_json, f, indent=2)
    print(f"[saved] {os.path.join(out_dir, json_name)}")

    # --- Plot Accuracy vs KL ---
    plt.figure(figsize=(7,5))
    plt.plot(KLs, accs_erm, marker='o', label="ERM")
    plt.plot(KLs, accs_sam, marker='o', label="SAM")
    plt.plot(KLs, accs_irs, marker='o', label="IRS")
    plt.xlabel("KL( p_test(x,y) || p_train(x,y) )")
    plt.ylabel("Accuracy (test)")
    plt.legend(); plt.tight_layout()
    acc_path = os.path.join(out_dir, fig_name_acc)
    plt.savefig(acc_path, dpi=160); plt.show()
    print(f"[saved] {acc_path}")

    # --- Plot CE vs KL ---
    plt.figure(figsize=(7,5))
    plt.plot(KLs, ces_erm, marker='o', label="ERM")
    plt.plot(KLs, ces_sam, marker='o', label="SAM")
    plt.plot(KLs, ces_irs, marker='o', label="IRS")
    plt.xlabel("KL( p_test(x,y) || p_train(x,y) )")
    plt.ylabel("Cross-entropy loss (test)")
    plt.legend(); plt.tight_layout()
    ce_path = os.path.join(out_dir, fig_name_ce)
    plt.savefig(ce_path, dpi=160); plt.show()
    print(f"[saved] {ce_path}")

def train_sam(model, train_loader, test_loader, epochs=20, lr=5e-3, weight_decay=0.0, rho=0.05):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    base_opt = torch.optim.SGD
    optimizer = SAM(model.parameters(), base_opt, lr=lr, weight_decay=weight_decay, momentum=0.9, rho=rho)
    # base_opt = torch.optim.Adam
    # optimizer = SAM(model.parameters(), base_opt, lr=lr, weight_decay=weight_decay, rho=rho)
    ce = nn.CrossEntropyLoss()
    tr_hist, te_hist = [], []
    for ep in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss = ce(logits, yb)
            loss.backward()
            optimizer.first_step(zero_grad=True)
            logits = model(xb)
            loss = ce(logits, yb)
            loss.backward()
            optimizer.second_step(zero_grad=True)

        tr_acc = accuracy(model, train_loader, device)
        te_acc = accuracy(model, test_loader, device)
        tr_hist.append(tr_acc)
        te_hist.append(te_acc)
    return model, tr_hist, te_hist


# ------------------------------
# IRS Trainer (Fragility with KL path over class priors)
# ------------------------------
def kl_div(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    eps = 1e-12
    p = torch.clamp(p, eps, 1.0)
    q = torch.clamp(q, eps, 1.0)
    return torch.sum(p * (torch.log(p) - torch.log(q)))


def kl_path_p_of_h(log_q: torch.Tensor, losses: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
    """
    KL path parameterization for class-wise shift:
        p_i(h) ∝ q_i * exp(h * loss_i)
    Implemented stably in log-space.
    """
    logits = log_q + h * losses
    logits = logits - torch.max(logits)  # stabilize
    p = torch.exp(logits)
    p = p / p.sum()
    return p


def maximize_kappa_secant(log_q: torch.Tensor, losses: torch.Tensor, tau: torch.Tensor,
                          h_init: float = 5.0, h_step: float = 5.0, max_iter: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    1-D derivative-free maximization of
        kappa(h) = ( <p(h), losses> - tau ) / KL( p(h) || q )
    Returns (h_star, p_star).
    """
    def kappa(h_scalar: float) -> float:
        h_t = torch.tensor(h_scalar, dtype=losses.dtype, device=losses.device)
        p = kl_path_p_of_h(log_q, losses, h_t)
        num = torch.dot(p, losses) - tau
        den = kl_div(p, torch.exp(log_q))
        return (num / (den + 1e-12)).item()

    # Bracket expansion
    h_left = h_init - h_step
    h_right = h_init + h_step
    f_left = kappa(h_left); f_mid = kappa(h_init); f_right = kappa(h_right)

    for _ in range(200):
        if f_mid >= f_left and f_mid >= f_right:
            break
        if f_right > f_left:
            h_left, f_left = h_init, f_mid
            h_init, f_mid = h_right, f_right
            h_right = h_init + h_step; f_right = kappa(h_right)
        else:
            h_right, f_right = h_init, f_mid
            h_init, f_mid = h_left, f_left
            h_left = h_init - h_step; f_left = kappa(h_left)

    a, b = h_left, h_right
    fa, fb = f_left, f_right
    best_h, best_f = h_init, f_mid
    for _ in range(max_iter):
        if abs(fb - fa) < 1e-9:
            h_new = 0.5 * (a + b)
        else:
            h_new = b - (fb * (b - a) / (fb - fa))
        f_new = kappa(h_new)
        if f_new > best_f:
            best_h, best_f = h_new, f_new
        if h_new < (a + b) / 2:
            b, fb = h_new, f_new
        else:
            a, fa = h_new, f_new
        if abs(b - a) < 1e-3:
            break

    h_star = torch.tensor(best_h, dtype=losses.dtype, device=losses.device)
    p_star = kl_path_p_of_h(log_q, losses, h_star)
    return h_star, p_star

@torch.no_grad()
def estimate_kappa_over_loader(model, loader, log_q: torch.Tensor, tau: torch.Tensor, K: int):
    """
    Computes kappa = ( <p*, losses> - tau ) / KL(p* || q ) per batch, then returns:
      - mean_kappa: average over batches
      - max_kappa:  max over batches
    Uses the same inner maximization (secant search) you use in training.
    """
    model.eval()
    device = next(model.parameters()).device
    kappas = []

    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        losses_k = per_class_losses(logits, yb, K=K)  # (K,)

        # --- sanity fix for batch-reference mode ---
        if log_q.shape[0] != losses_k.shape[0]:
            # Expand or rebuild log_q to match losses dimension (uniform)
            log_q = -torch.log(torch.tensor(float(losses_k.shape[0]), device=losses_k.device)) \
                    * torch.ones_like(losses_k)
        # argmax along the KL path (no grad)
        _, p_star = maximize_kappa_secant(log_q, losses_k, tau)

        # kappa(h*) = ( <p*, losses> - tau ) / KL(p* || q )
        num = torch.dot(p_star, losses_k) - tau
        den = kl_div(p_star, torch.exp(log_q)) + 1e-12
        kappa = (num / den).item()
        kappas.append(kappa)

    if len(kappas) == 0:
        return 0.0, 0.0
    return float(np.mean(kappas)), float(np.max(kappas))

def train_irs(model: nn.Module,
              train_loader,
              test_loader,
              epochs: int = 20,
              lr: float = 1e-4,
              warmup_epochs: int = 3,
              tau: float = 0.1,
              tau_eps: float = 0.05,
              base_priors: torch.Tensor = None,
              K: int = None,
              use_batch_reference: bool = False,
              device: torch.device = None):
    """
    Iterative Robust Satisficing (IRS) trainer.
    Now supports two modes:
      - use_batch_reference=False  -> label-space IRS (per-class losses, priors q)
      - use_batch_reference=True   -> batch-space IRS (per-sample losses, uniform q)

    Arguments:
        model: nn.Module classifier
        train_loader, test_loader: DataLoaders
        epochs: total epochs
        lr: learning rate
        warmup_epochs: epochs to update tau via EMA before fixing
        tau: initial satisficing threshold
        tau_eps: multiplicative epsilon for tau adjustment
        base_priors: training class priors (only used if use_batch_reference=False)
        K: number of classes (only used if use_batch_reference=False)
        use_batch_reference: if True, reference distribution = uniform over batch
        device: torch.device
    Returns:
        model, train_acc_hist, test_acc_hist, base_priors, metrics
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if base_priors is not None and isinstance(base_priors, np.ndarray):
        base_priors = torch.tensor(base_priors, dtype=torch.float32)

    base_priors = base_priors.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss(reduction="none")  # used for per-sample losses

    # warmup tau tracking
    running_tau = torch.tensor(tau, device=device)
    tau_fixed = None

    tr_hist, te_hist = [], []
    log_q_uniform_cache = None  # reused when using batch reference

    for ep in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad(set_to_none=True)

            logits = model(xb)

            # === Reference distribution and loss vector ===
            if use_batch_reference:
                # Uniform reference q over samples in the batch
                losses_vec = ce(logits, yb)  # shape [b]
                b = losses_vec.shape[0]
                if (log_q_uniform_cache is None) or (len(log_q_uniform_cache) != b):
                    log_q_uniform_cache = -torch.log(
                        torch.tensor(float(b), device=device)
                    ) * torch.ones(b, device=device)
                log_q = log_q_uniform_cache
            else:
                # Standard per-class IRS (label-space)
                assert base_priors is not None and K is not None, \
                    "base_priors and K must be provided for label-space IRS"
                losses_vec = per_class_losses(logits, yb, K=K)
                log_q = torch.log(base_priors.to(device))

            # === Maximize kappa wrt distribution p (no grad wrt theta) ===
            with torch.no_grad():
                _, p_star = maximize_kappa_secant(log_q, losses_vec, running_tau)

            # === Compute fragility-adjusted loss ===
            # L = <p*, losses> - tau + tau_eps * KL(p*||q)
            frag_loss = torch.dot(p_star, losses_vec) - running_tau \
                        + tau_eps * kl_div(p_star, torch.exp(log_q))
            frag_loss.backward()
            optimizer.step()

        # --- Tau schedule (warmup then fix) ---
        if ep < warmup_epochs:
            # Exponential moving average of CE loss for tau
            with torch.no_grad():
                ce_vals = []
                for xb, yb in train_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    losses_batch = ce(model(xb), yb)
                    ce_vals.append(losses_batch.mean().item())
                    break  # one batch sufficient for EMA
                ce_batch_mean = torch.tensor(np.mean(ce_vals), device=device)
                running_tau = 0.9 * running_tau + 0.1 * ce_batch_mean
        elif tau_fixed is None:
            tau_fixed = running_tau.clone().detach()

        # --- Logging ---
        tr_hist.append(accuracy(model, train_loader, device=device))
        te_hist.append(accuracy(model, test_loader, device=device))

        print(f"[Epoch {ep+1}/{epochs}] Train acc={tr_hist[-1]:.3f} "
              f"Test acc={te_hist[-1]:.3f}  tau={float(running_tau):.4f}")

    # === Final metrics ===
    final_tau = tau_fixed if tau_fixed is not None else running_tau
    mean_kappa, max_kappa = estimate_kappa_over_loader(
        model, test_loader, log_q, final_tau, 
        K if not use_batch_reference else len(test_loader.dataset)
    )
    metrics = {
        "tau": float(final_tau.item()),
        "kappa_mean": mean_kappa,
        "kappa_max": max_kappa,
    }

    return model, tr_hist, te_hist, base_priors, metrics



# ------------------------------
# Experiment harness
# ------------------------------
def build_loaders(
    cfg: GMMConfig,
    train_priors: np.ndarray,
    test_priors: np.ndarray,
    batch_size: int = 256,
    means_test: np.ndarray = None,
    cov_scale_test: float = None,
):
    """
    Builds training and test DataLoaders for GMM data.
    If means_test or cov_scale_test are provided, they override cfg.means / cfg.cov_scale for the test set only.
    If train_N=0, the training loader is skipped (returns None).
    """
    tr_loader = None
    if cfg.train_N > 0:
        Xtr, ytr = sample_gmm(cfg.train_N, train_priors, cfg.means, cfg.cov_scale)
        tr_ds = GMMDataset(Xtr, ytr)
        tr_loader = DataLoader(tr_ds, batch_size=batch_size, shuffle=True, drop_last=True)

    means_te = means_test if means_test is not None else cfg.means
    cov_te = cov_scale_test if cov_scale_test is not None else cfg.cov_scale
    Xte, yte = sample_gmm(cfg.test_N, test_priors, means_te, cov_te)
    te_ds = GMMDataset(Xte, yte)
    te_loader = DataLoader(te_ds, batch_size=batch_size, shuffle=False)

    return tr_loader, te_loader

ce = nn.CrossEntropyLoss(reduction="mean")
def compute_ce(model, loader): 
    model.eval()
    device = next(model.parameters()).device
    total_loss = 0.0
    total_count = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = ce(logits, yb)
            total_loss += loss.item() * yb.numel()
            total_count += yb.numel()
    return total_loss / total_count if total_count > 0 else 0.0

def run_learning_curves_under_label_shift(
    *,
    K: int = 4,
    kl_target: float = 0.8,
    epochs: int = 30,
    n_runs: int = 10,               # <-- NEW
    base_seed: int = 1337,          # <-- NEW (each run uses base_seed + i)
    train_N: int = 8000,
    test_N: int = 5000,
    batch_size: int = 256,
    radius_train: float = 2.0,
    cov_scale_train: float = 1.0,
    out_dir: str = "runs",
    out_pdf: str = "appendix_learning_curves_labelshift_meanstd.pdf",
    out_json: str = "appendix_learning_curves_labelshift_meanstd.json",
):
    """
    Repeat the experiment n_runs times (different random seeds) and plot mean ± std
    for train and shifted-test accuracy across epochs, for ERM / SAM / IRS.
    """
    import os, json
    import numpy as np
    import torch
    import matplotlib.pyplot as plt

    os.makedirs(out_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- training distribution (match Fig.1 setup) ---
    if K != 4:
        raise ValueError("This helper assumes K=4 (to match your priors). Adjust train_priors if needed.")

    train_priors = np.array([0.45, 0.25, 0.20, 0.10], dtype=np.float32)
    train_priors = train_priors / train_priors.sum()

    means_tr = default_means(K=K, d=2, radius=radius_train)

    # --- construct test priors with KL calibration ---
    shift_vec = np.linspace(0, 1, K).astype(np.float32)
    s = strength_for_target_kl(train_priors, shift_vec, kl_target)
    test_priors = priors_from_strength(train_priors, shift_vec, s)

    print(f"[LC] target KL(priors)={kl_target:.4f}  achieved strength s={s:.4f}")
    print("[LC] train priors:", train_priors)
    print("[LC] test  priors:", test_priors)
    print(f"[LC] n_runs={n_runs}, epochs={epochs}, base_seed={base_seed}")

    # --- storage: arrays shaped (n_runs, epochs) ---
    def alloc():
        return np.zeros((n_runs, epochs), dtype=np.float32)

    curves = {
        "ERM": {"train": alloc(), "test": alloc()},
        "SAM": {"train": alloc(), "test": alloc()},
        "IRS": {"train": alloc(), "test": alloc()},
    }

    run_seeds = []

    for r in range(n_runs):
        seed = int(base_seed + r)
        run_seeds.append(seed)
        set_seed(seed)

        # Build a fresh dataset per run (consistent with your current experimental style).
        # If you prefer *fixed data, varying init only*, then build loaders once outside the loop.
        cfg = GMMConfig(
            K=K, d=2,
            train_N=train_N, test_N=test_N,
            cov_scale=cov_scale_train,
            means=means_tr,
            train_priors=train_priors
        )
        tr_loader, te_loader_shift = build_loaders(
            cfg, train_priors, test_priors, batch_size=batch_size
        )

        # --- ERM ---
        model_erm = MLP(d=2, K=K, width=64).to(device)
        _, tr_erm, te_erm = train_erm(model_erm, tr_loader, te_loader_shift, epochs=epochs, lr=1e-3)
        curves["ERM"]["train"][r, :] = np.array(tr_erm, dtype=np.float32)
        curves["ERM"]["test"][r, :] = np.array(te_erm, dtype=np.float32)

        # --- SAM ---
        model_sam = MLP(d=2, K=K, width=64).to(device)
        _, tr_sam, te_sam = train_sam(model_sam, tr_loader, te_loader_shift, epochs=epochs, lr=5e-3, rho=0.05)
        curves["SAM"]["train"][r, :] = np.array(tr_sam, dtype=np.float32)
        curves["SAM"]["test"][r, :] = np.array(te_sam, dtype=np.float32)

        # --- IRS ---
        model_irs = MLP(d=2, K=K, width=64).to(device)
        _, tr_irs, te_irs, _, _ = train_irs(
            model_irs, tr_loader, te_loader_shift,
            epochs=epochs, lr=1e-4,
            warmup_epochs=3, tau=0.1, tau_eps=0.05,
            base_priors=train_priors, K=K,
            use_batch_reference=False
        )
        curves["IRS"]["train"][r, :] = np.array(tr_irs, dtype=np.float32)
        curves["IRS"]["test"][r, :] = np.array(te_irs, dtype=np.float32)

        print(f"[LC] finished run {r+1}/{n_runs} (seed={seed})")

    # --- aggregate ---
    def mean_std(arr):
        # arr: (n_runs, epochs)
        return arr.mean(axis=0), arr.std(axis=0, ddof=1) if n_runs > 1 else (arr.mean(axis=0), np.zeros(arr.shape[1]))

    summary = {}
    for m in ["ERM", "SAM", "IRS"]:
        tr_mu, tr_sd = mean_std(curves[m]["train"])
        te_mu, te_sd = mean_std(curves[m]["test"])
        summary[m] = {
            "train_mean": tr_mu.tolist(),
            "train_std": tr_sd.tolist(),
            "test_mean": te_mu.tolist(),
            "test_std": te_sd.tolist(),
        }

    # --- save JSON (both per-run and summary) ---
    payload = {
        "K": K,
        "epochs": epochs,
        "n_runs": n_runs,
        "run_seeds": run_seeds,
        "train_N": train_N,
        "test_N": test_N,
        "batch_size": batch_size,
        "radius_train": radius_train,
        "cov_scale_train": cov_scale_train,
        "kl_target_priors": float(kl_target),
        "strength_s": float(s),
        "train_priors": train_priors.tolist(),
        "test_priors": test_priors.tolist(),
        "curves_per_run": {
            m: {
                "train_acc": curves[m]["train"].tolist(),
                "test_acc": curves[m]["test"].tolist(),
            } for m in curves
        },
        "curves_summary": summary,
    }
    json_path = os.path.join(out_dir, out_json)
    with open(json_path, "w") as f:
        json.dump(payload, f, indent=2)
    print("[saved]", json_path)

    # --- plot mean ± std (shaded) ---
    xs = np.arange(1, epochs + 1)

    plt.rcParams.update({
        "axes.labelsize": 14,
        "axes.titlesize": 14,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 12,
        "axes.spines.top": False,
        "axes.spines.right": False,
        "legend.frameon": False,
        "axes.grid": True,
        "grid.alpha": 0.35,
        "grid.linestyle": ":",
        "grid.linewidth": 0.8,
    })

    fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)

    def shaded(ax, mu, sd, label, linestyle="-"):
        ax.plot(xs, mu, linewidth=2.2, linestyle=linestyle, label=label)
        ax.fill_between(xs, mu - sd, mu + sd, alpha=0.18)

    def panel(ax, mname):
        tr_mu = np.array(summary[mname]["train_mean"], dtype=np.float32)
        tr_sd = np.array(summary[mname]["train_std"], dtype=np.float32)
        te_mu = np.array(summary[mname]["test_mean"], dtype=np.float32)
        te_sd = np.array(summary[mname]["test_std"], dtype=np.float32)

        shaded(ax, tr_mu, tr_sd, "Train", linestyle="-")
        shaded(ax, te_mu, te_sd, "Test (shifted)", linestyle="--")
        ax.set_title(mname)
        ax.set_xlabel("Epoch")
        ax.margins(x=0.02)

    panel(axes[0], "ERM")
    panel(axes[1], "SAM")
    panel(axes[2], "IRS")
    axes[0].set_ylabel("Accuracy")
    axes[2].legend(loc="best")

    fig.suptitle(rf"Learning Curves under Label Shift (target KL(priors)={kl_target:.2f}, N={n_runs})", y=1.05)
    plt.tight_layout()

    pdf_path = os.path.join(out_dir, out_pdf)
    plt.savefig(pdf_path, dpi=220, bbox_inches="tight")
    plt.close(fig)
    print("[saved]", pdf_path)




if __name__ == "__main__":
    # experiment_name = "toy_gmm_featureshift"  # choose one
    # experiment_name = "toy_gmm_labelshift"  # choose one
    experiment_name = "toy_gmm_learningcurves_labelshift"


    if experiment_name == "toy_gmm_labelshift":
        print("Running toy_gmm_labelshift experiment...")
        # Set hard loader parameters
        set_seed(1337)
        K = 4
        radius_train = 2.0
        radius_test = 2.0
        cov_scale_train = 1.0
        cov_scale_test = 1.0
        rotate_test_deg = 0.0
        label_strength = 6.0
        asymmetric_shift = np.array([-0.0, 0.0], dtype=np.float32)
        batch_size = 256
        train_N = 8000
        test_N = 5000
        
        # Harder loaders
        tr_loader, te_loader, q_train, q_test, means_tr, means_te = build_loaders_hard(
            K=K,
            radius_train=radius_train,
            radius_test=radius_test,
            cov_scale_train=cov_scale_train,
            cov_scale_test=cov_scale_test,
            rotate_test_deg=rotate_test_deg,
            label_strength=label_strength,
            asymmetric_shift=asymmetric_shift,
            batch_size=batch_size,
            train_N=train_N,
            test_N=test_N
        )

        # wherever you know these:
        arch = "MLP"
        arch_args = {"d": 2, "K": K, "width": 64}  # <-- match your model construction

        os.makedirs("runs", exist_ok=True)

        # Print experiment setup
        print("Train priors:", q_train)
        print("Test priors: ", q_test)
        print("Train means:\n", means_tr)
        print("Test means:\n", means_te)
        print("Train samples:", len(tr_loader.dataset))
        print("Test samples: ", len(te_loader.dataset))
        print("Batch size:   ", batch_size)
        print("K =", K)
        print("Cov scale (train/test):", cov_scale_train, "/", cov_scale_test)
        print("Label shift strength:", label_strength)
        print("Asymmetric shift of class 0:", asymmetric_shift)
        print("Rotate test means by (deg):", rotate_test_deg)

        # =========================
        # N-run experiment: mean ± std accuracy vs KL(label shift)
        # =========================
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Sweep only LABEL SHIFT (keep geometry equal to TRAINING geometry)
        shift_vec = np.linspace(0, 1, K).astype(np.float32)

        # Decide KL range by probing a large tilt; you can also hardcode a KL_max you like
        s_max_probe = 12.0
        kl_cap = kl_cat(priors_from_strength(q_train, shift_vec, s_max_probe), q_train)
        num_points = 20
        kl_targets = np.linspace(0.0, kl_cap, num_points)
        #Print the all distributions we will probe
        print("Probing label shift priors:")
        for kl_t in kl_targets:
            s = strength_for_target_kl(q_train, shift_vec, kl_t)
            p = priors_from_strength(q_train, shift_vec, s)
            print(f"  KL={kl_t:.3f} (s={s:.3f}):", p)
        # N independent runs (fresh data + retraining each time)
        N_runs = 100  # <--- set the number of runs you want

        # Storage: (N_runs, num_points)
        acc_erm_runs = np.zeros((N_runs, num_points), dtype=np.float32)
        acc_sam_runs = np.zeros((N_runs, num_points), dtype=np.float32)
        acc_irs_runs = np.zeros((N_runs, num_points), dtype=np.float32)
        acc_klrs_runs = np.zeros((N_runs, num_points), dtype=np.float32)

        # Optionally track CE too (comment out if not needed)
        ce_erm_runs = np.zeros((N_runs, num_points), dtype=np.float32)
        ce_sam_runs = np.zeros((N_runs, num_points), dtype=np.float32)
        ce_irs_runs = np.zeros((N_runs, num_points), dtype=np.float32)
        ce_klrs_runs = np.zeros((N_runs, num_points), dtype=np.float32)

        for run_idx in range(N_runs):
            # Different seed per run for fresh data and init
            set_seed(1337 + run_idx)

            # -------- Rebuild TRAIN/TEST loaders for this run (your hard setup) --------
            tr_loader, te_loader, q_train, q_test, means_tr, means_te = build_loaders_hard(
                K=K,
                radius_train=radius_train,
                radius_test=radius_test,
                cov_scale_train=cov_scale_train,
                cov_scale_test=cov_scale_test,
                rotate_test_deg=rotate_test_deg,
                label_strength=label_strength,
                asymmetric_shift=asymmetric_shift,
                batch_size=batch_size,
                train_N=train_N,
                test_N=test_N
            )
            print(f"[run {run_idx+1}/{N_runs}] loaders built.")
            # -------- Retrain models from scratch on this run --------
            model_erm = MLP(d=2, K=K, width=64)
            print("Training ERM...")
            _, _, _ = train_erm(model_erm, tr_loader, te_loader, epochs=20, lr=1e-3)

            model_sam = MLP(d=2, K=K, width=64)
            print("Training SAM...")
            _, _, _ = train_sam(model_sam, tr_loader, te_loader, epochs=20, lr=5e-3, rho=0.05)

            model_irs = MLP(d=2, K=K, width=64)
            print("Training IRS...")
            _, _, _, _, _ = train_irs(
                model_irs, tr_loader, te_loader,
                epochs=20, lr=1e-4,
                warmup_epochs=3, tau=0.1, tau_eps=0.05,
                base_priors=q_train, K=K,
                use_batch_reference=False     # 👈 optional, makes it explicit
            )

            print(f"[run {run_idx+1}/{N_runs}] models trained.")

            # -------- For this run, sweep label shift on the TRAIN GEOMETRY (label shift only) --------
            # Recompute KL cap for this run in case priors differ slightly
            kl_cap_run = kl_cat(priors_from_strength(q_train, shift_vec, s_max_probe), q_train)
            # Use the same kl_targets grid scaled to the run's cap (so endpoints align across runs)
            # If you want identical KL grid across runs, comment the next line.
            kl_targets_run = np.linspace(0.0, kl_cap_run, num_points)

            for j, kl_t in enumerate(kl_targets_run):
                s = strength_for_target_kl(q_train, shift_vec, kl_t)
                test_priors = priors_from_strength(q_train, shift_vec, s)

                # Build a test loader with ONLY label shift (means/cov = train geometry)
                _, te_loader_s = build_loaders(
                    GMMConfig(K=K, d=2, train_N=train_N, test_N=test_N,
                            cov_scale=cov_scale_train, means=means_tr),
                    q_train, test_priors, batch_size=batch_size
                )

                # Accuracies
                acc_erm_runs[run_idx, j] = accuracy(model_erm, te_loader_s, device=device)
                acc_sam_runs[run_idx, j] = accuracy(model_sam, te_loader_s, device=device)
                acc_irs_runs[run_idx, j] = accuracy(model_irs, te_loader_s, device=device)
                # acc_klrs_runs[run_idx, j] = accuracy(model_klrs, te_loader_s, device=device)

                # Cross-entropy (optional)
                ce_erm_runs[run_idx, j] = compute_ce(model_erm, te_loader_s)
                ce_sam_runs[run_idx, j] = compute_ce(model_sam, te_loader_s)
                ce_irs_runs[run_idx, j] = compute_ce(model_irs, te_loader_s)
                # ce_klrs_runs[run_idx, j] = compute_ce(model_klrs, te_loader_s)

            print(f"[run {run_idx+1}/{N_runs}] done.")

        # Exit before plotting
        # import sys; 
        # sys.exit("[exiting before plotting]")
        KL = kl_targets  # x-axis

        def mean_std(arr_runs):
            return arr_runs.mean(axis=0), arr_runs.std(axis=0)

        erm_mean, erm_std = mean_std(acc_erm_runs)
        sam_mean, sam_std = mean_std(acc_sam_runs)
        irs_mean, irs_std = mean_std(acc_irs_runs)
        # klrs_mean, klrs_std = mean_std(acc_klrs_runs)

        # Optionally for CE:
        ce_erm_mean, ce_erm_std = mean_std(ce_erm_runs)
        ce_sam_mean, ce_sam_std = mean_std(ce_sam_runs)
        ce_irs_mean, ce_irs_std = mean_std(ce_irs_runs)
        # ce_klrs_mean, ce_klrs_std = mean_std(ce_klrs_runs)

        # ---------------------------
        # Save raw arrays for reproducibility
        # ---------------------------
        save_dict = {
            "N_runs": N_runs,
            "kl_targets": KL.tolist(),
            "acc_erm_runs": acc_erm_runs.tolist(),
            "acc_sam_runs": acc_sam_runs.tolist(),
            "acc_irs_runs": acc_irs_runs.tolist(),
            # "acc_klrs_runs": acc_klrs_runs.tolist(),
            "ce_erm_runs": ce_erm_runs.tolist(),
            "ce_sam_runs": ce_sam_runs.tolist(),
            "ce_irs_runs": ce_irs_runs.tolist(),
            # "ce_klrs_runs": ce_klrs_runs.tolist(),
        }
        with open("runs/label_shift_runs_raw_N100.json", "w") as f:
            json.dump(save_dict, f, indent=2)
        print("[saved] runs/label_shift_acc_mean_std.png, runs/label_shift_ce_mean_std.png, runs/label_shift_runs_raw_N100.json")
        # ---------------------------

        plt.figure(figsize=(7,5))
        for mean, std, label, color in [
            (erm_mean, erm_std, "ERM", "C0"),
            (sam_mean, sam_std, "SAM", "C1"),
            (irs_mean, irs_std, "IRS", "C2"),
            # (klrs_mean, klrs_std, "KLRS", "C3"),
        ]:
            plt.plot(KL, mean, label=label)
            plt.fill_between(KL, mean - std, mean + std, alpha=0.2)
        plt.xlabel(r"KL($\mathbb{P}_\text{test}$ || $\mathbb{P}_\text{train}$)  (label shift amount)")
        plt.ylabel("Accuracy (test)")
        plt.legend()
        os.makedirs("runs", exist_ok=True)
        plt.tight_layout()
        plt.savefig("runs/label_shift_acc_mean_std.pdf", dpi=160)
        plt.show()

        # ---------------------------
        # (Optional) Plot: CE mean ± std vs KL
        # ---------------------------
        plt.figure(figsize=(7,5))
        for mean, std, label, color in [
            (ce_erm_mean, ce_erm_std, "ERM", "C0"),
            (ce_sam_mean, ce_sam_std, "SAM", "C1"),
            (ce_irs_mean, ce_irs_std, "IRS", "C2"),
            # (ce_klrs_mean, ce_klrs_std, "KLRS", "C3"),
        ]:
            plt.plot(KL, mean, label=label)
            plt.fill_between(KL, mean - std, mean + std, alpha=0.2)
        plt.xlabel(r"KL($\mathbb{P}_\text{test}$ || $\mathbb{P}_\text{train}$)  (label shift amount)")
        plt.ylabel("Cross-entropy loss (test)")
        plt.title(f"Loss vs Label Shift (mean ± std over {N_runs} runs)")
        plt.legend()
        plt.tight_layout()
        plt.savefig("runs/label_shift_ce_mean_std.pdf", dpi=160)
        plt.show()
    
    elif experiment_name == "toy_gmm_featureshift":
        print("Running toy_gmm_featureshift experiment...")
        # Set hard loader parameters
        set_seed(1337)
        K = 4
        radius_train = 2.0
        radius_test = 1.8
        cov_scale_train = 1.0
        cov_scale_test = 1.5
        rotate_test_deg = 25.0
        label_strength = 6.0
        asymmetric_shift = np.array([-0.8, 0.6], dtype=np.float32)
        batch_size = 256
        train_N = 8000
        test_N = 5000
        
        # Harder loaders
        tr_loader, te_loader, q_train, q_test, means_tr, means_te = build_loaders_hard(
            K=K,
            radius_train=radius_train,
            radius_test=radius_test,
            cov_scale_train=cov_scale_train,
            cov_scale_test=cov_scale_test,
            rotate_test_deg=rotate_test_deg,
            label_strength=label_strength,
            asymmetric_shift=asymmetric_shift,
            batch_size=batch_size,
            train_N=train_N,
            test_N=test_N
        )

        os.makedirs("runs", exist_ok=True)

        # Print experiment setup
        print("Train priors:", q_train)
        print("Test priors: ", q_test)
        print("Train means:\n", means_tr)
        print("Test means:\n", means_te)
        print("Train samples:", len(tr_loader.dataset))
        print("Test samples: ", len(te_loader.dataset))
        print("Batch size:   ", batch_size)
        print("K =", K)
        print("Cov scale (train/test):", cov_scale_train, "/", cov_scale_test)
        print("Label shift strength:", label_strength)
        print("Asymmetric shift of class 0:", asymmetric_shift)
        print("Rotate test means by (deg):", rotate_test_deg)

        # Train loaders you already have:
        # tr_loader, te_loader, q_train, q_test, means_tr, means_te
        # cov_scale_train, etc.

        run_general_shift_sweep(
            K=K,
            tr_loader=tr_loader, te_loader=te_loader,
            q_train=q_train, means_tr=means_tr, cov_scale_train=cov_scale_train,
            IRS_BATCH_REFERENCE=False,            # <-- instance-level IRS (recommended for general shift)
            strengths=np.linspace(0.0, 1.5, 40), # abstract shift knob
            rot_deg_per_s=1.0,                   # rotation per 1.0 of s
            trans_per_s=(-0.18, 0.12),           # translate class 0 mean per 1.0 of s
            sigma_scale_per_s=0.1,              # inflate sigma per 1.0 of s
            label_tilt_per_s=1.5,                # label-tilt strength per 1.0 of s
            out_dir="runs",
        )

    elif experiment_name == "toy_gmm_learningcurves_labelshift":
        run_learning_curves_under_label_shift(
            kl_target=0.5,
            epochs=30,
            n_runs=10,
            base_seed=42,
            out_pdf="appendix_learning_curves_labelshift_KL0p5_meanstd.pdf",
            out_json="appendix_learning_curves_labelshift_KL0p5_meanstd.json",
        )
