import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence
class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=1, dropout=0.5):
        super(LSTMClassifier, self).__init__()

        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # Input x is of shape (batch_size, sequence_length, input_dim)

        # LSTM layer
        lstm_out, _ = self.lstm(x)  # lstm_out: (batch_size, sequence_length, hidden_dim)

        # Classifier
        output = self.classifier(lstm_out)  # output: (batch_size, sequence_length, num_classes)

        return output


class LSTMForecaster(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1, dropout=0.5):
        super(LSTMForecaster, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.forecaster = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        # Input x is padded data
        # LSTM layer
        lstm_out, (hidden, _) = self.lstm(x)
        # Forecaster layer
        output = self.forecaster(lstm_out)  # output: (batch_size, sequence_length, input_dim)
        return output, hidden