from typing import Any

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from .base_postprocessor import BasePostprocessor
from torch.utils.data import DataLoader


class ReactMaskPostprocessor(BasePostprocessor):

    def __init__(self, config):
        super(ReactMaskPostprocessor, self).__init__(config)
        # self.args = self.config.postprocessor.postprocessor_args
        # self.percentile = self.args.percentile
        self.percentile = 90
        # self.args_dict = self.config.postprocessor.postprocessor_sweep

    def setup(self, net: nn.Module, id_loader: DataLoader):
        activation_log = []
        net.eval()
        with torch.no_grad():
            for data, _ in id_loader:

                data = data.cuda().float()

                batch_size = data.shape[0]

                _, features = net(data, return_feature_list=True)

                feature = features[-1]
                dim = feature.shape[1]
                activation_log.append(feature.data.cpu().numpy().reshape(
                    batch_size, dim, -1).mean(2))

        activation_log = np.concatenate(activation_log, axis=0)
        self.threshold = np.percentile(activation_log.flatten(),
                                       self.percentile)
        print('Threshold at percentile {:2d} over id data is: {}'.format(
            self.percentile, self.threshold))

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        output, feature = net.forward_threshold(data, self.threshold)
        return output, feature

    def set_hyperparam(self, hyperparam: list):
        self.percentile = hyperparam[0]

    def get_hyperparam(self):
        return self.percentile