import torch
import torch.nn as nn

class LSTM(nn.Module):

    def __init__(self, num_outputs, input_size, hidden_size, num_layers, device='cuda'):
        super(LSTM, self).__init__()
        self.device = device
        self.num_outputs = num_outputs
        self.num_layers = num_layers
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                            num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_outputs)

    def forward(self, x):

        # Propagate input through LSTM
        if len(x.shape)==2:
            x = x.unsqueeze(-1)
        x, _ = self.lstm(x)
        out = self.fc(x[:,-1,:])
        return out

class Bidirectional_LSTM(nn.Module):

    def __init__(self, num_outputs, input_size, hidden_size, num_layers, device='cuda'):
        super(Bidirectional_LSTM, self).__init__()
        self.device = device
        self.num_outputs = num_outputs
        self.num_layers = num_layers
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                            num_layers=num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2*hidden_size, num_outputs)

    def forward(self, x):
        
        # Propagate input through LSTM
        if len(x.shape)==2:
            x = x.unsqueeze(-1)
        x, _ = self.lstm(x)
        out = self.fc(x[:,-1,:])
        return out
