

# GPU support via CuPy
try:
    import cupy as cp
    HAS_CUPY = True
except ImportError:
    cp = None
    HAS_CUPY = False

import numpy as np
from typing import List, Callable, Union, Tuple, Dict
from enum import Enum

# BallTree for fast k-NN (used in conditional OT-CP)
try:
    from sklearn.neighbors import BallTree
    HAS_BALLTREE = True
except ImportError:
    BallTree = None
    HAS_BALLTREE = False


def to_cpu(arr):
    """Move array to CPU."""
    if HAS_CUPY and hasattr(arr, 'device'):
        return cp.asnumpy(arr)
    return np.asarray(arr)


def to_gpu(arr):
    """Move array to GPU if available."""
    if HAS_CUPY:
        return cp.asarray(arr, dtype=cp.float32)
    return np.asarray(arr, dtype=np.float32)


# ------------------------------------------------------------------
# Imports pour les modes géométrique radial & OT-MK
# ------------------------------------------------------------------

try:
    from geo_quantile_gpu import (
        fit_center_outward_model,
        quantile_center_outward,
        level_new_score,
    )
except ImportError:
    print("Warning: Could not import geo_quantile_gpu. "
          "Geometric radial mode will not be available.")

try:
    from ot_mk_quantile3_gpu import (
        fit_mk_rank_model,
        mk_radii_new_scores,
        mk_rank_new_score,
    )
except ImportError:
    print("Warning: Could not import ot_mk_quantile2_gpu. "
          "OT_MK mode will not be available.")


class QuantileType(Enum):
    STANDARD = "standard"
    GEOMETRIC = "geometric"
    OT_MK = "ot_mk"
    GEOMETRIC_COND = "geometric_cond"
    GRCP_LOCAL = "grcp_local"  # Local nonparametric rank (no relabeling)  

class MultiScoringConformalPredictor:
    """
    Backend générique de Conformal Prediction multi-scoring.

    - On part d'un vecteur de scores multivarié s(x,y) ∈ R^{d_score},
      où d_score = len(scoring_functions).

    - STANDARD :
        quantiles 1D par composante → produit d'intervalles dans R^{d_score}.
        (utile pour des cas simples / de debug)

    - OT_MK :
        rang multivarié OT à la CP-OT (Monge–Kantorovich) + seuil CP sur
        les rayons r = ||R_n(s)||.

    Cette classe NE fait AUCUNE hypothèse sur la nature de y (classe, réel, etc.).
    Elle travaille uniquement dans l'espace des scores s(x,y).
    """

    def __init__(
        self,
        scoring_functions: List[Callable],
        quantile_type: Union[str, QuantileType] = QuantileType.STANDARD,
        alpha: float = 0.1,
        random_state: int = 42
    ):
        """
        Parameters
        ----------
        scoring_functions : List[Callable]
            Liste des fonctions de score. Chaque fonction prend
                (y_true, y_pred_like)
            et retourne un score 1D par point (array de shape (n,)).
            Exemple : en classification multiclasse, y_pred_like peut être
            la matrice des probabilités.
        quantile_type : str or QuantileType
            Type de quantile :
                - 'standard'
                - 'geometric'
                - 'ot_mk'
        alpha : float
            Niveau de non-couverture (1 - alpha = couverture nominale).
            - "componentwise" : utile seulement pour STANDARD (quantiles 1D).
            - "radial" : requis pour GEOMETRIC (center-outward multivarié).
        random_state : int
            Seed pour la génération de nombres aléatoires (splits, références OT, etc.).
        """
        self.scoring_functions = scoring_functions
        self.n_scores = len(scoring_functions)
        self.random_state = random_state

        if isinstance(quantile_type, str):
            self.quantile_type = QuantileType(quantile_type)
        else:
            self.quantile_type = quantile_type

        self.alpha = alpha

        # Scores de calibration (shape = (n_cal, d_score))
        self.calibration_scores_: np.ndarray | None = None

        # STANDARD : quantiles 1D par composante
        self.quantiles_: np.ndarray | None = None

        # GEOMETRIC : modèle center-outward + seuil sur tau ∈ [0,1]
        self.geo_model_: Dict | None = None
        self.geo_threshold_: float | None = None

        # OT_MK : rang OT + seuil sur le rayon r
        self.ot_mk_model_: Dict | None = None
        self.ot_mk_radius_threshold_: float | None = None
        self.ot_mk_indices_D1_: np.ndarray | None = None
        self.ot_mk_indices_D2_: np.ndarray | None = None
        self.split_ratio: float = 0.5    # proportion de D1 pour la carte de rang
        self.ot_mk_k_nn: int | None = 10  # k-NN pour extension des rangs (None = auto)
        self.ot_mk_reg: float = 0.0      # régularisation entropique (0.0 = EMD)
        self.ot_mk_method: str = "exact"  # method for OT-MK: "exact", "nearest_neighbor", "barycentric", "local_ot", "hybrid"
        self.max_samples_geometry: int = 10000  # max points pour fit geometric (éviter OOM GPU)
        self.X_D1_: np.ndarray | None = None  # Features for D1 (for kNN)
        self.scores_D1_: np.ndarray | None = None  # Scores for D1
        self.k_neighbors: int = 100  # Number of neighbors for conditional models
        self.use_conditional_ot: bool = False  # Flag for OT-CP+ mode
        self.knn_tree_: BallTree | None = None  # BallTree for fast k-NN search
    # ------------------------------------------------------------------
    # Utilitaire : calcul des scores sur un dataset
    # ------------------------------------------------------------------
    def compute_scores(
        self,
        y_true: np.ndarray,
        y_pred_like: np.ndarray,
    ) -> np.ndarray:
        """
        Calcule la matrice de scores S ∈ R^{n × d_score} pour un dataset.

        Parameters
        ----------
        y_true : (n,)
            Cibles vraies (labels, valeurs de régression, etc.).
        y_pred_like : (n, ...) 
            Objet de prédiction passé tel quel aux fonctions de score
            (ex : probas softmax, prédictions réelles, etc.).

        Returns
        -------
        S : ndarray, shape (n, d_score)
            Scores multivariés.
        """
        n = len(y_true)
        S = np.zeros((n, self.n_scores), dtype=float)
        for j, score_func in enumerate(self.scoring_functions):
            S[:, j] = score_func(y_true, y_pred_like)
        return S

    def enable_conditional_ot(self, k_neighbors: int = 100):
        """
        Enable OT-CP+ (conditional/adaptive OT-CP) mode.
        
        Args:
            k_neighbors: Number of neighbors for local models
        """
        if self.quantile_type != QuantileType.OT_MK:
            raise ValueError("Conditional OT is only supported for QuantileType.OT_MK")
        
        self.use_conditional_ot = True
        self.k_neighbors = k_neighbors
        print(f"    [OT-CP+] Conditional mode enabled with k={k_neighbors} neighbors")

    # ------------------------------------------------------------------
    # Calibration : construit le contour dans l'espace des scores
    # ------------------------------------------------------------------
    def calibrate(
        self,
        y_cal_true: np.ndarray,
        y_cal_pred: np.ndarray,
        X_cal: np.ndarray | None = None,
        split_geometry: bool = True,
        task: str = "classification",
    ) -> "MultiScoringConformalPredictor":
        """
        Calibre le prédicteur conforme sur un ensemble de calibration.

        1. Calcule les scores s_i = s(x_i, y_i) via compute_scores().
        2. Construit le contour (STANDARD / GEOMETRIC / OT_MK) dans R^{d_score}.

        Parameters
        ----------
        y_cal_true : (n_cal,)
            Vraies valeurs de l'ensemble de calibration.
        y_cal_pred : (n_cal, ...)
            Prédictions sur l'ensemble de calibration.

        Returns
        -------
        self
        """
        n_cal = len(y_cal_true)

        # 1) Scores de calibration
        self.calibration_scores_ = self.compute_scores(y_cal_true, y_cal_pred)
        X = self.calibration_scores_   # (n_cal, d_score)

        # Quantile ajusté (version CP 1D, utilisé comme niveau pour τ ou r)
        adjusted_alpha = (n_cal + 1) * (1 - self.alpha) / n_cal

        # --- Cas 1 : OT_MK (CP-OT, rang OT multivariÃ©) -----------------
        if self.quantile_type == QuantileType.OT_MK:
            if "fit_mk_rank_model" not in globals():
                raise ImportError(
                    "OT_MK mode requires 'fit_mk_rank_model' and "
                    "'mk_radii_new_scores' from ot_mk_quantile.py"
                )

            # NEW: Check if conditional mode (OT-CP+) is requested
            if hasattr(self, 'use_conditional_ot') and self.use_conditional_ot:
                # =============================================================
                # OT-CP+ (Conditional OT-CP) - del Barrio et al. (2024)
                # 
                # Methodology:
                # - D1 = rank set (for k-NN local models)
                # - D2 = calibration set (for threshold)
                # 
                # For each point x (in D2 or test):
                #   1. Find k-NN of x in D1 (features)
                #   2. Fit local MK model on k neighbors' scores
                #   3. Compute radius r = ||R_k(s|x)||
                # 
                # Threshold: global quantile of D2 radii
                # =============================================================
                if X_cal is None:
                    raise ValueError("OT-CP+ requires X_cal (covariates) for kNN-based local models")
                
                # Split into D1 and D2
                n = X.shape[0]
                rng_data = np.random.default_rng(seed=self.random_state)
                perm = rng_data.permutation(n)
                n1 = int(np.floor(self.split_ratio * n))
                n1 = max(1, min(n1, n - 1))
                idx_D1 = perm[:n1]
                idx_D2 = perm[n1:]
                
                print(f"    [OT-CP+] Split D1/D2: n_total={n}, n_D1={len(idx_D1)} (rank set), n_D2={len(idx_D2)} (calibration)")
                print(f"    [OT-CP+] D1 and D2 are DISJOINT: {len(set(idx_D1) & set(idx_D2)) == 0}")
                
                S_D1 = X[idx_D1]
                S_D2 = X[idx_D2]
                
                # Store D1 features and scores for k-NN at inference
                self.X_D1_ = X_cal[idx_D1]  # Features from D1 (for k-NN search)
                self.scores_D1_ = S_D1       # Scores from D1 (for local models)
                
                # Build BallTree on D1 for fast k-NN search
                if HAS_BALLTREE:
                    self.knn_tree_ = BallTree(self.X_D1_, leaf_size=40)
                    print(f"    [OT-CP+] BallTree built on D1 (n={len(self.X_D1_)})")
                else:
                    self.knn_tree_ = None
                    print(f"    [OT-CP+] WARNING: BallTree not available, using naive k-NN")
                
                # Compute radii for D2 using LOCAL models (k-NN in D1)
                print(f"    [OT-CP+] Computing D2 radii with local k-NN models (n_D2={len(S_D2)}, k={self.k_neighbors})...")
                X_D2 = X_cal[idx_D2]
                r_D2 = np.zeros(len(S_D2), dtype=np.float32)
                
                # Batch k-NN query for all D2 points at once
                if self.knn_tree_ is not None:
                    _, all_knn_idx = self.knn_tree_.query(X_D2, k=self.k_neighbors)
                else:
                    all_knn_idx, _ = find_k_nearest_neighbors(X_D2, self.X_D1_, k=self.k_neighbors)
                
                import time
                t0 = time.time()
                
                # CRITICAL: Use SAME seed for all local models to ensure exchangeability
                # The reference vectors U0 must be comparable across all models
                fixed_seed_for_refs = self.random_state + 77777
                
                # Try parallel processing for calibration
                try:
                    from joblib import Parallel, delayed
                    HAS_JOBLIB = True
                except ImportError:
                    HAS_JOBLIB = False
                
                def compute_single_radius(i, knn_idx, S_D2_i, S_D1, ot_mk_reg, fixed_seed, ot_method, ot_k_nn, task_type):
                    """Compute radius for a single D2 point"""
                    S_neighbors = S_D1[knn_idx]
                    rng_local = np.random.default_rng(seed=fixed_seed)
                    local_model = fit_mk_rank_model(S_neighbors, reg=ot_mk_reg, rng=rng_local, task=task_type)
                    _, r_i, _ = mk_rank_new_score(S_D2_i, local_model, method=ot_method, k_nn=ot_k_nn)
                    return r_i
                
                n_D2 = len(S_D2)
                if HAS_JOBLIB and n_D2 > 200:
                    import os
                    n_jobs = min(8, os.cpu_count() or 4)
                    print(f"    [OT-CP+] Using parallel calibration with {n_jobs} workers...")
                    
                    r_D2 = np.array(Parallel(n_jobs=n_jobs, prefer="threads")(
                        delayed(compute_single_radius)(
                            i, all_knn_idx[i], S_D2[i], self.scores_D1_, 
                            self.ot_mk_reg, fixed_seed_for_refs, 
                            self.ot_mk_method, self.ot_mk_k_nn, task
                        ) for i in range(n_D2)
                    ), dtype=np.float32)
                    
                    elapsed = time.time() - t0
                    print(f"      [OT-CP+] Processed {n_D2} D2 points in {elapsed:.1f}s ({n_D2/elapsed:.1f} pts/s)")
                else:
                    # Sequential fallback
                    log_interval = max(1, n_D2 // 5)
                    for i in range(n_D2):
                        knn_idx = all_knn_idx[i]
                        S_neighbors = self.scores_D1_[knn_idx]
                        rng_local = np.random.default_rng(seed=fixed_seed_for_refs)
                        local_model = fit_mk_rank_model(S_neighbors, reg=self.ot_mk_reg, rng=rng_local, task=task)
                        _, r_i, _ = mk_rank_new_score(S_D2[i], local_model, method=self.ot_mk_method, k_nn=self.ot_mk_k_nn)
                        r_D2[i] = r_i
                        
                        if (i + 1) % log_interval == 0 or (i + 1) == n_D2:
                            elapsed = time.time() - t0
                            rate = (i + 1) / elapsed if elapsed > 0 else 0
                            print(f"      [OT-CP+] Processed {i+1}/{n_D2} D2 points ({rate:.1f} pts/s)")
                
                # Compute global threshold from D2 radii
                n2 = len(r_D2)
                k = int(np.ceil((n2 + 1) * (1.0 - self.alpha)))
                k = max(1, min(k, n2))
                r_sorted = np.sort(r_D2)
                threshold = float(r_sorted[k - 1])
                
                n_ties = np.sum(r_sorted == threshold)
                if n_ties > 1:
                    print(f"    [OT-CP+ WARNING] {n_ties} ties at threshold.")
                
                self.ot_mk_radius_threshold_ = threshold
                self.ot_mk_indices_D1_ = idx_D1
                self.ot_mk_indices_D2_ = idx_D2
                self.ot_mk_model_ = None  # No global model in OT-CP+
                
                print(f"    [OT-CP+] Calibration complete: threshold={threshold:.4f}")
                print(f"    [OT-CP+] At inference: local k-NN model from D1, compare to global threshold")
                
            else:
                # Original OT-CP path (non-conditional)
                n = X.shape[0]
                rng_data = np.random.default_rng(seed=self.random_state)
                rng_refs = np.random.default_rng(seed=self.random_state + 12345)
                perm = rng_data.permutation(n)
                n1 = int(np.floor(self.split_ratio * n))
                n1 = max(1, min(n1, n - 1))
                idx_D1 = perm[:n1]
                idx_D2 = perm[n1:]

                S_D1 = X[idx_D1]
                S_D2 = X[idx_D2]

                # 1) Ajuster la carte de rang MK sur D1
                self.ot_mk_model_ = fit_mk_rank_model(
                    S_D1,
                    reg=self.ot_mk_reg,
                    rng=rng_refs,
                    task=task
                )

                # 2) Rayons sur D2 et seuil CP-OT
                print(f"    [OT-CP] Computing radii on D2 (n={len(S_D2)})...")
                r_D2 = mk_radii_new_scores(
                    S_D2,
                    self.ot_mk_model_,
                    method=self.ot_mk_method,
                    k_nn=self.ot_mk_k_nn
                )
                n2 = r_D2.shape[0]
                k = int(np.ceil((n2 + 1) * (1.0 - self.alpha)))
                k = max(1, min(k, n2))
                r_sorted = np.sort(r_D2)
                threshold = float(r_sorted[k - 1])
                n_ties = np.sum(r_sorted == threshold)
                if n_ties > 1:
                    print(f"    [OT-CP WARNING] {n_ties} ties at threshold. "
                          f"Coverage guarantee: [{self.alpha:.3f}, {self.alpha + n_ties/(n2+1):.3f}]")
                self.ot_mk_radius_threshold_ = threshold
                self.ot_mk_indices_D1_ = idx_D1
                self.ot_mk_indices_D2_ = idx_D2

            # Clear other quantile types
            self.quantiles_ = None
            self.geo_model_ = None
            self.geo_threshold_ = None

            return self
        
        # --- Cas 2 : GEOMETRIC (center-outward HK) ---------------------
        if self.quantile_type == QuantileType.GEOMETRIC:
            if "fit_center_outward_model" not in globals():
                raise ImportError(
                    "Geometric mode requires "
                    "'fit_center_outward_model', 'quantile_center_outward', "
                    "and 'level_new_score' from geo_quantile.py"
                )
            if split_geometry:
                # Mélange des indices pour éviter les biais d'ordre
                rng = np.random.default_rng(seed=self.random_state)
                perm = rng.permutation(n_cal)
                n1 = int(n_cal * self.split_ratio)
                
                # S'assurer d'avoir assez de points
                if n1 < 10 or (n_cal - n1) < 10:
                    print("Warning: Not enough points to split. Using full D_cal.")
                    X_geo, X_thresh = X, X
                    adjusted_alpha = (n_cal + 1) * (1 - self.alpha) / n_cal
                else:
                    idx_geo = perm[:n1]
                    idx_thresh = perm[n1:]
                    X_geo = X[idx_geo]
                    X_thresh = X[idx_thresh]
                    
                    # LIMITE GPU: sous-échantillonner X_geo si trop grand
                    if len(X_geo) > self.max_samples_geometry:
                        print(f"    [GEOMETRIC] Sous-échantillonnage de D1: {len(X_geo)} -> {self.max_samples_geometry} points (éviter OOM GPU)")
                        rng_subsample = np.random.default_rng(seed=self.random_state + 999)
                        idx_subsample = rng_subsample.choice(len(X_geo), size=self.max_samples_geometry, replace=False)
                        X_geo = X_geo[idx_subsample]
                    
                    # Ajustement alpha sur la taille de D2 (X_thresh)
                    n2 = len(X_thresh)
                    adjusted_alpha = (n2 + 1) * (1 - self.alpha) / n2
            else:
                # Méthode actuelle (transgressive)
                X_geo = X
                # LIMITE GPU: sous-échantillonner si trop grand
                if len(X_geo) > self.max_samples_geometry:
                    print(f"    [GEOMETRIC] Sous-échantillonnage pour géométrie: {len(X_geo)} -> {self.max_samples_geometry} points (éviter OOM GPU)")
                    rng_subsample = np.random.default_rng(seed=self.random_state + 999)
                    idx_subsample = rng_subsample.choice(len(X_geo), size=self.max_samples_geometry, replace=False)
                    X_geo = X_geo[idx_subsample]
                
                X_thresh = X
                adjusted_alpha = (n_cal + 1) * (1 - self.alpha) / n_cal

            # 1. Apprendre la géométrie sur D1 (ou tout D_cal)
            self.geo_model_ = fit_center_outward_model(X_geo)
            
            # 2. Appliquer le modèle sur D2 pour obtenir les niveaux
            # Note: Si X_geo == X_thresh, on récupère les valeurs internes du modèle pour gagner du temps, sinon on recalcule.
            if split_geometry and X_geo is not X_thresh:
                # On doit calculer les niveaux pour les points de X_thresh
                # en utilisant la géométrie fixée par X_geo
                taus = np.array([level_new_score(s, self.geo_model_) for s in X_thresh])
            else:
                # Cas sans split : on prend les valeurs déjà calculées lors du fit
                taus = self.geo_model_["tau_ref"]

            # 3. Calcul du quantile sur les taus de D2
            # On réutilise la logique de quantile_center_outward mais sur 'taus'
            k = int(np.ceil(adjusted_alpha * len(taus)))
            k = max(1, min(k, len(taus)))
            tau_sorted = np.sort(taus)
            self.geo_threshold_ = float(tau_sorted[k - 1])

            return self

        # --- Cas 3 : STANDARD (quantiles 1D par composante) ------------
        self.quantiles_ = np.zeros(self.n_scores)
        for j in range(self.n_scores):
            scores_j = X[:, j]
            self.quantiles_[j] = float(np.quantile(scores_j, adjusted_alpha))

        self.geo_model_ = None
        self.geo_threshold_ = None
        self.ot_mk_model_ = None
        self.ot_mk_radius_threshold_ = None

        return self

    # ------------------------------------------------------------------
    # Calcul adaptatif de k_nn
    # ------------------------------------------------------------------
    def _get_effective_k_nn(self, n_samples: int) -> int | None:
        """
        Calcule le k_nn effectif selon la taille du dataset.
        
        Stratégie (basée sur expérience : k de l'ordre 10²-10³):
        - Si self.ot_mk_k_nn est un entier > 0 : utiliser cette valeur (fixe)
        - Si self.ot_mk_k_nn is None : mode adaptatif
          * n < 200 : None (pas de k-NN, méthode barycentric)
          * 200 <= n < 1000 : k_nn = min(100, n // 2)
          * 1000 <= n < 5000 : k_nn = min(500, n // 5)
          * n >= 5000 : k_nn = min(1000, n // 5)
        
        Returns
        -------
        k_nn : int or None
            Nombre de voisins, ou None pour désactiver k-NN
        """
        if self.ot_mk_k_nn is not None and self.ot_mk_k_nn > 0:
            # Mode manuel : utiliser la valeur fixée
            return min(self.ot_mk_k_nn, n_samples - 1)
        
        # Mode adaptatif (k de l'ordre 10²-10³)
        if n_samples < 200:
            # Très petit dataset : pas de k-NN (barycentric uniquement)
            return None
        elif n_samples < 1000:
            # Petit dataset : k ~ 100
            k = max(50, min(100, n_samples // 2))
            return k
        elif n_samples < 5000:
            # Dataset moyen : k ~ 200-500
            k = max(100, min(500, n_samples // 5))
            return k
        else:
            # Grand dataset : k ~ 500-1000
            k = max(200, min(1000, n_samples // 5))
            return k

    # ------------------------------------------------------------------
    # Niveaux / rayons pour de NOUVEAUX scores
    # ------------------------------------------------------------------
    def geometric_level(self, scores: np.ndarray) -> float:
        """
        Pour GEOMETRIC : retourne τ(s) = ||tilde F^g(s)|| ∈ [0,1].

        Parameters
        ----------
        scores : (d_score,)
            Vecteur de scores s(x,y).

        Returns
        -------
        tau : float
        """
        if self.quantile_type != QuantileType.GEOMETRIC:
            raise ValueError("geometric_level is only valid for QuantileType.GEOMETRIC.")
        if self.geo_model_ is None:
            raise ValueError("geo_model_ is not fitted. Call calibrate() first.")

        scores = np.asarray(scores, dtype=float).ravel()
        if scores.shape[0] != self.n_scores:
            raise ValueError("scores must have dimension d_score.")

        tau = level_new_score(scores, self.geo_model_)
        return float(tau)

    def ot_mk_radius(self, scores: np.ndarray) -> float:
        """
        For OT_MK : returns radius r(s) = ||R_n(s)||.
        
        Parameters
        ----------
        scores : (d_score,)
            Score vector
        method : str, optional
            Method to use. If None, uses self.ot_mk_method
        """
        if self.quantile_type != QuantileType.OT_MK:
            raise ValueError("ot_mk_radius is only valid for QuantileType.OT_MK.")
        if self.ot_mk_model_ is None:
            raise ValueError("ot_mk_model_ is not fitted. Call calibrate() first.")

        scores = np.asarray(scores, dtype=float).ravel()
        if scores.shape[0] != self.n_scores:
            raise ValueError("scores must have dimension d_score.")

        # Use provided method or default to self.ot_mk_method
        _, r, _ = mk_rank_new_score(scores, self.ot_mk_model_, method=self.ot_mk_method, k_nn=self.ot_mk_k_nn)
        return float(r)

    # ------------------------------------------------------------------
    # Test d'appartenance au contour (1 point ou batch)
    # ------------------------------------------------------------------
    def is_inside(self, scores: np.ndarray, X_query: np.ndarray = None) -> bool:
        """
        Teste si un vecteur de scores s(x,y) est dans le contour conforme.

        Parameters
        ----------
        scores : (d_score,)
            Vecteur de scores.

        Returns
        -------
        inside : bool
        """
        scores = np.asarray(scores, dtype=float).ravel()
        if scores.shape[0] != self.n_scores:
            raise ValueError("scores must have dimension d_score.")

        if self.quantile_type == QuantileType.STANDARD:
            if self.quantiles_ is None:
                raise ValueError("quantiles_ not set. Call calibrate() first.")
            return bool(np.all(scores <= self.quantiles_))

        elif self.quantile_type == QuantileType.GEOMETRIC:
            tau = self.geometric_level(scores)
            if self.geo_threshold_ is None:
                raise ValueError("geo_threshold_ not set. Call calibrate() first.")
            return bool(tau <= self.geo_threshold_)

        elif self.quantile_type == QuantileType.OT_MK:
            # Check if conditional mode
            if hasattr(self, 'use_conditional_ot') and self.use_conditional_ot:
                # =============================================================
                # OT-CP+ (Conditional OT-CP) - del Barrio et al. (2024)
                # 
                # 1. Find k-NN of X_query in D1 (features)
                # 2. Fit local MK model on k neighbors' scores
                # 3. Compute radius r = ||R_k(s|X_query)||
                # 4. Compare to global threshold (from D2 calibration)
                # =============================================================
                if X_query is None:
                    raise ValueError("OT-CP+ requires X_query (covariate) for local k-NN model")
                
                if self.ot_mk_radius_threshold_ is None:
                    raise ValueError("OT-CP+ not calibrated. Call calibrate() first.")
                
                # 1. Find k-NN of X_query in D1
                X_query_2d = np.atleast_2d(X_query)
                if self.knn_tree_ is not None:
                    _, knn_idx = self.knn_tree_.query(X_query_2d, k=self.k_neighbors)
                    knn_idx = knn_idx[0]
                else:
                    knn_idx, _ = find_k_nearest_neighbors(
                        X_query_2d, self.X_D1_, k=self.k_neighbors
                    )
                    knn_idx = knn_idx[0]
                
                # 2. Fit local MK model on k neighbors from D1
                S_neighbors = self.scores_D1_[knn_idx]
                # Use FIXED seed for reference vectors (must match calibration)
                rng_local = np.random.default_rng(seed=self.random_state + 77777)
                local_model = fit_mk_rank_model(
                    S_neighbors,
                    reg=self.ot_mk_reg,
                    rng=rng_local,
                    task="classification"  # Default task
                )
                
                # 3. Compute radius using local model (exact for small k-NN models)
                _, r, _ = mk_rank_new_score(scores, local_model, method="exact", k_nn=None)
                
                # 4. Compare to global threshold
                return bool(r <= self.ot_mk_radius_threshold_)
            else:
                # Original OT-CP (non-conditional)
                r = self.ot_mk_radius(scores)
                if self.ot_mk_radius_threshold_ is None:
                    raise ValueError("ot_mk_radius_threshold_ not set. Call calibrate() first.")
                return bool(r <= self.ot_mk_radius_threshold_)

        else:
            raise ValueError(f"Unknown quantile_type: {self.quantile_type}")

    def is_inside_batch(self, scores: np.ndarray) -> np.ndarray:
        """
        Teste l'appartenance pour un batch de scores.

        Parameters
        ----------
        scores : (n, d_score)
            Matrice de vecteurs de scores.

        Returns
        -------
        inside : (n,) bool
        """
        S = np.asarray(scores, dtype=float)
        if S.ndim == 1:
            return np.array([self.is_inside(S)], dtype=bool)

        if S.shape[1] != self.n_scores:
            raise ValueError("scores must have shape (n, d_score).")

        n = S.shape[0]

        if self.quantile_type == QuantileType.STANDARD:
            if self.quantiles_ is None:
                raise ValueError("quantiles_ not set. Call calibrate() first.")
            return np.all(S <= self.quantiles_[None, :], axis=1)

        elif self.quantile_type == QuantileType.GEOMETRIC:
            if self.geo_model_ is None or self.geo_threshold_ is None:
                raise ValueError("GEOMETRIC mode not properly calibrated.")
            inside = np.zeros(n, dtype=bool)
            for i in range(n):
                tau_i = level_new_score(S[i], self.geo_model_)
                inside[i] = (tau_i <= self.geo_threshold_)
            return inside

        elif self.quantile_type == QuantileType.OT_MK:
            if self.ot_mk_model_ is None or self.ot_mk_radius_threshold_ is None:
                raise ValueError("OT_MK mode not properly calibrated.")
            
            # Use instance method and k_nn settings
            radii = mk_radii_new_scores(S, self.ot_mk_model_, method=self.ot_mk_method, k_nn=self.ot_mk_k_nn)
            return radii <= self.ot_mk_radius_threshold_
        
        else:
            raise ValueError(f"Unknown quantile_type: {self.quantile_type}")

    def is_inside_batch_conditional(self, scores: np.ndarray, X_query: np.ndarray) -> np.ndarray:
        """
        Batch membership test for conditional OT-CP (OT-CP+).
        
        OT-CP+ methodology (del Barrio et al. 2024):
        - Fit local MK model on k-NN of X_query in D1
        - Compute radii for all score vectors using local model
        - Compare to global threshold (calibrated on D2)
        
        This method fits the local MK model ONCE for the given X_query,
        then batch-checks all score vectors against that single model.
        This is ~N times faster than calling is_inside() N times.
        
        Parameters
        ----------
        scores : (n, d_score)
            Matrix of score vectors to test.
        X_query : (d_x,) or (1, d_x)
            Single covariate point that determines the local neighborhood.
            
        Returns
        -------
        inside : (n,) bool
            Whether each score vector is inside the conformal region.
        """
        if not (hasattr(self, 'use_conditional_ot') and self.use_conditional_ot):
            raise ValueError("is_inside_batch_conditional requires conditional OT mode. "
                             "Call enable_conditional_ot() and recalibrate.")
        
        if self.ot_mk_radius_threshold_ is None:
            raise ValueError("Not calibrated. Call calibrate() first.")
        
        S = np.asarray(scores, dtype=np.float32)
        if S.ndim == 1:
            S = S.reshape(1, -1)
        
        if S.shape[1] != self.n_scores:
            raise ValueError(f"scores must have shape (n, {self.n_scores}).")
        
        n = S.shape[0]
        X_query = np.asarray(X_query, dtype=np.float32).reshape(1, -1)
        
        # 1. Find k neighbors in D1 using BallTree (fast) or naive (fallback)
        if self.knn_tree_ is not None:
            _, knn_idx = self.knn_tree_.query(X_query, k=self.k_neighbors)
            knn_idx = knn_idx[0]  # (k,)
        else:
            knn_idx, _ = find_k_nearest_neighbors(
                X_query, self.X_D1_, k=self.k_neighbors
            )
            knn_idx = knn_idx[0]  # (k,)
        
        # 2. Fit local MK model ONCE on k neighbors from D1
        S_neighbors = self.scores_D1_[knn_idx]
        # Use FIXED seed for reference vectors (must match calibration)
        rng_local = np.random.default_rng(seed=self.random_state + 77777)
        local_model = fit_mk_rank_model(
            S_neighbors,
            reg=self.ot_mk_reg,
            rng=rng_local,
            task="classification"  # Default task
        )
        
        # 3. Batch compute radii for all scores using local model
        # Use configured method (hybrid/barycentric) for better accuracy in high dimensions
        radii = mk_radii_new_scores(S, local_model, method=self.ot_mk_method, k_nn=self.ot_mk_k_nn)
        
        # 4. Compare to global threshold (calibrated on D2)
        return radii <= self.ot_mk_radius_threshold_

def build_softmax_diff_scoring_functions(n_classes: int):
    """
    Construit une liste de fonctions de score, une par classe c = 0..K-1.

        score_c(y_true, y_probs)[i] = | p_i[c] - 1_{y_true[i] = c} |

    Pour un couple (x, y), le vecteur de scores est donc s(x,y) ∈ R^K.
    """
    def make_score_for_coord(c: int):
        def score_func(y_true: np.ndarray, y_probs: np.ndarray) -> np.ndarray:
            y_true = np.asarray(y_true, dtype=int).ravel()      # (n,)
            y_probs = np.asarray(y_probs, dtype=float)          # (n, K)
            indicator = (y_true == c).astype(float)             # (n,)
            return np.abs(y_probs[:, c] - indicator)            # (n,)
        return score_func

    return [make_score_for_coord(c) for c in range(n_classes)]


def build_logit_margin_scoring_functions(n_classes: int):
    """
    Construit une liste de fonctions de score basées sur les marges de logits.
    
    Pour z(x) ∈ R^K les logits et y la vraie classe:
        S_k(x, y) = z_y(x) - z_k(x)   pour tout k ∈ {1, ..., K}
    
    Le score est POSITIF quand la vraie classe a un logit plus élevé que la classe k.
    Le score est 0 quand k == y (marge avec soi-même).
    
    IMPORTANT: Cette fonction attend des LOGITS (pré-softmax) en entrée, pas des probabilités.
    Pour obtenir des logits:
      - XGBoost: clf.predict(X, output_margin=True)
      - Random Forest: np.log(clf.predict_proba(X) + eps) (approximation)
    
    Parameters
    ----------
    n_classes : int
        Nombre de classes K.
    
    Returns
    -------
    scoring_functions : List[Callable]
        Liste de K fonctions de score.
    """
    def make_logit_margin_score_for_class(k: int):
        def score_func(y_true: np.ndarray, z_logits: np.ndarray) -> np.ndarray:
            """
            Calcule la marge de logit: z[y] - z[k] pour tous les échantillons.
            
            Args:
                y_true: (n,) labels des vraies classes
                z_logits: (n, K) logits (activations pré-softmax)
            
            Returns:
                (n,) scores de marge de logit pour la classe k
            """
            y_true = np.asarray(y_true, dtype=int).ravel()  # (n,)
            z = np.asarray(z_logits, dtype=float)           # (n, K)
            n = len(y_true)
            
            # z[y] pour chaque échantillon
            z_true = z[np.arange(n), y_true]  # (n,)
            
            # z[k] pour la classe candidate k
            z_k = z[:, k]  # (n,)
            
            # Marge: z[y] - z[k]
            # Convention: on retourne la valeur NÉGATIVE pour avoir "lower is better"
            # comme pour softmax_diff (plus non-conforme = score plus élevé)
            return z_k - z_true  # (n,) - quand z_true > z_k, score négatif (conforme)
        
        return score_func
    
    return [make_logit_margin_score_for_class(k) for k in range(n_classes)]


from typing import List, Callable

def build_adaptive_rank_aware_scoring_functions(K: int) -> List[Callable]:
    """
    Construit une liste de fonctions de score multivariées en classification
    multiclasse, qui exploitent explicitement le RANG de la vraie classe
    dans le vecteur de probabilités.

    Pour chaque (x, y) avec vecteur de probas p(x) ∈ ℝ^{num_classes} où num_classes >= K,
    on construit un vecteur de scores

        s(x, y) = (s1, s2, ..., s_{K+4}) ∈ ℝ^{K+4}

    avec :

    - s1 : 1 - P(y)                                  ∈ [0, 1]
    - s2 : rang normalisé de y                       ∈ [0, 1]
    - s3 : masse de probabilité STRICTEMENT devant y ∈ [0, 1]
           (somme des probas des classes plus probables que y)
    - s4 : entropie normalisée du softmax            ∈ [0, 1]
    - s5 à s_{K+4} : P_top_i pour i=1 à K
                     (probabilités des classes triées par ordre décroissant)

    Plus le score est grand, plus la paire (x, y) est "non conforme"
    (les P_top_i sont indépendants de y).

    Chaque composante est une fonction de score compatible avec
    MultiScoringConformalPredictor : (y_true, y_probs) -> (n_samples,).
    """

    def _common_precomputations(y_true: np.ndarray, y_probs: np.ndarray):
        """
        Pré-calculs vectorisés partagés par toutes les composantes :
        - p_true           (proba de la vraie classe)
        - order            (classes triées par proba décroissante)
        - rank_true        (rang de la vraie classe)
        - p_sorted         (probas triées)
        - cum_prob_ahead   (masse stricte devant y)
        - entropy_norm     (entropie normalisée)
        """
        y_true = np.asarray(y_true, dtype=int).ravel()   # (n,)
        p = np.asarray(y_probs, dtype=float)             # (n, num_classes)
        n, num_classes = p.shape

        if num_classes < K:
            raise ValueError(
                f"num_classes must be at least K: got {num_classes} < {K}"
            )

        # Probabilité de la vraie classe
        p_true = p[np.arange(n), y_true]  # (n,)

        # Tri décroissant : order[i, r] = classe de rang r
        order = np.argsort(-p, axis=1)                # (n, num_classes)

        # Inverse permutation : inv_order[i, c] = rang de la classe c
        inv_order = np.empty_like(order)
        inv_order[np.arange(n)[:, None], order] = np.arange(num_classes)[None, :]
        rank_true = inv_order[np.arange(n), y_true]   # (n,)

        # Probas triées
        p_sorted = np.take_along_axis(p, order, axis=1)  # (n, num_classes)

        # Probas cumulées
        cum_probs = np.cumsum(p_sorted, axis=1)          # (n, num_classes)

        # Masse STRICTEMENT devant la vraie classe :
        #   - si rank_true = 0 (top-1) → 0
        #   - sinon → somme des probas des classes de rang < rank_true
        cum_prob_ahead = np.zeros(n, dtype=float)
        mask = rank_true > 0
        cum_prob_ahead[mask] = cum_probs[mask, rank_true[mask] - 1]

        # Entropie normalisée par log(num_classes)
        eps = 1e-12
        entropy = -np.sum(p * np.log(p + eps), axis=1)   # (n,)
        if num_classes > 1:
            entropy_norm = entropy / np.log(num_classes)
        else:
            entropy_norm = np.zeros_like(entropy)

        return {
            "p_true": p_true,
            "rank_true": rank_true,
            "cum_prob_ahead": cum_prob_ahead,
            "p_sorted": p_sorted,
            "entropy_norm": entropy_norm,
            "num_classes": num_classes,
        }

    def make_cached_scores(component: str) -> Callable:
        """
        Retourne une fonction de score (y_true, y_probs) -> (n,)
        pour la composante demandée.
        """
        def score_func(y_true: np.ndarray, y_probs: np.ndarray) -> np.ndarray:
            cache = _common_precomputations(y_true, y_probs)
            p_true = cache["p_true"]
            rank_true = cache["rank_true"]
            cum_prob_ahead = cache["cum_prob_ahead"]
            p_sorted = cache["p_sorted"]
            entropy_norm = cache["entropy_norm"]
            num_classes = cache["num_classes"]

            if component == "nonconformity":
                # s1 = 1 - P(y)
                return 1.0 - p_true

            elif component == "rank":
                # s2 = rang normalisé dans [0, 1]
                # (0 -> top-1, 1 -> pire classe)
                if num_classes > 1:
                    return rank_true.astype(float) / (num_classes - 1)
                else:
                    return np.zeros_like(p_true)

            elif component == "cumprob":
                # s3 = masse STRICTEMENT devant y :
                #     somme des probas des classes plus probables que y.
                # → 0 pour un top-1 très confiant, proche de 1 si beaucoup de masse devant y.
                return cum_prob_ahead

            elif component == "entropy":
                # s4 = entropie normalisée
                return entropy_norm

            elif component.startswith("p_top_"):
                # s_{4+i} = P_top_i (prob de la i-ème classe la plus probable)
                i = int(component.split("_")[-1])
                if 1 <= i <= K:
                    return p_sorted[:, i - 1]
                else:
                    raise ValueError(f"Invalid p_top index: {i} for K={K}")

            else:
                raise ValueError(f"Unknown component: {component}")

        return score_func

    # Liste ordonnée des composantes
    components = [
        "nonconformity",  # s1 = 1 - P(y)
        "rank",           # s2 = rang normalisé
        "cumprob",        # s3 = masse devant la vraie classe
        "entropy",        # s4 = entropie normalisée
    ] + [f"p_top_{i}" for i in range(1, K + 1)]  # s5... = P_top_i

    scoring_functions = [make_cached_scores(c) for c in components]
    return scoring_functions


def build_reduced_scoring_functions(K: int) -> List[Callable]:
    """
    Construit une liste de K fonctions de score pour classification multiclasse,
    basées sur les écarts entre la probabilité de la vraie classe et les top-K
    classes les plus probables.
    
    Pour chaque (x, y) avec vecteur de probabilités p(x) ∈ ℝ^{num_classes},
    on construit un vecteur de scores:
    
        s(x, y) = (s_1, s_2, ..., s_K) ∈ ℝ^K
    
    où:
        s_i = |p(x)_y - p(x)_{top_i}| pour i = 1, ..., K
    
    avec p(x)_{top_i} la probabilité de la i-ème classe la plus probable.
    
    Cette représentation réduite permet d'éviter la malédiction de la dimensionnalité
    pour les problèmes avec un grand nombre de classes (ex: CIFAR-100 avec 100 classes).
    
    Parameters
    ----------
    K : int
        Nombre de scores à générer (nombre de top classes à considérer).
        Typiquement K=5 pour les grands problèmes de classification.
    
    Returns
    -------
    scoring_functions : List[Callable]
        Liste de K fonctions de score, chacune prenant (y_true, y_probs) -> (n,).
    """
    
    def make_score_for_topk(k: int):
        """
        Crée une fonction de score pour le k-ème top (1-indexed).
        
        score_k(y_true, y_probs)[i] = |p_i[y_true[i]] - p_i[top_k]|
        """
        def score_func(y_true: np.ndarray, y_probs: np.ndarray) -> np.ndarray:
            y_true = np.asarray(y_true, dtype=int).ravel()   # (n,)
            p = np.asarray(y_probs, dtype=float)             # (n, num_classes)
            n = p.shape[0]
            
            # Probabilité de la vraie classe
            p_true = p[np.arange(n), y_true]  # (n,)
            
            # Tri décroissant pour obtenir les top-K
            p_sorted = -np.sort(-p, axis=1)  # (n, num_classes) triées décroissantes
            
            # Probabilité du k-ème top (k est 1-indexed donc k-1 pour 0-indexed)
            if k - 1 < p.shape[1]:
                p_topk = p_sorted[:, k - 1]  # (n,)
            else:
                # Si K dépasse le nombre de classes, utiliser la dernière classe
                p_topk = p_sorted[:, -1]  # (n,)
            
            # Score = différence absolue
            return np.abs(p_true - p_topk)
        
        return score_func
    
    # Créer K fonctions de score, une pour chaque top-k (k = 1, 2, ..., K)
    scoring_functions = [make_score_for_topk(k) for k in range(1, K + 1)]
    return scoring_functions

def find_k_nearest_neighbors(X_query, X_ref, k, metric='euclidean', tree=None):
    """
    Find k nearest neighbors of X_query in X_ref.
    
    Args:
        X_query: (n_query, p) query points
        X_ref: (n_ref, p) reference points  
        k: number of neighbors
        metric: distance metric
        tree: optional pre-built BallTree for faster queries
    
    Returns:
        indices: (n_query, k) indices of nearest neighbors in X_ref
        distances: (n_query, k) distances to neighbors
    """
    X_query = to_cpu(X_query)
    X_ref = to_cpu(X_ref)
    k_actual = min(k, X_ref.shape[0])
    
    # Use BallTree if provided (O(log n) per query)
    if tree is not None:
        distances, indices = tree.query(X_query, k=k_actual)
        return indices, distances
    
    # Fallback: naive pairwise distances (O(n) per query)
    from sklearn.metrics import pairwise_distances
    D = pairwise_distances(X_query, X_ref, metric=metric)  # (n_query, n_ref)
    
    # Use argpartition for O(n) partial sort instead of O(n log n) full sort
    if k_actual < X_ref.shape[0]:
        indices = np.argpartition(D, k_actual, axis=1)[:, :k_actual]
        # Get actual distances and sort by distance within the k neighbors
        distances = np.take_along_axis(D, indices, axis=1)
        sorted_order = np.argsort(distances, axis=1)
        indices = np.take_along_axis(indices, sorted_order, axis=1)
        distances = np.take_along_axis(distances, sorted_order, axis=1)
    else:
        indices = np.argsort(D, axis=1)[:, :k_actual]
        distances = np.take_along_axis(D, indices, axis=1)
    
    return indices, distances