from typing import Union, Sequence, Optional, List, Tuple

import torch

from models.common.rnns import V2SRNN
from models.common import RNN
from models.common.vae import sample_normal


class Gaussian2SRNN(torch.nn.Module):  # TODO: docs
    
    def __init__(self, layer_type: str, mode: str, input_dimension: int, output_dim: int,
                 hidden_dimensions: Union[Sequence[int], int], linear_layers: Optional[List[int]] = None,
                 n_recurrent_layers: Optional[int] = None, recurrent_activation: str = 'tanh',
                 recurrent_bias: bool = True, projection_size: int = 0, dilation: Optional[Sequence[int]] = None,
                 dropout: float = 0.0, linear_activation: str = 'torch.nn.functional.relu',
                 final_activation: Optional[str] = None, bidirectional: bool = False):

        super(Gaussian2SRNN, self).__init__()

        if mode == 's2s':
            self.rnn = RNN(layer_type, 's2s', input_dimension, hidden_dimensions, n_recurrent_layers,
                           recurrent_activation, recurrent_bias, bidirectional, projection_size, dilation, dropout)
        elif mode == 'v2s':
            self.rnn = V2SRNN(layer_type, input_dimension, output_dim, hidden_dimensions, linear_layers,
                              n_recurrent_layers, recurrent_activation, recurrent_bias, projection_size, dilation,
                              dropout, linear_activation, final_activation)
        else:
            raise ValueError(f'Unsupported mode {mode}.')

    def forward(self, mean: torch.Tensor, variance: torch.Tensor, num_samples: int = 1,
                force_sample: bool = False) -> Tuple[torch.Tensor, ...]:

        if self.training or num_samples > 1 or force_sample:
            sample = sample_normal(mean, variance, log_var=self.log_var, num_samples=num_samples)
        else:
            sample = mean

        return self.rnn(sample)



