import torch
import pdb

from ..utils.gaussian_utils import gaussian_diagonal_ll

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.kl import kl_divergence
from gpytorch.lazy import lazify

__all__ = ['sa_estimator', 'td_estimator', 'pd_estimator', 'elbo_estimator',
           'elbo_estimator2']


def sa_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                 make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using analytical results
    where possible for the SparseGPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions for KL divergence.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pu = MultivariateNormal(pu_mu, lazify(pu_cov))
        qu = MultivariateNormal(qu_mu, lazify(qu_cov))
    else:
        pu = MultivariateNormal(pu_mu, pu_cov)
        qu = MultivariateNormal(qu_mu, qu_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log KL(q(u)||p(u)) term.
    estimator -= kl_divergence(qu, pu).sum()

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def td_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                 make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the total
    derivative estimator for the SparseGPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions for KL divergence.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pu = MultivariateNormal(pu_mu, lazify(pu_cov))
        qu = MultivariateNormal(qu_mu, lazify(qu_cov))
    else:
        pu = MultivariateNormal(pu_mu, pu_cov)
        qu = MultivariateNormal(qu_mu, qu_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)
        u = qu.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()

        # log q(u) term.
        qu_term = qu.log_prob(u).sum()

        # log p(u) term.
        pu_term = pu.log_prob(u).sum()

        estimator += py_f_term - qu_term + pu_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def pd_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                 make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the path derivative
    estimator analytical results for the SparseGPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions for KL divergence.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pu = MultivariateNormal(pu_mu, lazify(pu_cov))
        qu = MultivariateNormal(qu_mu, lazify(qu_cov))
    else:
        pu = MultivariateNormal(pu_mu, pu_cov)
        qu = MultivariateNormal(qu_mu, qu_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)
        u = qu.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()

        # log q(u) term.
        qu_term = qu.log_prob(u).sum()

        # log p(u) term.
        pu_term = pu.log_prob(u).sum()

        estimator += py_f_term - qu_term + pu_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def elbo_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True,
                    mf=False, idx=None):
    """Estimates the negative ELBO using analytical results were possible,
    and the reparameterisation trick for the decoder term for models with an
    encoder/decoder architecture and GP prior over latent variables and
    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    with.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions for KL divergence.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pu = MultivariateNormal(pu_mu, lazify(pu_cov))
        qu = MultivariateNormal(qu_mu, lazify(qu_cov))
    else:
        pu = MultivariateNormal(pu_mu, pu_cov)
        qu = MultivariateNormal(qu_mu, qu_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log KL(q(u)||p(u)) term.
    try:
        estimator -= kl_divergence(qu, pu).sum()
    except:
        pdb.set_trace()
        print('wtf')

    return estimator


def elbo_estimator2(model, x, y, mask=None, num_samples=1, make_lazy=True,
                    mf=False, idx=None):
    """Estimates the ELBO using Monte-Carlo estimates were possible for the
    SGP-VAE.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    with.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, qu_mu, qu_cov, pu_mu, pu_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    # Required distributions for KL divergence.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pu = MultivariateNormal(pu_mu, lazify(pu_cov))
        qu = MultivariateNormal(qu_mu, lazify(qu_cov))
    else:
        pu = MultivariateNormal(pu_mu, pu_cov)
        qu = MultivariateNormal(qu_mu, qu_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)
        u = qu.sample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()

        # log q(u) term.
        qu_term = qu.log_prob(u).sum()

        # log p(u) term.
        pu_term = pu.log_prob(u).sum()

        estimator += py_f_term - qu_term + pu_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    return estimator
