from typing import Any

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class BasePostprocessor():
    

    def __init__(self, config):
        
        self.config = config 
    
    def setup(self, net: nn.Module, id_loader: DataLoader):
        pass
    
    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):

        pred = net(data)
        score = torch.softmax(pred, dim=1)
        conf, pred = score.max(1)
        return pred, conf
    
    def inference(self, net: nn.Module, dataloader: DataLoader):
        pass