import torch
import math


def calculate_kendall_tau(X_target, y_true, y_pred_mean, y_pred_cov, device=None):
    if device is None:
        if isinstance(y_true, torch.Tensor) and y_true.is_cuda:
            device = y_true.device
        else:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not isinstance(y_true, torch.Tensor):
        y_true = torch.tensor(y_true, dtype=torch.float32)
    if not isinstance(y_pred_mean, torch.Tensor):
        y_pred_mean = torch.tensor(y_pred_mean, dtype=torch.float32)
    if not isinstance(y_pred_cov, torch.Tensor):
        y_pred_cov = torch.tensor(y_pred_cov, dtype=torch.float32)

    y_true = y_true.to(device).flatten()
    y_pred_mean = y_pred_mean.to(device).flatten()
    y_pred_cov = y_pred_cov.to(device)

    N = y_true.shape[0]
    if N < 2:
        return float("nan")

    var = y_pred_cov.diag()

    idx_i, idx_j = torch.triu_indices(N, N, offset=1, device=device)

    mu_i = y_pred_mean[idx_i]
    mu_j = y_pred_mean[idx_j]
    var_i = var[idx_i]
    var_j = var[idx_j]
    cov_ij = y_pred_cov[idx_i, idx_j]

    sigma_diff = torch.sqrt(torch.clamp(var_i + var_j - 2.0 * cov_ij, min=1e-12))

    z = (mu_i - mu_j) / sigma_diff

    p = 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))

    yi = y_true[idx_i]
    yj = y_true[idx_j]
    ind_gt = (yi > yj).float()
    ind_lt = (yi < yj).float()
    Nc = torch.sum(ind_gt * p + ind_lt * (1.0 - p))
    Nd = torch.sum(ind_gt * (1.0 - p) + ind_lt * p)

    total_pairs = idx_i.numel()

    if total_pairs == 0:
        return float("nan")

    tau = (Nc - Nd) / float(total_pairs)

    return tau.item()
