from typing import List, Union, Optional, Sequence

import torch

from models.common.rnn import RNN
from utils.utils import str2cls
from utils.torch_utils import transform_tensor


class V2SRNN(torch.nn.Module):

    def __init__(self, layer_type: str, output_dim: int,
                 hidden_dimensions: Union[Sequence[int], int], linear_layers: Optional[List[int]] = None,
                 n_recurrent_layers: Optional[int] = None, recurrent_activation: str = 'tanh',
                 recurrent_bias: bool = True, projection_size: int = 0, dilation: Optional[Sequence[int]] = None,
                 dropout: float = 0.0, linear_activation: str = 'torch.nn.functional.relu',
                 final_activation: Optional[str] = None, output_shape : str = 'tbf'):

        super(V2SRNN, self).__init__()
        """Vector to sequence RNN.


        """  # TODO: docs

        self.output_shape = output_shape

        self.rnn = RNN(layer_type, 's2s', output_dim, hidden_dimensions, n_recurrent_layers, recurrent_activation,
                       recurrent_bias, False, projection_size, dilation, dropout)

        linear_input_dim = hidden_dimensions[-1] if isinstance(hidden_dimensions, list) else hidden_dimensions

        if linear_layers is None:
            self.fcs = torch.nn.ModuleList([torch.nn.Linear(linear_input_dim, output_dim)])
        else:
            self.fcs = torch.nn.ModuleList([torch.nn.Linear(in_dim, out_dim) for in_dim, out_dim in
                                            zip([linear_input_dim] + linear_layers, linear_layers + [output_dim])])

        self.fc_act    = str2cls(linear_activation)
        self.final_act = torch.nn.Identity() if final_activation is None else str2cls(final_activation)

    def _linear(self, inputs: torch.Tensor) -> torch.Tensor:
        """Apply linear layers and activations to inputs.

        :param inputs: Tensor of inputs.
        :return: Transformed input tensor.
        """

        outputs = inputs

        for fc in self.fcs[:-1]:
            outputs = self.fc_act(fc(outputs))

        return self.final_act(self.fcs[-1](outputs))

    def forward(self, initial_hidden: torch.Tensor, seq_len: int, target: torch.Tensor = None,
                reverse: bool = False) -> torch.Tensor:
        """Generate a sequence from an initial hidden state.

        If provided uses reference target during training only.

        :param initial_hidden: Initial hidden input to the RNN of shape (batch_size, feature_dimension)
        :param seq_len: Length of the sequence to generate.
        :param target: Target sequence to be reconstructed of shape (time, batch_size, feature_dimension), used as true inputs during training.
        :param reverse: If set, generates the sequence in reverse order (target is automatically inverted).
        :return: Generated sequence of shape (time, batch_size, feature_dimension).
        """

        hidden = torch.unsqueeze(initial_hidden, dim=0)
        output = [self._linear(hidden)]
        hidden = (hidden, torch.zeros_like(hidden))

        if self.training and target is not None:  # Generate sequence using reference target

            if reverse:
                inputs = torch.flip(target[1:], dims=(0,))
            else:
                inputs = target[:-1]

            result = self.rnn(inputs, hidden_states=hidden)
            result = self._linear(result)
            result = torch.cat(output + [result], dim=0)

        else:  # Generate sequence from scratch

            inputs = output

            for t in range(1, seq_len):

                rnn_input   = inputs[t - 1]
                out, hidden = self.rnn(rnn_input, hidden_states=hidden, return_hidden=True)

                output.append(self._linear(out))

            result = torch.cat(output, dim=0)

        if reverse:
            result = torch.flip(result, dims=(0,))

        return transform_tensor(result, 'tbf', self.output_shape)
