import torch
import numpy as np
from sklearn.metrics import confusion_matrix as cm
from sklearn.metrics import multilabel_confusion_matrix as mcm
from sklearn.metrics import precision_recall_fscore_support as pr
from sklearn.metrics import average_precision_score as aps
from torchmetrics import Metric
from torchmetrics.classification import MultilabelAveragePrecision


class AccuracyPL(Metric):
    full_state_update = False
    
    def __init__(self, topk=(1,)):
        super().__init__()
        self.topk = topk
        for k in topk:
            self.add_state(f"correct_{k}", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state(f"total", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape[0] == target.shape[0]
        maxk = min(max(self.topk), preds.size()[1])
        batch_size = target.size(0)
        _, pred = preds.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))
        self.correct_1 += correct[:1].reshape(-1).float().sum(0) * 100.
        self.correct_5 += correct[:5].reshape(-1).float().sum(0) * 100.
        self.total += target.numel()

    def compute(self):
        return (self.correct_1 / self.total), (self.correct_5 / self.total), self.total


class AP_PL(MultilabelAveragePrecision):
    is_differentiable = False
    higher_is_better = None
    full_state_update = False
    
    def __init__(self, **kwargs):
        super().__init__(
            **kwargs
        )
        self.add_state(f"preds", default=[], dist_reduce_fx="cat")
        self.add_state(f"target", default=[], dist_reduce_fx="cat")
        self.add_state(f"total", default=torch.tensor(0.), dist_reduce_fx="sum")
        
    def update(self, pred: torch.Tensor, y: torch.Tensor):
        assert pred.shape[0] == y.shape[0]
        self.preds.append(pred)
        self.target.append(y)
        self.total += y.shape[0]
        

def calc_acc(x):
    return np.round(np.sum(np.diag(x)) / np.sum(x), 4)


def accuracy(output, target, topk=(1,)):
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100.\
            / batch_size for k in topk]


def multi_label_acc(output, target, class_to_idx):
    result = {}
    values = sorted(list(class_to_idx.values()))
    idx_to_class = {str(j): i for i, j in class_to_idx.items()}
    mat = mcm(output, target, labels=values)
    for ind, i in enumerate(values):
        result[idx_to_class[str(i)]] = calc_acc(mat[ind])
    return result


def precision_recall(output, target, class_to_idx):
    precision, recall, _, _ = pr(target, 
                                 output, 
                                 average=None,
                                 labels=list(class_to_idx.values()))
    result = {}
    for name, value in class_to_idx.items():
        result[name + '_precision'] = precision[value]
        result[name + '_recall'] = recall[value]
    result['mean_precision'] = np.mean(precision)
    result['mean_recall'] = np.mean(recall)
    return result


def compute_bin(
    conf_thresh_lower: float,
    conf_thresh_upper: float, 
    conf, pred, true):
    filtered_tuples = [x for x in zip(pred, true, conf)
                      if conf_thresh_lower < x[2] <= conf_thresh_upper]
    if len(filtered_tuples) < 1:
        return 0, 0, 0
    else:
        correct = len([x for x in filtered_tuples if x[0] == x[1]])
        avg_conf = sum([x[2] for x in filtered_tuples]) / len(filtered_tuples)
        acc = float(correct)/len(filtered_tuples)
        bin_count = len(filtered_tuples)
        return acc, avg_conf, bin_count


def calc_aurc(confidences, labels):
    confidences = np.array(confidences)
    labels = np.array(labels)
    predictions = np.argmax(confidences, axis=1)
    max_confs = np.max(confidences, axis=1)
    n = len(labels)
    indices = np.argsort(max_confs)
    labels = labels[indices][::-1]
    predictions = predictions[indices][::-1]
    confidences = confidences[indices][::-1]
    risk_cov = np.divide(
        np.cumsum(labels != predictions).astype(np.float),
        np.arange(1, n+1))
    nrisk = np.sum(labels != predictions)
    aurc = np.mean(risk_cov)
    opt_aurc = (1./n) * np.sum(
        np.divide(np.arange(1, nrisk + 1).astype(np.float),
        n - nrisk + np.arange(1, nrisk + 1)))
    eaurc = aurc - opt_aurc
    return aurc, eaurc


def metric_ece_aurc_eaurc(confidences, truths, bin_size=0.1):
    confidences = np.asarray(confidences)
    truths = np.asarray(truths)
    total = len(confidences)
    predictions = np.argmax(confidences, axis=1)
    max_confs = np.amax(confidences, axis=1)
    upper_bounds = np.arange(bin_size, 1 + bin_size, bin_size)
    accs = []
    avg_confs = []
    bin_counts = []
    ces = []
    for upper_bound in upper_bounds:
        lower_bound = upper_bound - bin_size
        acc, avg_conf, bin_count = compute_bin(
            lower_bound, upper_bound, max_confs, predictions, truths)
        accs.append(acc)
        avg_confs.append(avg_conf)
        bin_counts.append(bin_count)
        ces.append(abs(acc - avg_conf) * bin_count)
    ece = 100 * sum(ces) / total
    aurc, e_aurc = calc_aurc(confidences, truths)
    return ece, aurc * 1000, e_aurc * 1000