from collections import defaultdict

import torch
from torch import nn
from torch.nn import functional as F

from src.utils.common import str_to_num

eps = 1e-7


def get_simplex_ece(n, logits, labels, p=1):
    softmaxes = F.softmax(logits, dim=1)
    labels_OHE = nn.functional.one_hot(labels.to(torch.int64), num_classes=logits.shape[1]).to(torch.float32)

    bin_boundaries = [(i+1)/n for i in range(n)]
    binning_dict = defaultdict(list)
    labels_dict = defaultdict(list)
    for i, point in enumerate(softmaxes):
        point_list = []
        for classes in range(3):
            for split in range(n):
                if point[classes] <= bin_boundaries[split]:
                    point_list.append(split)
                    break
        key = str_to_num(point_list)
        binning_dict[key].append(point)
        labels_dict[key].append(labels_OHE[i])

    ce_sum = 0
    for key, value_list in binning_dict.items():
        for element in value_list:
            # Sum of value_list, without element
            labels_sum = torch.sum(torch.stack(labels_dict[key], dim=0), dim=0) - element
            if torch.sum(labels_sum) == 0 or (len(value_list)-1) == 0:
                continue
            H = labels_sum / (len(value_list) - 1 + eps)
            ce_sum += torch.sum(torch.abs(H - element)**p)

    ECE = ce_sum / len(logits)

    return binning_dict, ECE