import torch.nn as nn


class LSTM(nn.Module):
    def __init__(
        self,
        num_blocks,
        data_dim,
        model_dim,
        label_dim,
        dropout_rate,
    ):
        super(LSTM, self).__init__()
        embedding_dim = model_dim
        self.embedding = nn.Embedding(data_dim, embedding_dim)
        self.lstm = nn.LSTM(
            model_dim,
            model_dim,
            num_layers=num_blocks,
            batch_first=True,
            dropout=dropout_rate,
            bidirectional=False,
        )
        self.linear = nn.Linear(model_dim, label_dim)

    def mask_grads(self):
        pass

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        return self.linear(x)
