from typing import Literal

import torch


class AdaptiveLayerNorm(torch.nn.Module):

    def __init__(self, *, dim, dim_cond):
        super().__init__()
        self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False)
        self.norm_cond = torch.nn.LayerNorm(dim_cond)

        self.to_gamma = torch.nn.Sequential(
            torch.nn.Linear(dim_cond, dim), torch.nn.Sigmoid()
        )

        self.to_beta = torch.nn.Linear(dim_cond, dim, bias=False)

    def forward(self, x, cond, mask):

        normed = self.norm(x)
        normed_cond = self.norm_cond(cond)

        gamma = self.to_gamma(normed_cond)
        beta = self.to_beta(normed_cond)
        out = normed * gamma + beta
        return out * mask[..., None]


class AdaptiveOutputScale(torch.nn.Module):

    def __init__(self, *, dim, dim_cond, adaln_zero_bias_init_value=-2.0):
        super().__init__()

        adaln_zero_gamma_linear = torch.nn.Linear(dim_cond, dim)
        torch.nn.init.zeros_(adaln_zero_gamma_linear.weight)
        torch.nn.init.constant_(
            adaln_zero_gamma_linear.bias, adaln_zero_bias_init_value
        )

        self.to_adaln_zero_gamma = torch.nn.Sequential(
            adaln_zero_gamma_linear, torch.nn.Sigmoid()
        )

    def forward(self, x, cond, mask):

        gamma = self.to_adaln_zero_gamma(cond)
        return x * gamma * mask[..., None]


class AdaptiveLayerNormIdentical(torch.nn.Module):

    def __init__(
        self,
        *,
        dim: int,
        dim_cond: int,
        mode: Literal["seq", "pair"],
        use_ln_cond: bool = False,
    ):

        super().__init__()
        assert mode in [
            "single",
            "pair",
        ], f"Mode {mode} not valid for AdaptiveLayerNormIdentical"
        self.mode = mode
        self.use_ln_cond = use_ln_cond

        self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False)
        if use_ln_cond:
            self.norm_cond = torch.nn.LayerNorm(dim_cond)

        self.to_gamma = torch.nn.Sequential(
            torch.nn.Linear(dim_cond, dim), torch.nn.Sigmoid()
        )

        self.to_beta = torch.nn.Linear(dim_cond, dim, bias=False)

    def forward(self, x, cond, mask):

        assert (
            cond.dim() == 2
        ), f"Expected tensor cond with shape [b, dim_cond], got {cond.shape}"

        if self.mode == "single":
            assert (
                x.dim() == 3
            ), f"Expected tensor x with shape [b, n, dim] for `single` mode, got {x.shape}"
            assert (
                mask.dim() == 2
            ), f"Expected 2D tensor mask with shape [b, n] for `single` mode, got {mask.shape}"

        if self.mode == "pair":
            assert (
                x.dim() == 4
            ), f"Expected tensor x with shape [b, n, n, d] for `pair` mode, got {x.shape}"
            assert (
                mask.dim() == 3
            ), f"Expected tensor mask with shape [b, n, n] for `pair` mode, got {mask.shape}"

        normed = self.norm(x)
        if self.use_ln_cond:
            normed_cond = self.norm_cond(cond)
        else:
            normed_cond = cond

        gamma = self.to_gamma(normed_cond)
        beta = self.to_beta(normed_cond)

        if self.mode == "single":
            gamma_brc = gamma[..., None, :]
            beta_brc = beta[..., None, :]
        else:
            gamma_brc = gamma[..., None, None, :]
            beta_brc = beta[..., None, None, :]

        out = normed * gamma_brc + beta_brc
        return out * mask[..., None]
