import torch

from proteinfoundation.nn.modules.adaptive_ln_scale import (
    AdaptiveLayerNorm,
    AdaptiveOutputScale,
)
from proteinfoundation.nn.modules.swiglu import SwiGLU


class Transition(torch.nn.Module):

    def __init__(self, dim, expansion_factor=4, layer_norm=False):
        super().__init__()

        dim_inner = int(dim * expansion_factor)

        self.use_layer_norm = layer_norm
        if self.use_layer_norm:
            self.ln = torch.nn.LayerNorm(dim)

        self.swish_linear = torch.nn.Sequential(
            torch.nn.Linear(dim, dim_inner * 2, bias=False),
            SwiGLU(),
        )
        self.linear_out = torch.nn.Linear(dim_inner, dim, bias=False)

    def forward(self, x, mask):

        if self.use_layer_norm:
            x = self.ln(x)
        x = self.linear_out(self.swish_linear(x))
        return x * mask[..., None]


class TransitionADALN(torch.nn.Module):

    def __init__(self, *, dim, dim_cond, expansion_factor=4):
        super().__init__()
        self.adaln = AdaptiveLayerNorm(dim=dim, dim_cond=dim_cond)
        self.transition = Transition(
            dim=dim, expansion_factor=expansion_factor, layer_norm=False
        )
        self.scale_output = AdaptiveOutputScale(dim=dim, dim_cond=dim_cond)

    def forward(self, x, cond, mask):

        x = self.adaln(x, cond, mask)
        x = self.transition(x, mask)
        x = self.scale_output(x, cond, mask)
        return x * mask[..., None]
