import warnings
import numpy as np

# GPU support via CuPy
try:
    import cupy as cp
    HAS_CUPY = True
except ImportError:
    cp = None
    HAS_CUPY = False

from typing import List, Callable, Dict, Any, Optional
from sklearn.decomposition import PCA

# Import scoring functions from multiscoring module to avoid duplication
try:
    from multiscoring_conformal_gpu import (
        build_softmax_diff_scoring_functions as _build_softmax_diff_scoring_functions,
        build_reduced_scoring_functions as _build_reduced_scoring_functions,
    )
    _HAS_MULTISCORING = True
except ImportError:
    _HAS_MULTISCORING = False


def to_cpu(arr):
    """Move array to CPU."""
    if HAS_CUPY and hasattr(arr, 'device'):
        return cp.asnumpy(arr)
    return np.asarray(arr)


def to_gpu(arr):
    """Move array to GPU if available."""
    if HAS_CUPY:
        return cp.asarray(arr, dtype=cp.float32)
    return np.asarray(arr, dtype=np.float32)


# =============================================================================
# KERNEL FUNCTIONS (GPU-ACCELERATED)
# =============================================================================

def rbf_kernel_gpu(X, Y, sigma, xp=np):
    """
    Compute RBF (Gaussian) kernel between X and Y.
    
    K(x, y) = exp(-||x - y||^2 / (2 * sigma^2))
    
    Args:
        X: (n, p) array
        Y: (m, p) array  
        sigma: Kernel bandwidth
        xp: numpy or cupy module
    
    Returns:
        K: (n, m) kernel matrix
    """
    # Compute pairwise squared Euclidean distances
    dists_sq = xp.sum(X**2, axis=1, keepdims=True) + \
               xp.sum(Y**2, axis=1, keepdims=True).T - \
               2 * xp.dot(X, Y.T)
    
    # Numerical stability: ensure non-negative
    dists_sq = xp.maximum(dists_sq, 0.0)
    
    # Apply RBF kernel
    K = xp.exp(-dists_sq / (2 * sigma**2))
    return K


def compute_kernel_weights_gpu(x, X_cal, sigma, eps=1e-10, xp=np):
    """
    Compute normalized kernel weights for local estimation.
    
    w_j(x) = K_h(x, X_j) / Σ_k K_h(x, X_k)
    
    Args:
        x: Query point (p,) or (n_query, p)
        X_cal: Calibration features (n_cal, p)
        sigma: Kernel bandwidth
        eps: Small value for numerical stability
        xp: numpy or cupy module
    
    Returns:
        weights: (n_cal,) or (n_query, n_cal) normalized weights
    """
    if x.ndim == 1:
        x = x[None, :]  # (1, p)
        squeeze_output = True
    else:
        squeeze_output = False
    
    # Compute kernel values
    K = rbf_kernel_gpu(x, X_cal, sigma, xp)  # (n_query, n_cal)
    
    # Normalize
    K_sum = xp.sum(K, axis=1, keepdims=True) + eps
    weights = K / K_sum
    
    if squeeze_output:
        return weights[0]
    return weights


# =============================================================================
# LOCAL NONPARAMETRIC RANK ESTIMATOR FOR SCORE SPACE
# =============================================================================

def compute_local_rank_scores_gpu(s, S_cal, x, X_cal, sigma, eps=1e-10, xp=np):
    """
    Compute the local nonparametric geometric rank in score space (GPU-accelerated).
    
    R̂_n(x, s) = Σ_{j=1}^n w_j(x) · (s - S_j) / ||s - S_j||
    
    where w_j(x) = K_h(x, X_j) / Σ_k K_h(x, X_k)
    
    Args:
        s: Score vectors to evaluate (n_test, K) or (K,)
        S_cal: Calibration score vectors (n_cal, K)
        x: Feature points for kernel weights (n_test, p) or (p,)
        X_cal: Calibration features (n_cal, p)
        sigma: Kernel bandwidth
        eps: Small value to avoid division by zero
        xp: numpy or cupy module
    
    Returns:
        ranks: Geometric rank vectors (n_test, K) or (K,)
    """
    s = xp.asarray(s, dtype=xp.float32)
    S_cal = xp.asarray(S_cal, dtype=xp.float32)
    x = xp.asarray(x, dtype=xp.float32)
    X_cal = xp.asarray(X_cal, dtype=xp.float32)
    
    squeeze_output = False
    if s.ndim == 1:
        s = s[None, :]  # (1, K)
        squeeze_output = True
    
    if x.ndim == 1:
        x = x[None, :]  # (1, p)
    
    n_test, K = s.shape
    n_cal = S_cal.shape[0]
    
    # Compute kernel weights for all test points
    weights = compute_kernel_weights_gpu(x, X_cal, sigma, eps, xp)  # (n_test, n_cal)
    
    # Vectorized rank computation for all test points
    # diff[i, j, :] = s[i] - S_cal[j]  shape: (n_test, n_cal, K)
    diff = s[:, None, :] - S_cal[None, :, :]  # (n_test, n_cal, K)
    
    # dist[i, j] = ||s[i] - S_cal[j]||  shape: (n_test, n_cal)
    dist = xp.linalg.norm(diff, axis=2)  # (n_test, n_cal)
    
    # Avoid division by zero: mask out very small distances
    safe_dist = xp.maximum(dist, eps)
    
    # Normalized direction: diff / ||diff||  shape: (n_test, n_cal, K)
    normalized_diff = diff / safe_dist[:, :, None]
    
    # Weighted sum: ranks[i] = Σ_j w[i,j] * normalized_diff[i,j]
    # shape: (n_test, K)
    ranks = xp.einsum('ij,ijk->ik', weights, normalized_diff)
    
    if squeeze_output:
        return ranks[0]
    return ranks


def compute_rank_scores_batch_gpu(S_all, S_cal, X_all, X_cal, sigma, eps=1e-10):
    """
    Batch computation of rank scores ||R̂_n(x, s)|| for all test points.
    
    Uses GPU if available, otherwise falls back to CPU.
    CHUNKED to avoid OOM on large batches.
    
    Args:
        S_all: Score vectors (n_test, K)
        S_cal: Calibration score vectors (n_cal, K)
        X_all: Feature points (n_test, p)
        X_cal: Calibration features (n_cal, p)
        sigma: Kernel bandwidth
        eps: Numerical stability
    
    Returns:
        rank_norms: ||R̂_n|| values (n_test,)
    """
    if HAS_CUPY:
        xp = cp
        S_all = to_gpu(S_all)
        S_cal = to_gpu(S_cal)
        X_all = to_gpu(X_all)
        X_cal = to_gpu(X_cal)
    else:
        xp = np
        S_all = np.asarray(S_all, dtype=np.float32)
        S_cal = np.asarray(S_cal, dtype=np.float32)
        X_all = np.asarray(X_all, dtype=np.float32)
        X_cal = np.asarray(X_cal, dtype=np.float32)
    
    n_test = S_all.shape[0]
    n_cal = S_cal.shape[0]
    K = S_all.shape[1]
    
    # Estimate memory for (chunk_size, n_cal, K) array
    # Target ~2GB max per chunk for safety with 16GB GPU
    bytes_per_element = 4  # float32
    target_memory = 2 * 1024**3  # 2GB
    memory_per_test = n_cal * K * bytes_per_element * 3  # diff, normalized_diff, etc.
    chunk_size = max(100, int(target_memory / memory_per_test))
    chunk_size = min(chunk_size, n_test)
    
    # Process in chunks
    rank_norms = xp.zeros(n_test, dtype=xp.float32)
    
    n_chunks = (n_test + chunk_size - 1) // chunk_size
    for c in range(n_chunks):
        start = c * chunk_size
        end = min((c + 1) * chunk_size, n_test)
        
        S_chunk = S_all[start:end]
        X_chunk = X_all[start:end]
        
        ranks_chunk = compute_local_rank_scores_gpu(S_chunk, S_cal, X_chunk, X_cal, sigma, eps, xp)
        rank_norms[start:end] = xp.linalg.norm(ranks_chunk, axis=1)
        
        # Free intermediate memory
        del ranks_chunk
        if HAS_CUPY:
            cp.get_default_memory_pool().free_all_blocks()
    
    return to_cpu(rank_norms)


# =============================================================================
# GRCP MULTICLASS PREDICTOR
# =============================================================================

class GRCPMulticlassPredictor:
    """
    GRCP (Geometric Rank Conformal Prediction) for Multiclass Classification.
    
    Uses the local nonparametric rank estimator in score space:
        R̂_n(x, s) = Σ_{j=1}^n w_j(x) · (s - S_j) / ||s - S_j||
    
    The conformity score is ||R̂_n(x, s)||.
    
    Key features:
    - True nonparametric rank estimation (no relabeling/CDF transform)
    - Conditional coverage via kernel weighting on features X
    - Uses zero as center (natural for softmax-diff scores)
    - GPU-accelerated when CuPy is available
    
    Data splits:
    - Rank set: Used to define the reference distribution S_cal
    - Calibration set: Used to compute the threshold via split conformal
    - Test set: Evaluation
    
    For convenience, you can also use a single calibration set (combining
    rank and calibration) with the `calibrate_single` method.
    """
    
    def __init__(
        self,
        scoring_functions: List[Callable],
        alpha: float = 0.1,
        sigma: float = 1.0,
        pca_dim: Optional[int] = 10,
        eps: float = 1e-10,
    ):
        """
        Initialize the GRCP multiclass predictor.
        
        Args:
            scoring_functions: List of scoring functions. Each takes (y_true, y_probs)
                              and returns a 1D array of scores.
            alpha: Miscoverage level (e.g., 0.1 for 90% coverage)
            sigma: Kernel bandwidth for local rank estimation.
                   Larger sigma = more global, smaller = more local.
            pca_dim: Number of PCA components for feature reduction (for kernel).
                     Set to None to skip PCA. Recommended for high-dim features.
            eps: Small value for numerical stability
        """
        self.scoring_functions = scoring_functions
        self.n_scores = len(scoring_functions)
        self.alpha = alpha
        self.sigma = sigma
        self.pca_dim = pca_dim
        self.eps = eps
        
        # Fitted attributes
        self.S_rank_: Optional[np.ndarray] = None  # Rank set scores (n_rank, K)
        self.X_rank_: Optional[np.ndarray] = None  # Rank set features (n_rank, p_reduced)
        self.threshold_: Optional[float] = None    # Calibrated threshold
        self.pca_model_: Optional[PCA] = None      # PCA model for feature reduction
        self.is_fitted_ = False
    
    def compute_scores(
        self,
        y_true: np.ndarray,
        y_probs: np.ndarray,
    ) -> np.ndarray:
        """
        Compute multivariate score vectors S ∈ R^{n × K}.
        
        Args:
            y_true: True labels (n,)
            y_probs: Predicted probabilities (n, n_classes)
        
        Returns:
            S: Score matrix (n, K) where K = n_scores
        """
        n = len(y_true)
        S = np.zeros((n, self.n_scores), dtype=np.float32)
        for j, score_func in enumerate(self.scoring_functions):
            S[:, j] = score_func(y_true, y_probs)
        return S
    
    def _reduce_features(self, X: np.ndarray, fit: bool = False) -> np.ndarray:
        """
        Reduce feature dimensionality via PCA for kernel weights.
        
        Args:
            X: Features (n, p)
            fit: Whether to fit the PCA model
        
        Returns:
            X_reduced: Reduced features (n, pca_dim)
        """
        if self.pca_dim is None or X.shape[1] <= self.pca_dim:
            return X
        
        if fit:
            n_components = min(self.pca_dim, X.shape[0], X.shape[1])
            self.pca_model_ = PCA(n_components=n_components)
            return self.pca_model_.fit_transform(X).astype(np.float32)
        else:
            if self.pca_model_ is None:
                raise ValueError("PCA model not fitted. Call calibrate first.")
            return self.pca_model_.transform(X).astype(np.float32)
    
    def calibrate(
        self,
        X_rank: np.ndarray,
        y_rank_true: np.ndarray,
        y_rank_probs: np.ndarray,
        X_cal: np.ndarray,
        y_cal_true: np.ndarray,
        y_cal_probs: np.ndarray,
        verbose: bool = True,
    ) -> "GRCPMulticlassPredictor":
        """
        Calibrate with separate rank and calibration sets (recommended).
        
        Args:
            X_rank: Rank set features (n_rank, p)
            y_rank_true: Rank set true labels (n_rank,)
            y_rank_probs: Rank set predicted probabilities (n_rank, n_classes)
            X_cal: Calibration set features (n_cal, p)
            y_cal_true: Calibration set true labels (n_cal,)
            y_cal_probs: Calibration set predicted probabilities (n_cal, n_classes)
            verbose: Print progress
        
        Returns:
            self
        """
        n_rank = len(y_rank_true)
        n_cal = len(y_cal_true)
        
        if verbose:
            print(f"[GRCP] Calibrating with rank set (n={n_rank}) and cal set (n={n_cal})")
        
        # 1. Compute score vectors for rank set (reference distribution)
        self.S_rank_ = self.compute_scores(y_rank_true, y_rank_probs)
        
        if verbose:
            print(f"[GRCP] Rank scores shape: {self.S_rank_.shape}")
        
        # 2. Reduce features via PCA (fit on rank set)
        X_rank = np.asarray(X_rank, dtype=np.float32)
        X_cal = np.asarray(X_cal, dtype=np.float32)
        
        self.X_rank_ = self._reduce_features(X_rank, fit=True)
        X_cal_reduced = self._reduce_features(X_cal, fit=False)
        
        if verbose:
            print(f"[GRCP] Features reduced: {X_rank.shape[1]} -> {self.X_rank_.shape[1]}")
        
        # 3. Compute calibration scores using rank estimator
        S_cal = self.compute_scores(y_cal_true, y_cal_probs)
        
        if verbose:
            print(f"[GRCP] Computing calibration scores...")
        
        cal_scores = compute_rank_scores_batch_gpu(
            S_cal, self.S_rank_, X_cal_reduced, self.X_rank_, self.sigma, self.eps
        )
        
        # 4. Compute threshold via split conformal
        q_level = np.ceil((n_cal + 1) * (1 - self.alpha))
        q_level = int(max(1, min(q_level, n_cal)))
        cal_scores_sorted = np.sort(cal_scores)
        self.threshold_ = float(cal_scores_sorted[q_level - 1])
        
        if verbose:
            print(f"[GRCP] Threshold: {self.threshold_:.4f} (alpha={self.alpha})")
            print(f"[GRCP] Calibration scores: min={cal_scores.min():.4f}, "
                  f"max={cal_scores.max():.4f}, mean={cal_scores.mean():.4f}")
        
        self.is_fitted_ = True
        return self
    
    def calibrate_single(
        self,
        X_cal: np.ndarray,
        y_cal_true: np.ndarray,
        y_cal_probs: np.ndarray,
        split_ratio: float = 0.5,
        verbose: bool = True,
    ) -> "GRCPMulticlassPredictor":
        """
        Calibrate with a single calibration set (split internally).
        
        Convenience method that splits the calibration set into rank and
        calibration subsets.
        
        Args:
            X_cal: Calibration features (n_cal, p)
            y_cal_true: Calibration true labels (n_cal,)
            y_cal_probs: Calibration predicted probabilities (n_cal, n_classes)
            split_ratio: Ratio of data for rank set (rest for calibration)
            verbose: Print progress
        
        Returns:
            self
        """
        n_cal = len(y_cal_true)
        
        # Shuffle and split
        rng = np.random.default_rng(seed=42)
        perm = rng.permutation(n_cal)
        n_rank = int(n_cal * split_ratio)
        
        if n_rank < 20 or (n_cal - n_rank) < 20:
            warnings.warn(f"Small split sizes: rank={n_rank}, cal={n_cal - n_rank}. "
                         "Consider using more calibration data.")
        
        idx_rank = perm[:n_rank]
        idx_cal = perm[n_rank:]
        
        return self.calibrate(
            X_rank=X_cal[idx_rank],
            y_rank_true=y_cal_true[idx_rank],
            y_rank_probs=y_cal_probs[idx_rank],
            X_cal=X_cal[idx_cal],
            y_cal_true=y_cal_true[idx_cal],
            y_cal_probs=y_cal_probs[idx_cal],
            verbose=verbose,
        )
    
    def compute_conformity_score(
        self,
        s: np.ndarray,
        x: np.ndarray,
    ) -> float:
        """
        Compute conformity score ||R̂_n(x, s)|| for a single point.
        
        Args:
            s: Score vector (K,)
            x: Feature vector (p,) - will be PCA-reduced if needed
        
        Returns:
            score: ||R̂_n(x, s)||
        """
        if not self.is_fitted_:
            raise ValueError("Predictor not fitted. Call calibrate() first.")
        
        s = np.asarray(s, dtype=np.float32)
        x = np.asarray(x, dtype=np.float32)
        
        # Reduce features
        x_reduced = self._reduce_features(x[None, :], fit=False)[0]
        
        # Compute rank score
        if HAS_CUPY:
            xp = cp
            s_gpu = to_gpu(s)
            S_rank_gpu = to_gpu(self.S_rank_)
            x_gpu = to_gpu(x_reduced)
            X_rank_gpu = to_gpu(self.X_rank_)
            rank = compute_local_rank_scores_gpu(
                s_gpu, S_rank_gpu, x_gpu, X_rank_gpu, self.sigma, self.eps, xp
            )
            return float(to_cpu(xp.linalg.norm(rank)))
        else:
            rank = compute_local_rank_scores_gpu(
                s, self.S_rank_, x_reduced, self.X_rank_, self.sigma, self.eps, np
            )
            return float(np.linalg.norm(rank))
    
    def is_inside(self, s: np.ndarray, x: np.ndarray) -> bool:
        """
        Test if a score vector is inside the conformal region.
        
        Args:
            s: Score vector (K,)
            x: Feature vector (p,)
        
        Returns:
            inside: True if ||R̂_n(x, s)|| <= threshold
        """
        score = self.compute_conformity_score(s, x)
        return score <= self.threshold_
    
    def is_inside_batch(
        self,
        S: np.ndarray,
        X: np.ndarray,
    ) -> np.ndarray:
        """
        Test multiple score vectors (batch version).
        
        Args:
            S: Score vectors (n, K)
            X: Feature vectors (n, p)
        
        Returns:
            inside: Boolean array (n,)
        """
        if not self.is_fitted_:
            raise ValueError("Predictor not fitted. Call calibrate() first.")
        
        S = np.asarray(S, dtype=np.float32)
        X = np.asarray(X, dtype=np.float32)
        
        # Reduce features
        X_reduced = self._reduce_features(X, fit=False)
        
        # Compute rank scores
        rank_norms = compute_rank_scores_batch_gpu(
            S, self.S_rank_, X_reduced, self.X_rank_, self.sigma, self.eps
        )
        
        return rank_norms <= self.threshold_
    
    def predict_set(
        self,
        y_probs: np.ndarray,
        X: np.ndarray,
        candidate_labels: Optional[np.ndarray] = None,
        verbose: bool = False,
    ) -> List[np.ndarray]:
        """
        Build conformal prediction sets for classification.
        
        For each test point i, returns the set of labels y such that
        ||R̂_n(x_i, s(x_i, y))|| <= threshold.
        
        Args:
            y_probs: Predicted probabilities (n_test, n_classes)
            X: Test features (n_test, p)
            candidate_labels: Labels to consider. If None, uses [0, 1, ..., n_classes-1]
            verbose: Print progress
        
        Returns:
            prediction_sets: List of arrays, each containing labels in the set
        """
        if not self.is_fitted_:
            raise ValueError("Predictor not fitted. Call calibrate() first.")
        
        y_probs = np.asarray(y_probs, dtype=np.float32)
        X = np.asarray(X, dtype=np.float32)
        n_test, n_classes = y_probs.shape
        
        if candidate_labels is None:
            candidate_labels = np.arange(n_classes)
        n_labels = len(candidate_labels)
        
        # Reduce features
        X_reduced = self._reduce_features(X, fit=False)
        
        if verbose:
            print(f"[GRCP] Building prediction sets for {n_test} test points, "
                  f"{n_labels} candidate labels")
        
        # ==================== VECTORIZED COMPUTATION ====================
        # Compute scores for all (test point, candidate label) pairs
        # Shape: (n_test, n_labels, K)
        all_scores = np.zeros((n_test, n_labels, self.n_scores), dtype=np.float32)
        
        for j, label in enumerate(candidate_labels):
            # Create fake y_true where all points have this label
            y_fake = np.full(n_test, label, dtype=int)
            all_scores[:, j, :] = self.compute_scores(y_fake, y_probs)
        
        # Reshape for batch processing
        # S_flat: (n_test * n_labels, K)
        # X_flat: (n_test * n_labels, p_reduced) - repeat each x for n_labels times
        S_flat = all_scores.reshape(-1, self.n_scores)
        X_flat = np.repeat(X_reduced, n_labels, axis=0)
        
        if verbose:
            print(f"[GRCP] Computing {n_test * n_labels} rank scores...")
        
        # Batch compute rank norms
        rank_norms_flat = compute_rank_scores_batch_gpu(
            S_flat, self.S_rank_, X_flat, self.X_rank_, self.sigma, self.eps
        )
        
        # Reshape back: (n_test, n_labels)
        rank_norms = rank_norms_flat.reshape(n_test, n_labels)
        
        # Build prediction sets
        inside = rank_norms <= self.threshold_  # (n_test, n_labels)
        
        prediction_sets = []
        for i in range(n_test):
            labels_in_set = candidate_labels[inside[i]]
            prediction_sets.append(labels_in_set)
        
        if verbose:
            sizes = [len(ps) for ps in prediction_sets]
            print(f"[GRCP] Prediction set sizes: min={min(sizes)}, max={max(sizes)}, "
                  f"mean={np.mean(sizes):.2f}")
        
        return prediction_sets


# =============================================================================
# EVALUATION FUNCTION (for compare_methods_multiclass_gpu.py integration)
# =============================================================================

def eval_grcp_multiclass(
    X_cal: np.ndarray,
    y_cal: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    base_clf: Any,
    scoring_functions: List[Callable],
    alpha: float,
    sigma: float = 1.0,
    pca_dim: Optional[int] = 10,
    split_ratio: float = 0.5,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Evaluate GRCP for multiclass classification.
    
    Args:
        X_cal: Calibration features (n_cal, p)
        y_cal: Calibration true labels (n_cal,)
        X_test: Test features (n_test, p)
        y_test: Test true labels (n_test,)
        base_clf: Trained classifier with predict_proba method
        scoring_functions: List of scoring functions
        alpha: Miscoverage level
        sigma: Kernel bandwidth
        pca_dim: PCA dimension for features (None to skip)
        split_ratio: Ratio for rank/cal split
        verbose: Print progress
    
    Returns:
        Dictionary with coverage, average_set_size, prediction_sets
    """
    n_classes = base_clf.n_classes_
    
    if verbose:
        print(f"[GRCP Eval] n_classes={n_classes}, alpha={alpha}, sigma={sigma}")
    
    # Get predictions
    y_cal_probs = base_clf.predict_proba(X_cal)
    y_test_probs = base_clf.predict_proba(X_test)
    
    # Create and calibrate predictor
    predictor = GRCPMulticlassPredictor(
        scoring_functions=scoring_functions,
        alpha=alpha,
        sigma=sigma,
        pca_dim=pca_dim,
    )
    
    predictor.calibrate_single(
        X_cal=X_cal,
        y_cal_true=y_cal,
        y_cal_probs=y_cal_probs,
        split_ratio=split_ratio,
        verbose=verbose,
    )
    
    # Build prediction sets
    prediction_sets = predictor.predict_set(
        y_probs=y_test_probs,
        X=X_test,
        verbose=verbose,
    )
    
    # Compute metrics
    n_test = len(y_test)
    coverage_count = 0
    total_size = 0
    
    for i in range(n_test):
        ps = prediction_sets[i]
        total_size += len(ps)
        if y_test[i] in ps:
            coverage_count += 1
    
    coverage = coverage_count / n_test
    avg_size = total_size / n_test
    
    # Convert to sets for WSC computation
    prediction_sets_as_sets = [set(ps.tolist()) for ps in prediction_sets]
    
    if verbose:
        print(f"[GRCP Eval] Coverage: {coverage:.4f}, Avg size: {avg_size:.2f}")
    
    return {
        "coverage": coverage,
        "target_coverage": 1.0 - alpha,
        "average_set_size": avg_size,
        "prediction_sets": prediction_sets_as_sets,
    }


# =============================================================================
# SCORING FUNCTIONS (from multiscoring_conformal_gpu.py)
# =============================================================================

def build_softmax_diff_scoring_functions(n_classes: int) -> List[Callable]:
    """
    Build scoring functions for GRCP multiclass classification.
    
    As defined in 5_Xps.tex:
        S(x,y) = (|p̂(x)_1 - 1(y=1)|, ..., |p̂(x)_d - 1(y=d)|) ∈ R^d
    
    Each score function computes |p_c - 1_{y=c}| for class c.
    
    Args:
        n_classes: Number of classes K
    
    Returns:
        List of K scoring functions, one per class.
    """
    # Use imported function if available to avoid code duplication
    if _HAS_MULTISCORING:
        return _build_softmax_diff_scoring_functions(n_classes)
    
    # Fallback implementation
    def make_score_for_coord(c: int):
        def score_func(y_true: np.ndarray, y_probs: np.ndarray) -> np.ndarray:
            y_true = np.asarray(y_true, dtype=int).ravel()
            y_probs = np.asarray(y_probs, dtype=np.float32)
            indicator = (y_true == c).astype(np.float32)
            return np.abs(y_probs[:, c] - indicator)
        return score_func
    
    return [make_score_for_coord(c) for c in range(n_classes)]


def build_reduced_scoring_functions(K: int = 5) -> List[Callable]:
    """
    Build reduced scoring functions based on top-K probabilities.
    
    For high-dimensional classification (many classes), this avoids
    curse of dimensionality by using only top-K probability differences.
    
    Args:
        K: Number of top probabilities to use
    
    Returns:
        List of K scoring functions based on sorted probabilities.
    """
    # Use imported function if available to avoid code duplication
    if _HAS_MULTISCORING:
        return _build_reduced_scoring_functions(K)
    
    # Fallback implementation
    def make_topk_score(rank: int):
        def score_func(y_true: np.ndarray, y_probs: np.ndarray) -> np.ndarray:
            y_true = np.asarray(y_true, dtype=int).ravel()
            y_probs = np.asarray(y_probs, dtype=np.float32)
            n = len(y_true)
            n_classes = y_probs.shape[1]
            
            # Sort probabilities descending
            sorted_probs = np.sort(y_probs, axis=1)[:, ::-1]
            
            # Get rank-th highest probability
            if rank < n_classes:
                return 1.0 - sorted_probs[:, rank]
            else:
                return np.ones(n, dtype=np.float32)
        return score_func
    
    return [make_topk_score(r) for r in range(K)]


# =============================================================================
# BANDWIDTH SELECTION HEURISTICS
# =============================================================================

def median_heuristic(X: np.ndarray, subsample: int = 1000) -> float:
    """
    Compute median heuristic for RBF kernel bandwidth.
    
    sigma = median(||x_i - x_j||) / sqrt(2)
    
    Args:
        X: Data points (n, p)
        subsample: Number of points to use (for efficiency)
    
    Returns:
        sigma: Recommended bandwidth
    """
    n = X.shape[0]
    if n > subsample:
        rng = np.random.default_rng(seed=42)
        idx = rng.choice(n, size=subsample, replace=False)
        X_sub = X[idx]
    else:
        X_sub = X
    
    # Compute pairwise distances
    from sklearn.metrics import pairwise_distances
    dists = pairwise_distances(X_sub, metric='euclidean')
    
    # Take median of non-zero distances
    dists_flat = dists[np.triu_indices_from(dists, k=1)]
    median_dist = np.median(dists_flat)
    
    sigma = median_dist / np.sqrt(2)
    return max(sigma, 1e-4)  # Ensure positive


if __name__ == "__main__":
    # Quick test
    print("Testing GRCP Multiclass...")
    
    # Synthetic data
    np.random.seed(42)
    n_train, n_cal, n_test = 500, 200, 100
    n_features, n_classes = 10, 5
    
    X_train = np.random.randn(n_train, n_features).astype(np.float32)
    y_train = np.random.randint(0, n_classes, n_train)
    
    X_cal = np.random.randn(n_cal, n_features).astype(np.float32)
    y_cal = np.random.randint(0, n_classes, n_cal)
    
    X_test = np.random.randn(n_test, n_features).astype(np.float32)
    y_test = np.random.randint(0, n_classes, n_test)
    
    # Simple classifier (random probabilities for testing)
    class DummyClassifier:
        def __init__(self, n_classes):
            self.n_classes_ = n_classes
        
        def predict_proba(self, X):
            n = X.shape[0]
            probs = np.random.rand(n, self.n_classes_).astype(np.float32)
            return probs / probs.sum(axis=1, keepdims=True)
    
    clf = DummyClassifier(n_classes)
    
    # Scoring functions
    scoring_functions = build_softmax_diff_scoring_functions(n_classes)
    
    # Evaluate
    results = eval_grcp_multiclass(
        X_cal=X_cal,
        y_cal=y_cal,
        X_test=X_test,
        y_test=y_test,
        base_clf=clf,
        scoring_functions=scoring_functions,
        alpha=0.1,
        sigma=1.0,
        pca_dim=5,
        verbose=True,
    )
    
    print(f"\nResults:")
    print(f"  Coverage: {results['coverage']:.4f} (target: {results['target_coverage']:.4f})")
    print(f"  Avg set size: {results['average_set_size']:.2f}")
    print("\nTest passed!")
