import torch
import numpy as np

from ..utils.gaussian_utils import gaussian_diagonal_ll, gaussian_diagonal_kl

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.distributions.multivariate_normal import kl_mvn_mvn
from gpytorch.lazy import lazify


def analytical_estimator(model, x, y, mask=None, num_samples=1,
                         decoder_scale=None, make_lazy=True, mf=False,
                         idx=None):
    """Estimates the negative ELBO using analytical results were possible,
    and the reparameterisation trick for the decoder term for models with an
    encoder/decoder architecture and GP prior over latent variables and
    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(x, y, mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # KL term.
    kl_term = kl_mvn_mvn(pf, qf)
    estimator -= kl_term.sum()

    # Outer summation over batch
    estimator /= x.shape[0]

    return - estimator


def elbo_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True,
                   mf=False, idx=None):
    """Estimates the ELBO using analytical results were possible for the
    GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    with.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    elbo = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(x, y, mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])

    # Monte-Carlo estimate of ELBO.
    # See Spatio-Temporal VAEs: ELBO
    for i in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        elbo += py_f_term

    # Inner summation over samples from q(f).
    elbo /= num_samples

    # KL term.
    kl_term = kl_mvn_mvn(pf, qf)
    elbo -= kl_term.sum()

    return elbo
