from slot_attention.model.transformer_blocks.layer_norm import LayerNorm
from slot_attention.model.transformer_blocks.mlp import MLP
from slot_attention.model.transformer_blocks.attention_module import AttentionModule


import torch.nn as nn


class CrossAttentionBlock(nn.Module):

    def __init__(self, params, embed_dim, n_heads, hidden_dim, qk_dim, layernorm_bias):
        super().__init__()
        self.ln_1 = LayerNorm(embed_dim, bias=layernorm_bias)
        self.attn = AttentionModule(params, embed_dim=embed_dim, n_heads=n_heads, qk_dim=qk_dim)
        self.ln_2 = LayerNorm(embed_dim, bias=layernorm_bias)
        self.mlp = MLP(embed_dim, hidden_dim, dropout=0.0, bias=True)

    def forward(self, x, kv):
        x = x + self.attn(q=self.ln_1(x), k=kv, v=kv)
        x = x + self.mlp(self.ln_2(x))
        return x