import copy

import torch
import torch.nn as nn

from .vae import VAE
from ..utils.matrix_utils import add_diagonal
from ..kernels.composition_kernels import KernelList

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.lazy import lazify

__all__ = ['GPVAE', 'SparseGPVAE', 'SparseGPVAE2', 'SparseFITCGPVAE',
           'TitsiasSparseGPVAE', 'HybridGPVAE', 'GPPVAE']

JITTER = 1e-5


class GPVAE(VAE):
    """VAE with GP prior.

    :param encoder: the encoder network.
    :param decoder: the decoder network.
    :param latent_dim: the dimension of latent space.
    :param kernel: the GP kernel.
    :param add_jitter: whether to add jitter to the GP prior covariance matrix.
    """
    def __init__(self, encoder, decoder, latent_dim, kernel, add_jitter=False):
        super().__init__(encoder, decoder, latent_dim)

        self.add_jitter = add_jitter

        if not isinstance(kernel, list):
            kernels = [copy.deepcopy(kernel) for _ in range(latent_dim)]
            self.kernels = KernelList(kernels)

        else:
            assert len(kernel) == latent_dim, 'Number of kernels must be ' \
                                              'equal to the latent dimension.'
            self.kernels = KernelList(copy.deepcopy(kernel))

    def get_latent_prior(self, x, diag=False):
        # Gaussian process prior.
        mf = torch.zeros(self.latent_dim, x.shape[0])
        kff = self.kernels.forward(x, x, diag)

        if self.add_jitter:
            # Add jitter to improve condition number.
            kff = add_diagonal(kff, JITTER)

        return mf, kff

    def get_latent_dists(self, x, y, mask=None, x_test=None):
        # Likelihood terms.
        if mask is not None:
            lf_y_mu, lf_y_sigma = self.encoder(y, mask)
        else:
            lf_y_mu, lf_y_sigma = self.encoder(y)

        # Reshape.
        lf_y_mu = lf_y_mu.transpose(0, 1)
        lf_y_sigma = lf_y_sigma.transpose(0, 1)
        lf_y_cov = lf_y_sigma.pow(2).diag_embed()
        lf_y_precision = lf_y_sigma.pow(-2).diag_embed()
        lf_y_root_precision = lf_y_sigma.pow(-1).diag_embed()

        # GP prior.
        pf_mu, kff = self.get_latent_prior(x)

        # See GPML section 3.4.3.
        a = kff.matmul(lf_y_root_precision)
        at = a.transpose(-1, -2)
        w = lf_y_root_precision.matmul(a)
        w = add_diagonal(w, 1)
        # Can improve efficiency here using back substitution.
        winv = w.inverse()

        if x_test is not None:
            # GP prior.
            ps_mu, kss = self.get_latent_prior(x_test)

            # GP conditional prior.
            ksf = self.kernels.forward(x_test, x)
            kfs = ksf.transpose(-1, -2)

            # GP test posterior.
            b = lf_y_root_precision.matmul(winv.matmul(lf_y_root_precision))
            c = ksf.matmul(b)
            qs_cov = kss - c.matmul(kfs)
            qs_mu = c.matmul(lf_y_mu.unsqueeze(2))
            qs_mu = qs_mu.squeeze(2)

            return qs_mu, qs_cov, ps_mu, kss
        else:
            # GP training posterior.
            qf_cov = kff - a.matmul(winv.matmul(at))
            qf_mu = qf_cov.matmul(lf_y_precision.matmul(lf_y_mu.unsqueeze(2)))
            qf_mu = qf_mu.squeeze(2)

            return qf_mu, qf_cov, pf_mu, kff, lf_y_mu, lf_y_cov

    def sample_latent_posterior(self, x, y, mask=None, num_samples=1,
                                full_cov=True, **kwargs):
        # Latent posterior distribution.
        qf_mu, qf_cov = self.get_latent_dists(x, y, mask, **kwargs)[:2]

        if full_cov:
            # Use GPyTorch MultivariateNormal class for sampling using the
            # full covariance matrix.
            qf = MultivariateNormal(qf_mu, lazify(qf_cov))
            samples = [qf.sample() for _ in range(num_samples)]
        else:
            qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
            samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                       for _ in range(num_samples)]

        return samples


class SparseGPVAE(GPVAE):
    """VAE with GP prior using Rich's VFE approximation.

    :param encoder: the encoder network.
    :param decoder: the decoder network.
    :param latent_dim: the dimension of latent space.
    :param kernel: the GP kernel.
    :param add_jitter: whether to add jitter to the GP prior covariance matrix.
    """
    def __init__(self, encoder, decoder, latent_dim, kernel, add_jitter=False):
        super().__init__(encoder, decoder, latent_dim, kernel, add_jitter)

    def get_latent_dists(self, x, y, mask=None, x_test=None, full_cov=False):
        # Likelihood terms.
        if mask is not None:
            z, lu_y_mu, lu_y_sigma = self.encoder(x, y, mask)
        else:
            z, lu_y_mu, lu_y_sigma = self.encoder(x, y)

        # Reshape.
        lu_y_mu = lu_y_mu.transpose(0, 1)
        lu_y_sigma = lu_y_sigma.transpose(0, 1)
        lu_y_cov = lu_y_sigma.pow(2).diag_embed()
        lu_y_precision = lu_y_sigma.pow(-2).diag_embed()
        lu_y_root_precision = lu_y_sigma.pow(-1).diag_embed()

        # GP prior.
        pu_mu, kuu = self.get_latent_prior(z)

        # See Spatio-Temporal VAEs: Sparse Approximations.
        a = kuu.matmul(lu_y_root_precision)
        at = a.transpose(-1, -2)
        w = lu_y_root_precision.matmul(a)
        w = add_diagonal(w, 1)
        winv = w.inverse()

        if x_test is not None:
            # GP prior.
            ps_mu, kss = self.get_latent_prior(x_test, diag=(not full_cov))

            # GP conditional prior.
            ksu = self.kernels.forward(x_test, z)
            kus = ksu.transpose(-1, -2)

            # GP test posterior.
            b = lu_y_root_precision.matmul(winv.matmul(lu_y_root_precision))
            c = ksu.matmul(b)
            qs_cov = kss - c.matmul(kus)
            qs_mu = c.matmul(lu_y_mu.unsqueeze(2))
            qs_mu = qs_mu.squeeze(2)

            return qs_mu, qs_cov, ps_mu, kss
        else:
            # GP inducing point posterior.
            qu_cov = kuu - a.matmul(winv.matmul(at))
            qu_mu = qu_cov.matmul(lu_y_precision.matmul(lu_y_mu.unsqueeze(2)))
            qu_mu = qu_mu.squeeze(2)

            # GP prior.
            pf_mu, kff = self.get_latent_prior(x, diag=(not full_cov))

            # GP conditional prior.
            kfu = self.kernels.forward(x, z)
            kuf = kfu.transpose(-1, -2)

            # GP training posterior.
            b = lu_y_root_precision.matmul(winv.matmul(lu_y_root_precision))
            c = kfu.matmul(b)
            qf_cov = kff - c.matmul(kuf)
            qf_mu = c.matmul(lu_y_mu.unsqueeze(2))
            qf_mu = qf_mu.squeeze(2)

            return qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, kuu, lu_y_mu, lu_y_cov

    def sample_latent_posterior(self, x, y, mask=None, num_samples=1,
                                full_cov=False, **kwargs):
        # Latent posterior distribution.
        qf_mu, qf_cov = self.get_latent_dists(x, y, mask, full_cov=full_cov,
                                              **kwargs)[:2]

        if full_cov:
            # Use GPyTorch MultivariateNormal class for sampling using the
            # full covariance matrix.
            qf = MultivariateNormal(qf_mu, lazify(qf_cov))
            samples = [qf.sample() for _ in range(num_samples)]
        else:
            qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
            samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                       for _ in range(num_samples)]

        return samples


class SparseGPVAE2(GPVAE):
    """VAE with GP prior using Michael's VFE approximation.

    :param encoder: the encoder network.
    :param decoder: the decoder network.
    :param latent_dim: the dimension of latent space.
    :param kernel: the GP kernel.
    :param z: initial inducing point locations.
    :param add_jitter: whether to add jitter to the GP prior covariance
    matrix.
    """
    def __init__(self, encoder, decoder, latent_dim, kernel, z,
                 add_jitter=False, fixed_inducing=False):
        super().__init__(encoder, decoder, latent_dim, kernel, add_jitter)

        if fixed_inducing:
            self.z = nn.Parameter(z, requires_grad=False)
        else:
            self.z = nn.Parameter(z, requires_grad=True)

    def get_latent_dists(self, x, y, mask=None, x_test=None, full_cov=False):
        # Likelihood terms.
        if mask is not None:
            lf_y_mu, lf_y_sigma = self.encoder(y, mask)
        else:
            lf_y_mu, lf_y_sigma = self.encoder(y)

        # Reshape.
        lf_y_mu = lf_y_mu.T
        lf_y_sigma = lf_y_sigma.T
        lf_y_cov = lf_y_sigma.pow(2).diag_embed()
        lf_y_precision = lf_y_sigma.pow(-2).diag_embed()

        # GP prior.
        pu_mu, kuu = self.get_latent_prior(self.z)

        # GP conditional prior.
        kfu = self.kernels.forward(x, self.z)
        kuf = kfu.transpose(-1, -2)

        # See Spatio-Temporal VAEs: Michael's Approach.
        # l = torch.cholesky(kuu)
        # a = torch.triangular_solve(kuf.matmul(lf_y_root_precision), l,
        #                            upper=False)
        # b = a.matmul(a.T)
        # b = add_diagonal(b, 1)
        # lb = torch.cholesky(b)
        # phi = torch.triangular_solve(torch.triangular_solve(b.inverse(), )
        # phi = torch.triangular_solve(torch.triangular_solve(b.inverse(),
        # c = torch.triangular_solve(a.matmul(lf_y_root_precision.matmul(
        #     lf_y_mu.unsqueeze(2))), lb, upper=False)
        # d = L^{-T}L_B^{-T}

        # Do it stupidly for now.
        kuu_inv = kuu.inverse()
        phi = (kuu + kuf.matmul(lf_y_precision).matmul(kfu)).inverse()
        qu_mu = kuu.matmul(phi.matmul(kuf.matmul(lf_y_precision.matmul(
            lf_y_mu.unsqueeze(2)))))

        if x_test is not None:
            # GP prior.
            ps_mu, kss = self.get_latent_prior(x_test, diag=(not full_cov))

            # GP conditional prior.
            ksu = self.kernels.forward(x_test, self.z)
            kus = ksu.transpose(-1, -2)

            qs_cov = kss - ksu.matmul(kuu_inv - phi).matmul(kus)
            qs_mu = ksu.matmul(kuu_inv.matmul(qu_mu)).squeeze(2)

            return qs_mu, qs_cov, ps_mu, kss
        else:
            # GP prior.
            # Note that only diagonals are needed when optimising ELBO.
            pf_mu, kff = self.get_latent_prior(x, diag=(not full_cov))

            qf_cov = kff - kfu.matmul(kuu_inv - phi).matmul(kuf)
            qf_mu = kfu.matmul(kuu_inv.matmul(qu_mu)).squeeze(2)

            qu_cov = kuu.matmul(phi.matmul(kuu))
            qu_mu = qu_mu.squeeze(2)

            return qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, kuu, lf_y_mu, lf_y_cov

    def sample_latent_posterior(self, x, y, mask=None, num_samples=1,
                                full_cov=False, **kwargs):
        # Latent posterior distribution.
        qf_mu, qf_cov = self.get_latent_dists(x, y, mask, full_cov=full_cov,
                                              **kwargs)[:2]

        if full_cov:
            # Use GPyTorch MultivariateNormal class for sampling using the
            # full covariance matrix.
            qf = MultivariateNormal(qf_mu, lazify(qf_cov))
            samples = [qf.sample() for _ in range(num_samples)]
        else:
            qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
            samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                       for _ in range(num_samples)]

        return samples


class TitsiasSparseGPVAE(GPVAE):
    """VAE with GP prior using Titsias' VFE approximation.

    :param decoder: the decoder network.
    :param latent_dim: the dimension of latent space.
    :param kernel: the GP kernel.
    :param z: the initial inducing point locations.
    :param add_jitter: whether to add jitter to the GP prior covariance matrix.
    :param initial_mu: A Tensor, sets the initial variational mean.
    :param initial_sigma: A Tensor, sets the initial variational sigma.
    :param min_sigma: A float, the minimum variational sigma.
    """
    def __init__(self, decoder, latent_dim, kernel,  z, add_jitter=False,
                 initial_mu=0., initial_sigma=1., min_sigma=0.):
        super().__init__(None, decoder, latent_dim, kernel, add_jitter)

        self.num_inducing = z.shape[0]

        # Ensure inducing points are two-dimensional.
        if len(z.shape) == 1:
            z = z.unsqueeze(1)

        self.z = nn.Parameter(z, requires_grad=True)
        self.min_sigma = min_sigma

        # Initialise mean and sigma of the variational distribution.
        self.lu_mu = nn.Parameter(
            torch.tensor(initial_mu)
            + JITTER * torch.randn(self.num_inducing, self.latent_dim),
            requires_grad=True)

        self.lu_sigma = nn.Parameter(
            torch.tensor(initial_sigma)
            + JITTER * torch.randn(self.num_inducing, self.latent_dim),
            requires_grad=True)

    def get_inducing_likelihood(self):
        # Get variational likelihood parameters.
        lu_mu = self.lu_mu
        lu_sigma = self.lu_sigma

        lu_sigma = self.min_sigma + (1 - self.min_sigma) * lu_sigma

        return lu_mu, lu_sigma

    def get_latent_dists(self, x, x_test=None, full_cov=False, **kwargs):
        # Likelihood terms.
        lu_mu, lu_sigma = self.get_inducing_likelihood()

        # Reshape.
        lu_mu = lu_mu.transpose(0, 1)
        lu_sigma = lu_sigma.transpose(0, 1)
        lu_cov = lu_sigma.pow(2).diag_embed()
        lu_precision = lu_sigma.pow(-2).diag_embed()
        lu_root_precision = lu_sigma.pow(-1).diag_embed()

        # GP prior.
        pu_mu, kuu = self.get_latent_prior(self.z)

        # See Spatio-Temporal VAEs: Sparse Approximations.
        a = kuu.matmul(lu_root_precision)
        at = a.transpose(-1, -2)
        w = lu_root_precision.matmul(a)
        w = add_diagonal(w, 1)
        winv = w.inverse()

        if x_test is not None:
            # GP prior.
            ps_mu, kss = self.get_latent_prior(x_test, diag=(not full_cov))

            # GP conditional prior.
            ksu = self.kernels.forward(x_test, self.z)
            kus = ksu.transpose(-1, -2)

            # GP test posterior.
            b = lu_root_precision.matmul(winv.matmul(lu_root_precision))
            c = ksu.matmul(b)
            qs_cov = kss - c.matmul(kus)
            qs_mu = c.matmul(lu_mu.unsqueeze(2))
            qs_mu = qs_mu.squeeze(2)

            return qs_mu, qs_cov, ps_mu, kss
        else:
            # GP inducing point posterior.
            qu_cov = kuu - a.matmul(winv.matmul(at))
            qu_mu = qu_cov.matmul(lu_precision.matmul(lu_mu.unsqueeze(2)))
            qu_mu = qu_mu.squeeze(2)

            # GP prior.
            pf_mu, kff = self.get_latent_prior(x, diag=(not full_cov))

            # GP conditional prior.
            kfu = self.kernels.forward(x, self.z)
            kuf = kfu.transpose(-1, -2)

            # GP training posterior.
            b = lu_root_precision.matmul(winv.matmul(lu_root_precision))
            c = kfu.matmul(b)
            qf_cov = kff - c.matmul(kuf)
            qf_mu = c.matmul(lu_mu.unsqueeze(2))
            qf_mu = qf_mu.squeeze(2)

            return qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, kuu, lu_mu, lu_cov

    def sample_latent_posterior(self, x, num_samples=1, full_cov=False,
                                **kwargs):
        # Latent posterior distribution.
        qf_mu, qf_cov = self.get_latent_dists(x, full_cov=full_cov,
                                              **kwargs)[:2]

        if full_cov:
            # Use GPyTorch MultivariateNormal class for sampling using the
            # full covariance matrix.
            qf = MultivariateNormal(qf_mu, lazify(qf_cov))
            samples = [qf.sample() for _ in range(num_samples)]
        else:
            qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
            samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                       for _ in range(num_samples)]

        return samples


class SparseFITCGPVAE(GPVAE):
    """VAE with sparse GP prior using the FITC approximation.

    :param encoder: the encoder network.
    :param decoder: the decoder network.
    :param latent_dim: the dimension of latent space.
    :param z: A torch.Tensor, initial inducing point locations.
    :param kernel: the GP kernel.
    """
    def __init__(self, encoder, decoder, latent_dim, z,
                 kernel):
        super().__init__(encoder, decoder, latent_dim, kernel)

        self.z = nn.Parameter(copy.deepcopy(z), requires_grad=True)

    def get_latent_fitc_prior(self, x):
        # Gaussian process prior.
        pf_mu = torch.zeros(self.latent_dim, x.shape[0])

        kff = self.kernels.forward(x, x, diag=True)
        kff = add_diagonal(kff, JITTER)
        kuu = self.kernels.forward(self.z, self.z)
        kuu = add_diagonal(kuu, JITTER)

        # FITC Gaussian process prior.
        kfu = self.kernels.forward(x, self.z)
        kuf = kfu.transpose(-1, -2)
        kuuinv = kuu.inverse()
        qff = kfu.matmul(kuuinv.matmul(kuf))
        lff = [(kff[i, ...].diag() - qff[i, ...].diag()).diag_embed()
               for i in range(kff.shape[0])]
        lff = torch.stack(lff)

        pf_cov = qff + lff

        return pf_mu, pf_cov

    def get_latent_dists(self, x, y, mask=None, x_test=None):
        # Likelihood terms.
        if mask is not None:
            lf_y_mu, lf_y_sigma = self.encoder(y, mask)
        else:
            lf_y_mu, lf_y_sigma = self.encoder(y)

        # Reshape.
        lf_y_mu = lf_y_mu.transpose(0, 1)  # [latent_dim, M]
        lf_y_sigma = lf_y_sigma.transpose(0, 1)  # [latent_dim, M]
        lf_y_cov = lf_y_sigma.pow(2).diag_embed()
        lf_y_precision = lf_y_sigma.pow(-2).diag_embed()

        # GP prior.
        pf_mu, kff = self.get_latent_prior(x)
        pu_mu, kuu = self.get_latent_prior(self.z)
        kfu = self.kernels.forward(x, self.z)
        kuf = kfu.transpose(-2, -1)

        # See Spatio-Temporal VAEs: Sparse Approximations
        kuuinv = kuu.inverse()
        qff = kfu.matmul(kuuinv.matmul(kuf))
        lff = [(kff[i, ...].diag() - qff[i, ...].diag()).diag_embed()
               for i in range(kff.shape[0])]
        lff = torch.stack(lff)
        pf_cov = qff + lff

        # \hat{Lff} = Lff + Sigma_{\phi}
        lffhat = lff + lf_y_cov
        lffhatinv = [cov.diag().pow(-1).diag_embed() for cov in lffhat]
        lffhatinv = torch.stack(lffhatinv)

        # A = Kuf * \hat{Lff}^{-1}
        a = kuf.matmul(lffhatinv)
        at = a.transpose(-1, -2)

        # W = Kuu + Kuf * \hat{Lff}^{-1} * Kfu
        w = kuu + a.matmul(kfu)
        winv = w.inverse()

        # Q_{ff} + \Lambda_{ff} + \Sigma_{\phi}
        sum_cov_inv = lffhatinv - at.matmul(winv.matmul(a))

        if x_test is not None:
            # GP prior
            ps_mu, kss = self.get_latent_prior(x_test)
            ksf = self.kernels.forward(x_test, x)
            kfs = ksf.transpose(-1, -2)

            # GP test posterior.
            qs_cov = kss - ksf.matmul(sum_cov_inv.matmul(kfs))
            qs_mu = ksf.matmul(sum_cov_inv.matmul(lf_y_mu.unsqueeze(2)))
            qs_mu = qs_mu.squeeze(2)

            return qs_mu, qs_cov, ps_mu, kss

        else:
            # GP training posterior.
            qf_cov = pf_cov - pf_cov.matmul(sum_cov_inv.matmul(pf_cov))
            qf_mu = qf_cov.matmul(lf_y_precision.matmul(lf_y_mu.unsqueeze(2)))
            qf_mu = qf_mu.squeeze(2)

            return qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov


class HybridGPVAE(nn.Module):
    """VAE with a mixture of GP and Gaussian latent prior.

    :param gp_encoder: the encoder network for latent GPs.
    :param sn_encoder: the encoder network for standard normal latent
    dimensions.
    :param decoder: the decoder network.
    :param gp_latent_dim: the number of latent GPs.
    :param sn_latent_dim: the number of standard normal latent dimensions.
    :param kernel: the GP kernel.
    :param add_jitter: whether to add jitter to the GP prior covariance matrix.
    """
    def __init__(self, gp_encoder, sn_encoder, decoder, gp_latent_dim,
                 sn_latent_dim, kernel, add_jitter=False):
        super().__init__()

        self.gp_encoder = gp_encoder
        self.sn_encoder = sn_encoder
        self.decoder = decoder
        self.gp_latent_dim = gp_latent_dim
        self.sn_latent_dim = sn_latent_dim
        self.add_jitter = add_jitter

        if not isinstance(kernel, list):
            kernels = [copy.deepcopy(kernel) for _ in range(gp_latent_dim)]
            self.kernels = KernelList(kernels)

        else:
            assert len(kernel) == gp_latent_dim, 'Number of kernels must be ' \
                                                 'equal to the latent ' \
                                                 'dimension.'
            self.kernels = KernelList(copy.deepcopy(kernel))

    def get_latent_gp_prior(self, x, diag=False):
        # Gaussian process prior.
        mf = torch.zeros(self.gp_latent_dim, x.shape[0])
        kff = self.kernels.forward(x, x, diag)

        if self.add_jitter:
            # Add jitter to improve condition number.
            kff = add_diagonal(kff, JITTER)

        return mf, kff

    def get_latent_sn_prior(self, x):
        # Standard normal prior.
        pf_mu = torch.zeros(self.sn_latent_dim, x.shape[0])
        pf_cov = torch.ones(self.sn_latent_dim, x.shape[0]).diag_embed()

        return pf_mu, pf_cov

    def get_latent_gp_dists(self, x, y, mask=None, x_test=None):
        # Likelihood terms.
        if mask is not None:
            lf_y_mu, lf_y_sigma = self.gp_encoder(y, mask)
        else:
            lf_y_mu, lf_y_sigma = self.gp_encoder(y)

        # Reshape.
        lf_y_mu = lf_y_mu.transpose(0, 1)
        lf_y_sigma = lf_y_sigma.transpose(0, 1)
        lf_y_cov = lf_y_sigma.pow(2).diag_embed()
        lf_y_precision = lf_y_sigma.pow(-2).diag_embed()
        lf_y_root_precision = lf_y_sigma.pow(-1).diag_embed()

        # GP prior.
        pf_mu, kff = self.get_latent_gp_prior(x)

        # See GPML section 3.4.3.
        a = kff.matmul(lf_y_root_precision)
        at = a.transpose(-1, -2)
        w = lf_y_root_precision.matmul(a)
        w = add_diagonal(w, 1)
        winv = w.inverse()

        if x_test is not None:
            # GP prior.
            ps_mu, kss = self.get_latent_prior(x_test)

            # GP conditional prior.
            ksf = self.kernels.forward(x_test, x)
            kfs = ksf.transpose(-1, -2)

            # GP test posterior.
            b = lf_y_root_precision.matmul(winv.matmul(lf_y_root_precision))
            c = ksf.matmul(b)
            qs_cov = kss - c.matmul(kfs)
            qs_mu = c.matmul(lf_y_mu.unsqueeze(2))
            qs_mu = qs_mu.squeeze(2)

            return qs_mu, qs_cov, ps_mu, kss
        else:
            # GP training posterior.
            qf_cov = kff - a.matmul(winv.matmul(at))
            qf_mu = qf_cov.matmul(lf_y_precision.matmul(lf_y_mu.unsqueeze(2)))
            qf_mu = qf_mu.squeeze(2)

            return qf_mu, qf_cov, pf_mu, kff, lf_y_mu, lf_y_cov

    def get_latent_sn_dists(self, x, y, mask=None):
        # Posterior.
        if mask is not None:
            qf_mu, qf_sigma = self.sn_encoder(y, mask)
        else:
            qf_mu, qf_sigma = self.sn_encoder(y)

        # Reshape.
        qf_mu = qf_mu.transpose(0, 1)
        qf_sigma = qf_sigma.transpose(0, 1)
        qf_cov = qf_sigma.pow(2).diag_embed()

        # Prior.
        pf_mu, pf_cov = self.get_latent_sn_prior(x)

        return qf_mu, qf_cov, pf_mu, pf_cov

    def sample_latent_gp_posterior(self, x, y=None, mask=None, num_samples=1,
                                   full_cov=True, **kwargs):
        # Latent posterior distribution.
        if y is not None:
            qf_mu, qf_cov = self.get_latent_gp_dists(x, y, mask, **kwargs)[:2]
        else:
            qf_mu, qf_cov = self.get_latent_gp_prior(x)

        if full_cov:
            # Use GPyTorch MultivariateNormal class for sampling using the
            # full covariance matrix.
            qf = MultivariateNormal(qf_mu, lazify(qf_cov))
            samples = [qf.sample() for _ in range(num_samples)]
        else:
            qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
            samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                       for _ in range(num_samples)]

        return samples

    def sample_latent_sn_posterior(self, x, y=None, mask=None,
                                   num_samples=1, **kwargs):
        if y is not None:
            # Latent posterior distribution.
            qf_mu, qf_cov = self.get_latent_sn_dists(x, y, mask)[:2]
        else:
            # Latent posterior distribution is the prior.
            qf_mu, qf_cov = self.get_latent_sn_prior(x)

        qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
        samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                   for _ in range(num_samples)]

        return samples

    def predict_y(self, **kwargs):
        # Sample latent posterior distributions.
        f_gp_samples = self.sample_latent_gp_posterior(**kwargs)
        f_sn_samples = self.sample_latent_sn_posterior(**kwargs)

        py_f_mus, py_f_sigmas, py_f_samples = [], [], []
        for f_gp, f_sn in zip(f_gp_samples, f_sn_samples):
            # Latent sample.
            f = torch.cat([f_gp, f_sn], dim=0)

            # Output conditional posterior distribution.
            py_f_mu, py_f_sigma = self.decoder(f.transpose(0, 1))
            py_f_mus.append(py_f_mu)
            py_f_sigmas.append(py_f_sigma)
            py_f_samples.append(
                py_f_mu + py_f_sigma * torch.randn_like(py_f_mu))

        py_f_mu = torch.stack(py_f_mus).mean(0).detach()
        py_f_sigma = torch.stack(py_f_samples).std(0).detach()

        return py_f_mu, py_f_sigma, py_f_samples


class GPPVAE(VAE):
    """VAE with GP prior and fully-factorised approximate posterior.

    :param encoder: the encoder network.
    :param decoder: the decoder network.
    :param latent_dim: the dimension of latent space.
    :param kernel: the GP kernel.
    :param add_jitter: whether to add jitter to the GP prior covariance matrix.
    """
    def __init__(self, encoder, decoder, latent_dim, kernel, add_jitter=False):
        super().__init__(encoder, decoder, latent_dim)

        self.add_jitter = add_jitter

        if not isinstance(kernel, list):
            kernels = [copy.deepcopy(kernel) for _ in range(latent_dim)]
            self.kernels = KernelList(kernels)

        else:
            assert len(kernel) == latent_dim, 'Number of kernels must be ' \
                                              'equal to the latent dimension.'
            self.kernels = KernelList(copy.deepcopy(kernel))

    def get_latent_prior(self, x, diag=False):
        # Gaussian process prior.
        mf = torch.zeros(self.latent_dim, x.shape[0])
        kff = self.kernels.forward(x, x, diag)

        if self.add_jitter:
            # Add jitter to improve condition number.
            kff = add_diagonal(kff, JITTER)

        return mf, kff

    def get_latent_dists(self, x, y, mask=None, x_test=None):
        # Likelihood terms.
        if mask is not None:
            lf_y_mu, lf_y_sigma = self.encoder(y, mask)
        else:
            lf_y_mu, lf_y_sigma = self.encoder(y)

        # Reshape.
        lf_y_mu = lf_y_mu.transpose(0, 1)
        lf_y_sigma = lf_y_sigma.transpose(0, 1)
        lf_y_cov = lf_y_sigma.pow(2).diag_embed()

        # GP prior.
        pf_mu, kff = self.get_latent_prior(x)
        kff_inv = kff.inverse()

        if x_test is not None:
            # GP prior.
            ps_mu, kss = self.get_latent_prior(x_test)

            # GP conditional prior.
            ksf = self.kernels.forward(x_test, x)
            kfs = ksf.transpose(-1, -2)

            # GP test posterior.
            qs_mu = ksf.matmul(kff_inv.matmul(lf_y_mu.unsqueeze(2)))
            qs_cov = kss - kfs.matmul(kff_inv).matmul(ksf) + ksf.matmul(
                kff_inv).matmul(lf_y_cov).matmul(kff_inv).matmul(kfs)
            qs_mu = qs_mu.squeeze(2)

            return qs_mu, qs_cov, ps_mu, kss
        else:
            return lf_y_mu, lf_y_cov, pf_mu, kff

    def sample_latent_posterior(self, x, y, mask=None, num_samples=1,
                                **kwargs):
        # Latent posterior distribution.
        qf_mu, qf_cov = self.get_latent_dists(x, y, mask, **kwargs)[:2]

        qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5
        samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)
                   for _ in range(num_samples)]

        return samples
