import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.utils import ModelOutput


class LatchOutput:
    noisy_repr = None
    true_repr = None
    attn_wts = None
    noise = None


class NoisyDenseLatchModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, attn_hidden_dim, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, repr) -> LatchOutput:
        z = self.model(repr)
        repr_ = repr + z
        output = LatchOutput()
        output.noise = z
        output.true_repr = repr
        output.noisy_repr = repr_
        output.attn_wts = None
        return output


class NoisyLatchModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, attn_hidden_dim, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.un_model_forget = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )
        self.un_model_retain = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )
        self.attn_wts = nn.Sequential(
            nn.Linear(input_dim, attn_hidden_dim),
            nn.ReLU(),
            nn.Linear(attn_hidden_dim, 2),
        )

    def forward(self, repr) -> LatchOutput:
        attn = torch.softmax(self.attn_wts(repr), dim=1)
        retain = self.un_model_retain(repr)
        forget = self.un_model_forget(repr)
        z = torch.einsum("bn,bni->bi", attn, torch.stack((retain, forget), dim=1))
        repr_ = repr + z
        output = LatchOutput()
        output.noise = z
        output.true_repr = repr
        output.noisy_repr = repr_
        output.attn_wts = attn
        return output


class NoisyUnlearnModel(nn.Module):
    def __init__(
        self,
        repr_model,
        out_model,
        input_dim,
        hidden_dim,
        attn_hidden_dim,
        loss_fn=None,
        use_attn_loss=False,
        latch=False,
    ):
        super(NoisyUnlearnModel, self).__init__()
        self.latch_model = NoisyLatchModel(input_dim, hidden_dim, attn_hidden_dim)
        self.repr_model = repr_model
        self.out_model = out_model
        self.__freeze_model(self.repr_model)
        self.__freeze_model(self.out_model)
        self.loss_fn = loss_fn
        self.use_attn_loss = use_attn_loss
        self.latch = latch

    def __freeze_model(self, model):
        for param in model.parameters():
            param.requires_grad = False

    def forward(self, x, y=None):
        output = ModelOutput()

        # get true representation
        repr = self.repr_model(x)

        # if latch is true : generate noisy representation
        if self.latch:
            latch_out = self.latch_model(repr)
            if self.loss_fn is not None and y is not None:
                y_new = self.out_model(latch_out.noisy_repr)
                y_orig = self.out_model(latch_out.true_repr)
                loss = self.loss_fn(y_new, y_orig, y)
                if self.use_attn_loss:
                    loss += F.cross_entropy(latch_out.attn_wts, y)
                output.loss = loss
            repr = latch_out.noisy_repr
        output.logits = self.out_model(repr)
        return output
