
# GPU support via CuPy
try:
    import cupy as cp
    import cupy.random as cp_random
    HAS_CUPY = True
    print("[WSC] CuPy disponible - GPU activé")
except ImportError:
    cp = None
    cp_random = None
    HAS_CUPY = False
    print("[WSC] CuPy non disponible - fallback CPU (numpy)")

import numpy as np
from typing import Tuple, Union, Optional

# Module actif (cp ou np)
xp = cp if HAS_CUPY else np


def to_cpu(arr) -> np.ndarray:
    """Move array to CPU."""
    if arr is None:
        return None
    if HAS_CUPY and isinstance(arr, cp.ndarray):
        return cp.asnumpy(arr)
    return np.asarray(arr)


def to_gpu(arr):
    """Move array to GPU if available."""
    if arr is None:
        return None
    if HAS_CUPY:
        if isinstance(arr, cp.ndarray):
            return arr
        return cp.asarray(arr)
    return np.asarray(arr)


def get_array_module(arr):
    """Get the array module (cupy or numpy) for the given array."""
    if HAS_CUPY and isinstance(arr, cp.ndarray):
        return cp
    return np


def ensure_gpu_array(arr, dtype=None):
    """Ensure array is on GPU with optional dtype conversion."""
    arr_gpu = to_gpu(arr)
    if dtype is not None:
        arr_gpu = arr_gpu.astype(dtype)
    return arr_gpu


def wsc_pour_direction_gpu(coverage_status, projections, delta: float) -> float:
    """
    Calcule le Worst Slab Coverage (WSC) pour une direction de projection donnée (GPU).

    Entièrement vectorisé sur GPU pour performance maximale.

    Args:
        coverage_status: Array booléen (N,) où True indique une couverture réussie (GPU ou CPU).
        projections: Array float (N,) des données projetées sur la direction 'v' (GPU ou CPU).
        delta: Proportion minimale des données requise dans chaque "slab".

    Returns:
        Le WSC minimum trouvé pour cette direction.
    """
    # S'assurer que les données sont sur GPU
    coverage_gpu = ensure_gpu_array(coverage_status, dtype=xp.float32)
    proj_gpu = ensure_gpu_array(projections, dtype=xp.float32)
    
    N = len(proj_gpu)
    N_min = int(np.ceil(delta * N))

    if N_min <= 0 or N_min > N:
        return 1.0

    # Triage selon la projection (GPU)
    indices_tries = xp.argsort(proj_gpu)
    sorted_coverage = coverage_gpu[indices_tries]
    
    # Somme cumulée pour calcul rapide des slabs (GPU)
    succes_cumules = xp.cumsum(sorted_coverage)

    # Calcul vectorisé du WSC minimum
    n_slabs = N - N_min + 1
    
    if n_slabs <= 0:
        return 1.0
    
    # Nombre de succès pour chaque slab via différences de cumsum
    # slab[i] = sum(coverage[i:i+N_min]) = cumsum[i+N_min-1] - cumsum[i-1]
    succes_at_end = succes_cumules[N_min - 1:]  # cumsum aux positions de fin
    
    # Construire cumsum aux positions i-1 (avec 0 pour i=0)
    succes_at_start = xp.zeros(n_slabs, dtype=xp.float32)
    if n_slabs > 1:
        succes_at_start[1:] = succes_cumules[:n_slabs - 1]
    
    # Différence vectorisée
    nombre_succes = succes_at_end - succes_at_start
    
    # Couverture de chaque slab
    couvertures = nombre_succes / N_min
    
    # WSC minimum (transfert minimal vers CPU pour le résultat final)
    wsc_min = float(xp.min(couvertures))
    
    return wsc_min


def wsc_batch_directions_gpu(coverage_status, all_projections, delta: float) -> np.ndarray:
    """
    Calcule le WSC pour plusieurs directions en parallèle (GPU optimisé).

    Args:
        coverage_status: Array booléen (N,) de couverture.
        all_projections: Array (N, M) de projections pour M directions.
        delta: Proportion minimale des données requise.

    Returns:
        Array (M,) des WSC pour chaque direction.
    """
    coverage_gpu = ensure_gpu_array(coverage_status, dtype=xp.float32)
    proj_gpu = ensure_gpu_array(all_projections, dtype=xp.float32)
    
    N, M = proj_gpu.shape
    N_min = int(np.ceil(delta * N))
    
    if N_min <= 0 or N_min > N:
        return np.ones(M)
    
    n_slabs = N - N_min + 1
    if n_slabs <= 0:
        return np.ones(M)
    
    # Tri de chaque colonne indépendamment
    indices_sorted = xp.argsort(proj_gpu, axis=0)  # (N, M)
    
    # Réorganiser coverage selon les indices triés pour chaque direction
    # coverage_sorted[i, m] = coverage[indices_sorted[i, m]]
    row_indices = indices_sorted.flatten()
    col_indices = xp.repeat(xp.arange(M), N)
    coverage_expanded = xp.broadcast_to(coverage_gpu[:, None], (N, M))
    
    # Utiliser advanced indexing pour trier
    sorted_coverage = coverage_gpu[indices_sorted]  # (N, M)
    
    # Cumsum par colonne
    cumsum = xp.cumsum(sorted_coverage, axis=0)  # (N, M)
    
    # Calculer WSC pour chaque direction
    # succes_at_end = cumsum[N_min-1:, :] shape (n_slabs, M)
    succes_at_end = cumsum[N_min - 1:, :]
    
    # succes_at_start avec padding de zéros
    succes_at_start = xp.zeros((n_slabs, M), dtype=xp.float32)
    if n_slabs > 1:
        succes_at_start[1:, :] = cumsum[:n_slabs - 1, :]
    
    # Différences
    nombre_succes = succes_at_end - succes_at_start  # (n_slabs, M)
    couvertures = nombre_succes / N_min
    
    # Min par colonne
    wsc_per_dir = xp.min(couvertures, axis=0)  # (M,)
    
    return to_cpu(wsc_per_dir)


def calculer_wsc_gpu(X_test, Y_test, C_test: list, delta: float, M: int, 
                     random_state: Optional[int] = None) -> list:
    """
    Calcule la distribution du WSC en échantillonnant M directions (GPU optimisé).

    Args:
        X_test: Caractéristiques des données de test (N, d).
        Y_test: Labels vrais (N,).
        C_test: Liste des ensembles de confiance prédits, C_test[i] est un set des classes.
        delta: Proportion minimale des données requise dans chaque "slab".
        M: Nombre de directions aléatoires à échantillonner.
        random_state: Graine pour reproductibilité.

    Returns:
        Liste des valeurs WSC pour chaque direction échantillonnée.
    """
    # Conversion en arrays (sur CPU d'abord pour C_test qui est une liste Python)
    X_test = np.asarray(X_test, dtype=np.float32)
    Y_test = np.asarray(Y_test)
    N, d = X_test.shape
    
    # Statut de couverture (CPU car C_test est une liste de sets Python)
    coverage_status = np.array([Y_test[i] in C_test[i] for i in range(N)], dtype=np.float32)
    
    # Transfert sur GPU
    X_gpu = to_gpu(X_test)
    coverage_gpu = to_gpu(coverage_status)
    
    # Générer M directions aléatoires sur GPU
    if random_state is not None:
        if HAS_CUPY:
            cp.random.seed(random_state)
        np.random.seed(random_state)
    
    if HAS_CUPY:
        V = cp.random.randn(d, M).astype(cp.float32)
        norms = cp.linalg.norm(V, axis=0, keepdims=True)
        V = V / norms
    else:
        V = np.random.randn(d, M).astype(np.float32)
        norms = np.linalg.norm(V, axis=0, keepdims=True)
        V = V / norms
    
    # Toutes les projections en une opération matricielle : (N, d) @ (d, M) -> (N, M)
    all_projections = X_gpu @ V
    
    # Calculer WSC batch pour toutes les directions
    wsc_results = wsc_batch_directions_gpu(coverage_gpu, all_projections, delta)
    
    return list(wsc_results)


def calculer_wsc_gpu_batch(X_test, Y_test, C_test: list, delta: float, M: int,
                          random_state: Optional[int] = None) -> list:
    """
    Version batch ultra-optimisée : projections ET WSC calculés en batch sur GPU.

    Args:
        X_test: Caractéristiques des données de test (N, d).
        Y_test: Labels vrais (N,).
        C_test: Liste des ensembles de confiance prédits.
        delta: Proportion minimale des données requise dans chaque "slab".
        M: Nombre de directions aléatoires à échantillonner.
        random_state: Graine pour reproductibilité.

    Returns:
        Liste des valeurs WSC pour chaque direction échantillonnée.
    """
    # Cette fonction utilise maintenant la même implémentation optimisée
    return calculer_wsc_gpu(X_test, Y_test, C_test, delta, M, random_state)


def calculer_wsc_regression_gpu(X_test, coverage_status, delta: float, M: int,
                                random_state: Optional[int] = None) -> Tuple[float, list]:
    """
    Calcule le WSC pour des tâches de régression (coverage_status pré-calculé).

    Optimisé GPU - Évite le calcul du coverage_status à partir de sets Python.

    Args:
        X_test: Caractéristiques de test (N, d).
        coverage_status: Array booléen (N,) indiquant si y_true est dans l'intervalle prédit.
        delta: Proportion minimale des données dans chaque slab.
        M: Nombre de directions aléatoires.
        random_state: Graine pour reproductibilité.

    Returns:
        (wsc_min, wsc_list): WSC minimum et liste des WSC par direction.
    """
    X_test = np.asarray(X_test, dtype=np.float32)
    coverage_status = np.asarray(coverage_status, dtype=np.float32)
    N, d = X_test.shape
    
    # Transfert sur GPU
    X_gpu = to_gpu(X_test)
    coverage_gpu = to_gpu(coverage_status)
    
    # Générer M directions aléatoires sur GPU
    if random_state is not None:
        if HAS_CUPY:
            cp.random.seed(random_state)
        np.random.seed(random_state)
    
    if HAS_CUPY:
        V = cp.random.randn(d, M).astype(cp.float32)
        norms = cp.linalg.norm(V, axis=0, keepdims=True)
        V = V / norms
    else:
        V = np.random.randn(d, M).astype(np.float32)
        norms = np.linalg.norm(V, axis=0, keepdims=True)
        V = V / norms
    
    # Projections batch
    all_projections = X_gpu @ V  # (N, M)
    
    # WSC batch
    wsc_results = wsc_batch_directions_gpu(coverage_gpu, all_projections, delta)
    wsc_min = float(np.min(wsc_results))
    
    return wsc_min, list(wsc_results)


# Alias pour compatibilité avec l'API non-GPU
calculer_wsc = calculer_wsc_gpu
calculer_wsc_batch = calculer_wsc_gpu_batch
calculer_wsc_regression = calculer_wsc_regression_gpu


def worst_slab_coverage_multireg_random_dirs_gpu(
    coverage_status,
    Z,
    M: int,
    delta: float,
    min_points: int = 30,
    random_state: Optional[int] = None
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
    """
    Calcule le WSC en multi-régression en tirant M directions aléatoires (GPU optimisé).

    Paramètres
    ----------
    coverage_status : (n_samples,)
        Indicateur binaire de couverture.
    Z : (n_samples, d_z)
        Features / embedding utilisés pour définir les slabs.
    M : int
        Nombre de directions aléatoires.
    delta : float
        Demi-largeur du slab.
    min_points : int
        Nb minimal de points dans un slab pour le prendre en compte.
    random_state : int ou None
        Graine pour la reproductibilité.

    Renvoie
    -------
    wsc : float
        Worst slab coverage (min sur les directions valides).
    coverages_per_dir : (M,)
        Couverture par direction (NaN si slab trop peu peuplé).
    counts_per_dir : (M,)
        Nombre de points dans chaque slab.
    directions_unit : (M, d_z)
        Les directions aléatoires normalisées utilisées.
    """
    # Conversion et transfert GPU
    Z_gpu = ensure_gpu_array(Z, dtype=xp.float32)
    coverage_gpu = ensure_gpu_array(coverage_status, dtype=xp.float32)
    n_samples, d_z = Z_gpu.shape

    # Graine pour reproductibilité
    if random_state is not None:
        if HAS_CUPY:
            cp.random.seed(random_state)
        np.random.seed(random_state)

    # Générer M directions aléatoires sur GPU
    if HAS_CUPY:
        directions = cp.random.randn(M, d_z).astype(cp.float32)
        norms = cp.linalg.norm(directions, axis=1, keepdims=True)
        directions_unit = directions / norms
    else:
        rng = np.random.default_rng(random_state)
        directions = rng.normal(size=(M, d_z)).astype(np.float32)
        norms = np.linalg.norm(directions, axis=1, keepdims=True)
        directions_unit = directions / norms

    # Appel de la fonction principale GPU
    wsc, coverages_per_dir, counts_per_dir = worst_slab_coverage_multireg_gpu(
        coverage_status=coverage_gpu,
        Z=Z_gpu,
        directions=directions_unit,
        delta=delta,
        min_points=min_points,
    )

    return wsc, to_cpu(coverages_per_dir), to_cpu(counts_per_dir), to_cpu(directions_unit)


# Alias pour compatibilité
worst_slab_coverage_multireg_random_dirs = worst_slab_coverage_multireg_random_dirs_gpu


def worst_slab_coverage_multireg_gpu(
    coverage_status,
    Z,
    directions,
    delta: float,
    min_points: int = 30
) -> Tuple[float, np.ndarray, np.ndarray]:
    """
    Calcule le Worst Slab Coverage (WSC) en multi-régression (GPU optimisé).

    Entièrement vectorisé sur GPU pour performance maximale.

    Args:
        coverage_status: Array (n_samples,) indicateur binaire de couverture.
        Z: Array (n_samples, d_z) features pour définir les slabs.
        directions: Array (n_dirs, d_z) directions normalisées.
        delta: Demi-largeur du slab.
        min_points: Nb minimal de points dans un slab.

    Returns:
        (wsc, coverages_per_dir, counts_per_dir)
    """
    # Transfert sur GPU
    coverage_gpu = ensure_gpu_array(coverage_status, dtype=xp.float32)
    Z_gpu = ensure_gpu_array(Z, dtype=xp.float32)
    directions_gpu = ensure_gpu_array(directions, dtype=xp.float32)

    n_samples = Z_gpu.shape[0]
    
    if coverage_gpu.shape[0] != n_samples:
        raise ValueError("coverage_status et Z doivent avoir le même nombre d'exemples.")

    if directions_gpu.ndim == 1:
        directions_gpu = directions_gpu.reshape(1, -1)

    # Normalisation des directions (GPU)
    norms = xp.linalg.norm(directions_gpu, axis=1, keepdims=True)
    if xp.any(norms == 0):
        raise ValueError("Certaines directions ont une norme nulle.")
    directions_unit = directions_gpu / norms

    n_dirs = directions_unit.shape[0]
    
    # Calcul vectorisé des projections : (n_samples, d_z) @ (d_z, n_dirs) -> (n_samples, n_dirs)
    all_projections = Z_gpu @ directions_unit.T  # (n_samples, n_dirs)
    
    # Masque pour les slabs : |proj| <= delta
    masks = xp.abs(all_projections) <= delta  # (n_samples, n_dirs)
    
    # Compter les points dans chaque slab
    counts_per_dir = xp.sum(masks, axis=0)  # (n_dirs,)
    
    # Calculer la couverture pour chaque direction
    # coverage_sum[k] = sum(coverage[i] for i where mask[i, k])
    coverage_expanded = coverage_gpu[:, None]  # (n_samples, 1)
    coverage_in_slab = masks * coverage_expanded  # (n_samples, n_dirs)
    coverage_sums = xp.sum(coverage_in_slab, axis=0)  # (n_dirs,)
    
    # Éviter division par zéro
    counts_safe = xp.maximum(counts_per_dir, 1)
    coverages_per_dir = coverage_sums / counts_safe  # (n_dirs,)
    
    # Marquer les directions avec trop peu de points comme NaN
    invalid_mask = counts_per_dir < min_points
    coverages_per_dir = xp.where(invalid_mask, xp.nan, coverages_per_dir)
    
    # Vérifier qu'au moins une direction est valide
    valid_count = xp.sum(~invalid_mask)
    if int(valid_count) == 0:
        raise ValueError(
            f"Aucune direction n'a suffisamment de points dans le slab "
            f"(min_points={min_points}). Augmente delta ou réduis min_points."
        )
    
    # WSC = minimum des couvertures valides
    wsc = float(xp.nanmin(coverages_per_dir))
    
    return wsc, to_cpu(coverages_per_dir), to_cpu(counts_per_dir).astype(int)


# Alias pour compatibilité
worst_slab_coverage_multireg = worst_slab_coverage_multireg_gpu
