"""Optimal transport utilities for ARCOS experiments."""

import torch
import numpy as np
from typing import Tuple, Optional
from sklearn.metrics.pairwise import euclidean_distances


def compute_sliced_wasserstein_fast(
    X: torch.Tensor,
    Y: torch.Tensor,
    num_projections: int = 256,
    p: int = 1,
    device: str = "cuda"
) -> float:
    """Compute sliced Wasserstein distance using PyTorch (GPU-accelerated).
    
    Args:
        X: First set of features (N, D)
        Y: Second set of features (M, D)
        num_projections: Number of random projections
        p: Order of Wasserstein distance (1 or 2)
        device: Device to use for computation
        
    Returns:
        Sliced Wasserstein distance
    """
    # Ensure tensors are on the correct device
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, device=device, dtype=torch.float32)
    elif X.device != torch.device(device):
        X = X.to(device)
    
    if not isinstance(Y, torch.Tensor):
        Y = torch.tensor(Y, device=device, dtype=torch.float32)
    elif Y.device != torch.device(device):
        Y = Y.to(device)
    
    n_samples, n_features = X.shape
    m_samples, _ = Y.shape
    
    # Generate random projections on GPU
    torch.manual_seed(42)  # For reproducibility
    projections = torch.randn(n_features, num_projections, device=device, dtype=torch.float32)
    projections = projections / torch.norm(projections, dim=0, keepdim=True)
    
    # Project data (matrix multiplication on GPU)
    X_proj = X @ projections  # (N, num_projections)
    Y_proj = Y @ projections  # (M, num_projections)
    
    # Compute 1D Wasserstein distances for all projections at once
    # Sort projections
    X_sorted, _ = torch.sort(X_proj, dim=0)  # (N, num_projections)
    Y_sorted, _ = torch.sort(Y_proj, dim=0)  # (M, num_projections)
    
    # Use quantile-based approach for fair comparison
    # Sample quantiles from both distributions
    quantiles = torch.linspace(0.01, 0.99, 100, device=device)
    
    # Compute quantiles for X
    x_quantiles = torch.quantile(X_sorted, quantiles, dim=0)  # (100, num_projections)
    
    # Compute quantiles for Y
    y_quantiles = torch.quantile(Y_sorted, quantiles, dim=0)  # (100, num_projections)
    
    # Compute distances for all projections at once
    if p == 1:
        distances = torch.mean(torch.abs(x_quantiles - y_quantiles), dim=0)  # (num_projections,)
    elif p == 2:
        distances = torch.sqrt(torch.mean((x_quantiles - y_quantiles) ** 2, dim=0))  # (num_projections,)
    else:
        raise ValueError(f"Unsupported p value: {p}")
    
    return torch.mean(distances).item()


def compute_max_sliced_wasserstein_fast(
    X: torch.Tensor,
    Y: torch.Tensor,
    K: int = 256,
    device: str = "cuda"
) -> float:
    """Compute max-sliced Wasserstein distance using PyTorch (GPU-accelerated).
    
    Args:
        X: First set of features (N, D)
        Y: Second set of features (M, D)
        K: Number of random projections
        device: Device to use for computation
        
    Returns:
        Max-sliced Wasserstein distance
    """
    # Ensure tensors are on the correct device
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, device=device, dtype=torch.float32)
    elif X.device != torch.device(device):
        X = X.to(device)
    
    if not isinstance(Y, torch.Tensor):
        Y = torch.tensor(Y, device=device, dtype=torch.float32)
    elif Y.device != torch.device(device):
        Y = Y.to(device)
    
    n_samples, n_features = X.shape
    m_samples, _ = Y.shape
    
    # Generate random projections on GPU
    torch.manual_seed(42)  # For reproducibility
    projections = torch.randn(n_features, K, device=device, dtype=torch.float32)
    projections = projections / torch.norm(projections, dim=0, keepdim=True)
    
    # Project data (matrix multiplication on GPU)
    X_proj = X @ projections  # (N, K)
    Y_proj = Y @ projections  # (M, K)
    
    # Compute 1D Wasserstein distances for all projections at once
    # Sort projections
    X_sorted, _ = torch.sort(X_proj, dim=0)  # (N, K)
    Y_sorted, _ = torch.sort(Y_proj, dim=0)  # (M, K)
    
    # Use quantile-based approach for fair comparison
    # Sample quantiles from both distributions
    quantiles = torch.linspace(0.01, 0.99, 100, device=device)
    
    # Compute quantiles for X
    x_quantiles = torch.quantile(X_sorted, quantiles, dim=0)  # (100, K)
    
    # Compute quantiles for Y
    y_quantiles = torch.quantile(Y_sorted, quantiles, dim=0)  # (100, K)
    
    # Compute W1 distances for all projections at once
    distances = torch.mean(torch.abs(x_quantiles - y_quantiles), dim=0)  # (K,)
    
    # Return maximum distance
    return torch.max(distances).item()


def compute_sliced_wasserstein(
    X: torch.Tensor,
    Y: torch.Tensor,
    num_projections: int = 256,
    p: int = 1
) -> float:
    """Compute sliced Wasserstein distance.
    
    Args:
        X: First set of features (N, D)
        Y: Second set of features (M, D)
        num_projections: Number of random projections
        p: Order of Wasserstein distance (1 or 2)
        
    Returns:
        Sliced Wasserstein distance
    """
    if isinstance(X, torch.Tensor):
        X = X.detach().cpu().numpy()
    if isinstance(Y, torch.Tensor):
        Y = Y.detach().cpu().numpy()
    
    n_samples, n_features = X.shape
    m_samples, _ = Y.shape
    
    # Generate random projections
    np.random.seed(42)  # For reproducibility
    projections = np.random.randn(n_features, num_projections)
    projections = projections / np.linalg.norm(projections, axis=0, keepdims=True)
    
    # Project data
    X_proj = X @ projections  # (N, num_projections)
    Y_proj = Y @ projections  # (M, num_projections)
    
    # Compute 1D Wasserstein distances
    distances = []
    for i in range(num_projections):
        x_proj = X_proj[:, i]
        y_proj = Y_proj[:, i]
        
        # Sort projections
        x_sorted = np.sort(x_proj)
        y_sorted = np.sort(y_proj)
        
        # Compute 1D Wasserstein distance
        if p == 1:
            # W1: sum of absolute differences
            dist = np.mean(np.abs(x_sorted - y_sorted))
        elif p == 2:
            # W2: sum of squared differences
            dist = np.sqrt(np.mean((x_sorted - y_sorted) ** 2))
        else:
            raise ValueError(f"Unsupported p value: {p}")
        
        distances.append(dist)
    
    return np.mean(distances)


def compute_max_sliced_wasserstein(
    X: torch.Tensor,
    Y: torch.Tensor,
    K: int = 256
) -> float:
    """Compute max-sliced Wasserstein distance.
    
    Args:
        X: First set of features (N, D)
        Y: Second set of features (M, D)
        K: Number of random projections
        
    Returns:
        Max-sliced Wasserstein distance
    """
    if isinstance(X, torch.Tensor):
        X = X.detach().cpu().numpy()
    if isinstance(Y, torch.Tensor):
        Y = Y.detach().cpu().numpy()
    
    n_samples, n_features = X.shape
    m_samples, _ = Y.shape
    
    # Generate random projections
    np.random.seed(42)  # For reproducibility
    projections = np.random.randn(n_features, K)
    projections = projections / np.linalg.norm(projections, axis=0, keepdims=True)
    
    # Project data
    X_proj = X @ projections  # (N, K)
    Y_proj = Y @ projections  # (M, K)
    
    # Compute 1D Wasserstein distances for each projection
    distances = []
    for i in range(K):
        x_proj = X_proj[:, i]
        y_proj = Y_proj[:, i]
        
        # Sort projections
        x_sorted = np.sort(x_proj)
        y_sorted = np.sort(y_proj)
        
        # Compute 1D Wasserstein distance (W1)
        dist = np.mean(np.abs(x_sorted - y_sorted))
        distances.append(dist)
    
    # Return maximum distance
    return np.max(distances)


def compute_sinkhorn_w1(
    X: torch.Tensor,
    Y: torch.Tensor,
    reg: float = 0.01,
    max_iter: int = 1000
) -> float:
    """Compute Sinkhorn W1 distance.
    
    Args:
        X: First set of features (N, D)
        Y: Second set of features (M, D)
        reg: Regularization parameter
        max_iter: Maximum iterations
        
    Returns:
        Sinkhorn W1 distance
    """
    try:
        import ot
    except ImportError:
        print("POT library not available, falling back to max-sliced Wasserstein")
        return compute_max_sliced_wasserstein(X, Y)
    
    if isinstance(X, torch.Tensor):
        X = X.detach().cpu().numpy()
    if isinstance(Y, torch.Tensor):
        Y = Y.detach().cpu().numpy()
    
    # Compute cost matrix
    cost_matrix = euclidean_distances(X, Y)
    
    # Compute Sinkhorn distance
    n_samples = X.shape[0]
    m_samples = Y.shape[0]
    
    # Uniform distributions
    a = np.ones(n_samples) / n_samples
    b = np.ones(m_samples) / m_samples
    
    # Sinkhorn algorithm
    P = ot.sinkhorn(a, b, cost_matrix, reg, numItermax=max_iter)
    
    # Compute W1 distance
    w1_distance = np.sum(P * cost_matrix)
    
    return w1_distance


def compute_w1_feature_space(
    source_features: torch.Tensor,
    target_features: torch.Tensor,
    method: str = "max-sliced",
    device: str = "cuda",
    **kwargs
) -> float:
    """Compute W1 distance in feature space.
    
    Args:
        source_features: Source domain features
        target_features: Target domain features
        method: Method to use (max-sliced, sliced, sinkhorn, max-sliced-fast, sliced-fast)
        device: Device to use for computation
        **kwargs: Additional arguments for specific methods
        
    Returns:
        W1 distance estimate
    """
    # Normalize kwargs: allow both 'num_projections' and 'K'
    if 'num_projections' in kwargs and 'K' not in kwargs:
        kwargs = dict(kwargs)
        kwargs['K'] = kwargs.pop('num_projections')

    if method == "max-sliced":
        return compute_max_sliced_wasserstein(source_features, target_features, **kwargs)
    elif method == "sliced":
        return compute_sliced_wasserstein(source_features, target_features, **kwargs)
    elif method == "sinkhorn":
        return compute_sinkhorn_w1(source_features, target_features, **kwargs)
    elif method == "max-sliced-fast":
        return compute_max_sliced_wasserstein_fast(source_features, target_features, device=device, **kwargs)
    elif method == "sliced-fast":
        return compute_sliced_wasserstein_fast(source_features, target_features, device=device, **kwargs)
    else:
        raise ValueError(f"Unknown method: {method}")


def normalize_features(features: torch.Tensor, method: str = "l2") -> torch.Tensor:
    """Normalize features.
    
    Args:
        features: Feature tensor
        method: Normalization method (l2, minmax, zscore)
        
    Returns:
        Normalized features
    """
    if method == "l2":
        # L2 normalization
        norm = torch.norm(features, p=2, dim=1, keepdim=True)
        norm = torch.clamp(norm, min=1e-8)  # Avoid division by zero
        return features / norm
    elif method == "minmax":
        # Min-max normalization
        min_vals = torch.min(features, dim=1, keepdim=True)[0]
        max_vals = torch.max(features, dim=1, keepdim=True)[0]
        range_vals = max_vals - min_vals
        range_vals = torch.clamp(range_vals, min=1e-8)  # Avoid division by zero
        return (features - min_vals) / range_vals
    elif method == "zscore":
        # Z-score normalization
        mean = torch.mean(features, dim=1, keepdim=True)
        std = torch.std(features, dim=1, keepdim=True)
        std = torch.clamp(std, min=1e-8)  # Avoid division by zero
        return (features - mean) / std
    else:
        raise ValueError(f"Unknown normalization method: {method}")


def compute_feature_statistics(features: torch.Tensor) -> dict:
    """Compute basic statistics of features.
    
    Args:
        features: Feature tensor
        
    Returns:
        Dictionary of statistics
    """
    if isinstance(features, torch.Tensor):
        features = features.detach().cpu().numpy()
    
    stats = {
        'mean': np.mean(features, axis=0),
        'std': np.std(features, axis=0),
        'min': np.min(features, axis=0),
        'max': np.max(features, axis=0),
        'l2_norm': np.linalg.norm(features, axis=1),
        'cosine_similarity': np.mean(features @ features.T)
    }
    
    return stats
