from enum import Enum

import torch.nn as nn
import numpy as np
import torch


def set_bn_eval(m):
    if isinstance(m, nn.modules.batchnorm._BatchNorm):
        m.eval()


def set_bn_train(m):
    if isinstance(m, nn.modules.batchnorm._BatchNorm):
        m.train()


class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3


class AverageMeter:
    def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)

    def summary(self):
        fmt_str = ""
        if self.summary_type is Summary.NONE:
            fmt_str = ""
        elif self.summary_type is Summary.AVERAGE:
            fmt_str = "{name} {avg:.3f}"
        elif self.summary_type is Summary.SUM:
            fmt_str = "{name} {sum:.3f}"
        elif self.summary_type is Summary.COUNT:
            fmt_str = "{name} {count:.3f}"
        else:
            raise ValueError("invalid summary type %r" % self.summary_type)
        return fmt_str.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def metric_ece_aurc_eaurc(confidences, truths, bin_size=0.1):
    confidences = np.asarray(confidences)
    truths = np.asarray(truths)

    total = len(confidences)
    predictions = np.argmax(confidences, axis=1)
    max_confs = np.amax(confidences, axis=1)

    upper_bounds = np.arange(bin_size, 1 + bin_size, bin_size)
    accs = []
    avg_confs = []
    bin_counts = []
    ces = []

    for upper_bound in upper_bounds:
        lower_bound = upper_bound - bin_size
        acc, avg_conf, bin_count = compute_bin(lower_bound, upper_bound, max_confs, predictions, truths)
        accs.append(acc)
        avg_confs.append(avg_conf)
        bin_counts.append(bin_count)
        ces.append(abs(acc - avg_conf) * bin_count)

    ece = 100 * sum(ces) / total
    aurc, e_aurc = calc_aurc(confidences, truths)
    return ece, aurc * 1000, e_aurc * 1000


def compute_bin(conf_thresh_lower, conf_thresh_upper, conf, pred, true):
    filtered_tuples = [x for x in zip(pred, true, conf) if x[2] > conf_thresh_lower and x[2] <= conf_thresh_upper]
    if len(filtered_tuples) < 1:
        return 0, 0, 0
    else:
        correct = len([x for x in filtered_tuples if x[0] == x[1]])
        avg_conf = sum([x[2] for x in filtered_tuples]) / len(filtered_tuples)
        accuracy = float(correct)/len(filtered_tuples)
        bin_count = len(filtered_tuples)
        return accuracy, avg_conf, bin_count


def calc_aurc(confidences, labels):
    confidences = np.array(confidences)
    labels = np.array(labels)
    predictions = np.argmax(confidences, axis=1)
    max_confs = np.max(confidences, axis=1)

    n = len(labels)
    indices = np.argsort(max_confs)
    labels, predictions, confidences = labels[indices][::-1], predictions[indices][::-1], confidences[indices][::-1]
    risk_cov = np.divide(np.cumsum(labels != predictions).astype(np.float), np.arange(1, n+1))
    nrisk = np.sum(labels != predictions)
    aurc = np.mean(risk_cov)
    opt_aurc = (1./n) * np.sum(np.divide(np.arange(1, nrisk + 1).astype(np.float), n - nrisk + np.arange(1, nrisk + 1)))
    eaurc = aurc - opt_aurc
    return aurc, eaurc
