from typing import Tuple, Union, Any
import torch


class VAE(torch.nn.Module):
    """General VAE framework.
    """

    def __init__(self, encoder: torch.nn.Module, decoder: torch.nn.Module):
        super(VAE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def sample(self, distribution: torch.distributions.distribution.Distribution, n_samples : int = 1,
               force_sample : bool = True, **kwargs) -> torch.Tensor:
        """Sample from the latent distribution and compute the input for the decoder.

        :param distribution: Distribution to sample from.
        :param n_samples: Number of samples to draw from the latent distribution.
        :param force_sample: Whether to force sampling instead of returning the mean.
        """

        if self.training or n_samples > 1 or force_sample:
            return distribution.rsample((n_samples, ))
        else:
            return distribution.mean

    def forward(self, x: Tuple[torch.Tensor], return_latent_sample: bool = False, n_samples: int = 1,
                force_sample: bool = False, **kwargs) -> Union[Tuple[torch.distributions.distribution.Distribution,
                                                                     Any],
                                                               Tuple[torch.distributions.distribution.Distribution,
                                                                     Any, torch.Tensor]]:
        """

        Some methods use a deterministic decoder and thus might want to return a tensor instead of a distribution.
        Thus the second element in the returned tuple is Any.

        :param x: Input tensor.
        :param return_latent_sample: Whether to return a sample from the latent distribution.
        :param n_samples: Number of samples to sample from the latent distribution.
        :param force_sample: Set to always sample from the latent distribution and not use the mean when not training.
        :return: Latent distribution, decoder distribution, and, if return_latent_sample is set, a sample in the latent space.
        """

        latent_distribution = self.encoder(x[0])

        z_sample = self.sample(latent_distribution, n_samples, force_sample, **kwargs)

        decoder_distribution = self.decoder(z_sample)

        if return_latent_sample:
            return latent_distribution, decoder_distribution, z_sample

        return latent_distribution, decoder_distribution




