from typing import Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from .base_postprocessor import BasePostprocessor


class ASHPostprocessor(BasePostprocessor):
    def __init__(self, config):
        super(ASHPostprocessor, self).__init__(config)
        self.args = self.config.postprocessor.postprocessor_args
        self.percentile = self.args.percentile
        self.args_dict = self.config.postprocessor.postprocessor_sweep

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        output = net.forward_threshold(data, self.percentile)
        _, pred = torch.max(output, dim=1)
        energyconf = torch.logsumexp(output.data.cpu(), dim=1)
        return pred, energyconf
    
    @torch.no_grad()
    def extract_stats(self, net, save_pth, ood_data_loader):
        net.eval()
    
        ash_scores = []
        for batch in ood_data_loader:
            batch = batch['data'].cuda()
            _, score = self.postprocess(net, batch)
            ash_scores.extend(score.cpu().tolist())
        ash_scores = np.array(ash_scores)

        ash_file_pth = os.path.join(save_pth, 'ash.npy')
        np.save(ash_file_pth, ash_scores)

    def set_hyperparam(self, hyperparam: list):
        self.percentile = hyperparam[0]

    def get_hyperparam(self):
        return self.percentile
