
from __future__ import annotations
import numpy as np
from typing import Dict, Tuple, Literal

try:
    import cupy as cp
    HAS_CUPY = True
except ImportError:
    cp = None
    HAS_CUPY = False

try:
    import ot
except ImportError as e:
    raise ImportError(
        "This module requires POT library.\n"
        "Install with: pip install POT"
    ) from e


def to_cpu(arr):
    """Move array to CPU."""
    if HAS_CUPY and hasattr(arr, 'device'):
        return cp.asnumpy(arr)
    return np.asarray(arr)


def to_gpu(arr):
    """Move array to GPU if available."""
    if HAS_CUPY:
        return cp.asarray(arr, dtype=cp.float32)
    return np.asarray(arr, dtype=np.float32)


# ---------------------------------------------------------------------------
# 1. Reference ranks construction (unchanged)
# ---------------------------------------------------------------------------

def build_reference_ranks(
    n: int,
    d: int,
    task: Literal["regression", "classification"] = "regression",
    rng=None,
):
    """
    Construit les vecteurs de référence {U_i}_{i=1}^n.
    
    U_i = (i/n) * θ_i
    
    - Classification : θ_i échantillonné uniformément dans {θ ∈ R^d_+ : ∥θ∥_1 = 1}
                      (simplex standard dans l'orthant positif)
    - Régression :     θ_i échantillonné uniformément sur la sphère unité (norme L2)
    
    GPU-accelerated.
    """
    xp = cp if HAS_CUPY else np
    
    # Rayons: i/n pour i=1..n
    if HAS_CUPY:
        radii = (cp.arange(1, n + 1, dtype=cp.float32) / n)[:, None]  # (n, 1)
    else:
        radii = (np.arange(1, n + 1, dtype=np.float32) / n)[:, None]  # (n, 1)
    
    if task == "classification":
        # Échantillonner uniformément sur le simplex standard (coordonnées positives)
        # Méthode: tirer des variables exponentielles et normaliser par leur somme
        if HAS_CUPY:
            # Variables exponentielles ~ Exp(1) via -log(Uniform(0,1))
            uniform_samples = cp.random.uniform(0, 1, size=(n, d), dtype=cp.float32)
            uniform_samples = cp.clip(uniform_samples, 1e-10, 1.0)
            exp_samples = -cp.log(uniform_samples)
            
            # Normaliser pour obtenir des points sur le simplex
            sum_exp = exp_samples.sum(axis=1, keepdims=True)
            theta = exp_samples / sum_exp  # (n, d) avec ∥θ_i∥_1 = 1
        else:
            if rng is None:
                rng = np.random.default_rng()
            
            # Variables exponentielles via -log(Uniform(0,1))
            uniform_samples = rng.uniform(0, 1, size=(n, d)).astype(np.float32)
            uniform_samples = np.clip(uniform_samples, 1e-10, 1.0)
            exp_samples = -np.log(uniform_samples)
            
            # Normaliser pour obtenir des points sur le simplex
            sum_exp = exp_samples.sum(axis=1, keepdims=True)
            theta = exp_samples / sum_exp  # (n, d) avec ∥θ_i∥_1 = 1
    
    elif task == "regression":
        # Échantillonner uniformément sur la sphère unité (norme L2)
        if HAS_CUPY:
            theta = cp.random.standard_normal(size=(n, d), dtype=cp.float32)
            norms = cp.linalg.norm(theta, axis=1, keepdims=True)
            norms = cp.where(norms == 0.0, cp.float32(1.0), norms)
            theta = theta / norms
        else:
            if rng is None:
                rng = np.random.default_rng()
            theta = rng.standard_normal(size=(n, d)).astype(np.float32)
            norms = np.linalg.norm(theta, axis=1, keepdims=True)
            norms[norms == 0.0] = 1.0
            theta = theta / norms
    
    else:
        raise ValueError(f"task doit être 'regression' ou 'classification', reçu: {task}")
    
    # U_i = (i/n) * θ_i
    U0 = radii * theta  # (n, d)
    
    return U0


def _empirical_cdf_1d(values):
    """Empirical CDF for 1D values."""
    xp = cp if HAS_CUPY else np
    v = xp.asarray(values, dtype=xp.float32).ravel()
    n = v.size
    order = xp.argsort(v)
    ranks = xp.empty(n, dtype=xp.float32)
    ranks[order] = xp.arange(1, n + 1, dtype=xp.float32)
    tau = ranks / n
    return tau


# ---------------------------------------------------------------------------
# 2. Fit MK rank model (store necessary info for out-of-sample extension)
# ---------------------------------------------------------------------------

def fit_mk_rank_model(
    scores_calib,
    task: Literal["regression", "classification"] = "regression",
    reg: float = 0.0,
    rng=None,
    numItermax: int = 1000000,
) -> Dict:
    """
    Fits MK rank map on calibration scores.
    
    CORRECTED: Stores all information needed for exact out-of-sample extension.
    
    Returns
    -------
    model : dict with keys:
        - S_ref : original scores (n, d)
        - S_ref_norm : normalized scores (n, d)
        - score_mean, score_std : normalization parameters
        - U0 : reference rank vectors (n, d)
        - U_ref : transported references (n, d), i.e., U_ref[i] = R_n(S_ref[i])
        - r_ref : radii ||U_ref[i]|| (n,)
        - tau_ref : empirical CDF of radii (n,)
        - assignment : permutation σ_n such that S_i -> U_{σ(i)}
        - reg : regularization parameter
        - task : "regression" or "classification"
    """
    xp = cp if HAS_CUPY else np
    
    # Convert to GPU if available
    X = xp.asarray(scores_calib, dtype=xp.float32)
    if X.ndim == 1:
        X = X[:, None]
    n, d = X.shape
    
    # Normalization (for numerical stability)
    score_mean_ = X.mean(axis=0)
    score_std_ = X.std(axis=0) + 1e-8
    X_norm = (X - score_mean_) / score_std_
    
    # Build reference ranks (use separate RNG)
    if rng is None:
        # Create a separate RNG for reference vectors to ensure independence
        rng_refs = np.random.default_rng(seed=None)
    else:
        # If a seed is provided, derive a different seed for references
        base_seed = rng.integers(0, 2**31) if hasattr(rng, 'integers') else 42
        rng_refs = np.random.default_rng(seed=base_seed + 12345)
    
    U0 = build_reference_ranks(n, d, task=task, rng=rng_refs)
    
    # Solve OT problem (on CPU - POT requirement)
    X_norm_cpu = to_cpu(X_norm)
    X_cpu = to_cpu(X)
    U0_cpu = to_cpu(U0)
    
    a = np.full(n, 1.0 / n, dtype=np.float32)
    b = np.full(n, 1.0 / n, dtype=np.float32)
    M = ot.dist(X_cpu, U0_cpu, metric="euclidean") ** 2
    
    reg = float(reg) if reg is not None else 0.0
    
    if reg > 0.0:
        gamma = ot.sinkhorn(a, b, M, reg, numItermax=numItermax)
    else:
        gamma = ot.emd(a, b, M, numItermax=numItermax)
    
    # Extract the assignment (permutation σ_n)
    # For discrete OT, gamma is almost a permutation matrix
    # Find the assignment: sigma[i] = j means S_i -> U_j
    assignment = np.argmax(gamma, axis=1)  # (n,) indices
    
    # Compute U_ref: for each score S_i, its rank is U_{sigma[i]}
    U_ref_cpu = U0_cpu[assignment]  # (n, d)
    U_ref = to_gpu(U_ref_cpu)
    
    # Radii and levels
    r_ref = xp.linalg.norm(U_ref, axis=1)
    tau_ref = _empirical_cdf_1d(r_ref)
    
    model: Dict = {
        "S_ref": X,
        "S_ref_norm": X_norm,
        "score_mean": score_mean_,
        "score_std": score_std_,
        "U0": to_gpu(U0),  # Store reference points
        "U_ref": U_ref,
        "r_ref": r_ref,
        "tau_ref": tau_ref,
        "assignment": to_gpu(assignment),  # Store the permutation
        "reg": reg,
        "task": task,
    }
    return model


# ---------------------------------------------------------------------------
# 3. Extension methods (barycentric, local_ot, hybrid)
# ---------------------------------------------------------------------------

def _barycentric_extension_gpu(distances, U_k):
    """Interpolation barycentrique GPU-accelerated."""
    xp = cp if HAS_CUPY else np
    d_safe = xp.maximum(distances, 1e-10)
    w = 1.0 / d_safe
    w = w / xp.sum(w)
    u_new = xp.sum(w[:, None] * U_k, axis=0)
    return u_new


def _barycentric_extension(distances, U_k):
    """Interpolation barycentrique simple k-NN (CPU fallback)."""
    distances = to_cpu(distances)
    U_k = to_cpu(U_k)
    d_safe = np.maximum(distances, 1e-10)
    w = 1.0 / d_safe
    w = w / np.sum(w)
    u_new = np.sum(w[:, None] * U_k, axis=0)
    return u_new


def _local_ot_extension(S_k, U_k, s_new, task: str = "regression"):
    """Extension par transport optimal local (CPU - POT)."""
    S_k = to_cpu(S_k)
    U_k = to_cpu(U_k)
    s_new = to_cpu(s_new)
    
    k = S_k.shape[0]
    S_aug = np.vstack([S_k, s_new[None, :]])
    n_aug = k + 1

    a = np.full(n_aug, 1.0 / n_aug, dtype=np.float32)
    b = np.full(k, 1.0 / k, dtype=np.float32)
    M = ot.dist(S_aug, U_k, metric="euclidean") ** 2

    try:
        gamma = ot.emd(a, b, M, numItermax=100000)
        w_new = gamma[-1, :]
        u_new = np.sum(w_new[:, None] * U_k, axis=0)
    except RuntimeError:
        u_new = _barycentric_extension(
            np.linalg.norm(S_aug[:-1] - s_new[None, :], axis=1), U_k
        )
    return u_new


def _hybrid_extension_gpu(S_k, U_k, s_new, distances, task: str = "regression"):
    """Extension hybride GPU-accelerated."""
    xp = cp if HAS_CUPY else np
    
    u_bary = _barycentric_extension_gpu(distances, U_k)
    r_k = xp.linalg.norm(U_k, axis=1)
    r_bary = xp.linalg.norm(u_bary)

    d_safe = xp.maximum(distances, 1e-10)
    w = (1.0 / d_safe) / xp.sum(1.0 / d_safe)
    r_pred = xp.sum(w * r_k)

    if float(to_cpu(r_bary)) > 1e-10:
        correction_factor = r_pred / (r_bary + 1e-10)
        correction_factor = xp.clip(correction_factor, 0.5, 2.0)
        u_hybrid = u_bary * correction_factor
    else:
        u_hybrid = u_bary

    return u_hybrid


def _hybrid_extension(S_k, U_k, s_new, distances, task: str = "regression"):
    """Extension hybride (CPU fallback)."""
    S_k = to_cpu(S_k)
    U_k = to_cpu(U_k)
    distances = to_cpu(distances)
    
    u_bary = _barycentric_extension(distances, U_k)
    r_k = np.linalg.norm(U_k, axis=1)
    r_bary = np.linalg.norm(u_bary)

    d_safe = np.maximum(distances, 1e-10)
    w = (1.0 / d_safe) / np.sum(1.0 / d_safe)
    r_pred = np.sum(w * r_k)

    if r_bary > 1e-10:
        correction_factor = r_pred / (r_bary + 1e-10)
        correction_factor = np.clip(correction_factor, 0.5, 2.0)
        u_hybrid = u_bary * correction_factor
    else:
        u_hybrid = u_bary

    return u_hybrid


# ---------------------------------------------------------------------------
# 4. CORRECTED: Exact out-of-sample MK rank computation
# ---------------------------------------------------------------------------

def mk_rank_new_score(
    score_new,
    model: Dict,
    method: Literal["exact", "nearest_neighbor", "barycentric", "local_ot", "hybrid"] = "exact",
    k_nn: int = 10,
) -> Tuple:
    """
    Computes MK rank for a new score.
    
    Parameters
    ----------
    score_new : array (d,)
        New score to rank
    model : dict
        Fitted MK model from fit_mk_rank_model
    method : str
        - "exact": Solve assignment problem (theoretically correct)
        - "nearest_neighbor": Find nearest reference in normalized space (fast approximation)
        - "barycentric": k-NN weighted interpolation of ranks
        - "local_ot": Local optimal transport extension
        - "hybrid": Hybrid method combining barycentric with radius correction
    k_nn : int
        Number of nearest neighbors for barycentric/local_ot/hybrid methods
    
    Returns
    -------
    u_new : array (d,)
        MK rank vector R_n(score_new)
    r_new : float
        Radius ||R_n(score_new)||
    tau_new : float
        Empirical level (CDF of radius among calibration radii)
    """
    xp = cp if HAS_CUPY else np
    
    # Retrieve model components
    S_ref_norm = model["S_ref_norm"]
    U_ref = model["U_ref"]
    U0 = model["U0"]
    r_ref = model["r_ref"]
    score_mean = model["score_mean"]
    score_std = model["score_std"]
    task = model.get("task", "regression")
    
    n, d = S_ref_norm.shape
    
    # Normalize the new score
    s = xp.asarray(score_new, dtype=xp.float32).ravel()
    if s.shape[0] != d:
        raise ValueError(f"score_new must have dimension {d}, got {s.shape[0]}.")
    
    s_norm = (s - score_mean) / score_std
    
    # --- Method selection ---
    
    if method == "exact":
        # Paper's Definition 2.1 / Equation (4):
        # Find j = argmin_j ||s_norm - U0[j]||^2
        # This is the discrete version of argmax_U {<U, s> - ψ(U)}
        
        # Compute distances to all reference points (in normalized space)
        diff = U0 - s_norm[None, :]  # (n, d)
        distances_sq = xp.sum(diff**2, axis=1)  # (n,)
        
        # Find the closest reference point
        j_star = int(to_cpu(xp.argmin(distances_sq)))
        
        # The MK rank is U0[j_star]
        u_new = U0[j_star].copy()
    
    elif method == "nearest_neighbor":
        # Alternative: find nearest calibration score, use its rank
        # This is slightly different but still valid for coverage
        diff = S_ref_norm - s_norm[None, :]
        distances = xp.linalg.norm(diff, axis=1)
        i_star = int(to_cpu(xp.argmin(distances)))
        u_new = U_ref[i_star].copy()
    
    elif method in ["barycentric", "local_ot", "hybrid"]:
        # k-NN based methods
        # Compute distances to calibration scores
        diff = S_ref_norm - s_norm[None, :]
        dist = xp.linalg.norm(diff, axis=1)
        
        # Check if exactly on a calibration point
        idx_zero = xp.where(dist < 1e-10)[0]
        if idx_zero.size > 0:
            i0 = int(to_cpu(idx_zero[0]))
            u_new = U_ref[i0].copy()
        else:
            # k-NN on GPU
            k = min(max(1, int(k_nn)), n)
            idx = xp.argpartition(dist, k - 1)[:k]
            d_k = dist[idx]
            S_k = S_ref_norm[idx]
            U_k = U_ref[idx]
            
            if method == "barycentric":
                u_new = _barycentric_extension_gpu(d_k, U_k)
            elif method == "local_ot":
                # OT local nécessite CPU
                u_new = _local_ot_extension(to_cpu(S_k), to_cpu(U_k), to_cpu(s_norm), task=task)
                u_new = to_gpu(u_new)
            elif method == "hybrid":
                u_new = _hybrid_extension_gpu(S_k, U_k, s_norm, d_k, task=task)
    
    else:
        raise ValueError(f"method must be 'exact', 'nearest_neighbor', 'barycentric', 'local_ot', or 'hybrid', got {method}")
    
    # Compute radius and level
    r_new = float(to_cpu(xp.linalg.norm(u_new)))
    tau_new = float(to_cpu(xp.mean(r_ref <= r_new)))
    
    return u_new, r_new, tau_new


def _local_ot_extension_batch(S_k_batch, U_k_batch, s_new_batch, task: str = "regression"):
    """
    Batch version of local OT extension using joblib parallelization.
    
    Args:
        S_k_batch: (m, k, d) - k neighbors for each of m query points
        U_k_batch: (m, k, d) - corresponding ranks for each neighbor set
        s_new_batch: (m, d) - query points (normalized)
        task: "regression" or "classification"
    
    Returns:
        u_new_batch: (m, d) - MK ranks for each query point
    """
    m = s_new_batch.shape[0]
    d = s_new_batch.shape[1]
    
    # Try to use joblib for parallelization
    try:
        from joblib import Parallel, delayed
        USE_JOBLIB = True
    except ImportError:
        USE_JOBLIB = False
    
    def compute_single_ot(idx):
        """Compute local OT for a single point."""
        S_k = S_k_batch[idx]
        U_k = U_k_batch[idx]
        s_new = s_new_batch[idx]
        
        k = S_k.shape[0]
        S_aug = np.vstack([S_k, s_new[None, :]])
        n_aug = k + 1

        a = np.full(n_aug, 1.0 / n_aug, dtype=np.float32)
        b = np.full(k, 1.0 / k, dtype=np.float32)
        M = ot.dist(S_aug, U_k, metric="euclidean") ** 2

        try:
            gamma = ot.emd(a, b, M, numItermax=100000)
            w_new = gamma[-1, :]
            u_new = np.sum(w_new[:, None] * U_k, axis=0)
        except RuntimeError:
            # Fallback to barycentric
            distances = np.linalg.norm(S_k - s_new[None, :], axis=1)
            d_safe = np.maximum(distances, 1e-10)
            w = 1.0 / d_safe
            w = w / np.sum(w)
            u_new = np.sum(w[:, None] * U_k, axis=0)
        return u_new
    
    if USE_JOBLIB and m > 100:
        # Use parallel processing for large batches
        n_jobs = min(8, max(1, m // 500))  # Adaptive number of jobs
        results = Parallel(n_jobs=n_jobs, backend="threading")(
            delayed(compute_single_ot)(i) for i in range(m)
        )
        u_new_batch = np.array(results, dtype=np.float32)
    else:
        # Sequential processing for small batches
        u_new_batch = np.zeros((m, d), dtype=np.float32)
        for i in range(m):
            u_new_batch[i] = compute_single_ot(i)
    
    return u_new_batch


def mk_radii_new_scores(
    scores_new,
    model: Dict,
    method: Literal["exact", "nearest_neighbor", "barycentric", "local_ot", "hybrid"] = "exact",
    k_nn: int = 10,
):
    """
    Computes radii for a batch of new scores.
    
    VECTORIZED VERSION: Computes all k-NN in batch and parallelizes OT computations.
    """
    xp = cp if HAS_CUPY else np
    X = xp.asarray(scores_new, dtype=xp.float32)
    if X.ndim == 1:
        X = X[:, None]
    m, d = X.shape
    
    # For small batches or "exact" method, use original loop
    if m <= 10 or method == "exact":
        radii = xp.zeros(m, dtype=xp.float32)
        for j in range(m):
            _, r_new, _ = mk_rank_new_score(X[j], model, method=method, k_nn=k_nn)
            radii[j] = r_new
        return radii
    
    # --- VECTORIZED PATH for large batches ---
    
    # Retrieve model components
    S_ref_norm = model["S_ref_norm"]
    U_ref = model["U_ref"]
    U0 = model["U0"]
    r_ref = model["r_ref"]
    score_mean = model["score_mean"]
    score_std = model["score_std"]
    task = model.get("task", "regression")
    
    n_ref = S_ref_norm.shape[0]
    k = min(max(1, int(k_nn)), n_ref)
    
    # Normalize all query scores at once
    X_norm = (X - score_mean) / score_std  # (m, d)
    
    # =====================================================================
    # GPU-NATIVE PATH for barycentric/hybrid (no CPU transfer needed)
    # =====================================================================
    if HAS_CUPY and method in ["barycentric", "hybrid", "nearest_neighbor"]:
        # Stay on GPU - adaptive chunk size based on ~16GB VRAM
        # Memory per chunk: (chunk_size × n_ref) × 4 bytes for distances
        # With n_ref~5000, chunk_size=8000 → ~160MB per distance matrix
        chunk_size = min(8000, m)
        all_radii = xp.zeros(m, dtype=xp.float32)
        
        # Pre-compute S_ref squared norms (reused across chunks)
        S_sq = xp.sum(S_ref_norm ** 2, axis=1)  # (n_ref,)
        
        for chunk_start in range(0, m, chunk_size):
            chunk_end = min(chunk_start + chunk_size, m)
            X_chunk = X_norm[chunk_start:chunk_end]  # (chunk_size, d)
            m_chunk = X_chunk.shape[0]
            
            # Compute distances on GPU: (m_chunk, n_ref)
            X_sq = xp.sum(X_chunk ** 2, axis=1, keepdims=True)  # (m_chunk, 1)
            XS = X_chunk @ S_ref_norm.T  # (m_chunk, n_ref)
            dist_sq = X_sq + S_sq[None, :] - 2 * XS
            dist = xp.sqrt(xp.maximum(dist_sq, 0))  # (m_chunk, n_ref)
            
            # Check for exact matches
            min_dist = dist.min(axis=1)
            exact_match_mask = min_dist < 1e-10
            
            # k-NN on GPU with argpartition
            knn_indices = xp.argpartition(dist, k - 1, axis=1)[:, :k]  # (m_chunk, k)
            
            # Gather k-NN distances
            knn_distances = xp.take_along_axis(dist, knn_indices, axis=1)  # (m_chunk, k)
            
            # Gather U_k_batch: (m_chunk, k, d)
            U_k_batch = U_ref[knn_indices]  # (m_chunk, k, d)
            
            # --- Compute MK ranks based on method ---
            if method == "nearest_neighbor":
                nn_idx = xp.argmin(dist, axis=1)  # (m_chunk,)
                u_new_batch = U_ref[nn_idx]  # (m_chunk, d)
                
            elif method == "barycentric":
                # Barycentric interpolation on GPU
                d_safe = xp.maximum(knn_distances, 1e-10)  # (m_chunk, k)
                w = 1.0 / d_safe
                w = w / w.sum(axis=1, keepdims=True)  # (m_chunk, k)
                u_new_batch = xp.einsum('mk,mkd->md', w, U_k_batch)  # (m_chunk, d)
                
            elif method == "hybrid":
                # Hybrid: barycentric with radius correction on GPU
                d_safe = xp.maximum(knn_distances, 1e-10)
                w = 1.0 / d_safe
                w = w / w.sum(axis=1, keepdims=True)
                u_bary = xp.einsum('mk,mkd->md', w, U_k_batch)  # (m_chunk, d)
                
                r_bary = xp.linalg.norm(u_bary, axis=1, keepdims=True)  # (m_chunk, 1)
                r_k = xp.linalg.norm(U_k_batch, axis=2)  # (m_chunk, k)
                r_pred = xp.sum(w * r_k, axis=1, keepdims=True)  # (m_chunk, 1)
                
                correction = xp.where(r_bary > 1e-10, r_pred / (r_bary + 1e-10), 1.0)
                correction = xp.clip(correction, 0.5, 2.0)
                u_new_batch = u_bary * correction
            
            # Handle exact matches
            if xp.any(exact_match_mask):
                exact_idx = xp.argmin(dist[exact_match_mask], axis=1)
                u_new_batch[exact_match_mask] = U_ref[exact_idx]
            
            # Compute radii on GPU
            radii_chunk = xp.linalg.norm(u_new_batch, axis=1)
            all_radii[chunk_start:chunk_end] = radii_chunk
        
        return all_radii  # Already on GPU
    
    # =====================================================================
    # CPU PATH (fallback for local_ot or when CuPy not available)
    # =====================================================================
    X_norm_cpu = to_cpu(X_norm)
    S_ref_norm_cpu = to_cpu(S_ref_norm)
    U_ref_cpu = to_cpu(U_ref)
    U0_cpu = to_cpu(U0)
    r_ref_cpu = to_cpu(r_ref)
    
    # --- BATCH k-NN computation ---
    chunk_size = min(2000, m)
    all_radii = np.zeros(m, dtype=np.float32)
    
    for chunk_start in range(0, m, chunk_size):
        chunk_end = min(chunk_start + chunk_size, m)
        X_chunk = X_norm_cpu[chunk_start:chunk_end]  # (chunk_size, d)
        m_chunk = X_chunk.shape[0]
        
        # Compute distances: (m_chunk, n_ref)
        X_sq = np.sum(X_chunk ** 2, axis=1, keepdims=True)  # (m_chunk, 1)
        S_sq = np.sum(S_ref_norm_cpu ** 2, axis=1, keepdims=True).T  # (1, n_ref)
        XS = X_chunk @ S_ref_norm_cpu.T  # (m_chunk, n_ref)
        dist_sq = X_sq + S_sq - 2 * XS
        dist = np.sqrt(np.maximum(dist_sq, 0))  # (m_chunk, n_ref)
        
        # Check for exact matches
        min_dist = dist.min(axis=1)
        exact_match_mask = min_dist < 1e-10
        
        # Get k-NN indices
        knn_indices = np.argpartition(dist, k - 1, axis=1)[:, :k]  # (m_chunk, k)
        
        # Gather k-NN data
        knn_distances = np.take_along_axis(dist, knn_indices, axis=1)
        S_k_batch = S_ref_norm_cpu[knn_indices]
        U_k_batch = U_ref_cpu[knn_indices]
        
        # --- Compute MK ranks based on method ---
        if method == "nearest_neighbor":
            nn_idx = np.argmin(dist, axis=1)
            u_new_batch = U_ref_cpu[nn_idx]
            
        elif method == "barycentric":
            d_safe = np.maximum(knn_distances, 1e-10)
            w = 1.0 / d_safe
            w = w / w.sum(axis=1, keepdims=True)
            u_new_batch = np.einsum('mk,mkd->md', w, U_k_batch)
            
        elif method == "local_ot":
            u_new_batch = _local_ot_extension_batch(S_k_batch, U_k_batch, X_chunk, task=task)
            
        elif method == "hybrid":
            d_safe = np.maximum(knn_distances, 1e-10)
            w = 1.0 / d_safe
            w = w / w.sum(axis=1, keepdims=True)
            u_bary = np.einsum('mk,mkd->md', w, U_k_batch)
            
            r_bary = np.linalg.norm(u_bary, axis=1, keepdims=True)
            r_k = np.linalg.norm(U_k_batch, axis=2)
            r_pred = np.sum(w * r_k, axis=1, keepdims=True)
            
            correction = np.where(r_bary > 1e-10, r_pred / (r_bary + 1e-10), 1.0)
            correction = np.clip(correction, 0.5, 2.0)
            u_new_batch = u_bary * correction
        
        else:
            raise ValueError(f"Unknown method: {method}")
        
        # Handle exact matches
        if np.any(exact_match_mask):
            exact_idx = np.argmin(dist[exact_match_mask], axis=1)
            u_new_batch[exact_match_mask] = U_ref_cpu[exact_idx]
        
        # Compute radii
        radii_chunk = np.linalg.norm(u_new_batch, axis=1)
        all_radii[chunk_start:chunk_end] = radii_chunk
    
    return to_gpu(all_radii) if HAS_CUPY else all_radii


def mk_levels_new_scores(
    scores_new,
    model: Dict,
    method: Literal["exact", "nearest_neighbor", "barycentric", "local_ot", "hybrid"] = "exact",
    k_nn: int = 10,
):
    """
    Computes empirical levels (CDF values) for a batch of new scores.
    """
    xp = cp if HAS_CUPY else np
    X = xp.asarray(scores_new, dtype=xp.float32)
    if X.ndim == 1:
        X = X[:, None]
    m, d = X.shape
    
    tau = xp.zeros(m, dtype=xp.float32)
    for j in range(m):
        _, _, tau_j = mk_rank_new_score(X[j], model, method=method, k_nn=k_nn)
        tau[j] = tau_j
    return tau


# ---------------------------------------------------------------------------
# 4. Quantile thresholds (unchanged)
# ---------------------------------------------------------------------------

def mk_quantile_from_levels(tau_values, alpha: float) -> float:
    """Quantile of levels tau in [0,1]."""
    xp = cp if HAS_CUPY else np
    tau_values = xp.asarray(tau_values, dtype=xp.float32).ravel()
    if not (0.0 < alpha < 1.0):
        raise ValueError("alpha must be in (0,1).")
    return float(to_cpu(xp.quantile(tau_values, alpha)))


def mk_quantile_from_radii(radii_values, coverage: float) -> float:
    """
    Quantile on radii for conformal prediction.
    
    Paper notation: coverage = α (target coverage level)
    """
    xp = cp if HAS_CUPY else np
    r = xp.asarray(radii_values, dtype=xp.float32).ravel()
    m = r.size
    if m == 0:
        raise ValueError("radii_values must be non-empty.")
    if not (0.0 < coverage < 1.0):
        raise ValueError("coverage must be in (0,1).")
    
    # Paper formula: ⌈(m+1)*coverage⌉
    k = int(np.ceil(coverage * (m + 1)))
    k = max(1, min(k, m))
    r_sorted = xp.sort(r)
    rho = float(to_cpu(r_sorted[k - 1]))
    return rho