import torch.nn as nn
from src.modules.mp_attention import SelfAttention
from src.modules.mlp import MLP


class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder = nn.ModuleDict(dict(
            drop=nn.Dropout(config.dropout),
            layers=nn.ModuleList([EncoderBlock(config) for _ in range(config.n_encoder_layers)]),
            norm=nn.LayerNorm(config.d_model)
        ))

    def forward(self, x):
        x = self.encoder.drop(x)
        for layer in self.encoder.layers:
            x = layer(x)
        return self.encoder.norm(x)


class EncoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.norm1 = nn.LayerNorm(config.d_model)
        self.attention = SelfAttention(config)
        self.norm2 = nn.LayerNorm(config.d_model)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
