import torch

class LSTM(torch.nn.Module):
    def __init__(self, emb_dim, out_size, hidden):
        super().__init__()
        self.lstm = torch.nn.LSTM(emb_dim, hidden, batch_first=True)
        self.out = torch.nn.Linear(hidden, out_size)

    def forward(self, inputs):
        feats, (hn, cn) = self.lstm(inputs)
        last_token_feat = feats[:, -1]
        out = self.out(last_token_feat)
        return out 