from typing import Any
import torch
import torch.nn as nn
from .base_postprocessor import BasePostprocessor
from .weiper_kldiv.utils import (
    calculate_WeiPerKLDiv_score,
)


class WeiPerKLDivPostprocessor(BasePostprocessor):
    def __init__(self, config):
        super().__init__(config)
        self.args = self.config.postprocessor.postprocessor_args
        self.args_dict = self.config.postprocessor.postprocessor_sweep
        self.lambda_1 = self.args.lambda_1
        self.lambda_2 = self.args.lambda_2
        self.smoothing = self.args.smoothing
        self.smoothing_perturbed = self.args.smoothing_perturbed
        self.n_bins = self.args.n_bins
        self.perturbation_distance = self.args.perturbation_distance
        self.n_repeats = self.args.n_repeats
        self.n_samples_for_setup = self.args.n_samples_for_setup
        self.train_dataset_list = None

    def setup(
        self,
        net: nn.Module,
        id_loader_dict,
        ood_loader_dict,
        id_name="imagenet",
        valid_num=None,
        layer_names=None,
        aps=None,
        use_cache=False,
        hyperparameter_search=False,
        train_dl=None,
        **kwargs,
    ):
        net.eval()
        device = next(iter(net.parameters())).device

        if hasattr(self, "latents"):
            latents = self.latents
        else:
            if train_dl is None:
                train_dl = id_loader_dict["train"]
                if self.train_dataset_list is None:
                    iter_train = iter(train_dl)
                    self.train_dataset_list = (
                        [
                            next(iter_train)
                            for _ in range(
                                int(
                                    min(
                                        self.n_samples_for_setup,
                                        len(train_dl.dataset),
                                    )
                                    * len(train_dl)
                                    / len(train_dl.dataset)
                                )
                            )
                        ]
                        if self.n_samples_for_setup < len(train_dl.dataset)
                        else id_loader_dict["train"]
                    )
                with torch.no_grad():
                    latents = torch.cat(
                        [
                            net(entry["data"].to(device), return_feature=True)[1].cpu()
                            for entry in self.train_dataset_list
                        ],
                        dim=0,
                    )
            else:
                with torch.no_grad():
                    latents = torch.cat(
                        [
                            net(x.to(device), return_feature=True)[1].cpu()
                            for x in train_dl
                        ]
                    )
            self.latents = latents
        self.train_densities, self.W_tilde = calculate_WeiPerKLDiv_score(
            net,
            [
                latents,
                torch.zeros_like(self.latents[:10]),
                {"": torch.zeros_like(self.latents[:10])},
            ],
            lambda_1=self.lambda_1,
            lambda_2=self.lambda_2,
            n_bins=self.n_bins,
            n_repeats=self.n_repeats,
            smoothing=self.smoothing,
            smoothing_perturbed=self.smoothing_perturbed,
            perturbation_distance=self.perturbation_distance,
            device=device,
        )[3:]

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        net.eval()
        output, feature = net(data, return_feature=True)
        feature = feature.cpu()
        pred = torch.max(torch.softmax(output, dim=1), dim=1)[1]
        device = next(iter(net.parameters())).device
        conf = calculate_WeiPerKLDiv_score(
            net,
            [
                self.latents,
                feature,
                {"": torch.zeros_like(feature[:10])},
            ],
            lambda_1=self.lambda_1,
            lambda_2=self.lambda_2,
            n_bins=self.n_bins,
            n_repeats=self.n_repeats,
            smoothing=self.smoothing,
            smoothing_perturbed=self.smoothing_perturbed,
            perturbation_distance=self.perturbation_distance,
            train_densities=self.train_densities,
            device=device,
        )[0][0]
        return pred, conf

    def set_hyperparam(self, hyperparam: list):
        self.lambda_1 = hyperparam[0]
        self.lambda_2 = hyperparam[1]
        self.n_bins = hyperparam[2]
        self.perturbation_distance = hyperparam[3]
        self.smoothing = hyperparam[4]
        self.smoothing_perturbed = hyperparam[5]

    def get_hyperparam(self):
        return (
            self.lambda_1,
            self.lambda_2,
            self.n_bins,
            self.perturbation_distance,
            self.smoothing,
            self.smoothing_perturbed,
        )
