import torch
import torch.nn as nn

__all__ = ['VAE']


class VAE(nn.Module):
    """Vanilla VAE with standard normal prior.

    :param encoder: the encoder network.
    :param decoder: the decoder network.
    :param latent_dim: latent space dimensionality.
    """
    def __init__(self, encoder, decoder, latent_dim):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = latent_dim

    def get_latent_prior(self, x):
        # Standard normal prior.
        pf_mu = torch.zeros(self.latent_dim, x.shape[0])
        pf_cov = torch.ones(self.latent_dim, x.shape[0]).diag_embed()

        return pf_mu, pf_cov

    def get_latent_dists(self, x, y, mask=None):
        # Posterior.
        if mask is not None:
            qf_mu, qf_sigma = self.encoder(y, mask)
        else:
            qf_mu, qf_sigma = self.encoder(y)

        # Reshape.
        qf_mu = qf_mu.transpose(0, 1)
        qf_sigma = qf_sigma.transpose(0, 1)
        qf_cov = qf_sigma.pow(2).diag_embed()

        # Prior.
        pf_mu, pf_cov = self.get_latent_prior(x)

        return qf_mu, qf_cov, pf_mu, pf_cov

    def sample_latent_posterior(self, x, y=None, mask=None, num_samples=1,
                                **kwargs):
        if y is not None:
            # Latent posterior distribution.
            qf_mu, qf_cov = self.get_latent_dists(x, y, mask)[:2]
        else:
            # Latent posterior distribution is the prior.
            qf_mu, qf_cov = self.get_latent_prior(x)

        qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
        samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                   for _ in range(num_samples)]

        return samples

    def predict_y(self, **kwargs):
        # Sample latent posterior distribution.
        f_samples = self.sample_latent_posterior(**kwargs)

        py_f_mus, py_f_sigmas, py_f_samples = [], [], []
        for f in f_samples:
            # Output conditional posterior distribution.
            py_f_mu, py_f_sigma = self.decoder(f.transpose(0, 1))
            py_f_mus.append(py_f_mu)
            py_f_sigmas.append(py_f_sigma)
            py_f_samples.append(
                py_f_mu + py_f_sigma * torch.randn_like(py_f_mu))

        py_f_mu = torch.stack(py_f_mus).mean(0).detach()
        py_f_sigma = torch.stack(py_f_samples).std(0).detach()

        return py_f_mu, py_f_sigma, py_f_samples
