from torch import nn
from Models.attention.blocks.encoder_layer import EncoderLayer, EncoderLayer_wo_v
from Models.attention.embedding.transformer_embedding import TransformerEmbedding


class Encoder(nn.Module):

    def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model,
                                        max_len=max_len,
                                        vocab_size=enc_voc_size,
                                        drop_prob=drop_prob,
                                        device=device)

        self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,
                                                  ffn_hidden=ffn_hidden,
                                                  n_head=n_head,
                                                  drop_prob=drop_prob)
                                     for _ in range(n_layers)])

    def forward(self, x):
        # x = self.emb(x)

        for layer in self.layers:
            x = layer(x)

        return x


class Brain(nn.Module):

    def __init__(self,d_model, ffn_hidden, n_head, drop_prob, device):
        super().__init__()
        self.layer = EncoderLayer_wo_v(d_model=d_model,
                                        ffn_hidden=ffn_hidden,
                                        n_head=n_head,
                                        drop_prob=drop_prob)

    def forward(self, x):
        a = self.layer(x)

        return a