from collections import defaultdict

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import digamma, softmax
from sklearn.metrics import roc_auc_score, auc
import sklearn.metrics

EPS = 1e-12

matplotlib.rcParams.update({
    "font.family": "Times New Roman",
    "axes.labelsize": 18,
    "font.size": 18,
    "legend.fontsize": 16,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
})


def kl_divergence(probs1, probs2):
    return np.sum(probs1 * (np.log(probs1 + EPS) - np.log(probs2 + EPS)), axis=1)


def expected_pairwise_kl_divergence(probs):
    kl = np.zeros((probs.shape[0],), dtype=np.float32)
    for i in range(probs.shape[1]):
        for j in range(probs.shape[1]):
            kl += kl_divergence(probs[:, i, :], probs[:, j, :])

    return kl

def entropy_of_expected(probs):
    if len(probs.shape) == 2:
        probs = softmax(probs, axis=1)
        return -np.sum(probs * np.log(probs + EPS), axis=1)
    elif len(probs.shape) == 3:
        probs = softmax(probs, axis=2)
        mean_probs = np.mean(probs, axis=0)
        print(mean_probs.shape)
        # breakpoint()
        return -np.sum(mean_probs * np.log(mean_probs + EPS), axis=1)


def expected_entropy(probs):
    probs = softmax(probs, axis=-1)
    log_probs = -np.log(probs + EPS)
    # breakpoint()
    return np.mean(np.sum(probs * log_probs, axis=2), axis=0)

def expected_entropy_pn(logits, offset=1):
    # logits = softmax(logits, axis=1)
    alphas = np.exp(logits) + offset 
    alpha0 = np.sum(alphas, axis=1, keepdims=True) + EPS
    probs = alphas / alpha0 
    return -np.sum(probs * (digamma(alphas + 1) - digamma(alpha0 + 1)), axis=1) 
    


def get_ensemble_proxy_uncertainty_values(ensemble_logits):
    # ensemble_logits [batch_size, ensemble_size, num_classes]
    probs = softmax(ensemble_logits, axis=2)
    mean_probs = np.mean(probs, axis=1)
    conf = np.max(mean_probs, axis=1)

    eoe = entropy_of_expected(probs)
    exe = expected_entropy(probs)
    mutual_info = eoe - exe

    epkl = expected_pairwise_kl_divergence(probs)
    mkl = epkl - mutual_info

    num_classes = probs.shape[1]
    alpha0 = (num_classes - 1) / (2 * mkl[:, None] + EPS)
    alphas = mean_probs * alpha0

    proxy_eoe = -np.sum(mean_probs * np.log(mean_probs + EPS), axis=1)
    proxy_exe = -np.sum((alphas / alpha0) * (digamma(alphas + 1) - digamma(alpha0 + 1)),
                        axis=1)
    proxy_mutual_info = proxy_eoe - proxy_exe

    proxy_epkl = np.squeeze((alphas.shape[1] - 1) / alpha0)
    proxy_mkl = epkl - proxy_mutual_info

    uncertainty = {'confidence': -conf,
                   'entropy_of_expected': proxy_eoe,
                   'expected_entropy': proxy_exe,
                   'mutual_information': proxy_mutual_info,
                   'EPKL': proxy_epkl,
                   'MKL': proxy_mkl}

    return uncertainty


def get_model_uncertainty_values(logits, offset=0.5):
    # logits [batch_size, num_classes]
    # logits = np.clip(logits, -20, 20)
    # logits = softmax(logits, axis=1)
    alphas = np.exp(logits) + offset 
    alpha0 = np.sum(alphas, axis=1, keepdims=True) + EPS
    probs = alphas / alpha0 
    # print(alphas)
    # print(alpha0)

    conf = np.max(probs, axis=1)

    entropy_of_exp = -np.sum(probs * np.log(probs + EPS), axis=1)
    expected_entropy = -np.sum(probs * (digamma(alphas + 1) - digamma(alpha0 + 1)), axis=1) 
    # expected_entropy = expected_entropy_pn(logits, offset=offset)
    mutual_info = entropy_of_exp - expected_entropy

    epkl = np.squeeze((alphas.shape[1] - 1) / alpha0)
    mkl = epkl - mutual_info

    uncertainty = {'confidence': np.mean(conf),
                   'entropy_of_expected': np.mean(entropy_of_exp),
                   'expected_entropy': np.mean(expected_entropy),
                   'mutual_information': np.mean(mutual_info),
                   'EPKL': np.mean(epkl),
                   'MKL': np.mean(mkl),
                   }

    return uncertainty



def get_calibration_errors(full_probs, accuracies, targets, num_bins):
    probs = full_probs.max(axis=1)
    assert len(probs) == len(accuracies)
    calibration_error = 0
    max_calibration_error = 0
    bins = np.linspace(0, 1, num_bins + 1, endpoint=True)
    total_samples = len(probs)

    bin_indices = np.digitize(probs, bins)

    unique_bin_indices = np.unique(bin_indices)

    for bin_ind in unique_bin_indices:
        mask_for_bin, = np.where(bin_indices == bin_ind)
        samples_in_bin = len(mask_for_bin)

        mean_probs = probs[mask_for_bin].mean()
        mean_accuracies = accuracies[mask_for_bin].mean()

        error_for_bin = np.abs(mean_probs - mean_accuracies).item()
        calibration_error += (samples_in_bin / total_samples) * error_for_bin
        max_calibration_error = max(max_calibration_error, error_for_bin)

    onehots = np.zeros_like(full_probs)
    onehots[np.arange(len(probs)), targets] = 1

    brier = np.square(full_probs - onehots).sum(1).mean().item()

    nll = -np.log(full_probs[np.arange(len(probs)), targets]+EPS).mean().item()

    return {"ECE": calibration_error, "MCE": max_calibration_error, "brier": brier, "NLL": nll}



def calc_tot_unc_auc_roc(in_probs, out_probs, plot=False):
    in_labels = np.zeros(in_probs.shape[0])
    out_labels = np.ones(out_probs.shape[0])
    all_labels = np.concatenate([in_labels, out_labels], axis=0)

    in_t_unc = entropy_of_expected(in_probs)
    out_t_unc = entropy_of_expected(out_probs)
    all_t_unc = np.concatenate([in_t_unc, out_t_unc], axis=0)
    
    # breakpoint()

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(all_labels, all_t_unc)
    auc_roc = sklearn.metrics.auc(fpr, tpr)

    if plot:
        plt.plot(fpr, tpr, label='Total uncertainty')
        plt.title('Total uncertainty OOD detection ROC curve')
        plt.xlabel('False positive rate')
        plt.ylabel('True positive rate')
        plt.legend()
        plt.show()

        plt.plot(thresholds, fpr, label='Total uncertainty')
        plt.xlabel('T')
        plt.ylabel('FPR')
        plt.legend()
        plt.show()

        plt.plot(thresholds, tpr, label='Total uncertainty')
        plt.xlabel('T')
        plt.ylabel('TPR')
        plt.legend()
        plt.show()

    return auc_roc

def calc_ensemble_know_unc_auc_roc(in_probs, out_probs, plot=False):
    in_labels = np.zeros(in_probs.shape[1])
    out_labels = np.ones(out_probs.shape[1])
    all_labels = np.concatenate([in_labels, out_labels], axis=0)

    in_t_unc = entropy_of_expected(in_probs)
    out_t_unc = entropy_of_expected(out_probs)

    in_d_unc = expected_entropy(in_probs)
    out_d_unc = expected_entropy(out_probs)

    in_k_unc = in_t_unc - in_d_unc
    out_k_unc = out_t_unc - out_d_unc
    all_k_unc = np.concatenate([in_k_unc, out_k_unc], axis=0)
    # breakpoint()

    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(all_labels, all_k_unc)
    auc_roc = sklearn.metrics.auc(recall, precision)

    if plot:
        plt.plot(recall, precision, label='Knowledge uncertainty')
        plt.title('Knowledge uncertainty OOD detection ROC curve')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.legend()
        plt.show()
        auc_roc = sklearn.metrics.auc(recall, precision)

    return auc_roc

def calc_pn_know_unc_auc_roc(in_preds, out_preds, plot=False):
    in_labels = np.zeros(in_preds.shape[0])
    out_labels = np.ones(out_preds.shape[0])
    all_labels = np.concatenate([in_labels, out_labels], axis=0)

    in_t_unc = entropy_of_expected(in_preds)
    out_t_unc = entropy_of_expected(out_preds)

    in_d_unc = expected_entropy_pn(in_preds)
    out_d_unc = expected_entropy_pn(out_preds)
    # breakpoint()

    in_k_unc = in_t_unc - in_d_unc
    out_k_unc = out_t_unc - out_d_unc
    all_k_unc = np.concatenate([in_k_unc, out_k_unc], axis=0)

    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(all_labels, all_k_unc)
    auc_roc = sklearn.metrics.auc(recall, precision)

    if plot:
        plt.plot(recall, precision, label='Knowledge uncertainty')
        plt.title('Knowledge uncertainty OOD detection ROC curve')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.legend()
        plt.show()
        auc_roc = sklearn.metrics.auc(recall, precision)

    return auc_roc


def reject_class(labels, preds, measure, measure_name, dataset_name, image_dir):
    inds = np.argsort(measure)

    total_data = preds.shape[0]

    errors = np.cumsum(labels[inds] != preds[inds], dtype=np.float32) * 100 / total_data
    percentages = np.linspace(100 / total_data, 100, num=total_data, endpoint=True, dtype=np.float32)

    base_error = errors[-1]
    n_items = errors.shape[0]
    auc_uns = 1 - auc(percentages / 100, errors[::-1] / 100)

    random_rejection = base_error * np.linspace(1, 1 - (n_items - 1) / n_items, num=n_items, endpoint=True,
                                                dtype=np.float32)
    auc_rnd = 1 - auc(percentages / 100, random_rejection / 100)

    orc_rejection = base_error * np.linspace(1,
                                             1 - (int(base_error / 100 * n_items) - 1) / (base_error / 100 * n_items),
                                             num=int(base_error / 100 * n_items), endpoint=True, dtype=np.float32)
    orc = np.zeros_like(errors)
    orc[0:orc_rejection.shape[0]] = orc_rejection
    auc_orc = 1 - auc(percentages / 100, orc / 100)

    rejection_ratio = (auc_uns - auc_rnd) / (auc_orc - auc_rnd)

    random_rejection = np.squeeze(random_rejection)
    orc = np.squeeze(orc)
    errors = np.squeeze(errors)
    if image_dir is not None:
        plot_dir = image_dir / dataset_name
        plot_dir.mkdir(parents=True, exist_ok=True)

        plt.plot(percentages, orc, lw=2)
        plt.fill_between(percentages, orc, random_rejection, alpha=0.5)
        plt.plot(percentages, errors[::-1], lw=2)
        plt.fill_between(percentages, errors[::-1], random_rejection, alpha=0.0)
        plt.plot(percentages, random_rejection, 'k--', lw=2)
        plt.legend(['Oracle', 'Uncertainty', 'Random'])
        plt.xlabel('Percentage of predictions rejected to oracle')
        plt.ylabel('Classification Error (%)')
        plt.title(f'{dataset_name}-{measure_name}')
        plt.savefig(image_dir / dataset_name / f'rej-{dataset_name}-{measure_name}-oracle.pdf', format='pdf',
                    bbox_inches='tight', dpi=300)
        plt.close()

        plt.plot(percentages, orc, lw=2)
        plt.fill_between(percentages, orc, random_rejection, alpha=0.0)
        plt.plot(percentages, errors[::-1], lw=2)
        plt.fill_between(percentages, errors[::-1], random_rejection, alpha=0.5)
        plt.plot(percentages, random_rejection, 'k--', lw=2)
        plt.legend(['Oracle', 'Uncertainty', 'Random'])
        plt.xlabel('Percentage of predictions rejected to oracle')
        plt.ylabel('Classification Error (%)')
        plt.title(f'{dataset_name}-{measure_name}')
        plt.savefig(image_dir / dataset_name / f'rej-{dataset_name}-{measure_name}-uncertainty.pdf', format='pdf',
                    bbox_inches='tight', dpi=300)
        plt.close()

    return rejection_ratio, auc_uns