import torch
import torch.nn as nn

class TLSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TLSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=1, batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h0 = torch.zeros(1, batch_size, self.hidden_dim).float().to(x.device)
        c0 = torch.zeros(1, batch_size, self.hidden_dim).float().to(x.device)

        out, (h_n, c_n) = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

class BidirectionalTLSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BidirectionalTLSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # Multiply hidden_dim by 2 for bidirectional

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h0 = torch.zeros(2, batch_size, self.hidden_dim).float().to(x.device)  # Two directions
        c0 = torch.zeros(2, batch_size, self.hidden_dim).float().to(x.device)

        out, (h_n, c_n) = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out