import copy
import torch

from .hvi_vae import HVIVAE
from ..utils.matrix_utils import add_diagonal
from ..kernels.composition_kernels import KernelList

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.lazy import lazify

__all__ = ['HVIGPVAE']

JITTER = 1e-5


class HVIGPVAE(HVIVAE):
    """Heirarchical VAE with GP prior.

    :param latent_encoder: the latent encoder network.
    :param latent_decoder: the latent decoder network.
    :param auxiliary_encoder: the auxiliary encoder network.
    :param auxiliary_decoder: the auxiliary decoder network.
    :param latent_dim: the latent space dimensionality.
    :param auxiliary_dim: the auxiliary space dimensionality.
    :param kernel: the GP kernel.
    :param add_jitter: whether to add jitter to the GP prior covariance matrix.
    """
    def __init__(self, latent_encoder, latent_decoder, auxiliary_encoder,
                 auxiliary_decoder, latent_dim, auxiliary_dim, kernel,
                 add_jitter=False):
        super().__init__(latent_encoder, latent_decoder, auxiliary_encoder,
                         auxiliary_decoder, latent_dim, auxiliary_dim)

        self.add_jitter = add_jitter

        if not isinstance(kernel, list):
            kernels = [copy.deepcopy(kernel) for _ in range(latent_dim)]
            self.kernels = KernelList(kernels)

        else:
            assert len(kernel) == latent_dim, 'Number of kernels must be ' \
                                              'equal to the latent dimension.'
            self.kernels = KernelList(copy.deepcopy(kernel))

    def get_latent_prior(self, x, diag=False):
        # Gaussian process prior.
        pf_mu = torch.zeros(self.latent_dim, x.shape[0])
        kff = self.kernels.forward(x, x, diag)

        if self.add_jitter:
            # Add jitter to improve condition number.
            kff = add_diagonal(kff, JITTER)

        return pf_mu, kff

    def get_latent_dists(self, x, y, s, mask=None, x_test=None, kff=None):
        # Likelihood terms.
        if mask is not None:
            y_ = torch.cat([s, y], dim=1)
            mask_ = torch.cat([torch.ones_like(s), mask], dim=1)
            lf_sy_mu, lf_sy_sigma = self.latent_encoder(y_, mask_)
        else:
            y_ = torch.cat([s, y], dim=0)
            lf_sy_mu, lf_sy_sigma = self.latent_encoder(y_)

        # Reshape.
        lf_sy_mu = lf_sy_mu.transpose(0, 1)
        lf_sy_sigma = lf_sy_sigma.transpose(0, 1)
        lf_sy_cov = lf_sy_sigma.pow(2).diag_embed()
        lf_sy_precision = lf_sy_sigma.pow(-2).diag_embed()
        lf_sy_root_precision = lf_sy_sigma.pow(-1).diag_embed()

        # GP prior.
        if kff is None:
            pf_mu, kff = self.get_latent_prior(x)
        else:
            pf_mu = torch.zeros(self.latent_dim, x.shape[0])

        # See GPML section 3.4.3
        a = kff.matmul(lf_sy_root_precision)
        at = a.transpose(-1, -2)
        w = lf_sy_root_precision.matmul(a)
        w = add_diagonal(w, 1)
        winv = w.inverse()

        if x_test is not None:
            # GP prior.
            pt_mu, ktt = self.get_latent_prior(x_test)

            # GP conditional prior.
            ktf = self.kernels.forward(x_test, x)
            kft = ktf.transpose(-1, -2)

            # GP test posterior.
            b = lf_sy_root_precision.matmul(winv.matmul(lf_sy_root_precision))
            c = ktf.matmul(b)
            qt_s_cov = ktt - c.matmul(kft)
            qt_s_mu = c.matmul(lf_sy_mu.unsqueeze(2))
            qt_s_mu = qt_s_mu.squeeze(2)

            return qt_s_mu, qt_s_cov, pt_mu, ktt
        else:
            # GP training posterior.
            qf_s_cov = kff - a.matmul(winv.matmul(at))
            qf_s_mu = qf_s_cov.matmul(lf_sy_precision.matmul(
                lf_sy_mu.unsqueeze(2)))
            qf_s_mu = qf_s_mu.squeeze(2)

            return qf_s_mu, qf_s_cov, pf_mu, kff, lf_sy_mu, lf_sy_cov

    def get_auxiliary_dists(self, x, y, mask=None):
        # Posterior.
        if mask is not None:
            qs_y_mu, qs_y_sigma = self.auxiliary_encoder(y, mask)
        else:
            qs_y_mu, qs_y_sigma = self.auxiliary_encoder(y)

        # Reshape.
        qs_y_cov = qs_y_sigma.pow(2).diag_embed()

        return qs_y_mu, qs_y_cov

    def sample_latent_posterior(self, x, y, mask=None, num_samples=None,
                                num_s_samples=1, num_f_samples=1,
                                full_cov=True, x_test=None):
        # Auxiliary posterior distribution.
        qs_y_mu, qs_y_cov = self.get_auxiliary_dists(x, y, mask)
        qs_y_var = torch.stack([cov.diag() for cov in qs_y_cov])

        s_samples, f_samples = [], []
        if num_samples is not None:
            for _ in range(num_samples):
                s = qs_y_mu + qs_y_var ** 0.5 * torch.randn_like(qs_y_mu)
                s_samples.append(s)

                # Latent posterior distribution.
                qf_s_mu, qf_s_cov = self.get_latent_dists(x, y, s, mask)[:2]

                if full_cov:
                    # Use GPyTorch MultivariateNormal class for sampling
                    # using the full covariance matrix.
                    qf_s = MultivariateNormal(qf_s_mu, lazify(qf_s_cov))
                    f_samples.append(qf_s.sample())
                else:
                    qf_s_var = torch.stack([cov.diag() for cov in qf_s_cov])
                    f = qf_s_mu + qf_s_var ** 0.5 * torch.randn_like(qf_s_mu)
                    f_samples.append(f)
        else:
            for _ in range(num_s_samples):
                s = qs_y_mu + qs_y_var ** 0.5 * torch.randn_like(qs_y_mu)
                s_samples.append(s)

                # Latent posterior distribution.
                qf_s_mu, qf_s_cov = self.get_latent_dists(x, y, s, mask)[:2]

                if full_cov:
                    # Use GPyTorch MultivariateNormal class for sampling
                    # using the full covariance matrix.
                    qf_s = MultivariateNormal(qf_s_mu, lazify(qf_s_cov))
                    f_samples += [qf_s.sample() for _ in range(num_f_samples)]
                else:
                    qf_s_var = torch.stack([cov.diag() for cov in qf_s_cov])
                    f_samples += [qf_s_mu + qf_s_var ** 0.5 *
                                  torch.randn_like(qf_s_mu)]

        return f_samples, s_samples
