import torch

class Transformer(torch.nn.Module):
    def __init__(
        self,
        vocab_size             :int          = 111,
        layers                 :int          = 1,
        symbol_embedding_size  :int          = 128,
        context_embedding_size :int          = 128,
        feedforward_size       :int          = 512,
        heads                  :int          = 2,
        activation             :str          = "gelu",
        dropout                :float        = 0.1,
        init                   :float | None = None,
        device                 :str          = "cuda:0"
    ):
        super().__init__()
        self.symbol_embeddings  = torch.nn.Embedding(vocab_size,  symbol_embedding_size, device=device)
        self.context_embeddings = torch.nn.Embedding(vocab_size, context_embedding_size, device=device)
        if init is not None: torch.nn.init.normal_(self.symbol_embeddings.weight, 0, init)
        if init is not None: torch.nn.init.normal_(self.context_embeddings.weight, 0, init)
        self.encoder      = torch.nn.TransformerEncoder(
            encoder_layer = torch.nn.TransformerEncoderLayer(
                d_model         = context_embedding_size,
                nhead           = heads            ,
                dim_feedforward = feedforward_size ,
                dropout         = dropout          ,
                activation      = activation       ,
                batch_first     = True             ,
                device          = device
            ),
            num_layers = layers,
            enable_nested_tensor = False,
        )
        self.last = torch.nn.Linear(context_embedding_size, symbol_embedding_size, device=device)

    def forward(self, symbol, context):
        symbol_emb = self.symbol_embeddings(symbol).squeeze(1)
        context_emb = self.last(self.encoder(self.context_embeddings(context)).mean(1))
        return {
             "symbol_emb" :  symbol_emb , 
            "context_emb" : context_emb ,
        }


