import torch

class Transformer(torch.nn.Module):
    def __init__(
        self,
        vocab_size             :int          = 111,
        layers                 :int          = 1,
        embedding_size         :int          = 128,
        output_size            :int          = 8,
        feedforward_size       :int          = 512,
        heads                  :int          = 2,
        activation             :str          = "gelu",
        dropout                :float        = 0.1,
        init                   :float | None = None,
        device                 :str          = "cuda:0"
    ):
        super().__init__()
        self.embeddings = torch.nn.Embedding(vocab_size, embedding_size, device=device)
        if init is not None: torch.nn.init.normal_(self.symbol_embeddings.weight, 0, init)
        if init is not None: torch.nn.init.normal_(self.context_embeddings.weight, 0, init)
        self.encoder      = torch.nn.TransformerEncoder(
            encoder_layer = torch.nn.TransformerEncoderLayer(
                d_model         = embedding_size   ,
                nhead           = heads            ,
                dim_feedforward = feedforward_size ,
                dropout         = dropout          ,
                activation      = activation       ,
                batch_first     = True             ,
                device          = device
            ),
            num_layers = layers,
            enable_nested_tensor = False,
        )
        self.last = torch.nn.Linear(embedding_size, output_size, device=device)

    def forward(self, src, tgt):
        embeddings = self.embeddings(src).squeeze(1)
        logits = self.last(self.encoder(embeddings).mean(1))
        return {"logits" : logits}


