from typing import Union, Tuple, Sequence

import torch

from models.common.rnn import RNN


class S2GaussianRNN(torch.nn.Module):
    def __init__(self, layer_type: str, input_dimension: int, hidden_dimensions: Union[Sequence[int], int],
                 latent_dim: int = 10, bidirectional: bool = False, mode: str = 's2s', logvar_out: bool = True):
        """An RNN followed by a fully connected neural network, that maps sequences to mean and standard deviation of gaussian distributions.

        """  # TODO: docs

        super(S2GaussianRNN, self).__init__()

        self.logvar = logvar_out

        self.rnn = RNN(layer_type, mode, input_dimension, hidden_dimensions, bidirectional=bidirectional)

        if isinstance(hidden_dimensions, int):
            out_hidden_size = 2 * hidden_dimensions if bidirectional else hidden_dimensions
        else:
            out_hidden_size = 2 * hidden_dimensions[-1] if bidirectional else hidden_dimensions[-1]

        self.linear = torch.nn.Linear(out_hidden_size, 2 * latent_dim)  # TODO: make parameter
        self.softplus = torch.nn.Softplus()

    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encodes input sequence in mean and standard deviation of a gaussian distribution.

        :param inputs: Tensor of inputs of shape (time, batch_size, feature_dimension)
        :type inputs: torch.Tensor
        :return: Mean and standard deviation of a gaussian.
        :rtype: Tuple[torch.Tensor, torch.Tensor]
        """

        rnn_out = self.rnn(inputs)

        mean, std_or_logvar = self.linear(rnn_out).tensor_split(2, dim=-1)

        if not self.logvar:
            std_or_logvar = self.softplus(std_or_logvar)

        return mean, std_or_logvar
