from __future__ import annotations

import os
import time
import warnings
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Callable
from math import gamma

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# GPU support via CuPy
try:
    import cupy as xp_gpu
    HAS_CUPY = True
except ImportError:
    xp_gpu = None
    HAS_CUPY = False

from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.multioutput import MultiOutputRegressor

# Try to import XGBoost for GPU acceleration
try:
    import xgboost as xgb
    HAS_XGBOOST = True
except ImportError:
    xgb = None
    HAS_XGBOOST = False

# Local imports
from load_dataset_grinsztajn import load_dataset
from train_cqr_v2_utils import split_train_cal_test
from multiscoring_conformal_gpu import MultiScoringConformalPredictor, QuantileType

# TR imports
from tr_retr import fit_tr_and_standardize, tr_transform, tr_retransform
from tr_retr_conditional import (
    ConditionalTR,
    fit_conditional_tr_and_standardize,
    conditional_tr_transform,
    conditional_tr_retransform,
    StabilizedConditionalTR,
    fit_stabilized_conditional_tr_and_standardize,
)

# WSC utilities
try:
    from worst_slab_cov_gpu import calculer_wsc_regression
    WSC_AVAILABLE = True
except Exception:
    WSC_AVAILABLE = False
    print("[OT-CP Benchmark] WSC not available - will be skipped")


# =============================================================================
# SCORING FUNCTIONS FOR MULTI-OUTPUT REGRESSION
# =============================================================================

def build_multioutput_residual_scoring_functions(n_outputs: int) -> List[Callable]:
    """
    Build scoring functions for multi-output regression.
    For each output dimension j: s_j(y_true, y_pred) = |y_true_j - y_pred_j|
    """
    def make_score_for_output(j: int):
        def score_func(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
            y_true = np.atleast_2d(y_true)
            y_pred = np.atleast_2d(y_pred)
            return np.abs(y_true[:, j] - y_pred[:, j])
        return score_func
    
    return [make_score_for_output(j) for j in range(n_outputs)]


# =============================================================================
# BASE MODEL TRAINING
# =============================================================================

def train_base_model(
    X_train: np.ndarray,
    y_train: np.ndarray,
    random_state: int = 42,
    dataset_name: str = None,
):
    """
    Train a multi-output regressor.
    
    Uses XGBoost with GPU acceleration if available, otherwise falls back to
    sklearn's GradientBoostingRegressor.
    """
    # Dataset-specific hyperparameters
    if dataset_name == "rf1":
        n_estimators = 100
        learning_rate = 0.1
        max_depth = 2
    else:
        n_estimators = 100
        learning_rate = 0.1
        max_depth = 3
    
    if HAS_XGBOOST:
        print(f"  Using XGBoost (device={'cuda' if HAS_CUPY else 'cpu'})")
        try:
            base_estimator = xgb.XGBRegressor(
                n_estimators=n_estimators,
                learning_rate=learning_rate,
                max_depth=max_depth,
                random_state=random_state,
                tree_method='hist',
                device='cuda' if HAS_CUPY else 'cpu',
            )
            model = MultiOutputRegressor(base_estimator)
            model.fit(X_train, y_train)
            return model
        except Exception as e:
            print(f"  XGBoost failed: {e}, falling back to sklearn")
    
    # Final fallback: sklearn GradientBoostingRegressor
    print("  Using sklearn GradientBoostingRegressor (CPU)")
    base_estimator = GradientBoostingRegressor(
        n_estimators=n_estimators,
        learning_rate=learning_rate,
        max_depth=max_depth,
        random_state=random_state,
    )
    
    model = MultiOutputRegressor(base_estimator)
    model.fit(X_train, y_train)
    
    return model

# =============================================================================
# VOLUME ESTIMATION (Ball sampling, adapted from GCQR_3)
# =============================================================================

def _unit_ball_volume(d: int) -> float:
    """Volume of unit ball in d dimensions."""
    return (np.pi ** (d / 2)) / gamma(d / 2 + 1)


def estimate_volume_mc_otcp(
    cp: MultiScoringConformalPredictor,
    scoring_functions: List[Callable],
    y_pred: np.ndarray,
    X_query: np.ndarray,
    cal_residuals_std: np.ndarray,
    n_samples: int = 10000,
    margin_factor: float = 1.5,
    use_conditional: bool = False,
    verbose: bool = False,
) -> Dict:
    """
    MC volume estimation for OT-CP region using BALL SAMPLING.
    
    In high dimensions, hypercube sampling is inefficient because
    V_ball / V_cube -> 0 as d increases. We sample uniformly from a ball.
    
    Args:
        cp: calibrated MultiScoringConformalPredictor
        scoring_functions: list of scoring functions
        y_pred: (d_y,) prediction center
        X_query: (d_x,) covariate for conditional models
        cal_residuals_std: (d_y,) std of calibration residuals per dimension
        n_samples: number of MC samples
        margin_factor: multiplicative factor for sampling radius
        use_conditional: whether to use conditional OT-CP
        verbose: whether to print detailed logs
        
    Returns:
        Dict with volume metrics
    """
    d_y = len(y_pred)
    n_scores = len(scoring_functions)
    
    if verbose:
        print(f"      [MC Volume] Starting estimation: d_y={d_y}, n_samples={n_samples}")
        print(f"      [MC Volume] y_pred norm: {np.linalg.norm(y_pred):.4f}, range: [{y_pred.min():.4f}, {y_pred.max():.4f}]")
        print(f"      [MC Volume] cal_residuals_std: mean={cal_residuals_std.mean():.4f}, max={cal_residuals_std.max():.4f}")
    
    # Estimate sampling radius from calibration residuals
    # Use 95th percentile of absolute residuals * safety factor
    sampling_radius_per_dim = 1.5 * np.abs(cal_residuals_std) * 3.0  # 3 sigma * 1.5
    sampling_radius = np.linalg.norm(sampling_radius_per_dim)
    
    if verbose:
        print(f"      [MC Volume] Base sampling radius: {sampling_radius:.4f}")
    
    # Adjust margin factor for high dimensions
    margin_factor_tight = 1.0 + 0.1 * np.log(d_y + 1) / np.log(16)
    margin_factor_tight = min(margin_factor_tight, margin_factor)
    margin_factor_tight = max(margin_factor_tight, 1.05)
    
    sampling_ball_radius = sampling_radius * margin_factor_tight
    
    if verbose:
        print(f"      [MC Volume] Adjusted margin factor: {margin_factor_tight:.4f}")
        print(f"      [MC Volume] Final sampling ball radius: {sampling_ball_radius:.4f}")
    
    # Reference ball volume
    unit_ball_vol = _unit_ball_volume(d_y)
    sampling_ball_volume = unit_ball_vol * (sampling_ball_radius ** d_y)
    
    if verbose:
        print(f"      [MC Volume] Unit ball volume (d={d_y}): {unit_ball_vol:.6e}")
        print(f"      [MC Volume] Sampling ball volume: {sampling_ball_volume:.6e}")
    
    # Sample uniformly from d-dimensional ball
    rng = np.random.default_rng(seed=42)
    
    if verbose:
        print(f"      [MC Volume] Generating {n_samples} uniform samples in {d_y}D ball...")
    
    # Random directions (normalize Gaussian vectors)
    directions = rng.standard_normal((n_samples, d_y)).astype(np.float32)
    directions = directions / (np.linalg.norm(directions, axis=1, keepdims=True) + 1e-12)
    
    # Random radii with proper distribution: r = R * U^(1/d)
    u_radii = rng.uniform(0, 1, size=(n_samples, 1)).astype(np.float32)
    radii = sampling_ball_radius * (u_radii ** (1.0 / d_y))
    
    if verbose:
        print(f"      [MC Volume] Radii stats: min={radii.min():.4f}, mean={radii.mean():.4f}, max={radii.max():.4f}")
    
    # Uniform samples in ball centered on y_pred
    Y_samples = y_pred + directions * radii
    
    if verbose:
        sample_norms = np.linalg.norm(Y_samples - y_pred, axis=1)
        print(f"      [MC Volume] Sample distances from center: min={sample_norms.min():.4f}, mean={sample_norms.mean():.4f}, max={sample_norms.max():.4f}")
    
    # Compute scores for all MC samples
    if verbose:
        print(f"      [MC Volume] Computing {n_scores} scores for {n_samples} samples...")
    
    y_pred_broadcast = np.broadcast_to(y_pred.reshape(1, -1), (n_samples, d_y))
    all_scores = np.zeros((n_samples, n_scores), dtype=np.float32)
    for j, score_func in enumerate(scoring_functions):
        all_scores[:, j] = score_func(Y_samples, y_pred_broadcast)
    
    if verbose:
        print(f"      [MC Volume] Score stats: min={all_scores.min():.4f}, mean={all_scores.mean():.4f}, max={all_scores.max():.4f}")
        print(f"      [MC Volume] Score per dim: " + ", ".join([f"dim{j}={all_scores[:,j].mean():.3f}" for j in range(min(5, n_scores))]) + ("..." if n_scores > 5 else ""))
    
    # Test membership using OT-CP
    if verbose:
        print(f"      [MC Volume] Testing membership (conditional={use_conditional})...")
    
    if use_conditional and X_query is not None:
        # Conditional OT-CP: use BATCH method (fits local model once)
        inside = cp.is_inside_batch_conditional(all_scores, X_query)
    else:
        # Non-conditional: batch check
        inside = cp.is_inside_batch(all_scores)
    
    inside_ratio = float(inside.mean())
    volume = inside_ratio * sampling_ball_volume
    
    if verbose:
        n_inside = inside.sum()
        print(f"      [MC Volume] Membership results: {n_inside}/{n_samples} inside ({inside_ratio*100:.2f}%)")
        print(f"      [MC Volume] Estimated volume: {volume:.6e}")
        print(f"      [MC Volume] Log volume: {np.log(max(volume, 1e-30)):.4f}")
    
    return {
        "volume": volume,
        "inside_ratio": inside_ratio,
        "sampling_ball_volume": sampling_ball_volume,
        "sampling_ball_radius": sampling_ball_radius,
        "margin_factor_used": margin_factor_tight,
    }


# =============================================================================
# OT-CP EVALUATION PIPELINE
# =============================================================================

def run_otcp_pipeline(
    X_train: np.ndarray,
    Y_train: np.ndarray,
    X_cal: np.ndarray,
    Y_cal: np.ndarray,
    X_test: np.ndarray,
    Y_test: np.ndarray,
    alpha: float = 0.1,
    random_state: int = 42,
    dataset_name: str = None,
    tr: Optional[Dict] = None,
    tr_cond: Optional[ConditionalTR] = None,
    volume_correction_method: str = "jacobian",
    use_conditional_ot: bool = False,
    k_neighbors: int = 100,
    ot_method: str = "local_ot",
    n_vol_samples: int = 20,
    n_mc_samples: int = 10000,
) -> Dict:
    """
    Full OT-CP Pipeline with GCQR3-style preprocessing.
    
    Args:
        X_train, Y_train: training data (for base model)
        X_cal, Y_cal: calibration data
        X_test, Y_test: test data
        alpha: miscoverage level
        random_state: random seed
        dataset_name: for logging
        tr: global TR dict (if using global TR)
        tr_cond: conditional TR object (if using conditional TR)
        volume_correction_method: "jacobian", "retransform", or "none"
        use_conditional_ot: whether to use OT-CP+ (conditional)
        k_neighbors: number of neighbors for conditional OT
        ot_method: method for OT-MK
        n_vol_samples: number of test points for volume estimation
        n_mc_samples: MC samples per test point for volume
    
    Returns:
        Dict with metrics
    """
    d_y = Y_train.shape[1]
    n_test = X_test.shape[0]
    
    print("=" * 70)
    print(" OT-CP Pipeline")
    print("=" * 70)
    print(f"  d_y={d_y}, α={alpha}")
    print(f"  OT method: {ot_method}")
    print(f"  Conditional OT: {use_conditional_ot}")
    
    # 1. Train base model
    print("\n[1] Training base model...")
    base_model = train_base_model(X_train, Y_train, random_state, dataset_name)
    
    # 2. Get predictions
    print("\n[2] Computing predictions...")
    Y_cal_pred = base_model.predict(X_cal)
    Y_test_pred = base_model.predict(X_test)
    
    # 3. Build scoring functions
    print("\n[3] Setting up scoring functions...")
    scoring_functions = build_multioutput_residual_scoring_functions(d_y)
    
    # 4. Create and calibrate conformal predictor
    print("\n[4] Calibrating OT-CP...")
    cp = MultiScoringConformalPredictor(
        scoring_functions=scoring_functions,
        quantile_type=QuantileType.OT_MK,
        alpha=alpha,
        random_state=random_state,
    )
    cp.ot_mk_method = ot_method
    cp.split_ratio = 1/3  # D1 = 1/3 of cal (rank estimation), D2 = 2/3 of cal (threshold)
    
    if use_conditional_ot:
        cp.enable_conditional_ot(k_neighbors=k_neighbors)
        cp.calibrate(
            y_cal_true=Y_cal,
            y_cal_pred=Y_cal_pred,
            X_cal=X_cal,
            task="regression",
        )
    else:
        cp.calibrate(
            y_cal_true=Y_cal,
            y_cal_pred=Y_cal_pred,
            task="regression",
        )
    
    print(f"  OT-MK radius threshold: {cp.ot_mk_radius_threshold_:.4f}")
    
    # 5. Evaluate coverage on all test points
    print("\n[5] Evaluating coverage on all test points...")
    coverage_status = np.zeros(n_test, dtype=bool)
    
    # Batch compute all test scores first
    all_test_scores = np.zeros((n_test, d_y), dtype=np.float32)
    for i in range(n_test):
        y_pred_i = Y_test_pred[i]
        y_true_i = Y_test[i]
        for j, score_func in enumerate(scoring_functions):
            all_test_scores[i, j] = score_func(
                y_true_i.reshape(1, -1),
                y_pred_i.reshape(1, -1)
            )[0]
    
    if use_conditional_ot:
        # Conditional OT: use BallTree and batch per test point
        # Each test point has its own X_query, so we batch by X_query
        for i in range(n_test):
            # Single score, single X_query - use is_inside (uses BallTree internally)
            coverage_status[i] = cp.is_inside(all_test_scores[i], X_query=X_test[i])
            if (i + 1) % 200 == 0:
                print(f"    Coverage evaluation progress: {i+1}/{n_test}")
    else:
        # Non-conditional: single batch call
        coverage_status = cp.is_inside_batch(all_test_scores)
    
    # Ensure coverage_status is a NumPy array (may be CuPy)
    if HAS_CUPY and hasattr(coverage_status, 'get'):
        coverage_status = coverage_status.get()
    coverage_status = np.asarray(coverage_status)
    
    coverage = np.mean(coverage_status)
    print(f"  Coverage: {coverage:.4f} (target: {1 - alpha:.2f})")
    
    # 6. WSC computation
    print("\n[6] Computing Worst Slab Coverage (WSC)...")
    wsc_min = np.nan
    
    if WSC_AVAILABLE:
        delta_map = {"scm20d": 0.10, "sgemm": 0.01, "rf1": 0.10, "rf2": 0.10, "scm1d": 0.10}
        delta = delta_map.get(dataset_name, 0.10)
        M_directions = 100
        
        try:
            # Ensure X_test is a NumPy array (may be CuPy)
            X_test_np = X_test
            if HAS_CUPY and hasattr(X_test_np, 'get'):
                X_test_np = X_test_np.get()
            X_test_np = np.asarray(X_test_np, dtype=np.float32)
            
            wsc_min, wsc_list = calculer_wsc_regression(
                X_test_np, coverage_status.astype(np.float32), delta=delta, M=M_directions, random_state=42
            )
            print(f"  WSC_min: {wsc_min:.4f} (delta={delta})")
        except Exception as e:
            print(f"  WSC computation failed: {e}")
    else:
        print("  WSC not available")
    
    # 7. Volume estimation
    print("\n[7] Volume estimation (ball sampling)...")
    n_vol = min(n_vol_samples, n_test)
    rng = np.random.default_rng(seed=42)
    vol_indices = rng.choice(n_test, size=n_vol, replace=False)
    
    # Compute calibration residuals for radius estimation
    cal_residuals = Y_cal - Y_cal_pred
    cal_residuals_std = np.std(cal_residuals, axis=0)
    
    print(f"  Config: n_points={n_vol}, n_mc_samples={n_mc_samples}")
    
    all_vol_results = []
    log_jacobians = []
    
    # Enable verbose logging for first few points
    verbose_vol_estimation = True
    
    for idx, i in enumerate(vol_indices):
        y_pred_i = Y_test_pred[i]
        X_query_i = X_test[i] if use_conditional_ot else None
        
        # Only verbose for first 3 points to avoid excessive output
        verbose_this_point = verbose_vol_estimation and (idx < 3)
        if verbose_this_point:
            print(f"\n    --- Volume estimation for test point {idx+1}/{n_vol} (index={i}) ---")
        
        vol_res = estimate_volume_mc_otcp(
            cp=cp,
            scoring_functions=scoring_functions,
            y_pred=y_pred_i,
            X_query=X_query_i,
            cal_residuals_std=cal_residuals_std,
            n_samples=n_mc_samples,
            margin_factor=1.5,
            use_conditional=use_conditional_ot,
            verbose=verbose_this_point,
        )
        
        # Compute log-jacobian for TR correction
        if tr_cond is not None:
            log_jac = tr_cond.get_log_det_jacobian(X_test[i])
        elif tr is not None:
            eigvals = np.abs(np.linalg.eigvalsh(tr['sqrtS']))
            log_jac = float(np.sum(np.log(np.clip(eigvals, 1e-30, None))))
        else:
            log_jac = 0.0
        
        vol_res['log_jacobian'] = log_jac
        log_jacobians.append(log_jac)
        all_vol_results.append(vol_res)
        
        if (idx + 1) % 5 == 0:
            print(f"    Progress: {idx+1}/{n_vol}")
    
    # Aggregate volume metrics
    volumes = np.array([r['volume'] for r in all_vol_results])
    inside_ratios = np.array([r['inside_ratio'] for r in all_vol_results])
    
    mean_volume_transformed = float(np.mean(volumes))
    mean_inside_ratio = float(np.mean(inside_ratios))
    
    # Compute normalized log-volume with per-point Jacobian correction
    eps_vol = 1e-30
    log_jacobian_mean = float(np.mean(log_jacobians))
    
    # Determine TR type
    tr_type = "none"
    if tr_cond is not None:
        tr_type = "conditional"
    elif tr is not None:
        tr_type = "global"
    
    if tr_type != "none" and volume_correction_method != "none":
        # Per-point correction
        log_vols_corrected = []
        for res in all_vol_results:
            log_vol_transf = np.log(max(res['volume'], eps_vol))
            log_vol_corrected = log_vol_transf + res['log_jacobian']
            log_vols_corrected.append(log_vol_corrected)
        
        mean_log_vol_corrected = float(np.mean(log_vols_corrected))
        norm_log_vol_final = mean_log_vol_corrected / d_y
        norm_log_vol_transformed = (1 / d_y) * np.log(max(mean_volume_transformed, eps_vol))
    else:
        norm_log_vol_transformed = (1 / d_y) * np.log(max(mean_volume_transformed, eps_vol))
        norm_log_vol_final = norm_log_vol_transformed
    
    print(f"\n  Volume Summary:")
    print(f"    Mean inside ratio:    {mean_inside_ratio*100:.2f}%")
    print(f"    Mean vol (transformed): {mean_volume_transformed:.6e}")
    print(f"    (1/d) log vol (transf): {norm_log_vol_transformed:+.4f}")
    print(f"    (1/d) log vol (final):  {norm_log_vol_final:+.4f}")
    print(f"    TR type:                {tr_type}")
    
    return {
        "coverage": coverage,
        "target": 1 - alpha,
        "ot_threshold": cp.ot_mk_radius_threshold_,
        "vol_transformed": mean_volume_transformed,
        "inside_ratio": mean_inside_ratio,
        "norm_log_vol_final": norm_log_vol_final,
        "norm_log_vol_transformed": norm_log_vol_transformed,
        "wsc_min": wsc_min,
        "log_jacobian_mean": log_jacobian_mean,
        "tr_type": tr_type,
        "ot_method": ot_method,
        "conditional_ot": use_conditional_ot,
        "k_neighbors": k_neighbors if use_conditional_ot else None,
    }


# =============================================================================
# Results Saving
# =============================================================================

def save_results_to_csv(
    results_list: list,
    output_dir: str = "results",
    filename: str = "results_otcp_gcqr_preprocessing.csv",
) -> str:
    """Save results to CSV file (simplified: minimal columns only)."""
    os.makedirs(output_dir, exist_ok=True)
    
    df = pd.DataFrame(results_list)
    
    # Minimal columns for multi-run analysis
    keep_cols = ['run', 'seed', 'coverage', 'wsc_min', 'norm_log_vol_scaled', 'norm_log_vol_transformed', 'time']
    df = df[[c for c in keep_cols if c in df.columns]]
    
    output_path = os.path.join(output_dir, filename)
    df.to_csv(output_path, index=False, float_format='%.6f')
    
    print(f"\n✓ Results saved to: {output_path}")
    return output_path


# =============================================================================
# Main
# =============================================================================

if __name__ == "__main__":
    
    # Configuration matching GCQR3
    datasets = ["scm20d", "rf1", "scm1d", "rf2", "sgemm"] # "scm20d", "rf1", "scm1d", "rf2", "sgemm"
    alpha = 0.1
    base_seed = 42
    n_runs = 10  # Number of runs per dataset
    seeds = [base_seed + i * 111 for i in range(n_runs)]  # Reproducible seeds
    tr_retr_bool = True
    ridge = 1e-3
    mode = "tr_retr" if tr_retr_bool else "standard"
    # TR mode: "conditional", "stabilized", or "global"
    tr_mode = "stabilized"
    tr_knn_k = 100
    tr_shrinkage_alpha = 0.5
    tr_fallback_threshold = 1e6
    
    # Volume correction method
    volume_correction_method = "jacobian"
    
    # OT-CP settings
    use_conditional_ot = True  # Set to True for OT-CP+ (slower but adaptive)
    ot_cp_plus = "+" if use_conditional_ot else ""
    k_neighbors = 1000  # For conditional OT
    ot_method = "local_ot" 
    
    # Dataset-specific parameters
    test_size = {"scm20d": 0.2, "sgemm": 0.2, "rf1": 0.2, "rf2": 0.2, "scm1d": 0.2}
    cal_size = {"scm20d": 0.6, "sgemm": 0.6, "rf1": 0.6, "rf2": 0.6, "scm1d": 0.6}  # D1+D2 combined
    
    MAX_SAMPLES = 25000
    
    print("=" * 70)
    print(" OT-CP Evaluation with GCQR3 Preprocessing")
    print("=" * 70)
    print(f"  TR mode: {tr_mode}")
    print(f"  OT method: {ot_method}")
    print(f"  Conditional OT: {use_conditional_ot}")
    print(f"  Number of runs: {n_runs}")
    print(f"  Seeds: {seeds}")
    
    for dataset in datasets:
        dataset_results = []  # Results for this dataset only
        print(f"\n{'='*70}")
        print(f" Dataset: {dataset} (mode: {mode})")
        print(f"{'='*70}")
        
        # Load dataset once (before run loop)
        print(f"\nLoading dataset: {dataset}")
        if dataset == "rf1":
            X_full, Y_full = load_dataset(dataset, use_rf1_resplit=True, random_state=base_seed)
        else:
            X_full, Y_full = load_dataset(dataset, random_state=base_seed)
        
        X_full = X_full.values.astype(np.float32) if hasattr(X_full, "values") else np.asarray(X_full, dtype=np.float32)
        Y_full = np.asarray(Y_full, dtype=np.float32)
        print(f"  Full dataset shape: X={X_full.shape}, Y={Y_full.shape}")
        
        for run_idx, seed in enumerate(seeds):
            print(f"\n  --- Run {run_idx + 1}/{n_runs} (seed={seed}) ---")
            run_start_time = time.time()
            
            # Copy data for this run
            X = X_full.copy()
            Y = Y_full.copy()
            
            # Subsample if needed (with run-specific seed)
            n_samples = X.shape[0]
            if MAX_SAMPLES is not None and n_samples > MAX_SAMPLES:
                print(f"    Subsampling to {MAX_SAMPLES} samples...")
                np.random.seed(seed)
                indices = np.random.choice(n_samples, MAX_SAMPLES, replace=False)
                X = X[indices]
                Y = Y[indices]
            
            # Split (with run-specific seed)
            X_train, X_cal, X_test, Y_train, Y_cal, Y_test = split_train_cal_test(
                X, Y,
                test_size=test_size[dataset],
                cal_size=cal_size[dataset],
                random_state=seed,
            )
        
            # Standardize X
            if dataset != "sgemm":
                scaler_X = StandardScaler()
                X_train = scaler_X.fit_transform(X_train).astype(np.float32)
                X_cal = scaler_X.transform(X_cal).astype(np.float32)
                X_test = scaler_X.transform(X_test).astype(np.float32)
            
            # TR transformation
            tr = None
            tr_cond = None
            
            if tr_retr_bool:
                if tr_mode == "stabilized":
                    print(f"    Using STABILIZED CONDITIONAL TR (k={tr_knn_k}, shrinkage={tr_shrinkage_alpha})...")
                    tr_cond, Y_train = fit_stabilized_conditional_tr_and_standardize(
                        X_train, Y_train,
                        method="knn",
                        k=tr_knn_k,
                        ridge=ridge,
                        shrinkage_alpha=tr_shrinkage_alpha,
                        fallback_threshold=tr_fallback_threshold,
                        eigenvalue_shrinkage=0.0,
                    )
                    Y_cal = conditional_tr_transform(X_cal, Y_cal, tr_cond)
                    Y_test = conditional_tr_transform(X_test, Y_test, tr_cond)
                    
                elif tr_mode == "conditional":
                    print(f"    Using CONDITIONAL TR (KNN, k={tr_knn_k})...")
                    tr_cond, Y_train = fit_conditional_tr_and_standardize(
                        X_train, Y_train,
                        method="knn",
                        k=tr_knn_k,
                        ridge=ridge,
                    )
                    Y_cal = conditional_tr_transform(X_cal, Y_cal, tr_cond)
                    Y_test = conditional_tr_transform(X_test, Y_test, tr_cond)
                    
                else:  # global
                    print(f"    Using GLOBAL TR...")
                    tr, Y_train = fit_tr_and_standardize(Y_train, ridge=ridge)
                    Y_cal = tr_transform(Y_cal, tr)
                    Y_test = tr_transform(Y_test, tr)
            else:
                scaler_Y = StandardScaler()
                Y_train = scaler_Y.fit_transform(Y_train).astype(np.float32)
                Y_cal = scaler_Y.transform(Y_cal).astype(np.float32)
                Y_test = scaler_Y.transform(Y_test).astype(np.float32)
        
            # Additional normalization (same as GCQR3)
            norm_max_train = np.linalg.norm(Y_train, axis=1).max()
            scale_factor = 1 / norm_max_train
            Y_train = Y_train * scale_factor
            Y_cal = Y_cal * scale_factor
            Y_test = Y_test * scale_factor
            print(f"    Applied additional normalization (max norm scaling): factor=1/{norm_max_train:.4f}")
            print(f"    Shapes: train={X_train.shape}, cal={X_cal.shape}, test={X_test.shape}")
            
            # Run OT-CP pipeline
            results = run_otcp_pipeline(
                X_train, Y_train,
                X_cal, Y_cal,
                X_test, Y_test,
                alpha=alpha,
                random_state=seed,
                dataset_name=dataset,
                tr=tr,
                tr_cond=tr_cond,
                volume_correction_method=volume_correction_method,
                use_conditional_ot=use_conditional_ot,
                k_neighbors=k_neighbors,
                ot_method=ot_method,
                n_vol_samples=50,
                n_mc_samples=20000,
            )
            
            run_time = time.time() - run_start_time
            
            # Add run metadata
            results['run'] = run_idx + 1
            results['seed'] = seed
            results['time'] = run_time
            results['norm_log_vol_scaled'] = results['norm_log_vol_transformed'] + np.log(norm_max_train)
            
            dataset_results.append(results)
            
            # Print summary for this run
            print(f"    Run {run_idx + 1} complete: coverage={results['coverage']:.4f}, "
                  f"wsc={results['wsc_min']:.4f}, time={run_time:.1f}s")
        
        # Save results for this dataset (one CSV per dataset)
        save_results_to_csv(
            dataset_results,
            output_dir="results_otcp",
            filename=f"results_otcp_{ot_cp_plus}_{dataset}_{mode}.csv"
        )
        
        # Print dataset summary
        print(f"\n  === {dataset} Summary ({n_runs} runs) ===")
        coverages = np.array([float(r['coverage']) for r in dataset_results])
        wscs = np.array([float(r['wsc_min']) for r in dataset_results if not np.isnan(r['wsc_min'])])
        vols = np.array([float(r['norm_log_vol_scaled']) for r in dataset_results])
        print(f"    Coverage: {np.mean(coverages):.4f} ± {np.std(coverages):.4f}")
        if len(wscs) > 0:
            print(f"    WSC:      {np.mean(wscs):.4f} ± {np.std(wscs):.4f}")
        else:
            print(f"    WSC:      N/A (computation failed)")
        print(f"    Log-vol:  {np.mean(vols):.4f} ± {np.std(vols):.4f}")
    
    # Final summary
    print("\n" + "=" * 70)
    print(" BENCHMARK COMPLETE")
    print("=" * 70)
    print(f"  Datasets evaluated: {len(datasets)}")
    print(f"  Runs per dataset: {n_runs}")
    print(f"  Results saved to: results_otcp/results_otcp_{ot_cp_plus}_{dataset}_{mode}.csv")