from typing import Any
import numpy as np
import torch
import torch.nn as nn
from edl import EDL
from edl.model import BeliefOptions
from torch.utils.data import DataLoader
from tqdm import tqdm
from .base_postprocessor import BasePostprocessor
from ..utils import comm

class EDLPostprocessor(BasePostprocessor):
    def __init__(self, config):
        super().__init__(config)

        self.args = self.config.postprocessor.postprocessor_args

        num_classes = config.dataset.num_classes
        self.uncertainty_fn = self.args.uncertainty_fn
        self.use_react = self.args.get('react', False)
        self.percentile = self.args.percentile
        self.args_dict = self.args.postprocessor_sweep
        self.setup_flag = True
        try:
            loss_fn = self.config.trainer.trainer_args.loss_fn
            evi_fn = self.config.trainer.trainer_args.evi_fn
            prior_fn = self.config.trainer.trainer_args.prior_fn
        except KeyError as e:
            loss_fn = None
            evi_fn = self.args.evi_fn
            prior_fn = self.args.prior_fn

        self.edl_utils = EDL(K=num_classes, evi_fn=evi_fn, prior_fn=prior_fn, loss_fn=loss_fn)

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        if self.use_react:
            logits = net.forward_threshold(data, self.threshold)
        else:
            if 'postnet' in self.config.network.name:
                logits = net(data, return_output='logits')
            elif self.config.network.name == 'natpn':
                assert self.edl_utils.evi_fn == 'exp'
                y_pred, _ = net(data)
                epsilon = 1e-6
                diff = y_pred.alpha - self.edl_utils.get_priors().to(y_pred.alpha)
                logits = (diff + epsilon).log()
            else:
                logits = net(data)
        _, pred = torch.max(logits, dim=1)
        self.edl_utils.set_belief_fn(BeliefOptions.max_probability)
        conf = self.edl_utils.get_belief(logits)
        alpha = self.edl_utils.logits_to_alpha(logits)
        uncertainty_dict = dict()
        for fn in self.uncertainty_fn:
            self.edl_utils.set_uncertainty_fn(fn)
            uncertainty = self.edl_utils.get_uncertainty(logits=logits)
            uncertainty_dict[fn] = uncertainty

        return pred, conf, alpha, uncertainty_dict, logits

    def inference(self,
                  net: nn.Module,
                  data_loader: DataLoader,
                  progress: bool = True):
        # pred, conf, alpha, uncertainty_dict, logits
        pred_list, conf_list, alpha_list, label_list, logits_list = [], [], [], [], []
        uncertainty_list = dict()
        for fn in self.uncertainty_fn:
            uncertainty_list[fn] = []

        for batch in tqdm(data_loader,
                          disable=not progress or not comm.is_main_process()):
            data = batch['data'].cuda()
            label = batch['label'].cuda()
            one_hot_label = None
            pred, conf, alpha, uncertainty_dict, logits = self.postprocess(net, data)

            pred_list.append(pred.cpu())
            conf_list.append(conf.cpu())
            label_list.append(label.cpu())
            logits_list.append(logits.cpu())
            alpha_list.append(alpha.cpu())
            for fn in uncertainty_dict:
                uncertainty_list[fn].append(uncertainty_dict[fn].cpu())

        # convert values into numpy array
        pred_list = torch.cat(pred_list).numpy().astype(int)
        conf_list = torch.cat(conf_list).numpy()
        label_list = torch.cat(label_list)
        logits_list = torch.cat(logits_list)
        alpha_list = torch.cat(alpha_list).numpy()

        for fn in uncertainty_list:
            if len(uncertainty_list[fn])>0:
                uncertainty_list[fn] = torch.cat(uncertainty_list[fn]).numpy()

        return pred_list, conf_list, alpha_list, label_list.numpy().astype(int), uncertainty_list, logits_list.numpy()


    def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
        if not self.setup_flag:
            activation_log = []
            net.eval()
            with torch.no_grad():
                for batch in tqdm(id_loader_dict['val'],
                                  desc='Setup: ',
                                  position=0,
                                  leave=True):
                    data = batch['data'].cuda()
                    data = data.float()

                    _, feature = net(data, return_feature=True)
                    activation_log.append(feature.data.cpu().numpy())

            self.activation_log = np.concatenate(activation_log, axis=0)
            self.setup_flag = True
        else:
            pass

    def set_hyperparam(self, hyperparam: list):
        self.percentile = hyperparam[0]
        self.threshold = np.percentile(self.activation_log.flatten(),
                                       self.percentile)
        print('Threshold at percentile {:2d} over id data is: {}'.format(
            self.percentile, self.threshold))

    def get_hyperparam(self):
        return self.percentile