from typing import Any

import torch
import torch.nn as nn

from .base_postprocessor import BasePostprocessor


class TulipPostprocessor(BasePostprocessor):
    def __init__(self, config):
        super().__init__(config)
        self.args = self.config.postprocessor.postprocessor_args
        self.args_dict = self.config.postprocessor.postprocessor_sweep

        self.K = self.args.K
        self.lambd = self.args.lambd
        self.delta = self.args.delta
        self.perturb_power = self.args.perturb_power

        self.injected = False
        self.aps_count = 0

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        # APS mode ON
        if self.APS_mode:
            if not self.hyperparam_search_done:
                if self.aps_count == 0:
                    net.resetting()
                    net.injecting(self.perturb_power)
                self.aps_count = (self.aps_count + 1) % 5
            else:
                if not self.injected:
                    net.resetting()
                    net.injecting(self.perturb_power)
                    self.injected = True

        # APS mode OFF
        else:
            if not self.injected:
                net.injecting(self.perturb_power)
                self.injected = True

        logits, uncertainty = net.eval_forward(
            data, self.K, self.delta, self.lambd
        )
        _, pred = torch.max(logits, dim=1)
        return pred, -uncertainty

    def set_hyperparam(self, hyperparam: list):
        self.K = hyperparam[0]
        self.lambd = hyperparam[1]
        self.delta = hyperparam[2]
        self.perturb_power = hyperparam[3]

    def get_hyperparam(self):
        return [self.K, self.lambd, self.delta, self.perturb_power]
