import torch

class LSTM(torch.nn.Module):
    def __init__(
        self,
        vocab_size     : int   = 111      ,
        layers         : int   = 1        ,
        embedding_size : int   = 128      ,
        output_size    : int   = 8        ,
        hidden_size    : int   = 512      ,
        dropout        : float = 0.1      ,
        bidirectional  : bool  = True     ,
        device         : str   = "cuda:0"
    ):
        super().__init__()
        self.embeddings = torch.nn.Embedding(vocab_size, embedding_size, device=device)
        self.encoder    = torch.nn.LSTM(
            input_size    = embedding_size ,
            hidden_size   = hidden_size    ,
            num_layers    = layers         ,
            bias          = True           ,
            batch_first   = True           ,
            dropout       = dropout        ,
            bidirectional = bidirectional  ,
            device        = device         ,
        )
        self.last = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), output_size, device=device)

    def forward(self, src, tgt):
        embeddings = self.embeddings(src).squeeze(1)
        logits = self.last(self.encoder(embeddings)[0].mean(1))
        return {"logits" : logits}


