import torch.nn as nn
from linear_attention_transformer.linear_attention_transformer import (
    SelfAttention as LinearSelfAttention,
)


class NeoMLPAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        num_layers,
        shared_weights=False,
        use_linear_attention=False,
        use_ffn=True,
        ffn_dim=None,
        dropout=0.0,
        use_layer_norm=True,
        pre_layer_norm=True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.shared_weights = shared_weights
        self.use_linear_attention = use_linear_attention

        actual_num_layers = 1 if shared_weights else num_layers
        self.mha = nn.ModuleList(
            [
                (
                    LinearSelfAttention(self.embed_dim, self.num_heads)
                    if use_linear_attention
                    else nn.MultiheadAttention(
                        self.embed_dim,
                        self.num_heads,
                        dropout=dropout,
                        batch_first=True,
                    )
                )
                for _ in range(actual_num_layers)
            ]
        )
        ffn_dim = ffn_dim if ffn_dim is not None else self.embed_dim
        self.att_lin = nn.ModuleList(
            [
                (
                    nn.Sequential(
                        nn.Linear(self.embed_dim, ffn_dim),
                        nn.GELU(),
                        nn.Dropout(dropout),
                        nn.Linear(ffn_dim, self.embed_dim),
                        nn.Dropout(dropout),
                    )
                    if use_ffn
                    else nn.GELU()
                )
                for _ in range(actual_num_layers)
            ]
        )

        if use_layer_norm:
            self.ln1 = nn.ModuleList(
                [nn.LayerNorm(self.embed_dim) for _ in range(actual_num_layers)]
            )
            self.ln2 = nn.ModuleList(
                [nn.LayerNorm(self.embed_dim) for _ in range(actual_num_layers)]
            )
        else:
            self.ln1 = nn.ModuleList(
                [nn.Identity(self.embed_dim) for _ in range(actual_num_layers)]
            )
            self.ln2 = nn.ModuleList(
                [nn.Identity(self.embed_dim) for _ in range(actual_num_layers)]
            )

        if pre_layer_norm:
            self.attention_func = self._pre_ln_attn
            self.ffn_func = self._pre_ln_ffn
        else:
            self.attention_func = self._post_ln_ffn
            self.ffn_func = self._post_ln_ffn

    def _attn_fn(self, x, i, attn_mask=None):
        if self.use_linear_attention:
            return self.mha[i](x)
        else:
            return self.mha[i](x, x, x, attn_mask=attn_mask)[0]

    def _pre_ln_attn(self, x, idx):
        return x + self._attn_fn(self.ln1[idx](x), idx)

    def _post_ln_attn(self, x, idx):
        return self.ln1[idx](x + self._attn_fn(x, idx))

    def _pre_ln_ffn(self, x, idx):
        return x + self.att_lin[idx](self.ln2[idx](x))

    def _post_ln_ffn(self, x, idx):
        return self.ln2[idx](x + self.att_lin[idx](x))

    def forward(self, x):
        for i in range(self.num_layers):
            idx = 0 if self.shared_weights else i
            x = self.attention_func(x, idx)
            x = self.ffn_func(x, idx)
        return x
