import numpy as np
import matplotlib.pyplot as plt


palette = plt.get_cmap('tab20')
font1 = {
'family' : 'sans-serif',
'weight' : 'normal',
'size'   : 12,
}
marker_list = ['D' ,'o', 'v', "^", "s", "H"]

def get_bucket_scores(y_score, buckets=10):
    """
    Organizes real-valued posterior probabilities into buckets.
    For example, if we have 10 buckets, the probabilities 0.0, 0.1,
    0.2 are placed into buckets 0 (0.0 <= p < 0.1), 1 (0.1 <= p < 0.2),
    and 2 (0.2 <= p < 0.3), respectively.
    """

    bucket_values = [[] for _ in range(buckets)]
    bucket_indices = [[] for _ in range(buckets)]
    for i, score in enumerate(y_score):
        for j in range(buckets):
            if score < float((j + 1) / buckets):
                break
        bucket_values[j].append(score)
        bucket_indices[j].append(i)
    return (bucket_values, bucket_indices)

def get_bucket_confidence(bucket_values):
    """
    Computes average confidence for each bucket. If a bucket does
    not have predictions, returns -1.
    """

    return [
        np.mean(bucket)
        if len(bucket) > 0 else -1.
        for bucket in bucket_values
    ]


def get_bucket_accuracy(bucket_values, y_true, y_pred):
    """
    Computes accuracy for each bucket. If a bucket does
    not have predictions, returns -1.
    """

    per_bucket_correct = [
        [int(y_true[i] == y_pred[i]) for i in bucket]
        for bucket in bucket_values
    ]
    return [
        np.mean(bucket)
        if len(bucket) > 0 else -1.
        for bucket in per_bucket_correct
    ]

def calculate_error(n_samples, bucket_values, bucket_confidence, bucket_accuracy):
    """
    Computes several metrics used to measure calibration error:
        - Expected Calibration Error (ECE): \sum_k (b_k / n) |acc(k) - conf(k)|
        - Maximum Calibration Error (MCE): max_k |acc(k) - conf(k)|
        - Total Calibration Error (TCE): \sum_k |acc(k) - conf(k)|
    """

    assert len(bucket_values) == len(bucket_confidence) == len(bucket_accuracy)
    assert sum(map(len, bucket_values)) == n_samples

    expected_error, max_error, total_error = 0., 0., 0.
    for (bucket, accuracy, confidence) in zip(
        bucket_values, bucket_accuracy, bucket_confidence
    ):
        if len(bucket) > 0:
            delta = abs(accuracy - confidence)
            expected_error += (len(bucket) / n_samples) * delta
            max_error = max(max_error, delta)
            total_error += delta
    return (expected_error * 100., max_error * 100., total_error * 100.)


def calibration_util(preds, confs, labels):
    acc = np.mean(np.array(preds) == np.array(labels))
    avg_conf = np.mean(np.array(confs))
    bucket_values, bucket_indices = get_bucket_scores(confs, buckets=10)
    bucket_confidence = get_bucket_confidence(bucket_values)
    bucket_accuracy = get_bucket_accuracy(bucket_indices, labels, preds)
    ece = calculate_error(len(preds), bucket_values, bucket_confidence, bucket_accuracy)[0]

    bin_acc_avg = []
    bin_conf_avg = []
    bin_ratio = []

    for i, val in enumerate(bucket_confidence):
        if val != -1:
            bin_acc_avg.append(bucket_accuracy[i])
            bin_conf_avg.append(bucket_confidence[i])
            bin_ratio.append(len(bucket_indices[i]))
    bin_ratio = np.array(bin_ratio)
    bin_ratio = bin_ratio / bin_ratio.sum()

    # Conf correct
    correct_idx = np.array(preds) == np.array(labels)
    avg_conf_correct = np.mean(np.array(confs)[correct_idx])
    bucket_values_correct, bucket_indices_correct = get_bucket_scores(np.array(confs)[correct_idx], buckets=10)
    bucket_confidence_correct = get_bucket_confidence(bucket_values_correct)

    bin_conf_avg_correct = []
    bin_ratio_correct = []

    for i, val in enumerate(bucket_confidence_correct):
        if val != -1:
            bin_conf_avg_correct.append(bucket_confidence_correct[i])
            bin_ratio_correct.append(len(bucket_indices_correct[i]))
    bin_ratio_correct = np.array(bin_ratio_correct)
    bin_ratio_correct = bin_ratio_correct / bin_ratio_correct.sum()


    # Conf incorrect
    incorrect_idx = np.array(preds) != np.array(labels)
    avg_conf_incorrect = np.mean(np.array(confs)[incorrect_idx])
    bucket_values_incorrect, bucket_indices_incorrect = get_bucket_scores(np.array(confs)[incorrect_idx], buckets=10)
    bucket_confidence_incorrect = get_bucket_confidence(bucket_values_incorrect)

    bin_conf_avg_incorrect = []
    bin_ratio_incorrect = []

    for i, val in enumerate(bucket_confidence_incorrect):
        if val != -1:
            bin_conf_avg_incorrect.append(bucket_confidence_incorrect[i])
            bin_ratio_incorrect.append(len(bucket_indices_incorrect[i]))
    bin_ratio_incorrect = np.array(bin_ratio_incorrect)
    bin_ratio_incorrect = bin_ratio_incorrect / bin_ratio_incorrect.sum()

    ece /= 100
    
    return_dict = {
        "acc": acc,
        "ece": ece,
        "bin_conf_avg": bin_conf_avg,
        "bin_acc_avg": bin_acc_avg,
        "bin_ratio": bin_ratio,
        "avg_conf": avg_conf,
        "avg_conf_correct": avg_conf_correct,
        "avg_conf_incorrect": avg_conf_incorrect,
        "bin_conf_avg_correct": bin_conf_avg_correct,
        "bin_ratio_correct": bin_ratio_correct,
        "bin_conf_avg_incorrect": bin_conf_avg_incorrect,
        "bin_ratio_incorrect": bin_ratio_incorrect,
    }

    return return_dict

def plot_diagrams(ax1, ax2, calibration_dict, base_model, dataset_name, prefix=""):
    ax1.plot(calibration_dict['bin_conf_avg'], calibration_dict['bin_acc_avg'], linewidth=3.0, marker=marker_list[0], markersize=10, color=palette(0), label=f"{base_model}, Acc: {calibration_dict['acc']:.4f}, ECE: {calibration_dict['ece']:.4f}")
    ax1.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100), linewidth=3.0, linestyle='dashed', color='black', label="Perfect Calibration")

    ax2.plot(calibration_dict['bin_conf_avg'], calibration_dict['bin_ratio'], linewidth=3.0, marker=marker_list[0], markersize=10, color=palette(0), label=f"{base_model}, Avg conf: {calibration_dict['avg_conf']:.4f}")
    ax2.plot(calibration_dict['bin_conf_avg_correct'], calibration_dict['bin_ratio_correct'], linewidth=3.0, marker=marker_list[0], markersize=10, color=palette(2), label=f"{base_model}, Avg conf (Correct): {calibration_dict['avg_conf_correct']:.4f}")
    ax2.plot(calibration_dict['bin_conf_avg_incorrect'], calibration_dict['bin_ratio_incorrect'], linewidth=3.0, marker=marker_list[0], markersize=10, color=palette(4), label=f"{base_model}, Avg conf (Incorrect): {calibration_dict['avg_conf_incorrect']:.4f}")
    ax2.vlines([calibration_dict['acc'], calibration_dict['avg_conf']], ymin=0, ymax=1, linestyles='dashed', color=palette(0))
    ax2.text(calibration_dict['acc'], 0.7, "Acc.", rotation=-90, fontsize=12, fontweight='bold')
    ax2.text(calibration_dict['avg_conf'], 0.7, "Avg. Conf", rotation=-90, fontsize=12, fontweight='bold')

    ax1.set_title(f"Task: {dataset_name} Reliable Diagram {prefix}", fontsize=16)
    ax1.set_xlabel("Confidence", fontsize=16)
    ax1.set_ylabel("Accuracy", fontsize=16)
    ax1.set_xlim(0.0, 1.0)
    ax1.set_ylim(0.0, 1.0)
    ax1.tick_params(axis='both', which='major', labelsize=16)
    ax1.legend(loc='best', prop=font1, frameon=True, fancybox=True, framealpha=0.8, borderpad=1)

    ax2.set_xlabel("Confidence", fontsize=16)
    ax2.set_title(f"Task: {dataset_name} Confidence Histogram {prefix}", fontsize=16)
    ax2.set_ylabel("% of Samples", fontsize=16)
    ax2.set_xlim(0.0, 1.0)
    ax2.set_ylim(0.0, 1.0)
    ax2.tick_params(axis='both', which='major', labelsize=16)
    ax2.legend(loc='best', prop=font1, frameon=True, fancybox=True, framealpha=0.8, borderpad=1)