import numpy as np
import tensorflow as tf
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
from scipy.stats import pearsonr, spearmanr, kendalltau

#-------------------------------
# Metrics for Failure Prediction
#-------------------------------

def compute_auroc(true_label, confidence):
    auroc = roc_auc_score(true_label, confidence)
    return auroc

def compute_fpr95(true_label, confidence):
    fpr, tpr, thresholds = roc_curve(true_label, confidence)
    fpr95 = fpr[np.where(tpr >= 0.95)[0][0]]
    return fpr95

def compute_auprc_success(true_label, confidence):
    auprc = average_precision_score(true_label, confidence, pos_label=1)
    return auprc

def compute_auprc_error(true_label, confidence):
    auprc = average_precision_score(true_label, -confidence, pos_label=0)
    return auprc

def compute_aurc(true_label, confidence):
    coverages = []
    risks = []
    residuals = 1 - true_label
    n = len(residuals)
    idx_sorted = np.argsort(confidence)
    cov = n
    error_sum = sum(residuals[idx_sorted])
    coverages.append(cov/ n)
    risks.append(error_sum / n)
    weights = []
    tmp_weight = 0
    
    for i in range(0, len(idx_sorted) - 1):
        cov = cov-1
        error_sum = error_sum - residuals[idx_sorted[i]]
        selective_risk = error_sum /(n - 1 - i)
        tmp_weight += 1
        if i == 0 or confidence[idx_sorted[i]] != confidence[idx_sorted[i - 1]]:
            coverages.append(cov / n)
            risks.append(selective_risk)
            weights.append(tmp_weight / n)
            tmp_weight = 0
    if tmp_weight > 0:
        coverages.append(0)
        risks.append(risks[-1])
        weights.append(tmp_weight / n)
    aurc = sum([(risks[i] + risks[i+1]) * 0.5 * weights[i] for i in range(len(weights)) ])
    return aurc

#------------------------------------
# Metrics for Correlation with Margin
#------------------------------------
    
def spearman_corr(metric, margin):
    corr, _ = spearmanr(metric, margin)
    return corr

def pearson_corr(metric,margin):
    corr, _ = pearsonr(metric, margin)
    return corr

def kendall_corr(metric, margin):
    corr, _ = kendalltau(metric, margin)
    return corr

#-----------------------------------
# Metrics for Confidence Calibration
#-----------------------------------

def compute_ece(confidence, true_y, pred_y, n_bins):
    bin_size = 1.0 / n_bins
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    indices = np.digitize(confidence, bins, right=True)

    bin_acc = np.zeros(n_bins, dtype=np.float32)
    bin_conf = np.zeros(n_bins, dtype=np.float32)
    bin_counts = np.zeros(n_bins, dtype=np.int32)

    for b in range(n_bins):
        selected = np.where(indices == b + 1)[0]
        if len(selected) > 0:
            bin_acc[b] = np.mean(true_y[selected] == pred_y[selected])
            bin_conf[b] = np.mean(confidence[selected])
            bin_counts[b] = len(selected)

    gaps = np.abs(bin_acc - bin_conf)
    ece = np.sum(gaps * bin_counts) / np.sum(bin_counts) * 100
    return ece

def compute_nll(confidence, true_y, n_classes):
    cce = tf.keras.losses.CategoricalCrossentropy()
    true_y = tf.one_hot(true_y, depth=n_classes)
    return cce(true_y, confidence).numpy()

def compute_brier_score(confidence, true_y, n_classes):
    true_y = tf.one_hot(true_y, depth=n_classes)
    brier_score = np.mean(np.sum((confidence - true_y)**2, axis=1))
    return brier_score

def compute_classification_error(metric, true_y, k=1):
    count = 0
    for i in range(len(metric)):
        idx = np.argsort(metric[i])[-k:]
        if true_y[i] in idx:
            count += 1
    acc = count / len(metric)
    return 1 - acc