import torch
import torch.nn as nn
from torch.nn import functional as F
from .param_inject import *
from copy import deepcopy


class TulipNet(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.original = deepcopy(net)
        self.prev_perturb_power = -1.0

    def resetting(self):
        self.net = deepcopy(self.original)

    def injecting(self, perturb_power=0.15):
        ls = inject_net(self.net, perturb_power=perturb_power)
        self.prev_perturb_power = perturb_power

    def change_perturb_power(self, perturb_power=0.15):
        if perturb_power == self.prev_perturb_power:
            return
        for name, l in self.named_modules():
            if isinstance(l, ParameterInjector):
                pnorm = l.get_param_norm()
                lnorm = pnorm * perturb_power
                l.set_norm(lnorm)

    def get_predictions(self, x, times=-1, dropout_iters=10, gammas=1):
        if times <= 0:
            times = dropout_iters

        if not isinstance(gammas, torch.Tensor):
            gammas = torch.Tensor([gammas for i in range(dropout_iters)])

        enable_perturb(self.net)
        logits = []

        for i in range(times):
            resample_perturb(self.net, sample_gamma=gammas[i])
            logits.append(self.net(x))

        logits = torch.stack(logits, dim=0)

        return logits

    def margin(self, logits):
        top2 = torch.topk(logits, 2, dim=-1)[0]
        margin = top2[:, :, 0] - top2[:, :, 1]
        return -margin.mean(0)

    def entropy(self, logits):
        probs = F.softmax(logits, dim=-1)
        model_prediction = probs.mean(0)
        entropy = -torch.sum(
            model_prediction * torch.log(model_prediction + 1e-8), dim=1
        )
        return entropy

    def energy(self, logits):

        energy = -torch.logsumexp(logits, dim=-1)
        energy = energy.mean(0)

        return energy

    def forward(self, x, return_feature=False, return_feature_list=False):
        try:
            return self.original(x, return_feature, return_feature_list)
        except TypeError:
            return self.original(x, return_feature)

    def eval_forward(self, x, K=1, delta=2, lambd=0.005):

        logits = self.get_predictions(x)

        # Pop state
        cache = get_states(self.net)

        # Compute <g,p>
        set_perturb_norm(
            self.net,
            noise_norm=None,
            noise_norm_ex=delta,
            noise_pattern="prop-deterministic",
        )
        logits_det = self.get_predictions(x, times=1)  # [1, bs, outdim]

        # Compute original logits
        set_perturb_norm(
            self.net, noise_norm=0, noise_pattern="prop-deterministic"
        )
        logits_original = self.get_predictions(x, times=1)  # [1, bs, outdim]

        # Push state
        set_states(self.net, cache)

        det_diff = logits_det - logits_original

        logits_diff_restored = (logits - logits_original)
        
        Ozz = ((logits_diff_restored) ** 2).mean(dim=0).sum(dim=-1)
        Oxz = (
            K
            * math.sqrt(logits.shape[-1])
            * torch.norm(det_diff, dim=-1).squeeze()
        )

        uq = self.get_uncertainty_from_ub(
            x,
            logits_original.squeeze(0),
            logits,
            logits_det.squeeze(0),
            Ozz,
            Oxz,
            lambd,
        )

        # logits, uncertainty
        return self.ref_l.mean(dim=0), uq

    def get_uncertainty_from_ub(self, x, l, lp, ld, Ozz, gpp, lambd):

        logit_differences = lp - l

        sqrt_Ozz = torch.sqrt(Ozz)
        sqrt_ub = torch.sqrt(
            torch.maximum(Ozz - lambd * gpp, torch.zeros_like(Ozz))
        )
        scale_gamma = sqrt_ub / sqrt_Ozz
        logit_differences = scale_gamma[None, :, None] * logit_differences

        new_lp = l + logit_differences
        self.ref_l = new_lp

        return self.entropy(self.ref_l)
