import torch

class LSTM(torch.nn.Module):
    def __init__(
        self,
        vocab_size             :int   = 111      ,
        layers                 :int   = 1        ,
        symbol_embedding_size  :int   = 128      ,
        context_embedding_size :int   = 128      ,
        hidden_size            :int   = 512      ,
        bidirectional          :bool  = True     ,
        dropout                :float = 0.1      ,
        device                 :str   = "cuda:0"
    ):
        super().__init__()
        self.symbol_embeddings  = torch.nn.Embedding(vocab_size,  symbol_embedding_size, device=device)
        self.context_embeddings = torch.nn.Embedding(vocab_size, context_embedding_size, device=device)
        self.encoder    = torch.nn.LSTM(
            input_size    = context_embedding_size ,
            hidden_size   = hidden_size    ,
            num_layers    = layers         ,
            bias          = True           ,
            batch_first   = True           ,
            dropout       = dropout        ,
            bidirectional = bidirectional  ,
            device        = device         ,
        )
        self.last = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), symbol_embedding_size, device=device)

    def forward(self, symbol, context):
        symbol_emb = self.symbol_embeddings(symbol).squeeze(1)
        context_emb = self.last(self.encoder(self.context_embeddings(context))[0].mean(1))
        return {
             "symbol_emb" :  symbol_emb , 
            "context_emb" : context_emb ,
        }


