from typing import Union, Tuple, Sequence, Optional, Callable, List

import torch
import torch.distributions as td

from models.common.rnn import RNN
from utils.torch_utils import transform_tensor


class S2DRNN(torch.nn.Module):
    def __init__(self, layer_type: str, input_dimension: int, hidden_dimensions: Union[Sequence[int], int],
                 latent_dim: int = 10, latent_split : Optional[Tuple[int, ...]] = (5, 5),
                 n_recurrent_layers: Optional[int] = None, recurrent_activation: str = 'tanh',
                 recurrent_bias: bool = True, bidirectional: bool = False, projection_size: int = 0,
                 dilation: Optional[Sequence[int]] = None, dropout: float = 0.0, mode: str = 's2s',
                 distribution : Callable[..., td.distribution.Distribution] = td.Normal,
                 parameter_transforms : Optional[Union[List[Callable], Callable]] = None,
                 input_shape : str = 'btf'):

        """An RNN followed by a fully connected neural network, that maps sequences to a distribution.

        :param layer_type: The type of recurrent layer (RNN, LSTM, GRU).
        :param input_dimension: Number of dimensions of the feature space of the input data.
        :param hidden_dimensions: Number of dimensions of each hidden state for all recurrent layers.
        :param latent_dim:
        :param latent_split:
        :param n_recurrent_layers: Number of recurrent layers, if hidden_dimensions is the same for all layers (int).
        :param recurrent_activation: Activation function to use in each recurrent layer (has to be defined in torch.nn).
        :param recurrent_bias: Whether to use bias in the computations of the recurrent layers.
        :param bidirectional: Use bidirectional recurrent layers.
        :param projection_size: Parameter relevant for LSTM layers only. See pytorch docs for details.
        :param dilation:
        :param dropout:
        :param mode: Type of model (s2s, s2as, s2fh, s2mh)
        :param distribution:
        :param parameter_transforms:
        """  # TODO: docs

        super(S2DRNN, self).__init__()

        self.rnn = RNN(layer_type=layer_type, model=mode, input_dimension=input_dimension,
                       hidden_dimensions=hidden_dimensions, n_recurrent_layers=n_recurrent_layers,
                       recurrent_activation=recurrent_activation, recurrent_bias=recurrent_bias,
                       bidirectional=bidirectional, projection_size=projection_size, dilation=dilation, dropout=dropout)

        if isinstance(hidden_dimensions, int):
            out_hidden_size = 2 * hidden_dimensions if bidirectional else hidden_dimensions
        else:
            out_hidden_size = 2 * hidden_dimensions[-1] if bidirectional else hidden_dimensions[-1]

        self.linear = torch.nn.Linear(out_hidden_size, latent_dim)  # TODO: make parameter

        if latent_split is None:
            self.latent_split = []

        else:

            assert sum(latent_split) == latent_dim

            self.latent_split = [sum(latent_split[:i]) for i in range(1, len(latent_split))]

        if parameter_transforms is None:
            self.act = [torch.nn.Identity() for _ in latent_split]
        else:
            if isinstance(parameter_transforms, list):
                self.act = parameter_transforms
            else:
                self.act = [parameter_transforms for _ in latent_split]

        self.distribution = distribution

        self.input_shape = input_shape

    def distribution_parameters(self, network_output : torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """Converts the network output into parameters for the output distribution.

        :param network_output: The output from the RNN after the linear layer has been applied.
        :return: Tuple of parameters passed to the distribution.
        """

        split_output = network_output.tensor_split(self.latent_split, dim=-1)

        return tuple(act(parameter) for act, parameter in zip(self.act, split_output))

    def forward(self, inputs: torch.Tensor, transform_output : bool = True, **kwargs) -> td.distribution.Distribution:
        """Encodes input sequence in mean and standard deviation of a gaussian distribution.

        Kwargs are passed to the distribution.

        :param inputs: Tensor of inputs with dimensions batch_size (b), time (t), and features (f).
        :param transform_output: Whether to transform the output into the original shape when using a sequence-to-sequence model. Otherwise return a distribution with tensors of shape tbf.
        :return: Distribution
        """

        permuted_inputs = transform_tensor(inputs, self.input_shape, 'tbf')

        network_output = self.linear(self.rnn(permuted_inputs))

        if transform_output and self.rnn.model in ['s2s', 's2as']:
            network_output = transform_tensor(network_output, 'tbf', self.input_shape)

        return self.distribution(*self.distribution_parameters(network_output), **kwargs)


def make_positive(parameter : torch.Tensor, beta : float = 1, threshold : float = 20) -> torch.Tensor:
    return torch.nn.functional.softplus(parameter, beta, threshold)
