
import warnings
import os
import gc

# GPU support via CuPy
try:
    import cupy as cp
    HAS_CUPY = True
except ImportError:
    cp = None
    HAS_CUPY = False

from sklearn.decomposition import PCA

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message=".*force_all_finite.*"
)

import numpy as np
from typing import Dict, Any
import pandas as pd
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import pairwise_distances

from sklearn.datasets import fetch_openml
from ucimlrepo import fetch_ucirepo

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
try:
    from xgboost import XGBClassifier
    HAS_XGB = True
except ImportError:
    HAS_XGB = False

from load_dataset_grinsztajn import load_dataset
from multiscoring_conformal_gpu import (MultiScoringConformalPredictor, QuantileType, 
                                    build_softmax_diff_scoring_functions,
                                    build_logit_margin_scoring_functions,
                                    build_reduced_scoring_functions)
from multiscoring_conformal_cond_gpu import MultiScoringConformalPredictorCond

# Import GRCP for multiclass (local nonparametric rank, no relabeling)
from grcp_multiclass import (
    GRCPMulticlassPredictor,
    eval_grcp_multiclass,
    build_softmax_diff_scoring_functions as grcp_build_softmax_diff_scoring_functions,
    median_heuristic,
)

# Import WSC version GPU
from worst_slab_cov_gpu import calculer_wsc_gpu

try:
    from mapie.classification import MapieClassifier
    MAPIE_LEGACY = True
except ImportError:
    from mapie.classification import SplitConformalClassifier
    MAPIE_LEGACY = False


def to_cpu(arr):
    """Move array to CPU."""
    if HAS_CUPY and hasattr(arr, 'device'):
        return cp.asnumpy(arr)
    return np.asarray(arr)


def to_gpu(arr):
    """Move array to GPU if available."""
    if HAS_CUPY:
        return cp.asarray(arr, dtype=cp.float32)
    return np.asarray(arr, dtype=np.float32)


#


# ---------------------------------------------------------------------
# 2. Split train / cal / test
# ---------------------------------------------------------------------
def split_train_cal_test(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float = 0.2,
    cal_size: float = 0.2,
    random_state: int = 42,
):
    """
    Split : train  (1 - test_size - cal_size),
            cal    cal_size,
            test   test_size.
    """
    X = np.asarray(X)
    y = np.asarray(y)

    # d'abord train+cal vs test
    X_train_cal, X_test, y_train_cal, y_test = train_test_split(
        X, y,
        test_size=test_size,
        stratify=y,
        random_state=random_state,
    )

    # puis train vs cal (cal représente cal_size / (1 - test_size) du reste)
    relative_cal_size = cal_size / (1.0 - test_size)
    X_train, X_cal, y_train, y_cal = train_test_split(
        X_train_cal, y_train_cal,
        test_size=relative_cal_size,
        stratify=y_train_cal,
        random_state=random_state,
    )

    return X_train, X_cal, X_test, y_train, y_cal, y_test


# ---------------------------------------------------------------------
# 3. Entraînement du modèle de base
# ---------------------------------------------------------------------
def train_base_model(
    X_train: np.ndarray,
    y_train: np.ndarray,
    random_state: int = 42,
    model_type: str = "random_forest",  # "random_forest" ou "xgboost"
) -> Any:
    """
    Entraîne un modèle de base multiclasses.
    
    Parameters
    ----------
    model_type : str
        Type de modèle: "random_forest" (défaut) ou "xgboost"
    """
    n_classes = int(len(np.unique(y_train)))
    
    if model_type == "random_forest":
        from sklearn.ensemble import RandomForestClassifier
        
        print("[INFO] Entraînement Random Forest...")
        clf = RandomForestClassifier(
            n_estimators=500,
            max_depth=20,
            min_samples_split=5,
            min_samples_leaf=2,
            max_features='sqrt',
            random_state=random_state,
            n_jobs=-1,
            verbose=1,
        )
        clf.fit(X_train, y_train)
        print("[INFO] Base model: RandomForestClassifier.")
        return clf
    
    elif model_type == "xgboost":
        # Tentative XGBoost (GPU puis CPU hist)
        if HAS_XGB:
            # Essai 1: GPU
            try:
                print("Essai XGBoost GPU...")
                clf = XGBClassifier(
                    n_estimators=300,
                    learning_rate=0.1,
                    max_depth=6,
                    subsample=0.9,
                    colsample_bytree=0.9,
                    objective="multi:softprob",
                    num_class=n_classes,
                    tree_method="gpu_hist",
                    predictor="gpu_predictor",
                    eval_metric="mlogloss",
                    random_state=random_state,
                    n_jobs=-1,
                )
                clf.fit(X_train, y_train)
                print("[INFO] Base model: XGBClassifier (GPU).")
                return clf
            except Exception as e:
                print(f"[INFO] XGBoost GPU non disponible ({e}), essai CPU hist...")
            
            # Essai 2: CPU avec hist (rapide)
            try:
                clf = XGBClassifier(
                    n_estimators=300,
                    learning_rate=0.1,
                    max_depth=6,
                    subsample=0.9,
                    colsample_bytree=0.9,
                    objective="multi:softprob",
                    num_class=n_classes,
                    tree_method="hist",
                    eval_metric="mlogloss",
                    random_state=random_state,
                    n_jobs=-1,
                    verbosity=1,
                )
                clf.fit(
                    X_train, y_train,
                    eval_set=[(X_train, y_train)],
                    verbose=10,
                )
                print("[INFO] Base model: XGBClassifier (CPU hist).")
                return clf
            except Exception as e:
                print(f"[WARN] Échec de XGBClassifier ({e}). Fallback vers Random Forest.")
        
        # Fallback si XGBoost échoue
        print("[INFO] XGBoost non disponible, utilisation de Random Forest...")
        from sklearn.ensemble import RandomForestClassifier
        clf = RandomForestClassifier(
            n_estimators=500,
            max_depth=20,
            random_state=random_state,
            n_jobs=-1,
        )
        clf.fit(X_train, y_train)
        print("[INFO] Base model: RandomForestClassifier (fallback).")
        return clf
    
    else:
        raise ValueError(f"model_type inconnu: {model_type}. Utilisez 'random_forest' ou 'xgboost'.")



# ---------------------------------------------------------------------
# 4. Évaluation CP multiscoring (GEOMETRIC ou OT_MK)
# ---------------------------------------------------------------------
def eval_multiscoring_cp(
    X_cal: np.ndarray,
    y_cal: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    base_clf: Any,
    quantile_type: QuantileType,
    alpha: float,
    K_rank: int = 5,
    h: float = 1.0,
    X_cond_cal: np.ndarray | None = None,
    X_cond_test: np.ndarray | None = None,
    use_conditional: bool = False,
    k_neighbors: int = 1000,
    ot_mk_method: str = "exact",
    score_type: str = "logit_margin",  # "softmax_diff", "logit_margin", "reduced"
) -> Dict[str, float]:
    """
    MultiScoringConformalPredictor avec score s(x,y) ∈ R^K.
    Si K est grand (> 5) et X_train est fourni, on entraîne une ACP sur les scores de train
    pour réduire la dimension à 5 composantes.
    
    IMPORTANT: Pour les grandes dimensions (n_classes > 20), on utilise une approche adaptative
    pour éviter la malédiction de la dimensionnalité.
    
    Parameters
    ----------
    score_type : str
        Type de score à utiliser:
        - "softmax_diff": |p_k - 1_{y=k}| (défaut, utilise probabilités)
        - "logit_margin": z_k - z_y (vecteur de marges logits, utilise logits)
        - "reduced": scores réduits top-K
    """
    n_classes = base_clf.n_classes_
    
    # =====================================================================
    # SCORE TYPE SELECTION: softmax_diff, logit_margin, or reduced
    # =====================================================================
    print(f"    [Multivariate] Score type: {score_type}")
    
    if score_type == "logit_margin":
        # Vecteur de marges logits: S_k(x,y) = z_k(x) - z_y(x)
        scoring_functions = build_logit_margin_scoring_functions(n_classes)
        print(f"    [Multivariate] Mode LOGIT_MARGIN: {n_classes} scoring functions")
        
        # Extraction des logits selon le type de modèle
        if HAS_XGB and hasattr(base_clf, 'get_booster'):
            # XGBoost: utilise output_margin=True pour obtenir les logits
            print(f"    [Multivariate] XGBoost détecté: extraction des logits via output_margin=True")
            y_cal_scores_input = base_clf.predict(X_cal, output_margin=True)
            y_test_scores_input = base_clf.predict(X_test, output_margin=True)
        else:
            # Random Forest / autres: approximation via log-probabilités
            print(f"    [Multivariate] Modèle non-XGBoost: approximation logits via log(proba + eps)")
            eps = 1e-10
            y_cal_probs = base_clf.predict_proba(X_cal)
            y_test_probs = base_clf.predict_proba(X_test)
            y_cal_scores_input = np.log(y_cal_probs + eps)
            y_test_scores_input = np.log(y_test_probs + eps)
        
        print(f"    [Multivariate] Shape logits cal: {y_cal_scores_input.shape}, test: {y_test_scores_input.shape}")
        use_logit_margin = True
        
    elif score_type == "reduced":
        # Mode adaptatif: scores réduits basés sur top-K
        scoring_functions = build_reduced_scoring_functions(K=K_rank)
        print(f"    [Multivariate] Mode REDUCED: top-{K_rank} scoring functions")
        
        # Probabilités sur cal / test
        y_cal_probs = base_clf.predict_proba(X_cal)
        y_test_probs = base_clf.predict_proba(X_test)
        y_cal_scores_input = y_cal_probs
        y_test_scores_input = y_test_probs
        
        print(f"    [Multivariate] Shape probs cal: {y_cal_probs.shape}, test: {y_test_probs.shape}")
        use_logit_margin = False
        
    else:  # score_type == "softmax_diff" (default)
        # Stratégie adaptative pour le nombre de scores selon n_classes
        # Pour éviter la malédiction de la dimensionnalité dans les méthodes géométriques
        if n_classes <= 30:
            # Petit nombre de classes : utiliser tous les scores
            K_rank = n_classes
            use_all_scores = True
        else:
            # Grand nombre de classes (>30) : utiliser top-30 seulement
            K_rank = 30
            use_all_scores = True
        
        if use_all_scores:
            # Mode standard : tous les scores
            scoring_functions = build_softmax_diff_scoring_functions(n_classes)
            print(f"    [Multivariate] Mode SOFTMAX_DIFF: {n_classes} scoring functions")
        else:
            # Mode adaptatif : scores réduits basés sur top-K
            scoring_functions = build_reduced_scoring_functions(K=K_rank)
            print(f"    [Multivariate] Mode adaptatif : top-{K_rank} scoring functions (reduced)")
            print(f"    [Multivariate] Raison : n_classes={n_classes} > 30 (éviter curse of dimensionality)")
        
        # Probabilités sur cal / test
        y_cal_probs = base_clf.predict_proba(X_cal)
        y_test_probs = base_clf.predict_proba(X_test)
        y_cal_scores_input = y_cal_probs
        y_test_scores_input = y_test_probs
        
        print(f"    [Multivariate] Shape probs cal: {y_cal_probs.shape}, test: {y_test_probs.shape}")
        use_logit_margin = False
    
    print(f"    [Multivariate] Nombre de classes: {n_classes}, dimension effective des scores: {len(scoring_functions)}")

    # Check if OT-CP+ is requested
    if quantile_type == QuantileType.OT_MK and use_conditional:
        print(f"    [OT-CP+] Using conditional OT with k={k_neighbors} neighbors, method={ot_mk_method}")
        cp = MultiScoringConformalPredictor(
            scoring_functions=scoring_functions,
            quantile_type=quantile_type,
            alpha=alpha,
        )
        cp.ot_mk_method = ot_mk_method  # Configure method
        cp.enable_conditional_ot(k_neighbors=k_neighbors)
        
        # Calibrate with covariates (use scores_input which is logits or probs depending on score_type)
        cp.calibrate(
            y_cal_true=y_cal, 
            y_cal_pred=y_cal_scores_input, 
            task="classification",
            X_cal=X_cond_cal if X_cond_cal is not None else X_cal
        )
    elif quantile_type in {QuantileType.GEOMETRIC, QuantileType.OT_MK}:
        # Original path...
        cp = MultiScoringConformalPredictor(
            scoring_functions=scoring_functions,
            quantile_type=quantile_type,
            alpha=alpha,
        )
        if quantile_type == QuantileType.OT_MK:
            cp.ot_mk_method = ot_mk_method  # Configure method
            print(f"    [Multivariate] OT-MK method: {ot_mk_method}")
        print(f"    [Multivariate] Calibration sur {len(y_cal)} échantillons...")
        cp.calibrate(y_cal_true=y_cal, y_cal_pred=y_cal_scores_input, task="classification")
        print(f"    [Multivariate] Calibration terminée.")
    
    else:
        # Mode conditionnel
        print(f"    [Multivariate] Mode conditionnel ({quantile_type.name})")
        print(f"    [Multivariate] Kernel: gaussian, bande passante h={h:.4f}")
        cp = MultiScoringConformalPredictorCond(
            scoring_functions=scoring_functions,
            quantile_type=quantile_type,
            alpha=alpha,
            kernel_name="gaussian",
            h=h,  # bande passante du noyau
        )
        if X_cond_cal is None:
            X_cond_cal = X_cal
        print(f"    [Multivariate] Shape covariables cal: {X_cond_cal.shape}")
        print(f"    [Multivariate] Calibration sur {len(y_cal)} échantillons...")
        cp.calibrate(X_cal=X_cond_cal, y_cal_true=y_cal, y_cal_pred=y_cal_scores_input)
        print(f"    [Multivariate] Calibration terminée.")

    # Mais on doit s'assurer que n_scores local correspond bien à ce que cp attend
    n_scores = cp.n_scores
    print(f"    [Multivariate] Dimension des scores après calibration: {n_scores}")

    # Conformal sets sur le test
    candidate_classes = np.arange(n_classes)
    n_test = X_test.shape[0]
    
    print(f"    [Multivariate] Construction des prediction sets sur {n_test} échantillons test...")
    print(f"    [Multivariate] Vectorisation des calculs de scores...")

    # ==================== VECTORIZED SCORE COMPUTATION ====================
    # Pre-compute ALL scores for ALL test points and ALL classes at once
    # Shape: (n_test, n_classes, n_scores)
    
    all_scores = np.zeros((n_test, n_classes, n_scores), dtype=np.float32)
    
    print(f"    [Multivariate] Calcul des scores pour {n_classes} classes × {n_scores} scores...")
    
    if use_logit_margin:
        # ==================== LOGIT MARGIN SCORING ====================
        # For logit_margin scoring: score_j(y, z) = z_j - z_y
        # When candidate class = k:
        #   - score_j = z_j - z_k  (logit margin)
        # Convention: score > 0 means class j is more likely than candidate k
        # (same "lower is better" as softmax_diff)
        print(f"    [Multivariate] Mode LOGIT_MARGIN: z_j - z_k pour chaque classe candidate k")
        
        for k in range(n_classes):
            if k % 10 == 0 or k == n_classes - 1:
                print(f"      Classe {k+1}/{n_classes}...")
            # For candidate class k:
            # score_j = z_j - z_k (how much class j logit exceeds candidate k's logit)
            z_k = y_test_scores_input[:, k]  # (n_test,) - logit of candidate class k
            for j in range(n_scores):
                # score_j = z_j - z_k
                all_scores[:, k, j] = y_test_scores_input[:, j] - z_k
    else:
        # ==================== SOFTMAX_DIFF SCORING ====================
        # For softmax_diff scoring: score_j(y, p) = |p[j] - 1_{y=j}|
        # When candidate class = k:
        #   - score_k = |p[k] - 1| = 1 - p[k]  (since p[k] <= 1)
        #   - score_j (j != k) = |p[j] - 0| = p[j]
        print(f"    [Multivariate] Mode SOFTMAX_DIFF: |p_j - 1_{{y=j}}|")
        
        for k in range(n_classes):
            if k % 10 == 0 or k == n_classes - 1:
                print(f"      Classe {k+1}/{n_classes}...")
            # For candidate class k:
            # all_scores[:, k, j] = |y_test_scores_input[:, j] - 1_{k=j}|
            for j in range(n_scores):
                if j == k:
                    # score_k = |p[k] - 1| = 1 - p[k]
                    all_scores[:, k, j] = 1.0 - y_test_scores_input[:, j]
                else:
                    # score_j = |p[j] - 0| = p[j]
                    all_scores[:, k, j] = y_test_scores_input[:, j]
    
    print(f"    [Multivariate] Scores calculés: shape={all_scores.shape}")
    print(f"    [Multivariate] Vérification batch sur GPU/CPU...")
    
    # ==================== BATCH INSIDE CHECK ====================
    # Check if we need to pass covariates (OT-CP+ or other conditional methods)
    need_covariates = (quantile_type == QuantileType.OT_MK and use_conditional) or \
                      quantile_type not in {QuantileType.GEOMETRIC, QuantileType.OT_MK}
    
    if not need_covariates:
        # Non-conditional: check all (n_test * n_classes) points at once
        print(f"    [Multivariate] Mode non-conditionnel: batch de {n_test * n_classes} vérifications...")
        scores_flat = all_scores.reshape(-1, n_scores)  # (n_test * n_classes, n_scores)
        print(f"      Appel is_inside_batch sur {scores_flat.shape[0]} points...")
        inside_flat = cp.is_inside_batch(scores_flat)    # (n_test * n_classes,) bool
        print(f"      Transfert CPU...")
        inside_flat = to_cpu(inside_flat)                # Single CPU transfer
        all_inside = inside_flat.reshape(n_test, n_classes)  # (n_test, n_classes)
        print(f"      Terminé.")
    else:
        # Conditional: batch per test point with progress reporting
        if X_cond_test is None:
            X_cond_test = X_test
        
        print(f"    [Multivariate] Conditional mode: processing {n_test} test points × {n_classes} classes...")
        
        all_inside = np.zeros((n_test, n_classes), dtype=bool)
        
        # Process in batches of test points to show progress
        batch_size = 100
        n_batches = (n_test + batch_size - 1) // batch_size
        
        import time
        start_time = time.time()
        
        for b in range(n_batches):
            if b % 2 == 0 or b == n_batches - 1:
                elapsed = time.time() - start_time
                progress_pct = (b + 1) / n_batches * 100
                eta = (elapsed / (b + 1)) * (n_batches - b - 1) if b > 0 else 0
                print(f"      Progress: {min((b+1)*batch_size, n_test)}/{n_test} ({progress_pct:.1f}%) - ETA: {eta:.1f}s")
            
            start_idx = b * batch_size
            end_idx = min((b + 1) * batch_size, n_test)
            
            for i in range(start_idx, end_idx):
                # Get scores for all classes for this test point
                scores_i = all_scores[i]  # (n_classes, n_scores)
                X_i = X_cond_test[i]  # (p,)
                
                # Check if cp has a batch method for conditional with same X
                if hasattr(cp, 'is_inside_batch_conditional'):
                    result = cp.is_inside_batch_conditional(scores_i, X_i)
                    all_inside[i] = to_cpu(result)  # Convert CuPy -> NumPy
                else:
                    # Loop over classes (still faster due to vectorized score computation)
                    for k in range(n_classes):
                        all_inside[i, k] = cp.is_inside(scores_i[k], X_i)
        
        total_time = time.time() - start_time
        print(f"      Terminé en {total_time:.1f}s ({n_test * n_classes / total_time:.0f} checks/s)")
    
    print(f"    [Multivariate] Vérification terminée, construction des prediction sets...")
    
    # ==================== BUILD PREDICTION SETS ====================
    prediction_sets = []
    coverage_count = 0
    total_size = 0
    
    for i in range(n_test):
        inside = all_inside[i]  # (n_classes,) bool array
        pred_set = candidate_classes[inside]

        # Fallback : ensemble vide → {classe la plus probable}
        if pred_set.size == 0:
            pred_set = np.array([int(np.argmax(y_test_probs[i]))], dtype=int)

        prediction_sets.append(pred_set)
        total_size += pred_set.size
        if y_test[i] in pred_set:
            coverage_count += 1

    coverage = coverage_count / n_test
    avg_size = total_size / n_test
    
    print(f"    [Multivariate] Construction terminée: {n_test} prediction sets créés")
    print(f"    [Multivariate] Coverage empirique: {coverage:.4f}, taille moyenne: {avg_size:.2f}")

    # Convertir prediction_sets en liste de sets pour WSC
    prediction_sets_as_sets = [set(ps.tolist()) for ps in prediction_sets]

    return {
        "coverage": coverage,
        "target_coverage": 1.0 - alpha,
        "average_set_size": avg_size,
        "prediction_sets": prediction_sets_as_sets,
    }


# ---------------------------------------------------------------------
# 5. Évaluation CP vanilla avec MAPIE (APS / score cumulé)
# ---------------------------------------------------------------------
def eval_mapie_cp(
    X_cal: np.ndarray,
    y_cal: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    base_clf: Any,
    alpha: float,
    method: str = "aps",
) -> Dict[str, float]:
    """
    CP vanilla via MAPIE (APS/RAPS/LAC).

    - Si MAPIE_LEGACY = True  : utilise MapieClassifier (v0.x)
    - Sinon (MAPIE v1+)       : utilise SplitConformalClassifier
    """
    if MAPIE_LEGACY:
        # --- API MAPIE 0.x : MapieClassifier + cv="prefit" ---
        mapie = MapieClassifier(
            estimator=base_clf,
            method=method,      # ou "aps", "lac" si tu veux changer
            cv="prefit",
            random_state=42,
        )
        # Ici, X_cal sert de "calibration set" (le modèle est déjà fit sur X_train)
        mapie.fit(X_cal, y_cal)
        # y_ps_mapie : (n_samples, n_classes, n_alpha)
        y_pred_mapie, y_ps_mapie = mapie.predict(X_test, alpha=alpha)
    else:
        # --- API MAPIE 1.x : SplitConformalClassifier ---
        # alpha = 1 - confidence_level
        mapie = SplitConformalClassifier(
            estimator=base_clf,
            confidence_level=1.0 - alpha,
            conformity_score=method,   # "aps", "lac", "top_k" possibles aussi
            prefit=True,               # base_clf est déjà entraîné
            random_state=42,
        )
        # Calibration sur le set de conformalisation
        mapie = mapie.conformalize(X_cal, y_cal)
        # y_ps_mapie : (n_samples, n_classes, n_confidence_levels)
        y_pred_mapie, y_ps_mapie = mapie.predict_set(X_test)

    # ------------------------------------------------------------------
    # Calcul coverage / taille moyenne des ensembles
    # ------------------------------------------------------------------
    # On suppose un seul niveau (un seul alpha / confidence_level)
    # -> on prend le "canal" 0
    mask = y_ps_mapie[:, :, 0].astype(bool)  # (n_test, n_classes)

    n_test, n_classes = mask.shape

    # Fallback : si un ensemble est vide, on ajoute la prédiction top-1
    for i in range(n_test):
        if not mask[i].any():
            # y_pred_mapie[i] est un label (ici int 0..K-1 car LabelEncoder)
            mask[i, int(y_pred_mapie[i])] = True

    # Taille des ensembles = nombre de True par ligne
    set_sizes = mask.sum(axis=1).astype(float)  # (n_test,)

    # y_test contient déjà les labels encodés 0..K-1 (LabelEncoder)
    coverage = np.mean(mask[np.arange(n_test), y_test])
    avg_size = float(set_sizes.mean())

    # Convertir mask en liste de sets pour WSC
    prediction_sets = [set(np.where(mask[i])[0].tolist()) for i in range(n_test)]

    return {
        "coverage": coverage,
        "target_coverage": 1.0 - alpha,
        "average_set_size": avg_size,
        "prediction_sets": prediction_sets,
    }


def compute_wsc_for_method(X_test: np.ndarray, y_test: np.ndarray, 
                           prediction_sets: list, delta: float = 0.01, 
                           M: int = 1000) -> float:
    """
    Calcule le Worst Slab Coverage (WSC) pour une méthode.
    Retourne le WSC minimum sur M directions aléatoires.
    """
    wsc_results = calculer_wsc_gpu(X_test, y_test, prediction_sets, delta=delta, M=M)
    return float(np.min(wsc_results))


# ---------------------------------------------------------------------
# 6. Boucle principale sur les datasets
# ---------------------------------------------------------------------
if __name__ == "__main__":
    datasets = ["Fashion-MNIST", "mnist_784", "letter", "emnist"] # , "Fashion-MNIST", "mnist_784"
    # datasets = [ "yeast", "vehicle", "segment", "satimage", 186, 602, "Fashion-MNIST", "mnist_784", "letter"] 
    # datasets = [186, "letter"]
    datasets_id = {602: "Dry Bean Dataset", 186: "Wine Quality Dataset"}
    alphas_cp = {"letter": 0.03, "satimage": 0.05, "vehicle" : 0.1, "segment": 0.02, "yeast": 0.2, 602: 0.05, 186: 0.15, "Dry Bean Dataset": 0.05, "Wine Quality Dataset": 0.15, "mnist_784": 0.03, "Fashion-MNIST": 0.1, "emnist": 0.1}
    test_size = {"letter": 0.3, "satimage": 0.3, "vehicle": 0.25, "segment": 0.25, "yeast": 0.25, 602: 0.3, 186: 0.3, "mnist_784": 0.45, "Fashion-MNIST": 0.45, "emnist": 0.45}
    cal_size = {"letter": 0.3, "satimage": 0.3, "vehicle": 0.25, "segment": 0.25, "yeast": 0.3, 602: 0.3, 186: 0.3, "mnist_784": 0.45, "Fashion-MNIST": 0.45, "emnist": 0.45}
    standardize = True
    model_type = "random_forest"  # "random_forest" ou "xgboost"
    MAX_SAMPLES = 25000  # Limite à 30k échantillons max pour éviter calculs trop longs
    ot_mk_method = {"Fashion-MNIST": "hybrid", "mnist_784": "barycentric", "letter": "barycentric", "emnist": "hybrid"}
    
    # Option to disable OT-CP+ (conditional) evaluation
    EVAL_OT_CP_COND = True  # Set to False to disable OT-CP+ evaluation
    
    # =====================================================================
    # MULTI-RUN CONFIGURATION
    # =====================================================================
    n_runs = 10
    random_seeds = [42 + i * 111 for i in range(n_runs)]
    
    # Create output directory for results
    results_dir = "results_classif"
    os.makedirs(results_dir, exist_ok=True)
    print(f"[CONFIG] n_runs={n_runs}, seeds={random_seeds}")
    print(f"[CONFIG] Results will be saved to: {results_dir}/")
    
    # Dictionary to accumulate results across runs for each dataset
    all_results = {}
    
    for dataset_name in datasets:
        # Get proper dataset name for storage
        current_dataset_name = datasets_id[dataset_name] if isinstance(dataset_name, int) else dataset_name
        all_results[current_dataset_name] = []
        print("=" * 80)
        print(f"DATASET : {current_dataset_name}")
        print("=" * 80)
        
        # =====================================================================
        # MULTI-RUN LOOP
        # =====================================================================
        for run_idx, random_state in enumerate(random_seeds):
            print("\n" + "#" * 80)
            print(f"### RUN {run_idx + 1}/{n_runs} | Seed={random_state} | Dataset={current_dataset_name}")
            print("#" * 80)
            
            # Set numpy random seed for reproducibility of subsampling
            np.random.seed(random_state)
            
            out = {}  # Reset per-run output dict

            # ---------------------------------------------
            # Chargement du dataset OpenML ou EMNIST
            # ---------------------------------------------
            is_emnist = (dataset_name == "emnist")
            
            if is_emnist:
                # Cas spécial EMNIST : charger train et test séparément
                print("[EMNIST] Chargement EMNIST Balanced depuis load_dataset...")
                # Charger train set (pour train+cal)
                X_df_train, y_train_all, classes = load_dataset(
                    "emnist",
                    root="./data_cache",
                    train=True,
                    download=True,
                    max_rows=None,
                    seed=random_state
                )
                # Charger test set séparément
                X_df_test, y_test, _ = load_dataset(
                    "emnist",
                    root="./data_cache",
                    train=False,
                    download=True,
                    max_rows=None,
                    seed=random_state
                )
                
                print(f"[EMNIST] {len(classes)} classes")
                print(f"[EMNIST] Train set: {X_df_train.shape}, Test set: {X_df_test.shape}")
                
                # Convertir en numpy
                X_train_all = X_df_train.values.astype(np.float32)
                y_train_all = y_train_all.values
                X_test = X_df_test.values.astype(np.float32)
                y_test = y_test.values
                
                # Sous-échantillonnage si MAX_SAMPLES est défini
                # Répartition fixe: 40% test, 40% cal, 20% train
                if MAX_SAMPLES is not None:
                    total_samples = X_train_all.shape[0] + X_test.shape[0]
                    if total_samples > MAX_SAMPLES:
                        # Répartition fixe : test=40%, cal=40%, train=20%
                        n_test_sub = int(MAX_SAMPLES * 0.4)   # 10000 pour MAX_SAMPLES=25000
                        n_cal_sub = int(MAX_SAMPLES * 0.4)    # 10000 pour MAX_SAMPLES=25000
                        n_train_sub = MAX_SAMPLES - n_test_sub - n_cal_sub  # 5000 pour MAX_SAMPLES=25000
                        n_train_all_sub = n_train_sub + n_cal_sub  # train+cal avant split
                        
                        print(f"[EMNIST] Sous-échantillonnage: {total_samples} -> {MAX_SAMPLES} échantillons")
                        print(f"[EMNIST] Cible: Train={n_train_sub}, Cal={n_cal_sub}, Test={n_test_sub}")
                        
                        # Sous-échantillonner train_all (qui sera ensuite splitté en train+cal)
                        indices_train = np.random.choice(X_train_all.shape[0], n_train_all_sub, replace=False)
                        X_train_all = X_train_all[indices_train]
                        y_train_all = y_train_all[indices_train]
                        
                        # Sous-échantillonner test
                        indices_test = np.random.choice(X_test.shape[0], n_test_sub, replace=False)
                        X_test = X_test[indices_test]
                        y_test = y_test[indices_test]
                        
                        # Ratio pour le split train/cal : cal_sub / train_all_sub
                        cal_ratio = n_cal_sub / n_train_all_sub
                else:
                    cal_ratio = cal_size[dataset_name]
                
                # Séparer train_all en train+cal
                print(f"[EMNIST] Séparation train en train+cal (cal_ratio={cal_ratio:.2f})...")
                X_train, X_cal, y_train, y_cal = train_test_split(
                    X_train_all, y_train_all,
                    test_size=cal_ratio,
                    stratify=y_train_all,
                    random_state=random_state
                )
                
                print(f"[EMNIST] Train: {X_train.shape}, Cal: {X_cal.shape}, Test: {X_test.shape}")
            elif isinstance(dataset_name, int):
                try:
                    ds = fetch_ucirepo(id=dataset_name)
                    X = ds.data.features.to_numpy()
                    # targets est un DataFrame -> on prend la première colonne
                    y = ds.data.targets.iloc[:, 0].to_numpy()
                except (ConnectionError, Exception) as e:
                    print(f"Erreur de connexion UCI: {e}")
                    print(f"Skipping dataset {datasets_id.get(dataset_name, dataset_name)}...")
                    continue
                print(f"Shape X : {X.shape}, y : {y.shape}")
            else:
                try:
                    X, y = fetch_openml(dataset_name, version=1, as_frame=False, return_X_y=True)
                    if dataset_name == "mnist_784" or dataset_name == "Fashion-MNIST":
                        X = X.astype("float32") / 255.0
                        y = y.astype("int64")
                except Exception as e:
                    print(f"Erreur de chargement OpenML: {e}")
                    print(f"Skipping dataset {dataset_name}...")
                    continue
                print(f"Shape X : {X.shape}, y : {y.shape}")
            
            # Sous-échantillonnage si MAX_SAMPLES est défini (pour datasets non-EMNIST)
            if not is_emnist and MAX_SAMPLES is not None and X.shape[0] > MAX_SAMPLES:
                print(f"Sous-échantillonnage: {X.shape[0]} -> {MAX_SAMPLES} échantillons")
                indices = np.random.choice(X.shape[0], MAX_SAMPLES, replace=False)
                X = X[indices]
                y = y[indices]
                print(f"Shape après sous-échantillonnage: {X.shape}")

            alpha_cp = alphas_cp[dataset_name]

            if not is_emnist:
                # Encodage labels -> entiers si nécessaire
                y = np.asarray(y)
                # Toujours utiliser LabelEncoder pour s'assurer que les classes sont 0..K-1
                le = LabelEncoder()
                y = le.fit_transform(y)
                print(f"Labels encodés en entiers : {np.unique(y)}")
                
                # ---------------------------------------------
                # Split train / cal / test
                # ---------------------------------------------
                X_train, X_cal, X_test, y_train, y_cal, y_test = split_train_cal_test(
                    X, y,
                    test_size=test_size[dataset_name],
                    cal_size=cal_size[dataset_name],
                    random_state=random_state,
                )

                print(f"Taille train : {X_train.shape[0]}")
                print(f"Taille cal   : {X_cal.shape[0]}")
                print(f"Taille test  : {X_test.shape[0]}")

            # Standardisation
            if standardize:
                scaler = StandardScaler()
                X_train = scaler.fit_transform(X_train)
                X_cal = scaler.transform(X_cal)
                X_test = scaler.transform(X_test)
                print("Données standardisées.")
                
            # Calcul de sigma sur un sous-échantillon pour éviter OOM
            # Utiliser au maximum 10000 échantillons pour le calcul de distances
            n_train = X_train.shape[0]
            n_sigma_samples = min(10000, n_train)
            
            if n_train > n_sigma_samples:
                print(f"Sous-échantillonnage pour calcul sigma: {n_sigma_samples} échantillons (sur {n_train})")
                indices = np.random.choice(n_train, n_sigma_samples, replace=False)
                X_train_sigma = X_train[indices]
            else:
                X_train_sigma = X_train
            
            # Matrice de distances euclidiennes sur le sous-échantillon
            D = pairwise_distances(X_train_sigma, metric="euclidean")
            # On prend les éléments au-dessus de la diagonale
            triu = np.triu_indices_from(D, k=1)
            dists = D[triu]
            sigma = np.median(dists)

            print(f"Bande passante sigma (kernel cond. géométrique) : {sigma:.3f}")
            # ---------------------------------------------
            # apprendre une ACP sur X_train
            pca = PCA(n_components=3, svd_solver='full')
            pca.fit(X_train)
            # X_train_pca = pca.transform(X_train)
            X_cal_pca = pca.transform(X_cal)
            X_test_pca = pca.transform(X_test)
            
            # ---------------------------------------------
            # Entraînement modèle de base
            # ---------------------------------------------
            clf = train_base_model(X_train, y_train, random_state=random_state, model_type=model_type)

            y_cal_pred = clf.predict(X_cal)
            y_test_pred = clf.predict(X_test)

            acc_cal = accuracy_score(y_cal, y_cal_pred)
            acc_test = accuracy_score(y_test, y_test_pred)

            print(f"\nAccuracy (calibration) : {acc_cal:.3f}")
            print(f"Accuracy (test)        : {acc_test:.3f}")
            
            # ---------------------------------------------
            # CP vanilla (MAPIE, LAC/APS/RAPS)
            # ---------------------------------------------
            
            print("\n[MAPIE] CP vanilla (LAC)...")
            metrics_mapie_lac = eval_mapie_cp(
                X_cal, y_cal, X_test, y_test, clf, alpha_cp, method="lac"
            )
            print(f"  Coverage empirique   : {metrics_mapie_lac['coverage']:.2%}")
            print(f"  Coverage cible       : {metrics_mapie_lac['target_coverage']:.2%}")
            print(f"  Taille moyenne sets  : {metrics_mapie_lac['average_set_size']:.2f}")
            # ---------------------------------------------
            print("\n[MAPIE] CP vanilla (APS)...")
            metrics_mapie_aps = eval_mapie_cp(
                X_cal, y_cal, X_test, y_test, clf, alpha_cp, method="aps"
            )
            print(f"  Coverage empirique   : {metrics_mapie_aps['coverage']:.2%}")
            print(f"  Coverage cible       : {metrics_mapie_aps['target_coverage']:.2%}")
            print(f"  Taille moyenne sets  : {metrics_mapie_aps['average_set_size']:.2f}")
            # ---------------------------------------------
            print("\n[MAPIE] CP vanilla (RAPS)...")
            metrics_mapie_raps = eval_mapie_cp(
                X_cal, y_cal, X_test, y_test, clf, alpha_cp, method="raps"
            )
            print(f"  Coverage empirique   : {metrics_mapie_raps['coverage']:.2%}")
            print(f"  Coverage cible       : {metrics_mapie_raps['target_coverage']:.2%}")
            print(f"  Taille moyenne sets  : {metrics_mapie_raps['average_set_size']:.2f}")
            
            # ---------------------------------------------
            # MultiScoring CP – OT_MK (OT-CP)
            # ---------------------------------------------
            print(f"\n[MultiScoring] OT-CP (Monge-Kantorovich) - {ot_mk_method[dataset_name]}...")
            print(f"  Alpha = {alpha_cp}, coverage cible = {1.0 - alpha_cp:.2%}")
            metrics_ot = eval_multiscoring_cp(
                X_cal, y_cal, X_test, y_test, clf,
                quantile_type=QuantileType.OT_MK,
                alpha=alpha_cp,
                ot_mk_method=ot_mk_method[dataset_name],
            )
            print(f"  Coverage empirique   : {metrics_ot['coverage']:.2%}")
            print(f"  Coverage cible       : {metrics_ot['target_coverage']:.2%}")
            print(f"  Taille moyenne sets  : {metrics_ot['average_set_size']:.2f}")
            
            # ---------------------------------------------
            # MultiScoring CP – OT_MK CONDITIONAL (OT-CP+)
            # ---------------------------------------------
            print("\n[MultiScoring] OT-CP+ (conditional OT)...")
            print(f"  Alpha = {alpha_cp}, coverage cible = {1.0 - alpha_cp:.2%}")
            
            # Initialize for summary (will be used even if skipped)
            metrics_ot_cond = None
            wsc_ot_cond = np.nan
            
            if EVAL_OT_CP_COND:
                # Determine k adaptively
                k_neighbors = min(1000, X_cal.shape[0] // 10)
                
                metrics_ot_cond = eval_multiscoring_cp(
                    X_cal, y_cal, X_test, y_test, clf,
                    quantile_type=QuantileType.OT_MK,
                    alpha=alpha_cp,
                    use_conditional=True,  # Enable OT-CP+
                    k_neighbors=k_neighbors,
                    ot_mk_method=ot_mk_method[dataset_name],
                    X_cond_cal=X_cal_pca,
                    X_cond_test=X_test_pca,
                )
                print(f"  Coverage empirique   : {metrics_ot_cond['coverage']:.2%}")
                print(f"  Coverage cible       : {metrics_ot_cond['target_coverage']:.2%}")
                print(f"  Taille moyenne sets  : {metrics_ot_cond['average_set_size']:.2f}")
            else:
                print("  SKIPPED (EVAL_OT_CP_COND=False). Set EVAL_OT_CP_COND=True to enable.")
            
            # ---------------------------------------------
            # GRCP LOCAL (Local nonparametric rank, no relabeling)
            # ---------------------------------------------
            print("\n[GRCP] LOCAL (Local nonparametric rank, no relabeling)...")
            print(f"  Alpha = {alpha_cp}, coverage cible = {1.0 - alpha_cp:.2%}")
            print(f"  Kernel bandwidth sigma = {sigma:.4f}")
            
            # Build scoring functions for GRCP
            n_classes_grcp = clf.n_classes_
            grcp_scoring_functions = grcp_build_softmax_diff_scoring_functions(n_classes_grcp)
            
            metrics_grcp_local = eval_grcp_multiclass(
                X_cal=X_cal, y_cal=y_cal,
                X_test=X_test, y_test=y_test,
                base_clf=clf,
                scoring_functions=grcp_scoring_functions,
                alpha=alpha_cp,
                sigma=sigma,
                pca_dim=3,  # PCA reduction for kernel weights
                split_ratio=0.5,
                verbose=True,
            )
            print(f"  Coverage empirique   : {metrics_grcp_local['coverage']:.2%}")
            print(f"  Coverage cible       : {metrics_grcp_local['target_coverage']:.2%}")
            print(f"  Taille moyenne sets  : {metrics_grcp_local['average_set_size']:.2f}")
            # ---------------------------------------------

            # ---------------------------------------------
            # Calcul du Worst Slab Coverage (WSC) pour chaque méthode
            # ---------------------------------------------
            print("\n[WSC] Calcul du Worst Slab Coverage...")
            wsc_delta = 0.01
            wsc_M = 1000
            print(f"  Paramètres: delta={wsc_delta}, M={wsc_M} directions aléatoires")
            print(f"  Shape X_test pour WSC: {X_test.shape}")

            
            print("  [WSC] LAC...")
            wsc_lac = compute_wsc_for_method(X_test, y_test, metrics_mapie_lac['prediction_sets'], delta=wsc_delta, M=wsc_M)
            print(f"  WSC LAC        : {wsc_lac:.3f}")
            
            print("  [WSC] APS...")
            wsc_aps = compute_wsc_for_method(X_test, y_test, metrics_mapie_aps['prediction_sets'], delta=wsc_delta, M=wsc_M)
            print(f"  WSC APS        : {wsc_aps:.3f}")
            
            print("  [WSC] RAPS...")
            wsc_raps = compute_wsc_for_method(X_test, y_test, metrics_mapie_raps['prediction_sets'], delta=wsc_delta, M=wsc_M)
            print(f"  WSC RAPS       : {wsc_raps:.3f}")
            
            print("  [WSC] OT-CP...")
            wsc_ot = compute_wsc_for_method(X_test, y_test, metrics_ot['prediction_sets'], delta=wsc_delta, M=wsc_M)
            print(f"  WSC OT-CP      : {wsc_ot:.3f}")
            
            if metrics_ot_cond is not None:
                print("  [WSC] OT-CP+...")
                wsc_ot_cond = compute_wsc_for_method(X_test, y_test, metrics_ot_cond['prediction_sets'], delta=wsc_delta, M=wsc_M)
                print(f"  WSC OT-CP+     : {wsc_ot_cond:.3f}")
            
            print("  [WSC] GRCP...")
            wsc_grcp_local = compute_wsc_for_method(X_test, y_test, metrics_grcp_local['prediction_sets'], delta=wsc_delta, M=wsc_M)
            print(f"  WSC GRCP       : {wsc_grcp_local:.3f}")
            # ---------------------------------------------

            # Résumé run
            print(f"\n--- RÉSUMÉ RUN {run_idx + 1}/{n_runs} (Seed={random_state}) ---")
            print(f"Dataset : {current_dataset_name}")
            print(f"Accuracy test      : {acc_test:.3f}")
            print(f"MAPIE LAC : cov = {metrics_mapie_lac['coverage']:.3f}, "
                  f"size = {metrics_mapie_lac['average_set_size']:.2f}, wsc = {wsc_lac:.3f}")
            print(f"MAPIE APS : cov = {metrics_mapie_aps['coverage']:.3f}, "
                  f"size = {metrics_mapie_aps['average_set_size']:.2f}, wsc = {wsc_aps:.3f}")
            print(f"MAPIE RAPS: cov = {metrics_mapie_raps['coverage']:.3f}, "
                  f"size = {metrics_mapie_raps['average_set_size']:.2f}, wsc = {wsc_raps:.3f}")
            print(f"OT-CP     : cov = {metrics_ot['coverage']:.3f}, "
                  f"size = {metrics_ot['average_set_size']:.2f}, wsc = {wsc_ot:.3f}")
            if metrics_ot_cond is not None:
                print(f"OT-CP+    : cov = {metrics_ot_cond['coverage']:.3f}, "
                      f"size = {metrics_ot_cond['average_set_size']:.2f}, wsc = {wsc_ot_cond:.3f}")
            print(f"GRCP      : cov = {metrics_grcp_local['coverage']:.3f}, "
                  f"size = {metrics_grcp_local['average_set_size']:.2f}, wsc = {wsc_grcp_local:.3f}")
            
            # Créer les lignes pour ce run
            methods_data = {
                "ML": {"cov": acc_test, "width": 1, "wsc": None},
                "LAC": {"cov": metrics_mapie_lac['coverage'], "width": metrics_mapie_lac['average_set_size'], "wsc": wsc_lac},
                "APS": {"cov": metrics_mapie_aps['coverage'], "width": metrics_mapie_aps['average_set_size'], "wsc": wsc_aps},
                "RAPS": {"cov": metrics_mapie_raps['coverage'], "width": metrics_mapie_raps['average_set_size'], "wsc": wsc_raps},
                "OT-CP": {"cov": metrics_ot['coverage'], "width": metrics_ot['average_set_size'], "wsc": wsc_ot},
                "GRCP": {"cov": metrics_grcp_local['coverage'], "width": metrics_grcp_local['average_set_size'], "wsc": wsc_grcp_local}
            }
            
            # Add OT-CP+ only if evaluated
            if metrics_ot_cond is not None:
                methods_data["OT-CP+"] = {"cov": metrics_ot_cond['coverage'], "width": metrics_ot_cond['average_set_size'], "wsc": wsc_ot_cond}
            
            # Accumulate results for this run (one row per method)
            for method, m in methods_data.items():
                all_results[current_dataset_name].append({
                    "Dataset": current_dataset_name,
                    "Run": run_idx + 1,
                    "Seed": random_state,
                    "Method": method,
                    "Coverage": m["cov"],
                    "Coverage_target": 1.0 - alpha_cp,
                    "Width": m["width"],
                    "WSC": m.get("wsc", None),
                })
            
            # Memory cleanup between runs
            gc.collect()
            if HAS_CUPY:
                cp.get_default_memory_pool().free_all_blocks()
            
            print(f"[RUN {run_idx + 1}/{n_runs}] Completed for {current_dataset_name}")
        
        # =====================================================================
        # END OF MULTI-RUN LOOP - Save CSV for this dataset
        # =====================================================================
        csv_filename = os.path.join(results_dir, f"results_multiclass_{current_dataset_name.replace(' ', '_')}.csv")
        df_results = pd.DataFrame(all_results[current_dataset_name])
        df_results.to_csv(csv_filename, index=False)
        print(f"\n[SAVED] Résultats pour {current_dataset_name} sauvegardés dans : {csv_filename}")
        print(f"[INFO] {len(df_results)} lignes ({n_runs} runs × {len(methods_data)} méthodes)")
        
        # Print summary statistics across runs
        print(f"\n{'='*60}")
        print(f"SUMMARY STATISTICS FOR {current_dataset_name} ({n_runs} runs)")
        print(f"{'='*60}")
        for method in ["ML", "LAC", "APS", "RAPS", "OT-CP", "GRCP"]:
            method_data = df_results[df_results["Method"] == method]
            if len(method_data) > 0:
                cov_mean = method_data["Coverage"].mean()
                cov_std = method_data["Coverage"].std()
                width_mean = method_data["Width"].mean()
                width_std = method_data["Width"].std()
                wsc_vals = method_data["WSC"].dropna()
                if len(wsc_vals) > 0:
                    wsc_mean = wsc_vals.mean()
                    wsc_std = wsc_vals.std()
                    print(f"{method:8s}: cov={cov_mean:.3f}±{cov_std:.3f}, width={width_mean:.2f}±{width_std:.2f}, wsc={wsc_mean:.3f}±{wsc_std:.3f}")
                else:
                    print(f"{method:8s}: cov={cov_mean:.3f}±{cov_std:.3f}, width={width_mean:.2f}±{width_std:.2f}, wsc=N/A")
        
        if EVAL_OT_CP_COND:
            method_data = df_results[df_results["Method"] == "OT-CP+"]
            if len(method_data) > 0:
                cov_mean = method_data["Coverage"].mean()
                cov_std = method_data["Coverage"].std()
                width_mean = method_data["Width"].mean()
                width_std = method_data["Width"].std()
                wsc_vals = method_data["WSC"].dropna()
                wsc_mean = wsc_vals.mean() if len(wsc_vals) > 0 else float('nan')
                wsc_std = wsc_vals.std() if len(wsc_vals) > 0 else float('nan')
                print(f"{'OT-CP+':8s}: cov={cov_mean:.3f}±{cov_std:.3f}, width={width_mean:.2f}±{width_std:.2f}, wsc={wsc_mean:.3f}±{wsc_std:.3f}")
        
        print(f"\nDONE with {current_dataset_name}.")
