import torch as th


class LSTM(th.nn.Module):
    """
    The LSTM network module, consisting of an LSTM layer and an output layer
    """
    def __init__(self, _input_size, _hidden_size, _output_size,
                 _hidden_bias=False, _output_bias=False):
        """
        Initialization of the LSTM network
        """
        super(LSTM, self).__init__()

        self.lstm = th.nn.LSTMCell(_input_size, _hidden_size, bias=_hidden_bias)
        self.output_layer = th.nn.Linear(_hidden_size, _output_size,
                                         bias=_output_bias)

    def forward(self, _net_input, _lstm_state):
        """
        Forward pass of the LSTM network module, receiving an input to forward
        through the model and an appropriate _lstm_state = (lstm_h, lstm_c)
        """

        # Extract the actual lstm_h and lstm_c states from the state tuple
        _lstm_h, _lstm_c = _lstm_state

        # Forward the input through the LSTM layer
        _lstm_h, _lstm_c = self.lstm(_net_input, (_lstm_h, _lstm_c))

        # Forward the LSTM output through the subsequent linear output layer
        _output = self.output_layer(_lstm_h)

        return _output, [_lstm_h, _lstm_c]