import torch
import numpy as np
from scipy.stats import pearsonr, spearmanr, kendalltau


def calculate_nmae(predictions, targets):
    """
    Calculate Normalized Mean Absolute Error (NMAE).

    Args:
        predictions (torch.Tensor): Predicted values
        targets (torch.Tensor): Ground truth values

    Returns:
        float: NMAE value
    """
    # Convert to numpy if tensors
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.detach().cpu().numpy()

    # Calculate NMAE
    mae = np.mean(np.abs(predictions - targets))
    range_val = np.max(targets) - np.min(targets)

    # Handle division by zero
    if range_val == 0:
        return 0.0

    return mae / range_val


def calculate_nrms(predictions, targets):
    """
    Calculate Normalized Root Mean Square Error (NRMS).

    Args:
        predictions (torch.Tensor): Predicted values
        targets (torch.Tensor): Ground truth values

    Returns:
        float: NRMS value
    """
    # Convert to numpy if tensors
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.detach().cpu().numpy()

    # Calculate NRMS
    mse = np.mean((predictions - targets) ** 2)
    rmse = np.sqrt(mse)
    range_val = np.max(targets) - np.min(targets)

    # Handle division by zero
    if range_val == 0:
        return 0.0

    return rmse / range_val


def calculate_pearson_correlation(predictions, targets):
    """
    Calculate Pearson correlation coefficient.

    Args:
        predictions (torch.Tensor): Predicted values
        targets (torch.Tensor): Ground truth values

    Returns:
        float: Pearson correlation coefficient
    """
    # Convert to numpy if tensors
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.detach().cpu().numpy()

    # Calculate Pearson correlation
    corr, _ = pearsonr(predictions, targets)

    return corr


def calculate_spearman_correlation(predictions, targets):
    """
    Calculate Spearman rank correlation coefficient.

    Args:
        predictions (torch.Tensor): Predicted values
        targets (torch.Tensor): Ground truth values

    Returns:
        float: Spearman rank correlation coefficient
    """
    # Convert to numpy if tensors
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.detach().cpu().numpy()

    # Calculate Spearman rank correlation
    corr, _ = spearmanr(predictions, targets)

    return corr


def calculate_kendall_correlation(predictions, targets):
    """
    Calculate Kendall tau rank correlation coefficient.

    Args:
        predictions (torch.Tensor): Predicted values
        targets (torch.Tensor): Ground truth values

    Returns:
        float: Kendall tau rank correlation coefficient
    """
    # Convert to numpy if tensors
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.detach().cpu().numpy()

    # Calculate Kendall tau rank correlation
    corr, _ = kendalltau(predictions, targets)

    return corr


def evaluate_predictions(predictions, targets):
    """
    Evaluate predictions using multiple metrics.

    Args:
        predictions (torch.Tensor): Predicted values
        targets (torch.Tensor): Ground truth values

    Returns:
        dict: Dictionary containing evaluation metrics
    """
    return {
        'nmae': calculate_nmae(predictions, targets),
        'nrms': calculate_nrms(predictions, targets),
        'pearson': calculate_pearson_correlation(predictions, targets),
        'spearman': calculate_spearman_correlation(predictions, targets),
        'kendall': calculate_kendall_correlation(predictions, targets)
    }