import torch
import numpy as np


def aceloss(softmaxes, labels, n_bins=15):
    d = softmaxes.device

    confidences, predictions = torch.max(softmaxes, 1)

    conf_sorted = torch.sort(confidences)[0]
    n_per_bin = conf_sorted.size(0) / n_bins
    bin_boundaries = torch.zeros(n_bins + 1, device=d)

    for i in range(1, n_bins):
        bin_boundaries[i] = conf_sorted[int(np.round(i * n_per_bin)) - 1]

    bin_boundaries[0] = 0
    bin_boundaries[-1] = 1

    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    accuracies = predictions.eq(labels)
    accuracy_in_bin_list = []
    avg_confidence_in_bin_list = []

    ece = torch.zeros(1, device=d)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0.0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

            accuracy_in_bin_list.append(accuracy_in_bin)
            avg_confidence_in_bin_list.append(avg_confidence_in_bin)

    acc_in_bin = torch.tensor(accuracy_in_bin_list, device=d)
    avg_conf_in_bin = torch.tensor(avg_confidence_in_bin_list, device=d)

    return ece, acc_in_bin, avg_conf_in_bin


def classwise_ace(softmaxes, labels, n_bins=15):
    num_classes = softmaxes.size(1)
    ace_c = []
    acc_in_bin_c = []
    avg_confidence_in_bin_c = []
    for c in range(num_classes):
        softmaxes_c = softmaxes[torch.where(labels == c)]
        labels_c = labels[torch.where(labels == c)]
        ace, acc_in_bin, avg_confidence_in_bin = aceloss(softmaxes_c, labels_c, n_bins)
        ace_c.append(ace.item())
        acc_in_bin_c.append(acc_in_bin)
        avg_confidence_in_bin_c.append(avg_confidence_in_bin)
    return ace_c, acc_in_bin_c, avg_confidence_in_bin_c
