import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMModel(nn.Module):
    def __init__(self, in_features, hidden_dim, out_features, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size=in_features, hidden_size=hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, out_features)

    def forward(self, data):
        x = data[0]  # node features [batch_size, N, in_features]
        mask = data[2]  # mask [batch_size, N]

        # Apply LSTM
        lstm_out, (h_n, c_n) = self.lstm(x)  # lstm_out: [batch_size, N, hidden_dim]

        # Apply mask
        lstm_out = lstm_out * mask.unsqueeze(-1)

        # Max pooling
        x = torch.max(lstm_out, dim=1)[0]

        # Apply dropout and fully connected layer
        x = self.dropout(x)
        x = self.fc(x)
        return x

