from typing import Union, Tuple, Sequence, Optional

import torch

from models.common.rnn import RNN


class S2GaussianRNN(torch.nn.Module):
    def __init__(self, layer_type: str, input_dimension: int, hidden_dimensions: Union[Sequence[int], int],
                 latent_dim: int = 10, n_recurrent_layers: Optional[int] = None, recurrent_activation: str = 'tanh',
                 recurrent_bias: bool = True, bidirectional: bool = False, projection_size: int = 0,
                 dilation: Optional[Sequence[int]] = None, dropout: float = 0.0, mode: str = 's2s',
                 logvar_out: bool = True):
        """An RNN followed by a fully connected neural network, that maps sequences to mean and standard deviation of gaussian distributions.

        :param layer_type: The type of recurrent layer (RNN, LSTM, GRU).
        :type layer_type: str
        :param input_dimension: Number of dimensions of the feature space of the input data.
        :type input_dimension: int
        :param hidden_dimensions: Number of dimensions of each hidden state for all recurrent layers.
        :type hidden_dimensions: Union[List[int], int]
        :param latent_dim:
        :type latent_dim: int
        :param n_recurrent_layers: Number of recurrent layers, if hidden_dimensions is the same for all layers (int).
        :type n_recurrent_layers: Optional[int]
        :param recurrent_activation: Activation function to use in each recurrent layer (has to be defined in torch.nn).
        :type recurrent_activation: str
        :param recurrent_bias: Whether to use bias in the computations of the recurrent layers.
        :type recurrent_bias: bool
        :param bidirectional: Use bidirectional recurrent layers.
        :type bidirectional: bool
        :param projection_size: Parameter relevant for LSTM layers only. See pytorch docs for details.
        :type projection_size: int
        :param dilation:
        :type dilation:
        :param dropout:
        :type dropout:
        :param mode: Type of model (s2s, s2as, s2fh, s2mh)
        :type mode: str
        :param logvar_out:
        :type logvar_out
        """  # TODO: docs

        super(S2GaussianRNN, self).__init__()

        self.logvar = logvar_out

        self.rnn = RNN(layer_type=layer_type, model=mode, input_dimension=input_dimension,
                       hidden_dimensions=hidden_dimensions, n_recurrent_layers=n_recurrent_layers,
                       recurrent_activation=recurrent_activation, recurrent_bias=recurrent_bias,
                       bidirectional=bidirectional, projection_size=projection_size, dilation=dilation, dropout=dropout)

        if isinstance(hidden_dimensions, int):
            out_hidden_size = 2 * hidden_dimensions if bidirectional else hidden_dimensions
        else:
            out_hidden_size = 2 * hidden_dimensions[-1] if bidirectional else hidden_dimensions[-1]

        self.linear = torch.nn.Linear(out_hidden_size, 2 * latent_dim)  # TODO: make parameter
        self.softplus = torch.nn.Softplus()

    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encodes input sequence in mean and standard deviation of a gaussian distribution.

        :param inputs: Tensor of inputs of shape (time, batch_size, feature_dimension)
        :type inputs: torch.Tensor
        :return: Mean and standard deviation of a gaussian.
        :rtype: Tuple[torch.Tensor, torch.Tensor]
        """

        rnn_out = self.rnn(inputs)

        mean, std_or_logvar = self.linear(rnn_out).tensor_split(2, dim=-1)

        if not self.logvar:
            std_or_logvar = self.softplus(std_or_logvar)

        return mean, std_or_logvar
