import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def accuracy(output, target):
    num_correct = np.sum(output == target)
    res = num_correct / len(target)

    return res

def accuracy_top_k(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def get_acc_auroc_curves(logdir):
    """
    :param logdir: Path to logs: E.g '/work/sagar/open_set_recognition/methods/ARPL/log/(12.03.2021_|_32.570)/'
    :return:
    """
    event_acc = EventAccumulator(logdir)
    event_acc.Reload()

    # Only gets scalars
    log_info = {}
    for tag in event_acc.Tags()['scalars']:

        log_info[tag] = np.array([[x.step, x.value] for x in event_acc.scalars._buckets[tag].items])

    return log_info

class ClassAccEval():
    def __init__(self, logger, dataset='cifar10'):
        super(ClassAccEval, self).__init__()
        self.dataset = dataset
        self.class_num = None
        self.logger = logger

    def reset(self):
        self.output_prob = []
        self.pids = []

    def update(self, output):  # called once for each batch
        prob, pid = output
        self.output_prob.append(prob)
        self.pids.extend(np.asarray(pid))

    def Entropy(self, input_):
        bs = input_.size(0)
        epsilon = 1e-5
        entropy = -input_ * torch.log(input_ + epsilon)
        entropy = torch.sum(entropy, dim=1)
        return entropy

    def compute(self):  # called after each epoch
        self.class_num = len(set(self.pids))
        output_prob = torch.cat(self.output_prob, dim=0)
#         if self.feat_norm:
#             print("The test feature is normalized")
#             feats = torch.nn.functional.normalize(feats, dim=1, p=2)  # along channel

        _, predict = torch.max(output_prob, 1)

        labels = torch.tensor(self.pids)
        if self.dataset == 'VisDA':
            output_prob = nn.Softmax(dim=1)(output_prob)
            _ent = self.Entropy(output_prob)
            mean_ent = 0
            for ci in range(self.class_num ):
                mean_ent += _ent[predict==ci].mean()
            mean_ent /= self.class_num

            matrix = confusion_matrix(labels, torch.squeeze(predict).float().cpu())
            acc = matrix.diagonal()/matrix.sum(axis=1) * 100
            # aacc is the mean value of all the accuracy of the each claccse
            aacc = acc.mean() / 100
            aa = [str(np.round(i, 2)) for i in acc]
            self.logger.info('Per-class accuracy is :')
            acc = ' '.join(aa)
            self.logger.info(acc)
            return aacc, mean_ent

        else:
            # import pdb; pdb.set_trace()
            accuracy = torch.sum((torch.squeeze(predict).float().cpu() == labels)).item() / float(labels.size()[0])
            output_prob = nn.Softmax(dim=1)(output_prob)
            mean_ent = torch.mean(self.Entropy(output_prob))
            acc = ''
            self.logger.info('normal accuracy {} {} {}'.format(accuracy, mean_ent, acc))
            return accuracy, mean_ent
        

class ClassificationPredSaver(object):

    def __init__(self, length, save_path=None):

        if save_path is not None:

            # Remove filetype from save_path
            save_path = save_path.split('.')[0]
            self.save_path = save_path

        self.length = length

        self.all_preds = None
        self.all_labels = None

        self.running_start_idx = 0

    def update(self, preds, labels=None):

        # Expect preds in shape B x C

        if torch.is_tensor(preds):
            preds = preds.detach().cpu().numpy()

        b, c = preds.shape

        if self.all_preds is None:
            self.all_preds = np.zeros((self.length, c))

        self.all_preds[self.running_start_idx: self.running_start_idx + b] = preds

        if labels is not None:
            if torch.is_tensor(labels):
                labels = labels.detach().cpu().numpy()

            if self.all_labels is None:
                self.all_labels = np.zeros((self.length,))

            self.all_labels[self.running_start_idx: self.running_start_idx + b] = labels

        # Maintain running index on dataset being evaluated
        self.running_start_idx += b

    def save(self):

        # Softmax over preds
        preds = torch.from_numpy(self.all_preds)
        preds = torch.nn.Softmax(dim=-1)(preds)
        self.all_preds = preds.numpy()

        pred_path = self.save_path + '.pth'
        print(f'Saving all predictions to {pred_path}')

        torch.save(self.all_preds, pred_path)

        if self.all_labels is not None:

            # Evaluate
            self.evaluate()
            torch.save(self.all_labels, self.save_path + '_labels.pth')

    def evaluate(self):

        topk = [1, 5, 10]
        topk = [k for k in topk if k < self.all_preds.shape[-1]]
        acc = accuracy(torch.from_numpy(self.all_preds), torch.from_numpy(self.all_labels), topk=topk)

        for k, a in zip(topk, acc):
            print(f'Top{k} Acc: {a.item()}')

class IndicatePlateau(object):

    def __init__(self, threshold=5e-4, patience_epochs=5, mode='min', threshold_mode='rel'):

        self.patience = patience_epochs
        self.cooldown_counter = 0
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.last_epoch = 0
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)

        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset()

    def _reset(self):
        """Resets num_bad_epochs counter and cooldown counter."""
        self.best = self.mode_worse
        self.cooldown_counter = 0
        self.num_bad_epochs = 0

    def step(self, metrics, epoch=None):
        # convert `metrics` to float, in case it's a zero-dim Tensor
        current = float(metrics)
        self.last_epoch = epoch

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs > self.patience:
            print('Tracked metric has plateaud')
            self._reset()
            return True
        else:
            return False

    def is_better(self, a, best):

        if self.mode == 'min' and self.threshold_mode == 'rel':
            rel_epsilon = 1. - self.threshold
            return a < best * rel_epsilon

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return a < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            rel_epsilon = self.threshold + 1.
            return a > best * rel_epsilon

        else:  # mode == 'max' and epsilon_mode == 'abs':
            return a > best + self.threshold

    def _init_is_better(self, mode, threshold, threshold_mode):

        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if threshold_mode not in {'rel', 'abs'}:
            raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

        if mode == 'min':
            self.mode_worse = float('inf')
        else:  # mode == 'max':
            self.mode_worse = -float('inf')

        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode


if __name__ == '__main__':

    x = IndicatePlateau(threshold=0.0899)
    eps = np.arange(0, 2000, 1)
    y = np.exp(-0.01 * eps)

    print(y)
    for i, y_ in enumerate(y):

        z = x.step(y_)
        if z:
            print(f'Plateaud at epoch {i} with val {y_}')