import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMDyn(nn.Module):
    def __init__(self, input_size=2, hidden_layer_size=200, state_size=1):
        super().__init__()
        torch.manual_seed(0)
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first=True)

        self.linearA = nn.Linear(hidden_layer_size, 121)
        self.linearB=nn.Linear(hidden_layer_size,44)

        self.linearStates = nn.Linear(hidden_layer_size, state_size)

    def forward(self, input_seq):
        if torch.cuda.is_available():
            self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size).cuda(),
                            torch.zeros(1, input_seq.shape[0], self.hidden_layer_size).cuda())
        else:
            self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size),
                                torch.zeros(1, input_seq.shape[0], self.hidden_layer_size))

        lstm_out, self.hidden_cell = self.lstm(input_seq, self.hidden_cell)
        predictionA = self.linearA(lstm_out[:,-1,:])
        predictionB=self.linearB(lstm_out[:,-1,:])
        predictionStates = self.linearStates(F.relu(lstm_out))

        return predictionA,predictionB,predictionStates,lstm_out[:,-1,:]