import numpy as np
import torch
import torch.nn.functional as F
import torchmetrics.functional as tmF
from torchmetrics import AUROC
from sklearn.cluster import KMeans
from scipy.stats import pearsonr, spearmanr, kendalltau

n_partitions = 4
n_groups = 4

def evaluate(y, true_prob=None, features=None, pred_prob=None, pred_logits=None):

    if pred_logits is not None:
        pred_logits = pred_logits.contiguous()
    if pred_prob is None and pred_logits is not None:
        pred_prob = torch.softmax(pred_logits, dim=1)


    results = {}

    if true_prob is not None:
        mae_dist_prob = mae_dist(pred_prob, true_prob)
        results["mae_dist_prob"] = mae_dist_prob

        mae_dist_logits = mae_dist(pred_logits, true_prob)
        results["mae_dist_logits"] = mae_dist_logits

        mse_dist_prob = mse_dist(pred_prob, true_prob)
        results["mse_dist_prob"] = mse_dist_prob

        mse_dist_logits = mse_dist(pred_logits, true_prob)
        results["mse_dist_logits"] = mse_dist_logits

        results["kl_dist_prob"] = kl_divergence(pred_prob, true_prob)
        results["kl_dist_logits"] = kl_divergence(pred_logits, true_prob)

        results["cross_entropy_prob"] = cross_entropy(pred_prob, true_prob)
        results["cross_entropy_logits"] = cross_entropy(pred_logits, true_prob)

        results["js_dist_prob"] = js_divergence(pred_prob, true_prob)
        results["js_dist_logits"] = js_divergence(pred_logits, true_prob)

        results["gen_kld_prob"] = generalized_kl_divergence(pred_prob, true_prob)
        results["gen_kld_logits"] = generalized_kl_divergence(pred_logits, true_prob)

        results["chebyshev_dist_prob"] = chebyshev_dist(pred_prob, true_prob)
        results["chebyshev_dist_logits"] = chebyshev_dist(pred_logits, true_prob)

        results["pearson_dist_prob"] = pearson_correlation(pred_prob, true_prob)
        results["pearson_dist_logits"] = pearson_correlation(pred_logits, true_prob)

        results["spearman_dist_prob"] = spearman_corr(pred_prob, true_prob)
        results["spearman_dist_logits"] = spearman_corr(pred_logits, true_prob)

        results["kendall_dist_prob"] = kendall_corr(pred_prob, true_prob)
        results["kendall_dist_logits"] = kendall_corr(pred_logits, true_prob)

    for n_bins in [10, 15, 20, 50]:
        bin_results = ece_with_bins_number(n_bins, pred_logits, pred_prob, y)
        results.update(bin_results)

    pg_results = pce_with_bgp_number(features, n_bins, n_groups, n_partitions, pred_logits, pred_prob, true_prob, y)
    results.update(pg_results)

    nll_score = nll(pred_prob, y)
    results["nll_prob"] = nll_score
    nll_score = nll(pred_logits, y)
    results["nll_logits"] = nll_score

    results["cross_entropy_y_prob"] = binary_cross_entropy(pred_prob, y)
    results["cross_entropy_y_logits"] = binary_cross_entropy(pred_logits, y)

    results["mse_y_prob"] = mse_dist(pred_prob, y)
    results["mse_y_logits"] = mse_dist(pred_logits, y)

    results["mae_y_prob"] = mae_dist(pred_prob, y)
    results["mae_y_logits"] = mae_dist(pred_logits, y)

    acc_score = acc(pred_prob, y)
    results["acc_prob"] = acc_score
    acc_score = acc(pred_logits, y)
    results["acc_logits"] = acc_score

    results["pcoc_prob"] = pcoc(pred_prob, y)
    results["pcoc_logits"] = pcoc(pred_logits, y)

    results["bias_prob"] = bias(pred_prob, y)
    results["bias_logits"] = bias(pred_logits, y)

    results["bias_abs_prob"] = bias_abs(pred_prob, y)
    results["bias_abs_logits"] = bias_abs(pred_logits, y)

    auroc = AUROC(task="binary")
    try:
        auc_score = auroc(pred_prob, y).item()
        results["auc_prob"] = auc_score
        auc_score = auroc(pred_logits, y).item()
        results["auc_logits"] = auc_score
    except:
        results["auc"] = float('nan')

    return results


def pce_with_bgp_number(features, n_bins, n_groups, n_partitions, pred_logits, pred_prob, true_prob, y):
    results = {}
    pce_value = pce_rand(n_bins, pred_prob, y)
    results[f"pce_rand_prob_b{n_bins}_g{n_groups}_p{n_partitions}"] = pce_value
    pce_value = pce_rand(n_bins, pred_logits, y)
    results[f"pce_rand_logits_b{n_bins}_g{n_groups}_p{n_partitions}"] = pce_value
    pce_value = pce_kmeans(n_bins, pred_prob, y, features, n_groups, n_partitions)

    results[f"pce_kmeans_prob_b{n_bins}_g{n_groups}_p{n_partitions}"] = pce_value
    pce_value = pce_kmeans(n_bins, pred_logits, y, features, n_groups, n_partitions)
    results[f"pce_kmeans_logits_b{n_bins}_g{n_groups}_p{n_partitions}"] = pce_value
    return results


def ece_with_bins_number(n_bins, pred_logits, pred_prob, y):
    results = {}
    ece_binary = ece(n_bins, pred_prob, y)
    results["ece_bin_" + str(n_bins) + "_prob"] = ece_binary
    ece_binary = ece(n_bins, pred_logits, y)
    results["ece_bin_" + str(n_bins) + "_logits"] = ece_binary
    ece_binary = ece_quantile(n_bins, pred_prob, y)
    results["ece_quan_" + str(n_bins) + "_prob"] = ece_binary
    ece_binary = ece_quantile(n_bins, pred_logits, y)
    results["ece_quan_" + str(n_bins) + "_logits"] = ece_binary
    return results


def mae_dist(pred, true_prob):
    # MAE l1 loss
    return F.l1_loss(pred, true_prob, reduction='mean').item()


def mse_dist(pred_prob, true_prob):
    # MSE l2 loss
    euclidean_dist_prob = F.mse_loss(true_prob, pred_prob, reduction='mean').sqrt().item()
    return euclidean_dist_prob



def huber_dist(pred, true_prob, delta=1.0):
    residual = torch.abs(true_prob - pred)
    loss = torch.where(residual <= delta,
                       0.5 * residual ** 2,
                       delta * residual - 0.5 * delta ** 2)
    return torch.mean(loss).item()


def kl_divergence(q, p, epsilon: float = 1e-8):

    p_clipped = np.clip(p.detach().numpy(), epsilon, 1.0 - epsilon)
    q_clipped = np.clip(q.detach().numpy(), epsilon, 1.0 - epsilon)

    term1 = p_clipped * np.log(p_clipped / q_clipped)
    term2 = (1 - p_clipped) * np.log((1 - p_clipped) / (1 - q_clipped))
    kl = term1 + term2

    return kl.mean().item()


def cross_entropy(pred_prob, true_prob, epsilon=1e-8):
    pred_prob = torch.clamp(pred_prob, epsilon, 1.0 - epsilon)
    pred_prob = pred_prob / pred_prob.sum(dim=-1, keepdim=True)
    return -torch.mean(true_prob * torch.log(pred_prob)).item()


def binary_cross_entropy(pred_prob, y, epsilon=1e-8):
    pred_prob = torch.clamp(pred_prob, epsilon, 1 - epsilon)

    term_0 = (1 - y) * torch.log(1 - pred_prob + epsilon)
    term_1 = y * torch.log(pred_prob + epsilon)

    loss = -torch.mean(term_0 + term_1).item()
    return loss


def js_divergence(pred_prob, true_prob, epsilon=1e-8):

    pred_prob = pred_prob + epsilon
    true_prob = true_prob + epsilon

    m = 0.5 * (pred_prob + true_prob)

    log_m = torch.log(m)

    kl_pm = F.kl_div(log_m, pred_prob, reduction='batchmean')
    kl_qm = F.kl_div(log_m, true_prob, reduction='batchmean')

    # JSD = 0.5*(KL(P||M) + KL(Q||M))
    jsd = 0.5 * (kl_pm + kl_qm)
    return jsd.item()


def generalized_kl_divergence(p, q, epsilon=1e-8):

    p = torch.clamp(p, min=epsilon)
    q = torch.clamp(q, min=epsilon)

    term1 = p * (torch.log(p) - torch.log(q))
    term2 = -p + q
    gkl = torch.sum(term1 + term2)
    return gkl.item()


def manhattan_dist(pred, true):
    return torch.sum(torch.abs(pred - true)).item()


def chebyshev_dist(pred, true):
    return torch.max(torch.abs(pred - true)).item()


def pearson_correlation(pred, true):
    corr, _ = pearsonr(pred.detach().numpy(), true.numpy())
    return corr


def spearman_corr(pred, true):
    pred_np = pred.detach().numpy().flatten()
    true_np = true.detach().numpy().flatten()

    corr, _ = spearmanr(pred_np, true_np)
    return float(corr)


def kendall_corr(pred, true):
    pred_np = pred.detach().cpu().numpy().flatten()
    true_np = true.detach().cpu().numpy().flatten()
    corr, _ = kendalltau(pred_np, true_np)
    return float(corr)


def acc(pred_prob, y):
    pred_labels = torch.argmax(pred_prob, dim=1) if pred_prob.dim() > 1 else torch.round(pred_prob)
    acc = (pred_labels == y).float().mean().item()
    return acc


def nll(pred_prob, y):
    nll = F.nll_loss(pred_prob, y).item()
    return nll


def pcoc(pred_prob, y, epsilon=1e-8):
    pcoc = pred_prob.mean() / (y.float().mean() + epsilon)
    return pcoc.item()


def bias(pred_prob, y, epsilon=1e-8):
    bias = pred_prob.mean() / (y.float().mean() + epsilon) - 1
    return bias.item()


def bias_abs(pred_prob, y, num_groups=20, epsilon=1e-8):
    group_probs = torch.ones(num_groups) / num_groups
    group_idx = torch.multinomial(group_probs, len(pred_prob), replacement=True)

    total_bias, total_weight = 0.0, 0.0
    for i in range(num_groups):
        mask = (group_idx == i)

        group_pred = pred_prob[mask].mean()
        group_y = y[mask].float().mean()
        group_weight = mask.sum().float()

        group_bias = torch.abs(group_pred / (group_y + epsilon) - 1)

        total_bias += group_bias * group_weight
        total_weight += group_weight

    return (total_bias / (total_weight + epsilon)).item()


def ece(n_bins, pred_prob, y):
    ece_binary = tmF.calibration_error(pred_prob, y, task="binary", n_bins=n_bins)
    return ece_binary.item()


def ece_quantile(n_bins, pred_prob, y):
    pred_prob = pred_prob.clone().detach()
    y = y.clone().detach()

    sorted_indices = torch.argsort(pred_prob)
    sorted_probs = pred_prob[sorted_indices]
    sorted_labels = y[sorted_indices]

    bin_edges = np.linspace(0, len(sorted_probs), n_bins + 1, dtype=int)
    ece = 0.0

    for i in range(n_bins):
        bin_indices = slice(bin_edges[i], bin_edges[i + 1])
        probs_bin = sorted_probs[bin_indices]
        labels_bin = sorted_labels[bin_indices]

        if len(probs_bin) == 0:
            continue

        acc_bin = labels_bin.float().mean()
        conf_bin = probs_bin.mean()
        ece += (acc_bin - conf_bin).abs() * len(probs_bin)

    ece /= len(y)
    return ece.item()


def pce(n_bins, partitions, pred_prob, y):

    n_samples, n_partitions = partitions.shape
    total_pce = 0.0

    for i in range(n_partitions):
        partition_labels = partitions[:, i]
        groups = torch.unique(partition_labels)

        partition_error = 0.0

        for group in groups:
            mask = (partition_labels == group)
            pred_group = pred_prob[mask]
            y_group = y[mask]

            if len(pred_group) == 0:
                continue

            group_ece = ece(n_bins, pred_group, y_group)
            weight = len(pred_group) / n_samples
            partition_error += weight * group_ece

        total_pce += partition_error

    return total_pce / n_partitions


def pce_rand(n_bins, pred_prob, y, n_groups=4, n_partitions=2):
    n_samples = y.shape[0]
    j = n_groups
    partitions = torch.stack([
        torch.randint(0, j, (n_samples,)) for _ in range(n_partitions)
    ], dim=1)
    return pce(n_bins, partitions, pred_prob, y)


def pce_kmeans(n_bins, pred_prob, y, features, n_groups=4, n_partitions=2):

    seeds = [i for i in range(n_partitions)]

    partitions = []
    for seed in seeds:
        np.random.seed(seed)
        kmeans = KMeans(n_clusters=n_groups, init='random', random_state=seed)
        cluster_labels = kmeans.fit_predict(features.detach().numpy())
        partitions.append(torch.tensor(cluster_labels, dtype=torch.long))

    partitions_tensor = torch.stack(partitions, dim=1)  # 形状 [n_samples, n_partitions]
    return pce(n_bins, partitions_tensor, pred_prob, y)
