import torch
import torch.nn as nn

__all__ = ['HVIVAE']


class HVIVAE(nn.Module):
    """Vanilla hierarchical VAE with standard normal prior.

    :param: latent_encoder: the latent encoder network.
    :param: latent_decoder: the latent decoder network.
    :param: auxiliary_encoder: the auxiliary encoder network.
    :param: auxiliary_decoder: the auxiliary decoder network.
    :param: latent_dim: the latent space dimensionality.
    :param: auxiliary_dim: the auxiliary space dimensionality.
    """
    def __init__(self, latent_encoder, latent_decoder, auxiliary_encoder,
                 auxiliary_decoder, latent_dim, auxiliary_dim):
        super().__init__()

        self.latent_encoder = latent_encoder
        self.latent_decoder = latent_decoder
        self.auxiliary_encoder = auxiliary_encoder
        self.auxiliary_decoder = auxiliary_decoder
        self.latent_dim = latent_dim
        self.auxiliary_dim = auxiliary_dim

    def get_latent_prior(self, x):
        # Standard normal prior
        pf_mu = torch.zeros(self.latent_dim, x.shape[0])
        pf_cov = torch.ones(self.latent_dim, x.shape[0]).diag_embed()

        return pf_mu, pf_cov

    def get_latent_dists(self, x, y, s, mask=None):
        # Posterior.
        if mask is not None:
            inputs = [s, y]
            masks = [None, mask]
            qf_s_mu, qf_s_sigma = self.latent_encoder(inputs, masks)
        else:
            inputs = [s, y]
            qf_s_mu, qf_s_sigma = self.latent_encoder(inputs)

        # Reshape.
        qf_s_mu = qf_s_mu.transpose(0, 1)
        qf_s_sigma = qf_s_sigma.transpose(0, 1)
        qf_s_cov = qf_s_sigma.pow(2).diag_embed()

        # Prior.
        pf_mu, pf_cov = self.get_latent_prior(x)

        return qf_s_mu, qf_s_cov, pf_mu, pf_cov

    def get_auxiliary_dists(self, x, y, mask=None):
        # Posterior.
        if mask is not None:
            qs_y_mu, qs_y_sigma = self.auxiliary_encoder(y, mask)
        else:
            qs_y_mu, qs_y_sigma = self.auxiliary_encoder(y)

        # Reshape.
        qs_y_cov = qs_y_sigma.pow(2).diag_embed()

        return qs_y_mu, qs_y_cov

    def sample_latent_posterior(self, x, y, mask=None, num_samples=None,
                                num_s_samples=1, num_f_samples=1):
        # Auxiliary posterior distribution.
        qs_y_mu, qs_y_cov = self.get_auxiliary_dists(x, y, mask)
        qs_y_var = torch.stack([cov.diag() for cov in qs_y_cov])

        s_samples, f_samples = [], []
        if num_samples is not None:
            for _ in range(num_samples):
                s = qs_y_mu + qs_y_var ** 0.5 * torch.randn_like(qs_y_mu)
                s_samples.append(s)

                # Latent posterior distribution.
                qf_s_mu, qf_s_cov = self.get_latent_dists(x, y, mask, s)[:2]
                qf_s_var = torch.stack([cov.diag() for cov in qf_s_cov])
                f = qf_s_mu + qf_s_var ** 0.5 * torch.randn_like(qf_s_mu)
                f_samples.append(f)
        else:
            for _ in range(num_s_samples):
                s = qs_y_mu + qs_y_var ** 0.5 * torch.randn_like(qs_y_mu)
                s_samples.append(s)

                # Latent posterior distribution.
                qf_s_mu, qf_s_cov = self.get_latent_dists(x, y, mask, s)[:2]
                qf_s_var = torch.stack([cov.diag() for cov in qf_s_cov])
                f_samples += [qf_s_mu + qf_s_var ** 0.5 * torch.randn_like(
                    qf_s_mu) for _ in range(num_f_samples)]

        return f_samples, s_samples

    def predict_y(self, **kwargs):
        # Sample latent posterior distribution.
        f_samples = self.sample_latent_posterior(**kwargs)[0]

        py_f_mus, py_f_sigmas, py_f_samples = [], [], []
        for f in f_samples:
            # Output conditional posterior distribution.
            py_f_mu, py_f_sigma = self.latent_decoder(f.transpose(0, 1))
            py_f_mus.append(py_f_mu)
            py_f_sigmas.append(py_f_sigma)
            py_f_samples.append(
                py_f_mu + py_f_sigma * torch.randn_like(py_f_mu))

        py_f_mu = torch.stack(py_f_mus).mean(0).detach()
        py_f_sigma = torch.stack(py_f_samples).std(0).detach()

        return py_f_mu, py_f_sigma, py_f_samples
