from torch import nn
from slot_attention.model.transformer_blocks.cross_attention_block import CrossAttentionBlock

from slot_attention.model.transformer_blocks.self_attention_block import SelfAttentionBlock

# https://github.com/karpathy/nanoGPT/blob/master/model.py

class VitoEncoder(nn.Module):
    def __init__(self, params, dim, depth, n_heads, mlp_dim, qk_dim, layernorm_bias):
        super().__init__()
        
        self.encoding_layers = nn.ModuleList([])
        for _ in range(params.vito_encoding_layers):
            self.encoding_layers.append(
                SelfAttentionBlock(params, dim, n_heads, mlp_dim, qk_dim, layernorm_bias)
                )
        
        self.interleaving_layers = nn.ModuleList([])
        for _ in range(params.vito_interleaving_layers):
            self.interleaving_layers.append(nn.ModuleList([
                SelfAttentionBlock(params, dim, n_heads, mlp_dim, qk_dim, layernorm_bias),
                SelfAttentionBlock(params, dim, n_heads, mlp_dim, qk_dim, layernorm_bias),
                CrossAttentionBlock(params, dim, n_heads, mlp_dim, qk_dim, layernorm_bias),
                CrossAttentionBlock(params, dim, n_heads, mlp_dim, qk_dim, layernorm_bias),
            ]))
        
        self.last_layer = CrossAttentionBlock(params, dim, n_heads, mlp_dim, qk_dim, layernorm_bias)
        
    def forward(self, patches, slots):
        
        for sa in self.encoding_layers:
            patches = sa(patches)
        
        for sa1, sa2, ca1, ca2 in self.interleaving_layers:
            patches = sa1(patches)
            slots = sa2(slots)
            slots = ca1(slots, patches)
            patches = ca2(patches, slots)
            
        slots = self.last_layer(slots, patches)
        return slots