# -*- coding: utf-8 -*-
""" 
# HANCOCK - Outcome Prediction  

Main steps of this notebook:

1. Import dependencies  
   For numerical computing, data processing, visualization, and ML modeling  

2. Define the optimal Random Forest model  

   Configure hyperparameters (number of trees, depth, splitting criteria, etc.)  

3. Load multimodal features  

   Includes structured features and text/image features  

4. Load data splits + UMAP dimensionality reduction  

   Use train/test splits and apply UMAP for visualization  

5. Train & Evaluate  

   Fit the model on the training set, evaluate on the testing set  

6. (UMAP splits & ROC curves) / Plotting (UMAP splits & ROC curves)  

   Show data distribution and classifier performance comparison

### mosek
"""

import warnings

# Silence the SCS “Converting A to a CSC ...” message
warnings.filterwarnings(
    "ignore",
    message=r"Converting A to a CSC.*",
    category=UserWarning,
    module=r"^scs(\.|$)"
)

# Silence all UserWarnings from ECOS (if you do not want to disable all, change to a message regex)
warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    module=r"^ecos(\.|$)"
)

# ---- MOSEK license setup (put this at the very top, before importing cvxpy) ----
import os, shutil, stat

def setup_mosek_license(
    lic_dest=os.path.expanduser("~/.mosek/mosek.lic"),
    lic_candidates=None,
    set_fallback_env=True,
    verbose=True,
):
    """
    Ensure MOSEK license is at ~/.mosek/mosek.lic and env vars are set.
    lic_candidates: list of possible existing license paths to copy from.
                    If None, defaults to ['mosek.lic', os.path.join(os.getcwd(), 'mosek.lic')].
    """
    if lic_candidates is None:
        lic_candidates = [
            "mosek.lic",
            os.path.join(os.getcwd(), "mosek.lic"),
            # Your project path (you can add an absolute path if needed):
        ]

    # 1) Make sure the directory exists
    os.makedirs(os.path.dirname(lic_dest), exist_ok=True)

    # 2) If the target does not exist, copy it from the candidate paths 
    if not os.path.exists(lic_dest):
        src = next((p for p in lic_candidates if p and os.path.exists(p)), None)
        if src is None:
            if verbose:
                print("[MOSEK] No license file found in candidates, skip copy. "
                      "You can place mosek.lic next to the script or set MOSEKLM_LICENSE_FILE yourself.")
        else:
            shutil.copyfile(src, lic_dest)
            if verbose:
                print(f"[MOSEK] Copied license from: {src} -> {lic_dest}")

    # 3) Permissions (600) 
    if os.path.exists(lic_dest):
        try:
            os.chmod(lic_dest, stat.S_IRUSR | stat.S_IWUSR)
        except Exception as e:
            if verbose:
                print(f"[MOSEK] chmod 600 failed (non-fatal): {e}")

    # 4) Set environment variable: use only $HOME/.mosek/mosek.lic (recommended on HPC)
    os.environ["MOSEKLM_LICENSE_FILE"] = lic_dest
    if set_fallback_env:
        os.environ["MOSEK_FALLBACK_LIC"] = lic_dest

    if verbose:
        print("[MOSEK] HOME =", os.path.expanduser("~"))
        print("[MOSEK] MOSEKLM_LICENSE_FILE =", os.environ.get("MOSEKLM_LICENSE_FILE"))
        print("[MOSEK] License exists:", os.path.exists(lic_dest))

# Call once, must be executed before importing cvxpy
setup_mosek_license()
# ------------------------------------------------------------------------------

import cvxpy as cp
print("Solvers available:", cp.installed_solvers())
x = cp.Variable()
cp.Problem(cp.Minimize((x - 2)**2)).solve(solver=cp.MOSEK)
print("MOSEK test ok; x* =", x.value)

_ORIG_SOLVE = cp.Problem.solve
def _solve_mosek_default(self, *args, **kwargs):
    if "solver" not in kwargs or kwargs["solver"] is None:
        kwargs["solver"] = cp.MOSEK
    return _ORIG_SOLVE(self, *args, **kwargs)

cp.Problem.solve = _solve_mosek_default
      
def mosek_child_initializer():
    import warnings
    warnings.filterwarnings(
        "ignore",
        message=r"Converting A to a CSC.*",
        category=UserWarning,
        module=r"^scs(\.|$)"
    )
    warnings.filterwarnings(
        "ignore",
        category=UserWarning,
        module=r"^ecos(\.|$)"
    )

    lic = os.path.expanduser("~/.mosek/mosek.lic")
    os.environ["MOSEKLM_LICENSE_FILE"] = lic
    os.environ["MOSEK_FALLBACK_LIC"] = lic
    import cvxpy as _cp
    y = _cp.Variable()
    _cp.Problem(_cp.Minimize((y - 1)**2)).solve(solver=_cp.MOSEK)

    _ORIG = _cp.Problem.solve
    def _solve_mosek_default(self, *args, **kwargs):
        if "solver" not in kwargs or kwargs["solver"] is None:
            kwargs["solver"] = _cp.MOSEK
        return _ORIG(self, *args, **kwargs)
    _cp.Problem.solve = _solve_mosek_default


# initializer: executed when each child process starts
def _init_mosek_env():
    import os
    LIC1 = os.path.expanduser("~/.mosek/mosek.lic")
    LIC2 = "/root/mosek/mosek.lic"
    os.environ["MOSEKLM_LICENSE_FILE"] = f"{LIC1}:{LIC2}"
    # Force MOSEK to load the license in the child process
    import cvxpy as _cp
    # Optional: solve a tiny problem to make sure MOSEK works in the child
    y = _cp.Variable()
    _cp.Problem(_cp.Minimize((y - 1)**2)).solve(solver=_cp.MOSEK)

# # Later, when running Parallel, include the initializer
# from joblib import Parallel, delayed
# all_results = Parallel(n_jobs=2, verbose=10, initializer=_init_mosek_env)(
#     delayed(run_experiment_for_split)(*args) for args in tasks
# )

# !unzip data-open.zip

"""## 1. Import Dependencies

"""

import os

import numpy as np

import pandas as pd

from random import Random

from pathlib import Path
# Import Path class for cross-platform file path operations

from imblearn.over_sampling import SMOTE
# Import SMOTE for oversampling to handle class imbalance

from sklearn.metrics import roc_curve, roc_auc_score
# Import ROC curve and AUC score metrics for model evaluation

from sklearn.ensemble import RandomForestClassifier
# Import RandomForestClassifier for classification modeling

from matplotlib import rcParams
# Import rcParams for global matplotlib settings

import matplotlib.pyplot as plt

import seaborn as sns
# Import seaborn for enhanced statistical visualization

import sys
# Import sys module for interacting with the Python interpreter

# =========================
# Relative path with auto-discovery (privacy-safe)
# Target: starting at current Notebook dir (e.g., .../Experiment/Notebooks),
#         automatically find sibling data-open/ as project root
# =========================

def locate_project_root(start: Path | None = None, target_dir: str = "data-open", max_depth: int = 6) -> Path:
    """
    Recursively search upwards for a folder named target_dir; if not found,
    fall back to ../data-open relative to current working directory.
    """
    p = (start or Path.cwd()).resolve()
    for _ in range(max_depth + 1):
        candidate = (p / target_dir).resolve()
        if candidate.exists() and candidate.is_dir():
            return candidate
        if p.parent == p:
            break
        p = p.parent
    return (Path.cwd() / ".." / target_dir).resolve()

# Locate sanitized project root (data-open)
project_root = locate_project_root()

# Optional: robust check with friendly message (no local path leakage)
if not project_root.exists():
    raise FileNotFoundError(
        f"[Path error] Could not find 'data-open' near this notebook. "
        f"Expected at: {project_root}"
    )

# Append subdirectories to sys.path so custom modules can be imported
hancock_root = project_root / "HANCOCK_MultimodalDataset-main"
hancock_explore = hancock_root / "data_exploration"

for pth in (hancock_explore, hancock_root):
    p_str = str(pth)
    if pth.exists() and p_str not in sys.path:
        sys.path.append(p_str)

from umap_embedding import setup_preprocessing_pipeline, get_umap_embedding

# Import functions from custom module umap_embedding:
# - setup_preprocessing_pipeline: preprocessing pipeline for data
# - get_umap_embedding: obtain UMAP embeddings

"""## 2. Define the Optimal Random Forest Model

"""

def return_optimal_random_forest(target: str, data_split: str, random_state=np.random.RandomState(42)):
    # Define function return_optimal_random_forest, which returns the optimal RandomForest model
    # depending on the prediction target (target) and the dataset split (data_split)

    if target == 'recurrence':
        # If the prediction target is "recurrence"

        if data_split == "In distribution":
            # If dataset split is "In distribution"
            return RandomForestClassifier(
                n_estimators=1600, min_samples_split=2, min_samples_leaf=1,
                max_leaf_nodes=1000, max_features='log2', max_depth=30,
                criterion='gini', random_state=random_state
            )
            # Return a tuned RandomForestClassifier with Gini impurity, max depth 30, and 1600 trees

        elif data_split == "Oropharynx":
            # If dataset split is "Oropharynx"
            return RandomForestClassifier(
                n_estimators=1200, min_samples_split=2, min_samples_leaf=1,
                max_leaf_nodes=1000, max_features='sqrt', max_depth=20,
                criterion='gini', random_state=random_state
            )
            # Return a RandomForest with 1200 trees, max depth 20, and sqrt feature selection

        elif data_split == "Out of distribution":
            # If dataset split is "Out of distribution"
            return RandomForestClassifier(
                n_estimators=800, min_samples_split=2, min_samples_leaf=1,
                max_leaf_nodes=100, max_features='sqrt', max_depth=80,
                criterion='log_loss', random_state=random_state
            )
            # Return a RandomForest with log_loss criterion, 800 trees, and max depth 80

    elif target == 'survival_status':
        # If the prediction target is "survival_status"

        if data_split == "In distribution":
            # If dataset split is "In distribution"
            return RandomForestClassifier(
                n_estimators=1000, min_samples_split=5, min_samples_leaf=1,
                max_leaf_nodes=1000, max_features='log2',
                criterion='entropy', random_state=random_state
            )
            # Return a RandomForest using entropy, with 1000 trees and min_samples_split=5

        else:  # "Oropharynx" / "Out of distribution"
            # Otherwise (covers both "Oropharynx" and "Out of distribution")
            return RandomForestClassifier(
                n_estimators=1200, min_samples_split=2, min_samples_leaf=1,
                max_leaf_nodes=1000, max_features='sqrt', max_depth=20,
                criterion='gini', random_state=random_state
            )
            # Return a RandomForest with 1200 trees, max depth 20, and Gini criterion

    # Raise error if target or data_split is not recognized
    raise KeyError(f'Target {target} or data split {data_split} not recognized')

"""## 3. Load Multimodal Features

"""

# Manually set data directories

data_dir = project_root / "HANCOCK_MultimodalDataset-main" / "features"
# Define the directory for HANCOCK dataset features

split_dir = project_root / "Hancock_Dataset" / "DataSplits_DataDictionaries"
# Define the directory for dataset splits and dictionaries (train/test splits and metadata)

# results_dir = project_root / "results"
# # Define the output directory for saving results and visualizations

# results_dir.mkdir(exist_ok=True)
# # Create results directory if it does not exist; do nothing if it already exists
# Create timestamped output directory for CSC/MAHTI
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = Path("outputs3") / f"{timestamp}"
# Define the output directory for saving results and visualizations

results_dir.mkdir(parents=True, exist_ok=True)
# Create results directory if it does not exist; do nothing if it already exists

# Save run configuration to text file (will be written after variables are defined)

target = "recurrence"
# Set the prediction target variable: either recurrence or survival_status

rng = np.random.RandomState(42)
# rng = np.random.default_rng(42)   # Generator
# Define random number generator with seed=42 to ensure reproducibility

clinical = pd.read_csv(data_dir/"clinical.csv", dtype={"patient_id": str})
patho = pd.read_csv(data_dir/"pathological.csv", dtype={"patient_id": str})
blood = pd.read_csv(data_dir/"blood.csv", dtype={"patient_id": str})
icd = pd.read_csv(data_dir/"icd_codes.csv", dtype={"patient_id": str})
cell_density = pd.read_csv(data_dir/"tma_cell_density.csv", dtype={"patient_id": str})

df = clinical.merge(patho, on="patient_id", how="outer")
df = df.merge(blood, on="patient_id", how="outer")
df = df.merge(icd, on="patient_id", how="outer")
df = df.merge(cell_density, on="patient_id", how="outer")
df = df.reset_index(drop=True)

print("Merged feature shape:", df.shape)

# show data
df.head()

"""## 4. Load Data Splits + UMAP Dimensionality Reduction

"""

x_linspace = np.linspace(0, 1, 100)
# Generate 100 evenly spaced points in the interval [0, 1] (for plotting or interpolation)

rcParams.update({"font.size": 6})
# Update matplotlib global font size to 6 for compact plots

rcParams["svg.fonttype"] = "none"
# Set SVG font type to 'none' to avoid font outlines and keep text editable

umap_embeddings = get_umap_embedding(data_dir, umap_min_dist=0.1, umap_n_neighbors=15)

# Call custom function get_umap_embedding:
# - Input: data_dir (dataset directory)
# - Parameter umap_min_dist=0.1: controls clustering density in UMAP
# - Parameter umap_n_neighbors=15: controls local neighborhood size
# Output: low-dimensional UMAP embeddings for visualization

data_split_paths = [
    split_dir/"dataset_split_in.json",
    split_dir/"dataset_split_out.json",
    split_dir/"dataset_split_Oropharynx.json"
]

# Define paths to three dataset split JSON files:
# - dataset_split_in.json: in-distribution split
# - dataset_split_out.json: out-of-distribution split
# - dataset_split_Oropharynx.json: oropharynx-specific split

data_split_labels = ["In distribution", "Out of distribution", "Oropharynx"]
# Labels corresponding to dataset splits, used for plotting or reporting results

"""## 5. Define the Optimal Logistic Regression Model

"""

# Import LogisticRegression from scikit-learn to train a linear classifier
from sklearn.linear_model import LogisticRegression


# Define a factory function that returns an "optimal" model instance based on model type, target, and data split
def return_optimal_model(model_type: str, target: str, data_split: str, random_state=np.random.RandomState(42)):
    # Docstring: Return the chosen optimal model (Logistic Regression or Random Forest)
    """
    Return the chosen optimal model (Logistic Regression / Random Forest).
    """

    # If model_type is "random_forest", delegate to the optimal Random Forest builder
    if model_type == "random_forest":
        # Return the optimal Random Forest (requires an external return_optimal_random_forest definition)
        return return_optimal_random_forest(target, data_split, random_state)

    # If model_type is "logistic_regression", construct and return a configured LogisticRegression
    elif model_type == "logistic_regression":
        # Use the "liblinear" solver (good for small/medium data with L2), max_iter=1000, class_weight balanced, fixed RNG state
        return LogisticRegression(
            solver="liblinear",
            max_iter=1000,
            class_weight="balanced",
            random_state=random_state
        )

    # Otherwise raise KeyError to signal unsupported model type and fail fast on misconfiguration
    else:
        raise KeyError(f"Unsupported model_type={model_type}")

"""## 6. Define Wasserstein-DRO-Logistic Regression

"""

# Import LogisticRegression from scikit-learn.linear_model for training logistic regression models
from sklearn.linear_model import LogisticRegression

# Import RandomForestClassifier from scikit-learn.ensemble for training random forest models
from sklearn.ensemble import RandomForestClassifier

# Import NumPy library for numerical computations and random number handling
import numpy as np
 
# Final import of WassersteinDRO class for distributionally robust optimization in logistic regression
from dro.linear_model.wasserstein_dro import WassersteinDRO

# ============================================================
# WDRO Logistic Regression Wrapper (compatible with sklearn API)
# ============================================================
class WDROLogisticWrapper:
    """
    A wrapper for Wasserstein-DRO Logistic Regression so that it behaves like sklearn models.
             Automatically converts {0,1} labels into {-1,1}.
    """

    # Constructor to initialize WDRO hyperparameters, solver, and random seed
    def __init__(self, eps=0.05, kappa=1.0, solver="MOSEK", random_state=None):
        self.eps = eps                  # Wasserstein ball radius
        self.kappa = kappa              # Regularization parameter
        # self.solver = solver            # Optimization solver (default: MOSEK)
        self.random_state = random_state  # Random seed
        self.model = None               # Internal WDRO model instance

        # 规范化 solver
        if isinstance(solver, str):
            s = solver.strip().upper()
            if s == "MOSEK":
                solver = cp.MOSEK
            elif s == "ECOS":
                solver = cp.ECOS
            elif s == "SCS":
                solver = cp.SCS
            else:
                raise ValueError(f"Unknown solver string: {solver}")
        self.solver = solver

    # Fit function to train the WDRO logistic regression model
    def fit(self, X, y):
        # Set random seed to ensure reproducibility
        if self.random_state is not None:
            np.random.seed(self.random_state if isinstance(self.random_state, int) else 42)

        # Convert labels {0,1} into {-1,1} as required by WDRO
        y_wdro = np.where(y == 0, -1, 1)

        # Create WDRO model instance (logistic type, using specified solver)
        self.model = WassersteinDRO(
            input_dim=X.shape[1],
            model_type="logistic",
            solver=self.solver
        )

        # Update WDRO parameters (eps and kappa)
        self.model.update({'eps': self.eps, 'kappa': self.kappa})

        # # Fit the WDRO model
        # self.model.fit(X, y_wdro)

        # If dro.WassersteinDRO.fit supports the solver= argument, use it:
        try:
            self.model.fit(X, y_wdro, solver=self.solver)
        except TypeError:
            # Some versions do not accept solver=, so set the attribute first and then call fit 
            if hasattr(self.model, "solver"):
                self.model.solver = self.solver
            self.model.fit(X, y_wdro)
        
        return self

    # Predict probability function, outputs probabilities for each class
    def predict_proba(self, X):
        # WDRO's predict method returns logits (margin values)
        logits = self.model.predict(X)

        # Apply logistic sigmoid to convert logits into positive class probabilities
        probs_pos = 1 / (1 + np.exp(-logits))

        # Stack into [negative class prob, positive class prob], transpose to sklearn format
        probs = np.vstack([1 - probs_pos, probs_pos]).T
        return probs

    # Predict class labels using threshold 0.5, output 0/1 labels
    def predict(self, X):
        return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)
 


"""## 7. Define Wasserstein-DRO-MRO-Logistic Regression

"""

# Import NumPy for numerical computing
import numpy as np
 
from sklearn.linear_model import LogisticRegression



import numpy as np
import cvxpy as cp


class WDROMROLogisticGame:
    """
    WDRO--MRO Logistic Regression (Dual-Game Hybrid, single-λ), aligned with the paper.

    - Common λ_t for both learner and oracle.
    - Per-sample dual envelopes s_i(·, λ_t); nature uses Δ_i = s_i(w_t,λ_t) - s_i(w'_t,λ_t).
    - Learner/Oracle: minimize  λ_t * ρ + Σ_i π_{t+1}(i) * s_i(·, λ_t).
    - Radius-dual update: λ_{t+1} ← Π_[0,λ_max](λ_t + η_λ (ρ - \hat ρ_t)).

    Implementable envelope (p=2 upper bound):
        s_i(w,λ) ≤ ℓ(y_i, x_i^T w) + (1/(4λ)) * Σ_k ||w_k||_2^2 / α_k
    which yields a convex learner/oracle subproblem.

    Parameters
    ----------
    p : int
        Wasserstein norm parameter; this implementation supports p=2 (others raise NotImplementedError).
    alphas : array-like or None
        Modality weights α_k > 0. If None, default = np.ones(K). By default each feature is its own group.
    group_slices : list[(start,end)] or None
        Optional feature blocks. If None, default = per-feature groups [(0,1), (1,2), ..., (d-1,d)].
        If provided but length mismatches alphas, we auto-resize alphas (no assertion).
    rho : float
        Wasserstein radius.
    eta : float
        Step size for nature's exponentiated-weights.
    eta_lambda : float
        Step size for λ update.
    T : int
        Number of dual-game iterations.
    solver : str
        CVXPY solver, e.g. "MOSEK", "ECOS", "OSQP", "SCS".
    lam_init : float
        Initial λ.
    lam_max : float
        Projection upper bound for λ (keep large; just prevents blow-up).
    seed : int
        RNG seed (not heavily used here).
    """

    def __init__(self,
                 p=2,
                 alphas=None,
                 group_slices=None,
                 rho=1.0,
                 eta=0.1,
                 eta_lambda=0.1,
                 T=50,
                 solver="ECOS",
                 lam_init=1.0,
                 lam_max=1e6,
                 seed=0,
                 verbose=False):
        self.p = int(p)
        self.alphas = alphas
        self.group_slices = group_slices
        self.rho = float(rho)
        self.eta = float(eta)
        self.eta_lambda = float(eta_lambda)
        self.T = int(T)
        self.solver = solver
        self.lam_init = float(lam_init)
        self.lam_max = float(lam_max)
        self.rng = np.random.default_rng(seed)
        self.verbose = verbose

        # Learned parameter (after fit): averaged predictor
        self.w_ = None

    # ----------------- helpers -----------------
    @staticmethod
    def _logistic_loss_vec(y, Xw):
        # stable log(1+exp(-y*Xw))
        z = -y * Xw
        return np.where(z > 0, z + np.log1p(np.exp(-z)), np.log1p(np.exp(z)))

    @staticmethod
    def _softmax_update(prev, delta, eta):
        # mean-center for numerical stability, then exponentiated weights
        d = delta - np.mean(delta)
        w = prev * np.exp(eta * d)
        s = w.sum()
        return w / s if s > 0 else np.full_like(prev, 1.0 / len(prev))

    def _make_groups(self, d):
        """
        Build (group_slices, alphas_vec) with robust defaults:
        - default groups: per-feature [(0,1),(1,2),...,(d-1,d)]
        - if alphas is None → ones(K)
        - if alphas length != K → auto-resize (truncate/repeat) to match K
        """
        if self.group_slices is None:
            group_slices = [(j, j + 1) for j in range(d)]  # per-feature
        else:
            # group_slices = list(self.group_slices)
            # Convert groups (lists of indices) to slice tuples
            group_slices = []
            for group in self.group_slices:
                if isinstance(group, (list, tuple, np.ndarray)) and len(group) > 0:
                    # Convert list of indices to (min, max+1) slice
                    group_array = np.array(group)
                    group_slices.append((int(np.min(group_array)), int(np.max(group_array)) + 1))
                elif isinstance(group, tuple) and len(group) == 2:
                    # Already a slice tuple
                    group_slices.append(group)
                else:
                    # Empty group, skip
                    continue

        K = len(group_slices)

        if self.alphas is None:
            alphas = np.ones(K, dtype=float)
        else:
            a = np.asarray(self.alphas, dtype=float).ravel()
            if a.size == 0:
                alphas = np.ones(K, dtype=float)
            elif a.size == 1:
                alphas = np.full(K, float(a[0]))
            else:
                # auto-resize to K (no assertions)
                alphas = np.resize(a, K).astype(float)

        # strictly positive to avoid divide-by-zero
        alphas = np.maximum(alphas, 1e-12)
        return group_slices, alphas

    @staticmethod
    def _group_norm2_sq(w, group_slices):
        return [float(np.dot(w[s:e], w[s:e])) for (s, e) in group_slices]
    
    # --------- envelope for p=2 (implementable upper bound) ---------
    def _envelope_p2(self, x_i, y_i, w, lam, group_slices, alphas):
        lam_eff = max(float(lam), 1e-12)
        loss = self._logistic_loss_vec(np.array([y_i]), np.array([x_i @ w]))[0]
        group_sq = self._group_norm2_sq(w, group_slices)
        reg = sum(gs / a for gs, a in zip(group_sq, alphas)) / (4.0 * lam_eff)
        return loss + reg

    # ----------------- main API -----------------
    def fit(self, X, y):
        """
        Train with the dual-game hybrid (single-λ) specialized to logistic (paper's Algorithm).
        X: (n, d) ndarray, y ∈ {+1, -1}.
        """
        if self.p != 2:
            raise NotImplementedError("This implementation currently supports p=2 only.")

        X = np.asarray(X, dtype=float)
        y = np.asarray(y, dtype=float)
        n, d = X.shape

        if self.verbose:
            print(f"WDROMROLogisticGame.fit: n={n}, d={d}, T={self.T}, ρ={self.rho}")

        group_slices, alphas = self._make_groups(d)
        
        if self.verbose:
            print(f"Groups: {len(group_slices)} groups, alphas={alphas}")

        # initialize learner/oracle weights and λ
        w = np.zeros(d)        # learner
        w_b = np.zeros(d)      # oracle/baseline
        lam = float(self.lam_init)
        pi = np.ones(n) / n

        # store iterates for averaging
        Ws = []

        # cvxpy variable (re-used)
        w_var = cp.Variable(d)

        for t in range(self.T):
            # ---- Dual envelopes at current (w, λ) and (w_b, λ) ----
            s_w = np.array([self._envelope_p2(X[i], y[i], w,   lam, group_slices, alphas) for i in range(n)])
            s_b = np.array([self._envelope_p2(X[i], y[i], w_b, lam, group_slices, alphas) for i in range(n)])

            # ---- Nature (adversary) update with relative regret ----
            pi = self._softmax_update(pi, s_w - s_b, self.eta)

            # ---- Learner best-response at (pi, λ) ----
            lam_eff = max(lam, 1e-12)
            logistic_terms = cp.sum([pi[i] * cp.logistic(-y[i] * (X[i] @ w_var)) for i in range(n)])
            reg_terms = 0
            for (s, e), a in zip(group_slices, alphas):
                reg_terms += (1.0 / (4.0 * lam_eff * a)) * cp.sum_squares(w_var[s:e])
            obj = lam * self.rho + logistic_terms + reg_terms
            cp.Problem(cp.Minimize(obj)).solve(solver=self.solver, verbose=False)
            w = w_var.value

            # ---- Oracle (benchmark) best-response at same (pi, λ) ----
            # same convex objective; we re-solve to emulate "oracle" update
            cp.Problem(cp.Minimize(obj)).solve(solver=self.solver, verbose=False)
            w_b = w_var.value.copy()

            # ---- Radius-dual update using envelope subgradient approximation ----
            # \hat ρ ≈ (1/(4 λ^2)) * Σ_k ||w_k||^2 / α_k  (learner's new w)
            group_sq = self._group_norm2_sq(w, group_slices)
            rho_hat = sum(gs / a for gs, a in zip(group_sq, alphas)) / (4.0 * lam_eff * lam_eff)
            lam_old = lam
            lam = float(np.clip(lam + self.eta_lambda * (self.rho - rho_hat), 0.0, self.lam_max))

            # ---- Print interaction process information ---- 
            if self.verbose and (t % 5 == 0): # print every 2 iters
                print(f"t={t:3d}: λ={lam:.4f} (Δ={lam-lam_old:+.4f}), "
                      f"ρ̂={rho_hat:.4f}, ρ={self.rho:.4f}, "
                      f"‖w‖₂={np.linalg.norm(w):.4f}, "
                      f"π_entropy={-np.sum(pi * np.log(pi + 1e-12)):.4f}")

            Ws.append(w.copy())

        # averaged predictor (can replace with tail-averaging)
        self.w_ = np.mean(Ws, axis=0)
        
        if self.verbose:
            print(f"Training completed. Final ‖w‖₂={np.linalg.norm(self.w_):.4f}")
            
        return self

    def predict_proba(self, X):
        X = np.asarray(X, dtype=float)
        logits = X @ self.w_
        p1 = 1.0 / (1.0 + np.exp(-logits))
        return np.vstack([1 - p1, p1]).T

    def predict(self, X):
        return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)



# Import NumPy for numerical computing
import numpy as np

# Import RandomForestClassifier from scikit-learn
from sklearn.ensemble import RandomForestClassifier

# Import LogisticRegression from scikit-learn
from sklearn.linear_model import LogisticRegression


# Safely extract an integer seed from RandomState/Generator/int (some sklearn models only accept int seeds)
def _seed_from(random_state):
    """
       Convert random_state into an int (return None if not extractable)
    """
    if isinstance(random_state, (int, np.integer)):
        return int(random_state)
    # Handle numpy Generator
    if hasattr(random_state, "integers"):  # np.random.Generator
        return int(random_state.integers(0, 2**31 - 1))
    # Handle numpy RandomState
    if hasattr(random_state, "randint"):   # np.random.RandomState
        return int(random_state.randint(0, 2**31 - 1))
    return None


# Factory function to return the corresponding "optimal model" instance based on model_type
def return_optimal_model(
    model_type: str,                  # Model type identifier
    target: str,                      # Target variable name (unused here)
    data_split: str,                  # Data split identifier (unused here)
    random_state=np.random.RandomState(42),  #  Random state generator
    # WDRO-specific parameters ---
    wdro_eps: float = 0.05,
    wdro_kappa: float = 1.0,
    wdro_solver: str = "MOSEK",
    # WDRO-MRO-specific parameters ---
    wdro_mro_groups=None,
    wdro_mro_alphas=None,
    wdro_mro_gamma: float = 0.5,
    wdro_mro_solver: str = "MOSEK",
    # WDRO-MRO Game-specific parameters ---
    wdro_mro_p: float = 2.0,
    wdro_mro_eta: float = 0.1,
    wdro_mro_eta_lambda: float = 0.1,
    # wdro_mro_T: int = 1000,
    wdro_mro_T: int = 200,
):
    """
    Return an optimal model instance; supported types:
      - "random_forest"        → Random Forest
      - "logistic_regression"  → Logistic Regression
      - "wdro_logistic"        → Wasserstein DRO Logistic Regression
      - "wdro_mro_logistic"    → Wasserstein DRO-MRO Logistic Regression

    Notes:
      * LogisticRegression only accepts int seeds, handled via _seed_from()
      * WDRO/WDRO-MRO solvers are upper-cased for easier checks and fallbacks
    """
    seed_int = _seed_from(random_state)              # Extract int random seed
    wdro_solver = (wdro_solver or "MOSEK").upper()   # nsure solver name is uppercase
    wdro_mro_solver = (wdro_mro_solver or "MOSEK").upper()

    if model_type == "random_forest":
        # RandomForest can accept int/RandomState/Generator, pass as-is
        return RandomForestClassifier(
            n_estimators=200,
            max_depth=None,
            class_weight="balanced",
            random_state=random_state,
        )

    elif model_type == "logistic_regression":
        # Logistic Regression only accepts int random_state, pass seed_int
        return LogisticRegression(
            solver="liblinear",
            max_iter=1000,
            class_weight="balanced",
            random_state=seed_int,
        )

    elif model_type == "wdro_logistic":
        # Return custom WDROLogisticWrapper (Wasserstein DRO logistic regression)
        return WDROLogisticWrapper(
            eps=wdro_eps,
            kappa=wdro_kappa,
            solver=wdro_solver,
            random_state=random_state,
        )

    elif model_type == "wdro_mro_logistic":
        # Return custom WDROMROLogistic (Wasserstein DRO logistic regression with MRO regularization)
        return WDROMROLogistic(
            groups=wdro_mro_groups,
            alphas=wdro_mro_alphas,
            gamma=wdro_mro_gamma,
            solver=wdro_mro_solver,
            random_state=random_state,
        )
    
    elif model_type == "wdro_mro_logistic_game":
        # Oracle-free Dual-Game Solver (Algorithm~\ref{alg:wdro-mro-logistic})
        return WDROMROLogisticGame(
            p=wdro_mro_p,
            group_slices=wdro_mro_groups,
            alphas=wdro_mro_alphas,
            rho=wdro_eps,
            eta=wdro_mro_eta,
            eta_lambda=wdro_mro_eta_lambda,
            T=wdro_mro_T,
            solver=wdro_mro_solver,
        )



    else:
        # Unsupported model type, raise an error
        raise KeyError(f"Unsupported model_type={model_type}")

"""## build modality groups from hancock"""

import re
from collections import defaultdict

def _strip_ohe_suffix(base_name: str, raw_cols: set) -> str:
    """
    Process column names after OneHot encoding: e.g., smoking_status_3 -> smoking_status
    Iteratively remove the last segment after '_' until it matches a raw column or can no longer be stripped.
    """
    if base_name in raw_cols:
        return base_name
    parts = base_name.split("_")
    while len(parts) > 1:
        parts = parts[:-1]
        cand = "_".join(parts)
        if cand in raw_cols:
            return cand
    return base_name  # Fallback: return as-is if not found 

def build_modalities_groups_from_hancock(preprocessor,
                                         clinical_df,
                                         pathological_df,
                                         blood_df,
                                         icd_df,
                                         tma_df):
    """
    Build a stable modalities→groups mapping based on the column sets of the five source CSVs from HANCOCK.  
    Returns:
    modalities: List of modalities filtered in a fixed order
    groups:     Index list aligned with modalities (for WMR/noise injection)
    audit_df:   Audit table (each output feature → modality/original column)
    """ 
    fns = list(preprocessor.get_feature_names_out())

    # 1) Five sets of raw columns (excluding patient_id) 
    raw_cols = {
        "clinical":     set(c for c in clinical_df.columns     if c != "patient_id"),
        "pathological": set(c for c in pathological_df.columns if c != "patient_id"),
        "blood":        set(c for c in blood_df.columns        if c != "patient_id"),
        "icd":          set(c for c in icd_df.columns          if c != "patient_id"),
        "tma":          set(c for c in tma_df.columns          if c != "patient_id"),
    }

    # 2) Build a complete set for One-Hot reverse mapping 
    all_raw = set().union(*raw_cols.values())

    # 3) Output feature → modality mapping and audit information 
    idx_by_modality = defaultdict(list)
    audit_rows = []

    for j, fn in enumerate(fns):
        # Parse prefix and "raw column name candidate"
        if "__" in fn:
            prefix, base = fn.split("__", 1)
        else:
            prefix, base = "", fn  # Non-standard case

        # If it is clearly an ICD code (e.g., c020/c021), assign directly to icd 
        if base in raw_cols["icd"]:
            mod = "icd"
            raw_col = base
        else:
            # One-Hot reverse mapping: try removing category suffix to recover the original column
            base_maybe = _strip_ohe_suffix(base, all_raw)

            # Look up group membership among the five categories
            mod = None
            for m in ["clinical","pathological","blood","icd","tma"]:
                if base_maybe in raw_cols[m]:
                    mod = m
                    raw_col = base_maybe
                    break

            # Fallback strategy: if still not matched, apply heuristics
            if mod is None:
                # Loose match for ICD codes (e.g., 'c020')
                if re.fullmatch(r"[cC]\d{3}", base):
                    mod = "icd"; raw_col = base
                # If it contains typical hematology units/keywords, assign to blood
                elif any(k in base.lower() for k in ["blood", "erythro", "hemoglobin",
                                                      "hematocrit", "platelet", "leukocyte"]):
                    mod = "blood"; raw_col = base_maybe
                # If it contains typical pathology markers, assign to pathological 
                elif any(k in base.lower() for k in ["invasion", "carcinoma", "pn", "pn_", "lv_", "v_"]):
                    mod = "pathological"; raw_col = base_maybe
                # If it contains tma/cell_density keywords
                elif "tma" in base.lower() or "cell_density" in base.lower():
                    mod = "tma"; raw_col = base_maybe
                # Otherwise default to clinical (can be changed to "unknown" and handled separately if needed)
                else:
                    mod = "clinical"; raw_col = base_maybe

        idx_by_modality[mod].append(j)
        audit_rows.append({
            "out_idx": j,
            "out_name": fn,
            "modality": mod,
            "raw_col": raw_col
        })

    # 4) Fix output order; keep only non-empty ones
    desired_order = ["clinical","pathological","blood","icd","tma"]
    modalities, groups = [], []
    for m in desired_order:
        if len(idx_by_modality[m]) > 0:
            modalities.append(m)
            groups.append(sorted(idx_by_modality[m]))

    # 5) Audit table
    import pandas as pd
    audit_df = pd.DataFrame(audit_rows).sort_values(["modality","out_idx"]).reset_index(drop=True)

    print("=== Modalities → #features ===")
    for m, g in zip(modalities, groups):
        print(f"{m:12s}: {len(g)}")
    return modalities, groups, audit_df
 
 
"""## 8. Add Noise, Train & Evaluate AUC & Robust AUC

### 8.1. Training
"""

def get_groups_from_preprocessor(preprocessor, modalities=None):
    """
    Generate WDRO-MRO feature group indices from the output column names of the preprocessor.

    Parameters
    ----
    preprocessor : ColumnTransformer / Pipeline
        A fitted preprocessor, must support .get_feature_names_out()
    modalities : list[str] or None
        List of modality prefixes (e.g., ["clinical", "patho", "blood", "icd", "tma"])
                 If None, automatically infer from prefixes in feature_names.

    Returns
    ----
    groups : list[list[int]]
        A list of feature index groups, one per modality
    """
    # Get all feature names from the preprocessor output
    feature_names = preprocessor.get_feature_names_out()
    groups = []

    if modalities is None:
        # If no modality list is provided, infer prefixes automatically
        #          e.g., ColumnTransformer may generate "clinical__age", prefix is "clinical"
        modalities = list(set([f.split("__")[0] for f in feature_names]))

    # Iterate over each modality and collect the corresponding feature indices
    for mod in modalities:
        idx = [j for j, f in enumerate(feature_names) if f.startswith(mod + "__")]
        if idx:  # Skip empty modalities
            groups.append(idx)

    # Return list of modality feature index groups
    return groups
 

import re
import numpy as np
from collections import defaultdict


# --- RNG compatibility layer ---
def _rng_random(rng, size=None):
    if hasattr(rng, "random"):
        return rng.random(size=size)
    return rng.random_sample(size=size)

def _rng_integers(rng, low, high=None, size=None):
    # Generator.integers(low, high, size)
    # RandomState.randint(low, high, size)
    if hasattr(rng, "integers"):
        return rng.integers(low, high, size=size)
    if high is None:
        return rng.randint(low, size=size)
    return rng.randint(low, high, size=size)

def _rng_choice(rng, a, size=None, replace=True):
    return rng.choice(a, size=size, replace=replace)


_ICD_ENC_PAT = re.compile(r"^encoded__([cdrtCDRT]\d{2,})$")

def _ohe_base(name: str) -> str:
    """
    Restore the base field name from 'categorical__smoking_status_3' → 'smoking_status';
    if it is already a base name or cannot be stripped, return the longest prefix that matches.
    (Here it is only used for grouping, not for validating against the raw column set.)
    """
    # Remove prefix 
    base = name.split("__", 1)[1] if "__" in name else name
    # For names like xxx_123 / xxx_foo_3 → remove the last segment 
    parts = base.split("_")
    if len(parts) > 1:
        return "_".join(parts[:-1])
    return base

def _split_group_by_type(feature_names, g_indices):
    """
    Further subdivide the columns within a group into:
    - numeric_cols: list[int]
    - encoded_icd_cols: list[int]     (encoded__c/d/r/t***)
    - encoded_bin_cols: list[int]     (encoded__*, but not ICD above)
    - categorical_blocks: dict[str -> list[int]]  One-hot sub-blocks grouped by base field name
    """ 
    numeric_cols = []
    encoded_icd_cols = []
    encoded_bin_cols = []
    categorical_blocks = defaultdict(list)

    for j in g_indices:
        fn = feature_names[j]
        if fn.startswith("numeric__"):
            numeric_cols.append(j)

        elif fn.startswith("categorical__"):
            base = _ohe_base(fn)
            categorical_blocks[base].append(j)

        elif fn.startswith("encoded__"):
            # Distinguish ICD vs non-ICD
            if _ICD_ENC_PAT.match(fn):
                encoded_icd_cols.append(j)
            else:
                encoded_bin_cols.append(j)

        else:
            # Rare cases: conservatively treat as binary (single-bit flip)
            encoded_bin_cols.append(j)

    # Convention: sort indices within each block for reproducibility 
    categorical_blocks = {k: sorted(v) for k, v in categorical_blocks.items()}
    return numeric_cols, encoded_icd_cols, encoded_bin_cols, categorical_blocks


def inject_group_noise(
    X,
    groups,
    feature_names,
    noise_rate,
    rng,
    *,
    sigma_numeric=0.10,         # Noise scale for continuous values (blood/path/tma/clinical)
    sigma_tma=None,             # Special scale for tma; if None, use sigma_numeric
    flip_prob_encoded=1.0,      # Flip probability for encoded binary; if chosen, the bit is always flipped
    keep_one_hot_for_icd=True,  # Enforce one-hot constraint for ICD
    keep_one_hot_for_categorical=True,  # Enforce one-hot constraint for categorical
):
    """
    Inject semantically appropriate noise into each group by distinguishing subtypes based on column names.
    - For each group g:
        1) Select n_noisy = floor(noise_rate * n_samples) sample indices idx_samples
        2) For each selected sample:
            * numeric: add N(0, sigma) noise to all numeric columns in the group
            * categorical: pick one one-hot sub-block and randomly change category (keep one-hot)
            * encoded-ICD: randomly change encoding across the entire ICD group (keep one-hot)
            * encoded-binary: randomly pick one column and flip 0/1
    Notes:
    - If a subtype subset is empty, skip it
    - For tma, the noise scale can be set by sigma_tma (defaults to same as numeric)
    Returns: X_noisy (deep copy)
    """ 
    X_noisy = X.copy()
    n_samples = X.shape[0]
    n_noisy = int(noise_rate * n_samples)
    if n_noisy <= 0:
        return X_noisy

    # Roughly detect tma columns (to set a separate sigma) 
    is_tma_col = np.array([("numeric__cd3" in fn) or ("numeric__cd8" in fn) or ("tma" in fn.lower())
                           for fn in feature_names])
    if sigma_tma is None:
        sigma_tma = sigma_numeric

    for g in groups:
        if len(g) == 0:
            continue
        idx_samples = _rng_choice(rng, n_samples, size=n_noisy, replace=False)

        # Further split within the group 
        numeric_cols, encoded_icd_cols, encoded_bin_cols, categorical_blocks = _split_group_by_type(feature_names, g)

        # ------- 1) numeric: add noise to the whole block (vectorized) ------- 
        if numeric_cols:
            # If all numeric are TMA, use sigma_tma; if mixed, use sigma_numeric for all
            cols = np.array(numeric_cols, dtype=int)
            if np.all(is_tma_col[cols]):
                sigma = sigma_tma
            else:
                sigma = sigma_numeric
            X_noisy[np.ix_(idx_samples, cols)] += rng.normal(0.0, sigma, size=(len(idx_samples), len(cols)))

        # ------- 2) categorical: pick one sub-block and change the category ------- 
        if keep_one_hot_for_categorical and categorical_blocks:
            block_keys = list(categorical_blocks.keys())
            for i in idx_samples:
                # Randomly choose one field (one one-hot sub-block)
                k = block_keys[_rng_integers(rng, 0, len(block_keys))]
                block = categorical_blocks[k]
                # Set one entry to 1 within this field and set the rest to 0
                j_new = block[_rng_integers(rng, 0, len(block))]
                X_noisy[i, block] = 0
                X_noisy[i, j_new] = 1

        # ------- 3) ICD (encoded__c/d/r/t***): keep group one-hot ------- 
        if keep_one_hot_for_icd and encoded_icd_cols:
            for i in idx_samples:
                # j_new = encoded_icd_cols[rng.integers(0, len(encoded_icd_cols))]
                j_new = encoded_icd_cols[_rng_integers(rng, 0, len(encoded_icd_cols))]
                X_noisy[i, encoded_icd_cols] = 0
                X_noisy[i, j_new] = 1

        # ------- 4) encoded binary (non-ICD): randomly pick 1 bit to flip ------- 
        if encoded_bin_cols and flip_prob_encoded > 0:
            for i in idx_samples:
                # j = encoded_bin_cols[rng.integers(0, len(encoded_bin_cols))]
                j = encoded_bin_cols[_rng_integers(rng, 0, len(encoded_bin_cols))]
                if rng.random() < flip_prob_encoded:
                    X_noisy[i, j] = 1 - X_noisy[i, j]

    return X_noisy

from joblib import Parallel, delayed
from imblearn.over_sampling import SMOTE
from sklearn.metrics import roc_auc_score
import random, os

# ==================================================
# Global random seed
# ==================================================

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)

rng = np.random.RandomState(SEED)

# ==================================================
# Configuration
# ==================================================
 
model_keys_str = os.environ.get("MODEL_KEYS", "LR,WDRO,WDRO_MRO_GAME")
model_keys = [x.strip() for x in model_keys_str.split(",")]

wdro_eps = float(os.environ.get('WDRO_EPS', 0.05))
noise_rates = os.environ.get("NOISE_RATES", "0.0,0.1,0.2,0.3,0.4,0.5")
noise_rates = [float(x) for x in noise_rates.split(",")]
 
wdro_mro_T_value = int(os.environ.get('WDRO_MRO_T_VALUE', 20))
repeated_iterations = int(os.environ.get('REPEATED_SEED_ITERATION', 5))

tpr_dict = {nr: {m: [[] for _ in range(len(data_split_paths))] for m in model_keys}
            for nr in noise_rates}
auc_dict = {nr: {m: [[] for _ in range(len(data_split_paths))] for m in model_keys}
            for nr in noise_rates}

# Container for group-level AUC (initialized later when groups are available)

group_auc_dict = {nr: {m: [] for m in model_keys} for nr in noise_rates}

# NOISE_AFTER_SMOTE = False

def run_experiment_for_split(i, noise_rate, iteration, seed,
                             X_train_clean, y_train_clean,
                             X_test_clean, y_test,
                             w_ref, b_ref, groups, feature_names, rng):
    """
    Train all models under one split + noise_rate + iteration, and return a results dictionary
    """
    print(f"  Task: Split {i}, Noise {noise_rate:.1f}, Iteration {iteration}, Seed {seed}")

    results = {"auc": {}, "group_auc": {}}
    for m in model_keys:
        results["auc"][m] = None
        results["group_auc"][m] = [[] for _ in range(len(groups))]

    rng = np.random.RandomState(seed)
    # rng = np.random.default_rng(seed)   # Generator

    # --- noisy copy ---
    X_tr, y_tr = X_train_clean.copy(), y_train_clean.copy()
    if noise_rate > 0:
        n_samples = len(y_tr)
        n_noisy = int(noise_rate * n_samples)
        if n_noisy > 0:
            flip_idx = rng.choice(n_samples, n_noisy, replace=False)
            y_tr[flip_idx] = 1 - y_tr[flip_idx]
        X_tr = inject_group_noise(X_tr, groups, feature_names, noise_rate, rng)
    X_tr, y_tr = SMOTE(random_state=rng).fit_resample(X_tr, y_tr)

    # (1) Logistic Regression
    print(f"    Training LR model...")
    lr_model = return_optimal_model("logistic_regression", target, data_split_labels[i], rng)
    lr_model.fit(X_tr, y_tr)
    y_pred_lr = lr_model.predict_proba(X_test_clean)[:, 1]
    results["auc"]["LR"] = roc_auc_score(y_test, y_pred_lr)
    print(f"    LR AUC: {results['auc']['LR']:.4f}")

    # (3) WDRO Logistic Regression
    print(f"    Training WDRO model...")
    wdro_model = return_optimal_model("wdro_logistic", target, data_split_labels[i], rng)
    wdro_model.fit(X_tr, y_tr)
    y_pred_wdro = wdro_model.predict_proba(X_test_clean)[:, 1]
    results["auc"]["WDRO"] = roc_auc_score(y_test, y_pred_wdro)
    print(f"    WDRO AUC: {results['auc']['WDRO']:.4f}")

    # (3) WDRO-MRO Logistic (Oracle-free Dual Game)
    print(f"    Training WDRO-MRO-GAME model (T={wdro_mro_T_value} iterations)...") 
    wdro_mro_game_model = return_optimal_model(
        "wdro_mro_logistic_game", target, data_split_labels[i], rng,
        wdro_mro_groups=groups,                # if groups are provided 
        wdro_mro_alphas=[1.0]*len(groups),     # or pass desired α_k
        wdro_mro_p=2,                          # Wasserstein ℓ_p
        wdro_eps=wdro_eps,                     # radius ρ
        wdro_mro_eta=0.01,                     # primal step size η
        wdro_mro_eta_lambda=0.005,             # dual step size η_λ
        wdro_mro_T=wdro_mro_T_value,           # number of iterations
        wdro_mro_solver="MOSEK",               # inner solver
    )

    wdro_mro_game_model.fit(X_tr, y_tr)

    y_pred_wdro_mro_game = wdro_mro_game_model.predict_proba(X_test_clean)[:, 1]
    results["auc"]["WDRO_MRO_GAME"] = roc_auc_score(y_test, y_pred_wdro_mro_game)
    # ==================================================
    # group-level AUC
    # ==================================================
    for g_idx, g in enumerate(groups):
        Xg, yg = X_train_clean.copy(), y_train_clean.copy()
        if noise_rate > 0:
            Xg = inject_group_noise(Xg, [g], feature_names, noise_rate, rng)
        Xg, yg = SMOTE(random_state=rng).fit_resample(Xg, yg)

        for model_name, model in [
            ("LR", return_optimal_model("logistic_regression", target, data_split_labels[i], rng)),
            ("WDRO", return_optimal_model("wdro_logistic", target, data_split_labels[i], rng)),
            ("WDRO_MRO_GAME", return_optimal_model(
                "wdro_mro_logistic_game", target, data_split_labels[i], rng,
                wdro_mro_groups=groups,
                wdro_mro_alphas=[1.0]*len(groups),
                wdro_mro_p=2,
                wdro_mro_eta=0.01,
                wdro_mro_eta_lambda=0.005,
                wdro_mro_T=wdro_mro_T_value,           # iteration timesteps
                wdro_mro_solver="MOSEK")),
        ]:
            try:
                model.fit(Xg, yg)
                y_hat = model.predict_proba(X_test_clean)[:, 1]
                auc = roc_auc_score(y_test, y_hat)
            except Exception:
                auc = np.nan
            results["group_auc"][model_name][g_idx].append(auc)

    auc_strs = [f"{m}={results['auc'][m]:.4f}" for m in model_keys if results['auc'][m] is not None]
    print(f"    Task completed - AUCs: {', '.join(auc_strs)}")
    return (i, noise_rate, iteration, results)




import numpy as np
import scipy.sparse as sp

def audit_standardization(preprocessor, X_train_clean, X_test_clean, tol_mean=1e-3, tol_std=0.05):
    """
    Check whether numeric features are standardized (mean≈0, std≈1), 
    while avoiding misclassification of one-hot columns.
    tol_mean: allowed threshold for mean deviation (absolute value)
    tol_std : allowed threshold for relative std error (|std-1| <= tol_std)
    """
    feat_names = np.array(preprocessor.get_feature_names_out())
    # Identify numeric columns by prefix (consistent with your ColumnTransformer naming) 
    numeric_mask = np.char.startswith(feat_names.astype('U'), "numeric__")
    num_idx = np.where(numeric_mask)[0]

    if len(num_idx) == 0:
        print("[Audit] No columns with prefix numeric__ detected, check your preprocessing naming.")
        return

    def _col_stats(X):
        if sp.issparse(X):
            # Sparse: compute mean per column
            means = np.array(X.mean(axis=0)).ravel()
            # compute std：E[X^2] - (E[X])^2
            X_sq = X.copy()
            X_sq.data **= 2
            ex2 = np.array(X_sq.mean(axis=0)).ravel()
            stds = np.sqrt(np.maximum(ex2 - means**2, 0.0))
        else:
            means = X.mean(axis=0)
            stds = X.std(axis=0, ddof=0)
        return means, stds

    # only check numeric columns
    Xtr_num = X_train_clean[:, num_idx]
    Xte_num = X_test_clean[:, num_idx]

    m_tr, s_tr = _col_stats(Xtr_num)
    m_te, s_te = _col_stats(Xte_num)

    bad_mean_tr = np.where(np.abs(m_tr) > tol_mean)[0]
    bad_std_tr  = np.where(np.abs(s_tr - 1.0) > tol_std)[0]

    print(f"[Audit] Number of numeric columns: {len(num_idx)}")
    print(f"[Train] Columns with non-zero mean: {len(bad_mean_tr)} / {len(num_idx)}")
    print(f"[Train] Columns with non-unit variance: {len(bad_std_tr)} / {len(num_idx)}")

    # If need specific column names, print the first 10
    if len(bad_mean_tr):
        print(" Example (mean issue):", feat_names[num_idx[bad_mean_tr]][:10])
    if len(bad_std_tr):
        print("  Example (variance issue):", feat_names[num_idx[bad_std_tr]][:10])

    # Also check: does the test set distribution look reasonable?
    # (After transformation with the training scaler, test means are not necessarily 0,
    #  but they should not be extreme)    
    print(f"[Test ] Median |mean| of numeric columns: {np.median(np.abs(m_te)):.4f}")
    print(f"[Test ] Median std of numeric columns: {np.median(s_te):.4f}")


# ==================================================
# Parallel execution
# ==================================================

# Reinitialize containers (aligned with splits)
group_auc_dict = {nr: {m: [] for m in model_keys} for nr in noise_rates}


tasks = []
for i in range(len(data_split_paths)):
    # …Here prepare X_train_clean, y_train_clean, X_test_clean, y_test, groups, feature_names

    print(f"Training and testing models on {data_split_labels[i]} data...")

    df_split = pd.read_json(data_split_paths[i], dtype={"patient_id": str})[["patient_id", "dataset"]]
    df_targets = pd.read_csv(data_dir/"targets.csv", dtype={"patient_id": str})
    df_split = df_split.merge(df_targets, on="patient_id", how="inner")

    # === Target processing ===

    if target == "recurrence":
        df_split = df_split[
            ((df_split.recurrence=="yes") & (df_split.days_to_recurrence <= 365*3)) |
            ((df_split.recurrence=="no") & ((df_split.days_to_last_information > 365*3) |
                                            (df_split.survival_status=="living")))]
        df_split["recurrence"] = df_split["recurrence"].replace({"no": 0, "yes": 1}).astype("int8")
    elif target == "survival_status":
        df_split = df_split[~(df_split.survival_status_with_cause=="deceased not tumor specific")]
        df_split["survival_status"] = df_split["survival_status"].replace({"living": 0, "deceased": 1}).astype("int8")

    df_train = df_split[df_split.dataset=="training"][["patient_id", target]].copy()
    df_test  = df_split[df_split.dataset=="test"][["patient_id", target]].copy()
    df_train.columns, df_test.columns = ["patient_id","target"], ["patient_id","target"]
    df_train = df_train.merge(df, on="patient_id", how="inner")
    df_test  = df_test.merge(df, on="patient_id", how="inner")

    # ---------- Clean preprocessing to obtain "clean" features ----------

    preprocessor = setup_preprocessing_pipeline(df_train.columns[2:])
    X_train_clean = preprocessor.fit_transform(df_train.drop(["patient_id","target"], axis=1))
    X_test_clean  = preprocessor.transform(df_test.drop(["patient_id","target"], axis=1))
    y_train_clean = df_train["target"].to_numpy().astype(np.int8)
    y_test        = df_test["target"].to_numpy().astype(np.int8)
    # print(f"X_train_clean: {X_train_clean}")
    # print(f"X_test_clean: {X_test_clean}")
    audit_standardization(preprocessor, X_train_clean, X_test_clean)
    print("preprocessor:", preprocessor)  # Quick look at the structure
    print(f"preprocessor.named_transformers_:{preprocessor.named_transformers_}")   # Inspect each sub-transformer


    feature_names = preprocessor.get_feature_names_out()
    # groups = get_groups_from_preprocessor(preprocessor)

    modalities, groups_modality, audit_df = build_modalities_groups_from_hancock(
        preprocessor,
        clinical_df=clinical,
        pathological_df=patho,
        blood_df=blood,
        icd_df=icd,
        tma_df=cell_density
    )

    # use mew modalities/groups
    modalities, groups = modalities, groups_modality  

    # —— Initialize the structure of group_auc_dict for the current split (each model: one list per group)

    for nr in noise_rates:
        if len(group_auc_dict[nr][model_keys[0]]) <= i:
            for m in model_keys:
                group_auc_dict[nr][m].append([[] for _ in range(len(groups))])

    # ---------- Anchor: strictly train once on the clean training set ----------

    anchor_lr = LogisticRegression(solver="liblinear", max_iter=2000, class_weight="balanced",
                                   random_state=_seed_from(rng)  if isinstance(rng, int) else None)
    anchor_lr.fit(X_train_clean, y_train_clean)
    w_ref = anchor_lr.coef_.ravel()
    b_ref = float(anchor_lr.intercept_.ravel()[0])
    for noise_rate in noise_rates:
        for iteration in range(repeated_iterations):
            seed = 42 + 7919*i + 101*int(noise_rate*100) + iteration
            tasks.append((i, noise_rate, iteration, seed,
                          X_train_clean, y_train_clean, X_test_clean, y_test,
                          w_ref, b_ref, groups, feature_names, rng))

print(f"\n=== Starting Parallel Execution ===")
print(f"Total tasks: {len(tasks)}")
print(f"WDRO-MRO-GAME iterations per task: {wdro_mro_T_value}")
print(f"Estimated time: {len(tasks) * wdro_mro_T_value * 0.1:.1f} seconds (rough estimate)")
print(f"Using {os.cpu_count()} CPU cores")
print("=" * 50)

all_results = Parallel(n_jobs=-1, verbose=10, initializer=mosek_child_initializer)(
    delayed(run_experiment_for_split)(*args) for args in tasks
)

# Print WDROMROLogisticGame T value and save folder information
print(f"\n=== WDROMROLogisticGame Configuration ===")
print(f"T value (iteration rounds): {wdro_mro_T_value}")
print(f"Save folder: {results_dir}")
print(f"Timestamp: {timestamp}")
print(f"Full path: {results_dir.absolute()}")
print(f"Note: Convergence monitoring is enabled during training")
print(f"      - Progress printed every 20 iterations")
print(f"      - Convergence checked after 10 iterations")
print(f"      - Convergence threshold: 1e-6")


# ==================================================
# Collect results
# ==================================================

for i, noise_rate, iteration, res in all_results:
    for m in model_keys:
        auc_dict[noise_rate][m][i].append(res["auc"][m])
        for g_idx, g_res in enumerate(res["group_auc"][m]):
            group_auc_dict[noise_rate][m][i][g_idx].extend(g_res)


print("\n" + "="*80)
print("AUC_DICT STRUCTURE AND CONTENT")
print("="*80)

print(f"\nAUC_DICT structure:")
print(f"- Outer keys (noise_rates): {list(auc_dict.keys())}")
print(f"- Middle keys (model_keys): {list(auc_dict[0.0].keys()) if 0.0 in auc_dict else 'N/A'}")
print(f"- Inner keys (data_splits): {len(auc_dict[0.0]['LR']) if 0.0 in auc_dict and 'LR' in auc_dict[0.0] else 'N/A'}")

print(f"\nAUC_DICT sample content (first few values):")
for nr in sorted(auc_dict.keys())[:5]:  # only show first five noise rates
    print(f"\n--- Noise Rate: {nr} ---")
    for m in model_keys:
        print(f"  {m}:")
        for split_idx in range(len(auc_dict[nr][m])):
            values = auc_dict[nr][m][split_idx]
            print(f"    Split {split_idx}: {len(values)} values, mean={np.mean(values):.4f} (range: {np.min(values):.4f}-{np.max(values):.4f})" if values else f"    Split {split_idx}: empty")


def plot_auc_dict_visualization(auc_dict, model_keys, data_split_labels, noise_rates, save_dir=None, server_mode=True):
    """
    绘制 auc_dict 的可视化图表
    """
    import matplotlib
    if server_mode:
        # 服务器模式：使用非交互式后端
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    # 设置绘图风格
    plt.style.use('default')
    sns.set_palette("husl")
    
    # 创建保存目录
    if save_dir is None:
        save_dir = results_dir
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    
    # 1. 准备数据 - 转换为长格式DataFrame
    plot_data = []
    for nr in noise_rates:
        for m in model_keys:
            for split_idx, split_name in enumerate(data_split_labels):
                values = auc_dict[nr][m][split_idx]
                if values:
                    for val in values:
                        plot_data.append({
                            'noise_rate': nr,
                            'model': m,
                            'split': split_name,
                            'auc': val
                        })
    
    df_plot = pd.DataFrame(plot_data)
    
    # 2. 绘制箱线图 - 按模型和噪声率
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('AUC Distribution by Model and Noise Rate', fontsize=16)
    
    # 2.1 所有模型在不同噪声率下的AUC分布
    ax1 = axes[0, 0]
    # sns.boxplot(data=df_plot, x='noise_rate', y='auc', hue='model', ax=ax1)
    
    colors = ['#F5D2D2', '#F8F7BA', '#BDE3C3']  # 浅粉、浅黄、浅绿
    patterns = ['///', '\\\\\\', '']  # 斜线、反斜线、实心
        
    # 准备数据
    noise_rates_unique = sorted(df_plot['noise_rate'].unique())
    models = df_plot['model'].unique()
    
    # 为每个噪声率创建箱线图
    positions = []
    data_for_plot = []
    labels = []
    
    for i, nr in enumerate(noise_rates_unique):
        for j, model in enumerate(models):
            model_data = df_plot[(df_plot['noise_rate'] == nr) & (df_plot['model'] == model)]['auc'].values
            if len(model_data) > 0:
                data_for_plot.append(model_data)
                positions.append(i + j * 0.25)  # 每个噪声率内模型间距
                labels.append(f'{nr}_{model}')
    
    # 绘制箱线图
    bp = ax1.boxplot(data_for_plot, positions=positions, patch_artist=True, 
                     widths=0.2, showfliers=True)
    
    # 设置颜色和图案
    for i, (patch, label) in enumerate(zip(bp['boxes'], labels)):
        model_idx = i % len(models)
        patch.set_facecolor(colors[model_idx])
        patch.set_hatch(patterns[model_idx])
        patch.set_edgecolor('black')
        patch.set_linewidth(1.5)
        patch.set_alpha(0.8)
    
    # 设置x轴标签
    ax1.set_xticks(range(len(noise_rates_unique)))
    ax1.set_xticklabels([f'ρ={nr}' for nr in noise_rates_unique])
    
    # 添加图例（底部水平平铺）
    from matplotlib.patches import Patch
    model_display = {"WDRO_MRO_GAME": "WDRO_MRO"}
    legend_elements = [Patch(facecolor=colors[i], hatch=patterns[i], 
                            label=model_display.get(model, model), edgecolor='black') 
                      for i, model in enumerate(models)]
    ax1.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, -0.10), ncol=len(models), frameon=True)
    
    # ax1.set_title('AUC Distribution Across Data Splits')
    ax1.set_xlabel('Noise Rate (ρ)')
    ax1.set_ylabel('AUC')
    
    # 单独保存左上角的箱线图
    fig_box = plt.figure(figsize=(10, 6))
    ax_box = fig_box.add_subplot(111)

    colors = ['#F5D2D2', '#F8F7BA', '#BDE3C3']  # 浅粉、浅黄、浅绿
    patterns = ['///', '\\\\\\', '']  # 斜线、反斜线、实心
    
    # 准备数据
    noise_rates_unique = sorted(df_plot['noise_rate'].unique())
    models = df_plot['model'].unique()
    
    # 为每个噪声率创建箱线图
    positions = []
    data_for_plot = []
    
    for i, nr in enumerate(noise_rates_unique):
        for j, model in enumerate(models):
            model_data = df_plot[(df_plot['noise_rate'] == nr) & (df_plot['model'] == model)]['auc'].values
            if len(model_data) > 0:
                data_for_plot.append(model_data)
                positions.append(i + j * 0.25)  # 每个噪声率内模型间距
    
    # 绘制箱线图
    bp = ax_box.boxplot(data_for_plot, positions=positions, patch_artist=True, 
                       widths=0.2, showfliers=True)
    
    # 设置颜色和图案
    for i, patch in enumerate(bp['boxes']):
        model_idx = i % len(models)
        patch.set_facecolor(colors[model_idx])
        patch.set_hatch(patterns[model_idx])
        patch.set_edgecolor('black')
        patch.set_linewidth(1.5)
        patch.set_alpha(0.8)
    
    # 设置x轴标签
    ax_box.set_xticks(range(len(noise_rates_unique)))
    ax_box.set_xticklabels([f'ρ={nr}' for nr in noise_rates_unique])
    
    # 添加图例（底部水平平铺）
    from matplotlib.patches import Patch
    model_display = {"WDRO_MRO_GAME": "WDRO_MRO"}
    legend_elements = [Patch(facecolor=colors[i], hatch=patterns[i], 
                            label=model_display.get(model, model), edgecolor='black') 
                      for i, model in enumerate(models)]
    ax_box.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, -0.10), ncol=len(models), frameon=True)


    # plt.title('AUC Distribution Across Data Splits', fontsize=14)
    plt.xlabel('Noise Rate (ρ)', fontsize=12)
    plt.ylabel('AUC', fontsize=12)
    # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(save_dir / 'auc_distribution_boxplot.pdf', dpi=300, bbox_inches='tight')
    if not server_mode:
        plt.show()
    else:
        plt.close()
    print(f"✓ auc_distribution_boxplot.pdf: 单独保存的箱线图")
    
    # 2.2 每个数据分割的AUC趋势
    ax2 = axes[0, 1]
    for split in data_split_labels:
        split_data = df_plot[df_plot['split'] == split]
        mean_auc = split_data.groupby(['noise_rate', 'model'])['auc'].mean().reset_index()
        for model in model_keys:
            model_data = mean_auc[mean_auc['model'] == model]
            if len(model_data) > 0:
                ax2.plot(model_data['noise_rate'], model_data['auc'], 
                        marker='o', label=f'{model} ({split})', linewidth=2)
    ax2.set_title('Mean AUC Trends Across Noise Rates')
    ax2.set_xlabel('Noise Rate (ρ)')
    ax2.set_ylabel('Mean AUC')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax2.grid(True, alpha=0.3)
    
    # 2.3 模型性能热力图
    ax3 = axes[1, 0]
    pivot_data = df_plot.groupby(['model', 'noise_rate'])['auc'].mean().unstack()
    # sns.heatmap(pivot_data, annot=True, fmt='.3f', cmap='RdYlBu_r', ax=ax3)
    # ax3.set_title('Mean AUC Heatmap (Model vs Noise Rate)')
    # ax3.set_xlabel('Noise Rate (ρ)')
    # ax3.set_ylabel('Model')
    im = sns.heatmap(pivot_data, annot=True, fmt='.3f', cmap='RdYlBu_r', ax=ax3,
                     annot_kws={'size': 10})
    ax3.set_title('Mean AUC Heatmap (Model vs Noise Rate)', fontsize=14)
    ax3.set_xlabel('Noise Rate (ρ)', fontsize=12)
    ax3.set_ylabel('Model', fontsize=12)
    ax3.tick_params(axis='both', labelsize=10)
    # Map model display names on y-axis
    _disp_map = {"WDRO_MRO_GAME": "WDRO_MRO"}
    ax3.set_yticklabels([_disp_map.get(t.get_text(), t.get_text()) for t in ax3.get_yticklabels()])
    try:
        cbar = im.collections[0].colorbar
        cbar.ax.tick_params(labelsize=10)
        cbar.set_label('Mean AUC', fontsize=12)
    except Exception:
        pass
    
    # 2.4 噪声敏感性分析
    ax4 = axes[1, 1]
    sensitivity_data = []
    for model in model_keys:
        for split in data_split_labels:
            model_split_data = df_plot[(df_plot['model'] == model) & (df_plot['split'] == split)]
            if len(model_split_data) > 0:
                clean_auc = model_split_data[model_split_data['noise_rate'] == 0.0]['auc'].mean()
                max_noise_auc = model_split_data[model_split_data['noise_rate'] == 0.5]['auc'].mean()
                if not np.isnan(clean_auc) and not np.isnan(max_noise_auc):
                    sensitivity = clean_auc - max_noise_auc
                    sensitivity_data.append({
                        'model': model,
                        'split': split,
                        'sensitivity': sensitivity
                    })
    
    if sensitivity_data:
        sens_df = pd.DataFrame(sensitivity_data)
        sns.barplot(data=sens_df, x='model', y='sensitivity', hue='split', ax=ax4)
        ax4.set_title('Noise Sensitivity (Clean AUC - Max Noise AUC)')
        ax4.set_xlabel('Model')
        ax4.set_ylabel('AUC Drop')
        ax4.legend(title='Split')
    
    # 3. 绘制每个数据分割的详细图表
    for split_idx, split_name in enumerate(data_split_labels):
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        # fig.suptitle(f'AUC Analysis for {split_name}', fontsize=16)
        
        # 3.1 该分割下的箱线图
        split_data = df_plot[df_plot['split'] == split_name]
        # sns.boxplot(data=split_data, x='noise_rate', y='auc', hue='model', ax=ax)
        
        colors = ['#F5D2D2', '#F8F7BA', '#BDE3C3']  
        patterns = ['///', '\\\\\\', '']  # 斜线、反斜线、实心
        
        noise_rates_unique = sorted(split_data['noise_rate'].unique())
        models = split_data['model'].unique()
        
        # 为每个噪声率创建箱线图
        positions = []
        data_for_plot = []
        
        for i, nr in enumerate(noise_rates_unique):
            for j, model in enumerate(models):
                model_data = split_data[(split_data['noise_rate'] == nr) & (split_data['model'] == model)]['auc'].values
                if len(model_data) > 0:
                    data_for_plot.append(model_data)
                    positions.append(i + j * 0.25)  # 每个噪声率内模型间距
        
        # 绘制箱线图
        bp = ax.boxplot(data_for_plot, positions=positions, patch_artist=True, 
                       widths=0.2, showfliers=True)
        
        # 设置颜色和图案
        for i, patch in enumerate(bp['boxes']):
            model_idx = i % len(models)
            patch.set_facecolor(colors[model_idx])
            patch.set_hatch(patterns[model_idx])
            patch.set_edgecolor('black')
            patch.set_linewidth(1.5)
            patch.set_alpha(0.8)
        
        # 设置x轴标签
        ax.set_xticks(range(len(noise_rates_unique)))
        ax.set_xticklabels([f'ρ={nr}' for nr in noise_rates_unique])
        
        # 添加图例（底部水平平铺）
        from matplotlib.patches import Patch
        model_display = {"WDRO_MRO_GAME": "WDRO_MRO"}
        legend_elements = [Patch(facecolor=colors[i], hatch=patterns[i], 
                                label=model_display.get(model, model), edgecolor='black') 
                          for i, model in enumerate(models)]
        ax.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, -0.10), ncol=len(models), frameon=True)

        # ax.set_title(f'AUC Distribution in {split_name}')
        ax.set_xlabel('Noise Rate (ρ)')
        ax.set_ylabel(f"AUC({split_name.lower()})")
        # ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        plt.tight_layout()
        plt.savefig(save_dir / f'auc_analysis_{split_name.lower().replace(" ", "_")}.pdf', 
                   dpi=300, bbox_inches='tight')
        if not server_mode:
            plt.show()
        else:
            plt.close()  # 服务器模式：关闭图形以释放内存
    
    # 4. 保存数据到CSV
    df_plot.to_csv(save_dir / 'auc_dict_plot_data.csv', index=False)
    
    # 5. 创建汇总统计表
    summary_stats = df_plot.groupby(['model', 'split', 'noise_rate'])['auc'].agg([
        'count', 'mean', 'std', 'min', 'max'
    ]).round(4)
    summary_stats.to_csv(save_dir / 'auc_dict_summary_stats.csv')
    
    print(f"\n图表已保存到: {save_dir}")
    print(f"- auc_dict_visualization.pdf: 总体可视化")
    print(f"- auc_analysis_*.pdf: 各分割详细分析")
    print(f"- auc_dict_plot_data.csv: 绘图数据")
    print(f"- auc_dict_summary_stats.csv: 汇总统计")
    
    return df_plot, summary_stats


def create_auc_dict_heatmap(auc_dict, model_keys, data_split_labels, noise_rates, save_dir=None, server_mode=True):
    """
    创建AUC热力图
    """
    import matplotlib
    if server_mode:
        # 服务器模式：使用非交互式后端
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    if save_dir is None:
        save_dir = results_dir
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    
    # 创建热力图数据
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    # fig.suptitle('AUC Heatmaps by Data Split', fontsize=16)
    
    for split_idx, split_name in enumerate(data_split_labels):
        # 准备热力图数据
        heatmap_data = []
        for model in model_keys:
            row = []
            for nr in noise_rates:
                values = auc_dict[nr][model][split_idx]
                mean_auc = np.mean(values) if values else np.nan
                row.append(mean_auc)
            heatmap_data.append(row)
        
        # 创建DataFrame
        heatmap_df = pd.DataFrame(heatmap_data, 
                                index=model_keys, 
                                columns=[f'ρ={nr}' for nr in noise_rates])
        
        # 绘制热力图
        im = sns.heatmap(heatmap_df, annot=True, fmt='.3f', cmap='RdYlBu_r', 
                   ax=axes[split_idx], cbar_kws={'label': 'Mean AUC'}, annot_kws={'size': 14})
        axes[split_idx].set_title(f'{split_name}', fontsize=14)
        axes[split_idx].set_xlabel('Noise Rate', fontsize=14)
        axes[split_idx].set_ylabel('Model', fontsize=14)
        axes[split_idx].tick_params(axis='both', labelsize=14)
        try:
            cbar = im.collections[0].colorbar
            cbar.ax.tick_params(labelsize=14)
            cbar.set_label('Mean AUC', fontsize=14)
        except Exception:
            pass
        
    plt.tight_layout()
    plt.savefig(save_dir / 'auc_heatmaps.pdf', dpi=300, bbox_inches='tight')
    if not server_mode:
        plt.show()
    else:
        plt.close()  # 服务器模式：关闭图形以释放内存
    
    print(f"热力图已保存到: {save_dir / 'auc_heatmaps.pdf'}")

def save_auc_visualizations(auc_dict, model_keys, data_split_labels, noise_rates, save_dir):
    """
    保存所有AUC可视化图表到指定目录
    """
    import os
    from pathlib import Path
    
    # 确保保存目录存在
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*60}")
    print("AUC可视化图表生成和保存")
    print(f"{'='*60}")
    print(f"保存目录: {save_dir.absolute()}")
    
    # 生成图表
    try:
        print("\n1. 生成综合可视化图表...")
        df_plot, summary_stats = plot_auc_dict_visualization(
            auc_dict, model_keys, data_split_labels, noise_rates, 
            save_dir, server_mode=True
        )
        print("   ✓ 综合可视化图表生成完成")
        
        print("\n2. 生成热力图...")
        create_auc_dict_heatmap(
            auc_dict, model_keys, data_split_labels, noise_rates, 
            save_dir, server_mode=True
        )
        print("   ✓ 热力图生成完成")
        
        # 检查保存的文件
        print(f"\n3. 检查保存的文件:")
        expected_files = [
            'auc_dict_visualization.pdf',
            'auc_distribution_boxplot.pdf',
            'auc_heatmaps.pdf',
            'auc_analysis_in_distribution.pdf',
            'auc_analysis_out_of_distribution.pdf', 
            'auc_analysis_oropharynx.pdf',
            'auc_dict_plot_data.csv',
            'auc_dict_summary_stats.csv'
        ]
        
        for filename in expected_files:
            file_path = save_dir / filename
            if file_path.exists():
                file_size = file_path.stat().st_size
                print(f"   ✓ {filename} ({file_size:,} bytes)")
            else:
                print(f"   ✗ {filename} (未找到)")
        
        print(f"\n{'='*60}")
        print("AUC可视化完成！")
        print(f"所有图表已保存到: {save_dir.absolute()}")
        print(f"{'='*60}")
        
        return df_plot, summary_stats
        
    except Exception as e:
        print(f"\n❌ 生成图表时出错: {e}")
        import traceback
        traceback.print_exc()
        return None, None

def save_auc_dict_to_csv(auc_dict, model_keys, data_split_labels, noise_rates, save_dir):
    """
    将 auc_dict 保存为 CSV 文件（只保存原始数据）
    """
    import pandas as pd
    import numpy as np
    from pathlib import Path
    
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*50}")
    print("保存 AUC_DICT 原始数据为 CSV")
    print(f"{'='*50}")
    
    # 保存原始 auc_dict 结构（每个组合一行）
    raw_data = []
    for nr in noise_rates:
        for m in model_keys:
            for split_idx, split_name in enumerate(data_split_labels):
                values = auc_dict[nr][m][split_idx]
                if values:
                    raw_data.append({
                        'noise_rate': nr,
                        'model': m,
                        'split': split_name,
                        'split_idx': split_idx,
                        'auc_values': values,
                        'auc_count': len(values),
                        'auc_mean': np.mean(values),
                        'auc_std': np.std(values),
                        'auc_min': np.min(values),
                        'auc_max': np.max(values)
                    })
    
    df_raw = pd.DataFrame(raw_data)
    df_raw.to_csv(save_dir / 'auc_dict_raw.csv', index=False)
    print(f"✓ auc_dict_raw.csv: 原始结构数据 ({len(df_raw)} 行)")
    
    print(f"\n{'='*50}")
    print("AUC_DICT 数据保存完成！")
    print(f"文件已保存到: {save_dir.absolute()}")
    print(f"{'='*50}")
    
    return df_raw
 

# 调用可视化函数
df_plot, summary_stats = save_auc_visualizations(auc_dict, model_keys, data_split_labels, noise_rates, results_dir)

# 保存 auc_dict 原始数据
df_raw = save_auc_dict_to_csv(auc_dict, model_keys, data_split_labels, noise_rates, results_dir)


"""### 8.2. Comprehensive Evaluation

"""

import numpy as np
import pandas as pd

def compute_all_metrics(
    auc_dict, model_keys,
    data_split_labels,
    noise_rates,
    group_auc_dict=None,
    groups=None,
    alphas=None,
    weighted_mode="max"
):
    """
    Compute WDRO-MRO experiment multi-dimensional metrics:
      - Performance: Average AUC, Std AUC
      - Robustness: Robust AUC, RR-AUC, Worst-Case Drop
      - Stability: Noise Sensitivity (Drop, Slope)
      - Fairness: Group_Fairness_Gap, Group-Noise Robustness
    """

    results = []
    # If groups provided but no alphas, default group weights to 1
    if groups is not None and alphas is None:
        alphas = [1.0] * len(groups)

    # Iterate over each model and split
    for m in model_keys:
        for i, split in enumerate(data_split_labels):

            # ========== Basic AUC ==========
            aucs_per_noise = {nr: auc_dict[nr][m][i] for nr in noise_rates}
            mean_per_noise = {
                nr: np.mean(v) if len(v) > 0 else np.nan
                for nr, v in aucs_per_noise.items()
            }

            # Merge all AUCs across noise rates
            aucs_all = [a for nr in noise_rates for a in auc_dict[nr][m][i] if len(auc_dict[nr][m][i]) > 0]
            avg_auc = np.mean(aucs_all) if len(aucs_all) > 0 else np.nan
            std_auc = np.std(aucs_all) if len(aucs_all) > 0 else np.nan

            # ========== Robustness  ==========
            vals = list(mean_per_noise.values())
            robust_auc = np.nanmin(vals) if np.any(~np.isnan(vals)) else np.nan
            rr_auc = robust_auc / (np.nanmax(vals) + 1e-12) if not np.isnan(robust_auc) else np.nan
            auc_curve = [mean_per_noise[nr] for nr in noise_rates]
            aunc = np.trapz(auc_curve, x=noise_rates) / (noise_rates[-1]-noise_rates[0]) if len(noise_rates) > 1 else np.nan
            worst_case_drop = np.nanmax(vals) - robust_auc if np.any(~np.isnan(vals)) else np.nan
            noise_var = np.nanvar(vals)
            nri = robust_auc / (avg_auc + 1e-12) if not np.isnan(avg_auc) and not np.isnan(robust_auc) else np.nan
            auc_range = np.nanmax(vals) - np.nanmin(vals) if np.any(~np.isnan(vals)) else np.nan

            # ========== Noise Sensitivity  ==========
            auc_vals = [v for v in mean_per_noise.values() if not np.isnan(v)]
            rho_vals = [nr for nr, v in mean_per_noise.items() if not np.isnan(v)]
            noise_drop, noise_slope, noise_rel = np.nan, np.nan, np.nan
            if len(auc_vals) > 1:
                auc_at_clean = mean_per_noise.get(0.0, np.nan)
                if not np.isnan(auc_at_clean) and not np.isnan(robust_auc):
                    noise_drop = auc_at_clean - robust_auc
                    # Drop from clean AUC to robust AUC
                try:
                    noise_slope = np.polyfit(rho_vals, auc_vals, 1)[0]  # Fit slope
                except Exception:
                    noise_slope = np.nan
                noise_rel = (max(auc_vals) - min(auc_vals)) / (max(auc_vals) + 1e-12)  # Relative variation


            # ========== Regret  ==========
            regret_per_noise = []
            for nr in noise_rates:
                # oracle = best AUC among all models at this noise rate
                oracle = max(
                    np.mean(auc_dict[nr][m2][i]) if len(auc_dict[nr][m2][i]) > 0 else np.nan
                    for m2 in model_keys
                )
                regret = oracle - mean_per_noise[nr]
                # Regret = oracle - current model
                regret_per_noise.append(regret)
            max_regret = np.nanmax(regret_per_noise)
            mean_regret = np.nanmean(regret_per_noise)
            max_regret_ratio = max_regret / (np.nanmax(vals) + 1e-12) if np.any(~np.isnan(vals)) else np.nan
            regret_var = np.nanvar(regret_per_noise)

            # ========== Weighted Modality Regret (修正版) ==========
            weighted_regret = np.nan
            if group_auc_dict is not None and groups is not None:
                # 归一化权重（仅对平均型聚合有意义；对 max 型聚合不使用）
                if alphas is None:
                    alphas = [1.0] * len(groups)
                alphas = np.asarray(alphas, dtype=float)
                alpha_sum = np.nansum(alphas) if np.isfinite(alphas).all() else np.nan

                # 先在“噪声维度”聚合到每个模态：r_g
                r_g_list = []
                for g_idx, alpha in enumerate(alphas):
                    # 收集该模态在各噪声点的 regret
                    regrets_g_over_noise = []
                    for nr in noise_rates:
                        # oracle_g(nr): 在该 split、该噪声、该模态上的跨模型最优
                        oracle_g = max(
                            (np.mean(group_auc_dict[nr][m2][i][g_idx])
                            if len(group_auc_dict[nr][m2][i][g_idx]) > 0 else np.nan)
                            for m2 in model_keys
                        )
                        mean_g = (np.mean(group_auc_dict[nr][m][i][g_idx])
                                  if len(group_auc_dict[nr][m][i][g_idx]) > 0 else np.nan)
                        if not np.isnan(oracle_g) and not np.isnan(mean_g):
                            regrets_g_over_noise.append(oracle_g - mean_g)

                    # 若该模态没有有效数据，跳过
                    if len(regrets_g_over_noise) == 0:
                        r_g = np.nan
                    else:
                        # 噪声维度聚合：这里给两种模式可选
                        if weighted_mode == "max":   # 最坏噪声下，该模态的 regret
                            r_g = np.nanmax(regrets_g_over_noise)
                        elif weighted_mode == "mean":  # 平均噪声下，该模态的 regret
                            r_g = np.nanmean(regrets_g_over_noise)
                        else:
                            raise ValueError(f"weighted_mode must be 'max' or 'mean', got {weighted_mode!r}")
                    r_g_list.append((r_g, alpha))

                # 再在“模态维度”聚合：
                # 你可以根据论文需要，二选一：
                USE_WORST_CASE_MODALITY = True   # ← True: 最坏模态；False: 加权平均模态

                valid_pairs = [(r, a) for (r, a) in r_g_list if not np.isnan(r) and not np.isnan(a)]
                if len(valid_pairs) == 0:
                    weighted_regret = np.nan
                else:
                    if USE_WORST_CASE_MODALITY:
                        # 真正的“最坏模态”
                        weighted_regret = np.nanmax([r for (r, a) in valid_pairs])
                    else:
                        # 加权平均（建议对 alphas 归一化）
                        denom = np.nansum([a for (r, a) in valid_pairs])
                        weighted_regret = (np.nansum([a * r for (r, a) in valid_pairs]) /
                                          (denom if denom > 0 else np.nan))

            # ========== Fairness & Group-Noise Robustness  ==========
            gnr_auc, worst_group_gap, group_fairness_gap = np.nan, np.nan, np.nan
            if group_auc_dict is not None and groups is not None:
                gnr_auc = np.inf
                group_means = []
                for g_idx in range(len(groups)):
                    aucs_group = []
                    for nr in noise_rates:
                        if len(group_auc_dict[nr][m][i][g_idx]) > 0:
                            aucs_group.append(np.mean(group_auc_dict[nr][m][i][g_idx]))
                    if aucs_group:
                        group_means.append(np.nanmean(aucs_group))
                        gnr_auc = min(gnr_auc, np.nanmin(aucs_group))
                if gnr_auc == np.inf:
                    gnr_auc = np.nan
                if group_means:
                    worst_group_gap = avg_auc - gnr_auc if not np.isnan(avg_auc) and not np.isnan(gnr_auc) else np.nan
                    group_fairness_gap = np.nanmax(group_means) - np.nanmin(group_means) if len(group_means) > 1 else 0.0

            # ========== Save results ==========
            results.append(dict(
                Model=m, Split=split,
                # Performance
                Average_AUC=avg_auc,
                Std_AUC=std_auc,
                # Robustness
                Robust_AUC=robust_auc,
                RR_AUC=rr_auc,
                AuNC=aunc,
                Worst_Case_Drop=worst_case_drop,
                Noise_Variance=noise_var,
                Noise_Robustness_Index=nri,
                AUC_Range=auc_range,
                # Noise Sensitivity
                Noise_Sensitivity_Drop=noise_drop,
                Noise_Sensitivity_Slope=noise_slope,
                Noise_Sensitivity_Relative=noise_rel,
                # Regret
                Max_Regret=max_regret,
                Mean_Regret=mean_regret,
                Max_Regret_Ratio=max_regret_ratio,
                Regret_Variance=regret_var,
                Weighted_Modality_Regret=weighted_regret,
                # Fairness
                Group_Noise_Robustness=gnr_auc,
                Worst_Group_Gap=worst_group_gap,
                Group_Fairness_Gap=group_fairness_gap,
                Weighted_Mode=weighted_mode
            ))

    # Return results DataFrame for further analysis/visualization
    return pd.DataFrame(results)

# === Compute Weighted Modality Regret using worst-case modality ("max") ===
df_metrics_max = compute_all_metrics(
    auc_dict=auc_dict,                              # Global AUC results dictionary
    model_keys=["LR", "WDRO", "WDRO_MRO_GAME"],          # Evaluate Logistic Regression, WDRO, WDRO-MRO, WDRO-MRO-GAME
    data_split_labels=["In distribution","Out of distribution","Oropharynx"],
                                                    # Split labels (in-distribution, out-of-distribution, Oropharynx subgroup)
    noise_rates=noise_rates,                        # List of noise rates
    group_auc_dict=group_auc_dict,                  # ⭐ Must pass group-level AUC results
    groups=groups,                                  # Groups derived from preprocessor
    alphas=[1.0]*len(groups),                       # Weight for each modality, default = 1
    weighted_mode="max"                             # Weight by worst-case modality
)

# === Compute Weighted Modality Regret using average modality ("mean") ===
df_metrics_mean = compute_all_metrics(
    auc_dict=auc_dict,
    model_keys=["LR", "WDRO", "WDRO_MRO_GAME"],
    data_split_labels=["In distribution","Out of distribution","Oropharynx"],
    noise_rates=noise_rates,
    group_auc_dict=group_auc_dict,
    groups=groups,
    alphas=[1.0]*len(groups),
    weighted_mode="mean"                            # Weight by average modality

)


# Save results to CSV files in timestamped output directory
df_metrics_max.to_csv(results_dir / "df_metrics_max.csv", index=False)
df_metrics_mean.to_csv(results_dir / "df_metrics_mean.csv", index=False)


# ✅ Metrics grouped by category
metrics_by_category = {
    "Performance": [                       # Performance metrics
        "Average_AUC",                     # Average AUC
        "Std_AUC"                          # AUC standard deviation
    ],
    "Robustness": [                        # Robustness metrics
        "Robust_AUC",                      # Robust AUC (worst-case AUC)
        "RR_AUC",                          # Relative Robustness AUC
        "AuNC",                            # Area under Noise Curve
        "Worst_Case_Drop",                 # Worst-case performance drop
        "Noise_Variance",                  # Variance under noise
        "Noise_Robustness_Index",          # Noise Robustness Index (NRI)
        "AUC_Range"                        # AUC range (max - min)
    ],
    "Stability": [                         # Stability metrics
        "Noise_Sensitivity_Drop",          # Noise sensitivity drop
        "Noise_Sensitivity_Slope",         # Noise sensitivity slope
        "Noise_Sensitivity_Relative"       # Noise sensitivity relative change
    ],
    "Regret": [                            # Regret metrics
        "Max_Regret",                      # Maximum regret
        "Mean_Regret",                     # Mean regret
        "Max_Regret_Ratio",                # Max regret ratio
        "Regret_Variance",                 # Regret variance
        "Weighted_Modality_Regret"         # Weighted modality regret
    ],
    "Fairness": [                          # Fairness metrics
        "Group_Noise_Robustness",          # Group-noise robustness
        "Worst_Group_Gap",                 # Worst group gap
        "Group_Fairness_Gap"               # Group fairness gap
    ]
}

# Iterate over each category and print the corresponding metric results
for category, cols in metrics_by_category.items():
    print(f"\n=== {category} ===")                         # Print category title
    print(df_metrics_max[["Model","Split"] + cols]         # Select model, split, and relevant metric columns
          .to_string(index=False))                         # Output as table format without index

import numpy as np
import pandas as pd

def best_by_split(df, maximize_cols, minimize_cols):
    """
    For each Split, find the best model for metrics to maximize and minimize.
    """
    records = []   # List to store results

    # Group by Split and process each group
    for split, g in df.groupby("Split"):

        # --- Metrics to maximize ---
        for col in maximize_cols:
            vals = g[col].copy()
            vals = vals.fillna(-np.inf)   # nan cannot be max, replace with -inf
            idx = vals.idxmax()           # Get index of maximum value
            row = g.loc[idx]              # Retrieve the best row
            records.append([split, col, row["Model"], f"{row[col]:.4f}"])

        # --- Metrics to minimize ---
        for col in minimize_cols:
            vals = g[col].copy()
            vals = vals.fillna(np.inf)    # nan cannot be min, replace with +inf
            idx = vals.idxmin()           # Get index of minimum value
            row = g.loc[idx]              # Retrieve the best row
            records.append([split, col, row["Model"], f"{row[col]:.4f}"])

    # Return results DataFrame with Split, Metric, Best Model, and Value
    return pd.DataFrame(records, columns=["Split","Metric","Best Model","Value"])

# ✅ Metric categorization

maximize_cols = [
    # Performance
    "Average_AUC",                 # Average AUC, higher is better
    # Robustness
    "Robust_AUC",                  # Robust AUC, higher is better
    "RR_AUC",                      # Relative Robust AUC, higher is better
    "AuNC",                        # Area under Noise Curve, higher is better
    "Group_Noise_Robustness",      # Group-noise robustness, higher is better
    "Noise_Robustness_Index"       # Noise Robustness Index (NRI), higher is better
]

# Note: The following metrics are better when smaller

minimize_cols = [
    # Robustness (negative direction)
    "Worst_Case_Drop",             # Worst-case drop, smaller is better
    "Noise_Variance",              # Variance under noise, smaller is better
    "AUC_Range",                   # AUC range (fluctuation), smaller is better
    # Stability
    "Noise_Sensitivity_Drop",      # Noise sensitivity drop, smaller is better
    "Noise_Sensitivity_Slope",     # Noise sensitivity slope, smaller is better
    "Noise_Sensitivity_Relative",  # Noise sensitivity relative change, smaller is better
    # Regret
    "Max_Regret",                  # Maximum regret, smaller is better
    "Mean_Regret",                 # Mean regret, smaller is better
    "Max_Regret_Ratio",            # Max regret ratio, smaller is better
    "Regret_Variance",             # Regret variance, smaller is better
    "Weighted_Modality_Regret",    # Weighted modality regret, smaller is better
    # Fairness
    "Worst_Group_Gap",             # Worst group gap, smaller is better
    "Group_Fairness_Gap"           # Group fairness gap, smaller is better
]

# ✅ Compute best models (for Weighted Mode="max" and "mean")
df_best_split_max  = best_by_split(df_metrics_max, maximize_cols, minimize_cols)
df_best_split_mean = best_by_split(df_metrics_mean, maximize_cols, minimize_cols)

# ✅ Print results
print("=== Weighted Mode: max ===")
print(df_best_split_max.to_string(index=False))   # Print best models table for max mode

print("\n=== Weighted Mode: mean ===")
print(df_best_split_mean.to_string(index=False))  # Print best models table for mean mode

# Save best models tables to CSV files
df_best_split_max.to_csv(results_dir / "df_best_split_max.csv", index=False)
df_best_split_mean.to_csv(results_dir / "df_best_split_mean.csv", index=False)

"""### 8.3 latex"""

import pandas as pd
import numpy as np

# directions up/down for each metric
HIGHER_IS_BETTER = {
    "Average_AUC": True,
    "Robust_AUC": True, "RR_AUC": True, "Worst_Case_Drop": False,
    "Noise_Sensitivity_Drop": False, "Noise_Sensitivity_Slope": False,
    "Max_Regret": False, "Mean_Regret": False,
    "Weighted_Modality_Regret": False,
    "Group_Noise_Robustness": True, "Group_Fairness_Gap": False,
}

# display names for each metric
DISP = {
    "AUC_mean_std": r"Avg $\uparrow$ AUC $\pm$ Std $\downarrow$",        # combined columns
    "Max_Regret": r"Max Regret $\downarrow$",
    "Mean_Regret": r"Mean Regret $\downarrow$",
    "Weighted_Modality_Regret": r"WMR $\downarrow$",
    "Robust_AUC": r"Robust AUC $\uparrow$",
    "RR_AUC": r"RR-AUC $\uparrow$",
    "Worst_Case_Drop": r"W.C. Drop $\downarrow$",
    "Noise_Sensitivity_Drop": r"NS Drop $\downarrow$",
    "Noise_Sensitivity_Slope": r"$\lvert$NS Slope$\rvert$ $\downarrow$",
    "Group_Noise_Robustness": r"GNR $\uparrow$",
    "Group_Fairness_Gap": r"GF Gap $\downarrow$",
}

def _fmt(x, d=3):
    if x is None or (isinstance(x, float) and (np.isnan(x) or np.isinf(x))):
        return "-"
    return f"{x:.{d}f}"

def _rel_imp(v, b, m):
    if b is None or pd.isna(b) or b == 0:
        return None
    up = HIGHER_IS_BETTER.get(m, True)
    return (v - b) / b * 100.0 if up else (b - v) / b * 100.0

def _best_mask(df, col):
    # Special-case: for Noise_Sensitivity_Slope, bold by smallest absolute magnitude per split
    if col == "Noise_Sensitivity_Slope":
        abs_col = df[col].abs()
        agg_abs_min = df.groupby("Split", observed=False)[col].transform(lambda s: s.abs().min())
        return abs_col == agg_abs_min
    up = HIGHER_IS_BETTER.get(col, True)
    agg = df.groupby("Split", observed=False)[col].transform("max" if up else "min")
    return df[col] == agg



def make_table_latex_auc_combined(
    df: pd.DataFrame,
    baseline_key="LR",
    rename_models={
        "LR": r"\raggedright \shortstack[l]{ERM \\ (Logistic)}",
        "WDRO": r"WDRO",
        "WDRO_MRO_GAME": r"\raggedright \shortstack[l]{WDRO-MRO\\ (ours)}",
    },
    model_order_keys=("LR","WDRO","WDRO_MRO_GAME"),
    split_map={"In distribution":"ID", "Out of distribution":"OOD"},
    split_order=("ID","OOD","Oropharynx"),
    caption=None,
    label=None,
    include_rel_improvement: bool = True,
):

    # --- timestamps: 一个用于文本(可含 \_)，一个用于 label(不得含 \ 或空格/冒号等) ---
    ts = timestamp                          # e.g. "20250919_131127" (来自外部)
    ts_tex = ts.replace("_", r"\_")         # 用于 caption 文本
    ts_lbl = ts.replace("_", "-")           # 用于 \label（不要用 \_）

    if caption is None:
        caption = (
            f"Generated on {ts_tex}. "
            "Evaluations on HANCOCK. Upper block reports performance, robustness, and fairness; "
            "lower block reports stability. "
            "Best values (per split, per column) are in \\textbf{bold}. "
        )
        if include_rel_improvement:
            caption += (
                "Numbers in parentheses indicate relative improvement (\\%) over the ERM baseline on the same split."
            )
    if label is None:
        label = f"tab:unified_results_{ts_lbl}"

    df = df.copy()

    # ---------- Split 统一 ----------
    df["Split"] = df["Split"].astype("string").str.strip().map(lambda s: split_map.get(s, s))
    df["Split"] = pd.Categorical(df["Split"], categories=list(split_order), ordered=True)

    # ---------- 模型名重命名 ----------
    orig_to_disp = rename_models
    df["Model"] = df["Model"].replace(orig_to_disp)
    baseline_name = orig_to_disp.get(baseline_key, baseline_key)

    # ---------- 目标模型顺序 ----------
    desired_models = [orig_to_disp.get(k, k) for k in model_order_keys]
    present = set(df["Model"].unique())
    desired_models = [m for m in desired_models if m in present]
    df["Model"] = pd.Categorical(df["Model"], categories=desired_models, ordered=True)
    df = df.sort_values(["Model","Split"]).reset_index(drop=True)


    # ---------- 列集合 ----------
    # 上半部分：Performance + Robustness + Fairness
    TOP_KEYS = [
        "AUC_mean_std",
        "Robust_AUC", "RR_AUC", "Worst_Case_Drop",
        "Group_Noise_Robustness", "Group_Fairness_Gap"
    ]
    # 下半部分：仅 Stability
    BOTTOM_KEYS = [
        "Noise_Sensitivity_Drop","Noise_Sensitivity_Slope"
    ]

    # ---------- 最优 mask ----------
    best_top_masks = {
        "AUC_mean_std": _best_mask(df, "Average_AUC") if "Average_AUC" in df.columns else pd.Series(False, index=df.index),
        "Robust_AUC": _best_mask(df, "Robust_AUC") if "Robust_AUC" in df.columns else pd.Series(False, index=df.index),
        "RR_AUC": _best_mask(df, "RR_AUC") if "RR_AUC" in df.columns else pd.Series(False, index=df.index),
        "Worst_Case_Drop": _best_mask(df, "Worst_Case_Drop") if "Worst_Case_Drop" in df.columns else pd.Series(False, index=df.index),
        "Group_Noise_Robustness": _best_mask(df, "Group_Noise_Robustness") if "Group_Noise_Robustness" in df.columns else pd.Series(False, index=df.index),
        "Group_Fairness_Gap": _best_mask(df, "Group_Fairness_Gap") if "Group_Fairness_Gap" in df.columns else pd.Series(False, index=df.index),
    }
    best_bot_masks = {k: _best_mask(df, k) if k in df.columns else pd.Series(False, index=df.index) for k in BOTTOM_KEYS}


    # ---------- 基线表（按 Split 聚合） ----------
    base_tab = (df[df["Model"]==baseline_name]
                .groupby("Split", observed=False)
                .mean(numeric_only=True))

    # ---------- 单元格格式函数 ----------
    def cell_upper(row, key):
        if key == "AUC_mean_std":
            mu = row.get("Average_AUC", np.nan)
            sd = row.get("Std_AUC", np.nan)
            s = _fmt(mu)
            sp = row["Split"]
            # if (row["Model"] != baseline_name) and (not base_tab.empty) and ("Average_AUC" in base_tab.columns) and (sp in base_tab.index):
            if include_rel_improvement and (row["Model"] != baseline_name) and (not base_tab.empty) and ("Average_AUC" in base_tab.columns) and (sp in base_tab.index):
                b = base_tab.loc[sp, "Average_AUC"]
                if pd.notna(mu) and pd.notna(b):
                    r = _rel_imp(mu, b, "Average_AUC")
                    if r is not None and not pd.isna(r):
                        sign = "+" if r >= 0 else ""
                        s += f" ({sign}{_fmt(r,1)}\\%)"
            if pd.notna(sd):
                s += f" $\\pm${_fmt(sd)}"
            if bool(best_top_masks[key].loc[row.name]):
                s = r"\textbf{" + s + "}"
            return s
        else:
            v = row.get(key, np.nan)
            s = _fmt(v)
            sp = row["Split"]
            # if (row["Model"] != baseline_name) and (not base_tab.empty) and (key in base_tab.columns) and (sp in base_tab.index):
            if include_rel_improvement and (row["Model"] != baseline_name) and (not base_tab.empty) and (key in base_tab.columns) and (sp in base_tab.index):
                b = base_tab.loc[sp, key]
                if pd.notna(v) and pd.notna(b):
                    r = _rel_imp(v, b, key)
                    if r is not None and not pd.isna(r):
                        sign = "+" if r >= 0 else ""
                        s += f" ({sign}{_fmt(r,1)}\\%)"
            if bool(best_top_masks[key].loc[row.name]):
                s = r"\textbf{" + s + "}"
            return s

    def cell_lower(row, key):
        v = row.get(key, np.nan)
        s = _fmt(v)
        sp = row["Split"]
        if include_rel_improvement and (row["Model"] != baseline_name) and (not base_tab.empty) and (key in base_tab.columns) and (sp in base_tab.index):
        # if (row["Model"] != baseline_name) and (not base_tab.empty) and (key in base_tab.columns) and (sp in base_tab.index):
            b = base_tab.loc[sp, key]
            if pd.notna(v) and pd.notna(b):
                r = _rel_imp(v, b, key)
                if r is not None and not pd.isna(r):
                    sign = "+" if r >= 0 else ""
                    s += f" ({sign}{_fmt(r,1)}\\%)"
        if bool(best_bot_masks[key].loc[row.name]):
            s = r"\textbf{" + s + "}"
        return s

    # =======================
    # 表 1：Performance + Robustness
    # =======================
    latex_top = []
    latex_top.append(r"\begin{table}[t]")
    latex_top.append(rf"\caption{{{caption}}}")
    latex_top.append(rf"\label{{{label}}}")
    latex_top.append(r"\resizebox{\textwidth}{!}{")
    latex_top.append(r"\begin{tabular}{ll" + "c"*len(TOP_KEYS) + "}")
    latex_top.append(r"\toprule")
    latex_top.append(r"Model & Split & \multicolumn{1}{c}{\textbf{Performance}} & \multicolumn{3}{c}{\textbf{Robustness}} & \multicolumn{2}{c}{\textbf{Fairness}} \\")
    latex_top.append(r"\cmidrule(lr){3-3}\cmidrule(lr){4-6}\cmidrule(lr){7-8}")
    latex_top.append("& & " + " & ".join([DISP[k] for k in TOP_KEYS]) + r" \\")
    latex_top.append(r"\midrule")

    for m_i, mdl in enumerate(desired_models):
        sub = df[df["Model"]==mdl].sort_values("Split")
        if sub.empty:
            continue
        rows = list(sub.itertuples(index=True))
        for i, r in enumerate(rows):
            cells = [cell_upper(df.loc[r.Index], k) for k in TOP_KEYS]
            if i == 0:
                latex_top.append(rf"\multirow{{{len(rows)}}}{{*}}{{{mdl}}} & {r.Split} & " + " & ".join(cells) + r" \\")
            else:
                latex_top.append(rf"& {r.Split} & " + " & ".join(cells) + r" \\")
        # 只有不是最后一个模型时才加 midrule
        if m_i != len(desired_models) - 1:
            latex_top.append(r"\midrule")

    latex_top.append(r"\bottomrule")
    latex_top.append(r"\end{tabular}")
    latex_top.append(r"}")
    latex_top.append(r"\end{table}")

    # =======================
    # table 2 2：Stability
    # =======================
    latex_bot = []
    latex_bot.append(r"\begin{table}[t]")
    latex_bot.append(rf"\caption{{{caption} (Stability block)}}")
    latex_bot.append(rf"\label{{{label}_bottom}}")
    latex_bot.append(r"\centering")
    latex_bot.append(r"\begin{tabular}{ll" + "c"*len(BOTTOM_KEYS) + "}")
    latex_bot.append(r"\toprule")
    # 分组表头：仅 Stability
    latex_bot.append(
        r"Model & Split & "
        r"\multicolumn{2}{c}{\textbf{Stability}} \\")
    # 列范围：Model(1), Split(2), Stability(3-4)
    latex_bot.append(r"\cmidrule(lr){3-4}")

    headers_bot = [
        DISP["Noise_Sensitivity_Drop"], DISP["Noise_Sensitivity_Slope"]
    ]
    latex_bot.append("& & " + " & ".join(headers_bot) + r" \\")
    latex_bot.append(r"\midrule")

    LOWER_ORDER = [
        "Noise_Sensitivity_Drop","Noise_Sensitivity_Slope"
    ]
 
    for i, mdl in enumerate(desired_models):
        sub = df[df["Model"]==mdl].sort_values("Split")
        if sub.empty:
            continue
        rows = list(sub.itertuples(index=True))
        for j, r in enumerate(rows):
            cells = [cell_lower(df.loc[r.Index], k) for k in LOWER_ORDER]
            if j == 0:
                latex_bot.append(
                    rf"\multirow{{{len(rows)}}}{{*}}{{{mdl}}} & {r.Split} & "
                    + " & ".join(cells) + r" \\"
                )
            else:
                latex_bot.append(rf"& {r.Split} & " + " & ".join(cells) + r" \\")
        # only not last model, add midrule
        if i != len(desired_models) - 1:
            latex_bot.append(r"\midrule")


    latex_bot.append(r"\bottomrule")
    latex_bot.append(r"\end{tabular}")
    # latex_bot.append(r"}")
    latex_bot.append(r"\end{table}")

    return "\n".join(latex_top + [""] + latex_bot)


latex_code = make_table_latex_auc_combined(df_metrics_max, include_rel_improvement=False)
print(latex_code)
with open(results_dir / "table_auc_combined.tex","w") as f: f.write(latex_code) 

