import torch
import numpy as np

def feature_correlations(X, device='cpu'):
    """
    Compute feature-feature Pearson correlations for the input data.

    Args:
        X (numpy.ndarray): Input data of shape (n_samples, n_features).
        method (str): Correlation method, either 'pearson' or 'spearman'.
        device (str or torch.device): Device to perform the computation ('cpu' or 'cuda').

    Returns:
        numpy.ndarray: Correlation matrix of shape (n_features, n_features).
    """
    # Move data to the specified device
    X = torch.tensor(X, dtype=torch.float32, device=device)

    # Center the data for correlation computation
    X_centered = X - X.mean(dim=0, keepdim=True)
    
    # Compute covariance matrix
    covariance = torch.mm(X_centered.T, X_centered) / (X.size(0) - 1)
    
    # Compute standard deviations
    std_devs = torch.sqrt(torch.diag(covariance))
    std_matrix = std_devs[:, None] * std_devs[None, :]
    
    # Compute correlation matrix
    correlation_matrix = covariance / std_matrix
    correlation_matrix[torch.isnan(correlation_matrix)] = 0  # Handle NaNs
    
    # Convert to NumPy and return
    return correlation_matrix.cpu().numpy()


def concordance_index(A: np.ndarray, B: np.ndarray, device: str = 'cpu') -> float:
    """
    Computes a global concordance index between two NumPy matrices by comparing
    all pairwise feature orderings across all samples using PyTorch backend.

    Parameters:
        A (np.ndarray): shape (n_samples, n_features)
        B (np.ndarray): shape (n_samples, n_features)
        device (str): 'cpu' or 'cuda'

    Returns:
        float: Concordance index in [-1, 1]
    """
    A = torch.tensor(A, dtype=torch.float32, device=device)
    B = torch.tensor(B, dtype=torch.float32, device=device)

    # Support 1D input by unsqueezing to (1, n_features)
    if A.ndim == 1:
        A = A.unsqueeze(0)
    if B.ndim == 1:
        B = B.unsqueeze(0)

    n_samples, n_features = A.shape
    n_samples, n_features = A.shape
    assert B.shape == (n_samples, n_features), "A and B must have the same shape"

    # Create all (i, j) index pairs for features where i < j
    i_idx, j_idx = torch.triu_indices(n_features, n_features, offset=1, device=A.device)

    # Compute differences for each pair (broadcast over samples)
    A_diff = A[:, i_idx] - A[:, j_idx]  # shape: (n_samples, n_pairs)
    B_diff = B[:, i_idx] - B[:, j_idx]  # shape: (n_samples, n_pairs)

    # Remove ties: where either difference is zero
    nonzero_mask = (A_diff != 0) & (B_diff != 0)

    # Sign agreement
    sign_agreement = torch.sign(A_diff) == torch.sign(B_diff)

    # Keep only non-tied comparisons
    concordant = (sign_agreement & nonzero_mask).sum()
    discordant = (~sign_agreement & nonzero_mask).sum()
    total = nonzero_mask.sum()

    if total == 0:
        return torch.tensor(float('nan'), device=A.device)

    return ((concordant - discordant).float() / total).item()