import torch.nn as nn

from slot_attention.model.transformer_blocks.layer_norm import LayerNorm
from slot_attention.model.transformer_blocks.mlp import MLP
from slot_attention.model.transformer_blocks.attention_module import AttentionModule

# https://sh-tsang.medium.com/review-pre-ln-transformer-on-layer-normalization-in-the-transformer-architecture-b6c91a89e9ab

class SelfAttentionBlock(nn.Module):

    def __init__(self, params, embed_dim, n_heads, hidden_dim, qk_dim, layernorm_bias=True):
        super().__init__()
        self.ln_1 = LayerNorm(embed_dim, layernorm_bias)
        self.attn = AttentionModule(params, embed_dim=embed_dim, n_heads=n_heads, qk_dim=qk_dim)
        self.ln_2 = LayerNorm(embed_dim, layernorm_bias)
        self.mlp = MLP(embed_dim, hidden_dim, dropout=0.0)

    def forward(self, x):
        input = self.ln_1(x)
        x = x + self.attn(q=input, k=input, v=input)
        x = x + self.mlp(self.ln_2(x))
        return x
    
