import torch
import numpy as np

from ..utils.gaussian_utils import gaussian_diagonal_ll, gaussian_diagonal_kl

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.lazy import lazify

__all__ = ['sf_estimator', 'td_estimator', 'analytical_estimator',
           'vfe_td_estimator', 'vfe_analytical_estimator', 'elbo_estimator',
           'vfe_elbo_estimator', 'iwae_estimator', 'vfe_iwae_estimator',
           'hybrid_elbo_estimator', 'hybrid_td_estimator',
           'conditional_td_estimator', 'conditional_analytical_estimator']


def sf_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                 make_lazy=True):
    """Estimates the negative ELBO using the score function trick for an
    model using an for models with an encoder/decoder architecture and GP prior
    over latent variables and approximate posterior of the form
    q(f) = 1/Z p(f)l(f|y).

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(mask)
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. - num_nan / num_observations
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = model.get_latent_dists(
        x, y, mask, num_samples=num_samples)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
    else:
        pf = MultivariateNormal(pf_mu, pf_cov)
        qf = MultivariateNormal(qf_mu, qf_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.sample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log l(f|y) term.
        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu, lf_y_var)
        lf_y_term = lf_y_term.sum()

        # log q(f) term.
        qf_term = gaussian_diagonal_ll(f, qf_mu, qf_var)
        estimator += qf_term * (py_f_term.detach() - lf_y_term.detach())

        # log p(f) term.
        pf_term = pf.log_prob(f).sum()
        estimator += pf_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def td_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                 make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the total 
    derivative of the reparaemterisation trick for the GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log l(f|y) term.
        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu.detach(),
                                         lf_y_var.detach())
        lf_y_term = lf_y_term.sum()
        estimator += - lf_y_term

        # log p(f) term.
        pf_term = pf.log_prob(f.detach()).sum()
        estimator += pf_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def new_td_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                     make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the total
    derivative of the reparaemterisation trick for the GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log l(f|y) term.
        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu.detach(),
                                         lf_y_var.detach())
        lf_y_term = lf_y_term.sum()
        estimator += - lf_y_term

        # log p(f) term.
        pf_term = pf.log_prob(f.detach()).sum()
        estimator += pf_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def analytical_estimator(model, x, y, mask=None, num_samples=1,
                         decoder_scale=None, make_lazy=True, mf=False,
                         idx=None):
    """Estimates the negative ELBO using analytical results were possible,
    and the reparameterisation trick for the decoder term for models with an
    encoder/decoder architecture and GP prior over latent variables and
    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    sum_cov = pf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        # qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        # qf = MultivariateNormal(qf_mu, qf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        # f = qf.rsample()
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log l(f|y) term.
    lf_y_term = gaussian_diagonal_ll(qf_mu, lf_y_mu, lf_y_var).sum()
    lf_y_term += - 0.5 * (qf_var / lf_y_var).sum()
    estimator += - lf_y_term

    # log Zq term.
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    estimator += zq_term

    # Outer summation over batch
    estimator /= x.shape[0]

    return - estimator


def vfe_td_estimator(model, x, y, mask=None, num_samples=1, 
                     decoder_scale=None, make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the total 
    derivative of the reparameterisation trick for the SparseGPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lu_y_mu, lu_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lu_y_mu, lu_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qu = MultivariateNormal(qu_mu, lazify(qu_cov))
        pu = MultivariateNormal(pu_mu, lazify(pu_cov))
    else:
        qu = MultivariateNormal(qu_mu, qu_cov)
        pu = MultivariateNormal(pu_mu, pu_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lu_y_var = torch.stack([cov.diag() for cov in lu_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)
        u = qu.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log l(u|y) term.
        lu_y_term = gaussian_diagonal_ll(u, lu_y_mu.detach(),
                                         lu_y_var.detach())
        lu_y_term = lu_y_term.sum()
        estimator += - lu_y_term

        # log p(u) term.
        pu_term = pu.log_prob(u.detach()).sum()
        estimator += pu_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch
    estimator /= x.shape[0]

    return - estimator


def vfe_analytical_estimator(model, x, y, mask=None, num_samples=1,
                             decoder_scale=None, make_lazy=True, mf=False,
                             idx=None):
    """Estimates the gradient of the negative ELBO using analytical results
    where possible for the SparseGPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lu_y_mu, lu_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lu_y_mu, lu_y_cov = \
            model.get_latent_dists(x, y, mask)

    sum_cov = pu_cov + lu_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        zq = MultivariateNormal(lu_y_mu, lazify(sum_cov))
    else:
        zq = MultivariateNormal(lu_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    qu_var = torch.stack([cov.diag() for cov in qu_cov])
    lu_y_var = torch.stack([cov.diag() for cov in lu_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log l(u|y) term.
    lu_y_term = gaussian_diagonal_ll(qu_mu, lu_y_mu, lu_y_var).sum()
    lu_y_term += - 0/5 * (qu_var / lu_y_var).sum()
    estimator += - lu_y_term

    # log Zq term.
    zq_term = zq.log_prob(torch.zeros_like(lu_y_mu)).sum()
    estimator += zq_term

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def elbo_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True,
                   mf=False, idx=None):
    """Estimates the ELBO using analytical results were possible for the
    GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    with.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    elbo = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    sum_cov = pf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO.
    # See Spatio-Temporal VAEs: ELBO
    for i in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        elbo += py_f_term

    # Inner summation over samples from q(f).
    elbo /= num_samples

    # log l(f|y) term.
    lf_y_term = gaussian_diagonal_ll(qf_mu, lf_y_mu, lf_y_var).sum()
    lf_y_term += - 0.5 * (qf_var / lf_y_var).sum()
    elbo += - lf_y_term

    # log Zq term.
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    elbo += zq_term

    return elbo


def elbo_estimator2(model, x, y, mask=None, num_samples=1, make_lazy=True,
                    mf=False, idx=None):
    """Estimates the negative ELBO using analytical results were possible,
    and the reparameterisation trick for the decoder term for models
    with an
    encoder/decoder architecture and GP prior over latent variables and
    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    with.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    elbo = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    sum_cov = pf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO.
    # See Spatio-Temporal VAEs: ELBO
    for i in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        elbo += py_f_term

        # log l(y|f) term.
        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu, lf_y_var)
        lf_y_term = lf_y_term.sum()
        elbo += - lf_y_term

    # Inner summation over samples from q(f).
    elbo /= num_samples

    # log Zq term.
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    elbo += zq_term

    return elbo


def vfe_elbo_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True,
                       mf=False, idx=None):
    """Estimates the negative ELBO using analytical results were possible,
    and the reparameterisation trick for the decoder term for models with an
    encoder/decoder architecture and GP prior over latent variables and
    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    with.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    elbo = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lu_y_mu, lu_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lu_y_mu, lu_y_cov = \
            model.get_latent_dists(x, y, mask)

    sum_cov = pu_cov + lu_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class.
        zq = MultivariateNormal(lu_y_mu, lazify(sum_cov))
    else:
        zq = MultivariateNormal(lu_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    qu_var = torch.stack([cov.diag() for cov in qu_cov])
    lu_y_var = torch.stack([cov.diag() for cov in lu_y_cov])

    # Monte-Carlo estimate of ELBO.
    # See Spatio-Temporal VAEs: Sparse Approximations.
    for i in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        elbo += py_f_term

    # Inner summation over samples from q(f).
    elbo /= num_samples

    # log l(u|y) term.
    lu_y_term = gaussian_diagonal_ll(qu_mu, lu_y_mu, lu_y_var).sum()
    lu_y_term += - 0.5 * (qu_var / lu_y_var).sum()
    elbo += - lu_y_term

    # log Zq term
    zq_term = zq.log_prob(torch.zeros_like(lu_y_mu)).sum()
    elbo += zq_term

    return elbo


def iwae_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True,
                   mf=False, idx=None):
    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(x, idx=idx)[:4]
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(x, y, mask)[:4]

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    weights = torch.ones(num_samples)

    # Importance weighted Monte-Carlo estimate of ELBO.
    for i in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()

        # log p(f) term.
        pf_term = pf.log_prob(f).sum()

        # log q(f) term.
        qf_term = qf.log_prob(f).sum()

        weights[i] = pf_term + py_f_term - qf_term - np.log(num_samples)

    iwae = torch.logsumexp(weights, dim=0)

    return iwae


def vfe_iwae_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True,
                       mf=False, idx=True):
    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov = model.get_latent_dists(x, idx=idx)[:2]
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov = model.get_latent_dists(x, y, mask)[:2]

    pf_mu, pf_cov = model.get_latent_prior(x)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    weights = torch.ones(num_samples)

    # Importance weighted Monte-Carlo estimate of ELBO.
    for i in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()

        # log p(f) term.
        pf_term = pf.log_prob(f).sum()

        # log q(f) term.
        qf_term = qf.log_prob(f).sum()

        weights[i] = pf_term + py_f_term - qf_term - np.log(num_samples)

    iwae = torch.logsumexp(weights, dim=0)

    return iwae


def hybrid_td_estimator(model, x, y, mask=None, num_samples=1,
                        decoder_scale=None, make_lazy=True, mf=False,
                        idx=None):
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \
            model.get_latent_gp_dists(x, idx=idx)
        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x,
                                                                             idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \
            model.get_latent_gp_dists(x, y, mask)
        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x,
                                                                             y,
                                                                             mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf_gp = MultivariateNormal(qf_gp_mu, lazify(qf_gp_cov))
        pf_gp = MultivariateNormal(pf_gp_mu, lazify(pf_gp_cov))
    else:
        qf_gp = MultivariateNormal(qf_gp_mu, qf_gp_cov)
        pf_gp = MultivariateNormal(pf_gp_mu, pf_gp_cov)

    lf_gp_y_var = torch.stack([cov.diag() for cov in lf_gp_y_cov])
    qf_sn_var = torch.stack([cov.diag() for cov in qf_sn_cov])
    pf_sn_var = torch.stack([cov.diag() for cov in pf_sn_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f_gp = qf_gp.rsample()
        f_sn = qf_sn_mu + qf_sn_var ** 0.5 * torch.randn_like(qf_sn_mu)
        f = torch.cat([f_gp, f_sn], dim=0)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log l(f_gp|y) term.
        lf_gp_y_term = gaussian_diagonal_ll(f_gp, lf_gp_y_mu.detach(),
                                            lf_gp_y_var.detach())
        lf_gp_y_term = lf_gp_y_term.sum()
        estimator += - lf_gp_y_term

        # log q(f_sn) term.
        qf_sn_term = gaussian_diagonal_ll(f_sn, qf_sn_mu, qf_sn_var)
        qf_sn_term = qf_sn_term.sum()
        estimator += - qf_sn_term

        # log p(f_gp) term.
        pf_gp_term = pf_gp.log_prob(f_gp.detach()).sum()
        estimator += pf_gp_term

        # log p(f_sn) term.
        pf_sn_term = gaussian_diagonal_ll(f_sn, pf_sn_mu, pf_sn_var)
        pf_sn_term = pf_sn_term.sum()
        estimator += pf_sn_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def hybrid_elbo_estimator(model, x, y, mask=None, num_samples=1,
                          make_lazy=True, mf=False, idx=None):
    elbo = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \
            model.get_latent_gp_dists(x, idx=idx)
        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x,
                                                                             idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \
            model.get_latent_gp_dists(x, y, mask)
        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x,
                                                                             y,
                                                                             mask)

    gp_sum_cov = pf_gp_cov + lf_gp_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf_gp = MultivariateNormal(qf_gp_mu, lazify(qf_gp_cov))
        zq_gp = MultivariateNormal(lf_gp_y_mu, lazify(gp_sum_cov))
    else:
        qf_gp = MultivariateNormal(qf_gp_mu, qf_gp_cov)
        zq_gp = MultivariateNormal(lf_gp_y_mu, gp_sum_cov)

    qf_gp_var = torch.stack([cov.diag() for cov in qf_gp_cov])
    lf_gp_y_var = torch.stack([cov.diag() for cov in lf_gp_y_cov])
    qf_sn_var = torch.stack([cov.diag() for cov in qf_sn_cov])
    pf_sn_var = torch.stack([cov.diag() for cov in pf_sn_cov])

    # Monte-Carlo estimate of ELBO.
    # See Spatio-Temporal VAEs: ELBO
    for i in range(num_samples):
        f_gp = qf_gp.rsample()
        f_sn = qf_sn_mu + qf_sn_var ** 0.5 * torch.randn_like(qf_sn_mu)
        f = torch.cat([f_gp, f_sn], dim=0)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        elbo += py_f_term

    # Inner summation over samples from q(f).
    elbo /= num_samples

    # log l(f_gp|y) term.
    lf_gp_y_term = gaussian_diagonal_ll(qf_gp_mu, lf_gp_y_mu,
                                        lf_gp_y_var).sum()
    lf_gp_y_term += - 0.5 * (qf_gp_var / lf_gp_y_var).sum()
    elbo += - lf_gp_y_term

    # log Zq_gp term.
    zq_gp_term = zq_gp.log_prob(torch.zeros_like(lf_gp_y_mu)).sum()
    elbo += zq_gp_term

    # KL(qf_sn || pf_sn) term.
    kl_sn_term = gaussian_diagonal_kl(qf_sn_mu, qf_sn_var, pf_sn_mu, pf_sn_var)
    kl_sn_term = kl_sn_term.sum()
    elbo += - kl_sn_term

    return elbo


def conditional_td_estimator(model, x, y, y_c, mask=None, mask_c=None,
                             num_samples=1, decoder_scale=None,
                             make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the total
    derivative of the reparaemterisation trick for the GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y_c, mask_c)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log l(f|y) term.
        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu.detach(),
                                         lf_y_var.detach())
        lf_y_term = lf_y_term.sum()
        estimator += - lf_y_term

        # log p(f) term.
        pf_term = pf.log_prob(f.detach()).sum()
        estimator += pf_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def conditional_analytical_estimator(model, x, y, y_c, mask=None, mask_c=None,
                                     num_samples=1, decoder_scale=None,
                                     make_lazy=True, mf=False, idx=None):
    """Estimates the negative ELBO using analytical results were possible,
    and the reparameterisation trick for the decoder term for models with an
    encoder/decoder architecture and GP prior over latent variables and
    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y_c, mask_c)

    sum_cov = pf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log l(f|y) term.
    lf_y_term = gaussian_diagonal_ll(qf_mu, lf_y_mu, lf_y_var).sum()
    lf_y_term += - 0.5 * (qf_var / lf_y_var).sum()
    estimator += - lf_y_term

    # log Zq term.
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    estimator += zq_term

    # Outer summation over batch
    estimator /= x.shape[0]

    return - estimator
