import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMU(nn.Module):
    def __init__(self, input_size=1, hidden_layer_size=200, state_size=1):
        super().__init__()
        torch.manual_seed(0)
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first=True)

        self.linearErrorTrain1 = nn.Linear(hidden_layer_size, 20)
        self.linearErrorTrain2 = nn.Linear(20, 11)

        self.linearErrorTest1=nn.Linear(hidden_layer_size,20)
        self.linearErrorTest2 = nn.Linear(20, 11)

        self.linearStates = nn.Linear(hidden_layer_size, state_size)

    def forward(self, input_seq):
        if torch.cuda.is_available():
            self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size).cuda(),
                            torch.zeros(1, input_seq.shape[0], self.hidden_layer_size).cuda())
        else:
            self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size),
                                torch.zeros(1, input_seq.shape[0], self.hidden_layer_size))
        lstm_out, self.hidden_cell = self.lstm(input_seq, self.hidden_cell)

        #predictionsErrorTrain = self.linearErrorTrain1(lstm_out)
        #predictionsErrorTrain = self.linearErrorTrain2(F.relu(predictionsErrorTrain))

        predictionsErrorTest = self.linearErrorTest1(lstm_out)
        predictionsErrorTest = self.linearErrorTest2(F.relu(predictionsErrorTest))

        predictionsStates = self.linearStates(lstm_out)
        return predictionsErrorTest,lstm_out[:,-1,:]