

from torch import nn

from slot_attention.model.transformer_blocks.self_attention_block import SelfAttentionBlock


class TransformerEncoder(nn.Module):
    def __init__(self, params, dim, depth, n_heads, mlp_dim, qk_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(SelfAttentionBlock(params, dim, n_heads, mlp_dim, qk_dim))
    def forward(self, x):
        for b in self.layers:
            x = b(x)
        return x