import torch
import numpy as np

from gpvae.utils.gaussian_utils import gaussian_diagonal_ll

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.distributions.multivariate_normal import kl_mvn_mvn
from gpytorch.lazy import lazify

__all__ = ['td_estimator', 'pd_estimator', 'analytical_estimator',
           'elbo_estimator', 'conditional_td_estimator']


def td_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                 make_lazy=True):
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(
        x, y, mask)

    # Required distributions.
    # if make_lazy:
    #     qf = MultivariateNormal(qf_mu, lazify(qf_cov))
    #     pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    # else:
    #     qf = MultivariateNormal(qf_mu, qf_cov)
    #     pf = MultivariateNormal(pf_mu, pf_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    pf_var = torch.stack([cov.diag() for cov in pf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log q(f) term.
        # qf_term = qf.log_prob(f).sum()
        qf_term = gaussian_diagonal_ll(f, qf_mu, qf_var)
        estimator += - qf_term.sum()

        # log p(f) term.
        # pf_term = pf.log_prob(f).sum()
        pf_term = gaussian_diagonal_ll(f, pf_mu, pf_var)
        estimator += pf_term.sum()

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def pd_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,
                 make_lazy=True):
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(
        x, y, mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        qf_fixed = MultivariateNormal(qf_mu.detach(), lazify(qf_cov.detach()))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        qf_fixed = MultivariateNormal(qf_mu.detach(), qf_cov.detach())
        pf = MultivariateNormal(pf_mu, pf_cov)

    # Monte-Carlo estimate of ELBO gradient.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log q(f) term.
        qf_term = qf_fixed.log_prob(f).sum()
        estimator += - qf_term

        # log p(f) term.
        pf_term = pf.log_prob(f).sum()
        estimator += pf_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def analytical_estimator(model, x, y, mask=None, num_samples=1,
                         decoder_scale=None, make_lazy=True):
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(
        x, y, mask)

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    # Monte-Carlo estimate of ELBO gradient.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # KL(q(f)|p(f))
    kl_term = kl_mvn_mvn(qf, pf)
    kl_term = kl_term.sum()
    estimator += - kl_term

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def elbo_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True):
    elbo = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(
        x, y, mask)

    # Required distributions.
    if make_lazy:
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    # Monte-Carlo estimate of ELBO.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        elbo += py_f_term

    # Inner summation over samples from q(f).
    elbo /= num_samples

    # KL(q(f)|p(f))
    kl_term = kl_mvn_mvn(qf, pf)
    kl_term = kl_term.sum()
    elbo += - kl_term

    return elbo


def iwae_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True):
    iwae = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(
        x, y, mask)

    # Required distributions.
    if make_lazy:
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        pf = MultivariateNormal(pf_mu, pf_cov)

    weights = torch.ones(num_samples)

    # Importance weighted Monte-Carlo estimate of ELBO.
    for i in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()

        # log p(f) term.
        pf_term = pf.log_prob(f).sum()

        # log q(f) term.
        qf_term = qf.log_prob(f).sum()

        weights[i] = pf_term + py_f_term - qf_term - np.log(num_samples)

    iwae = torch.logsumexp(weights, dim=0)

    return iwae


def conditional_td_estimator(model, x, y, y_c, mask=None, mask_c=None,
                             num_samples=1, decoder_scale=None,
                             make_lazy=True):
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(abs(1 - mask))
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. / (1. - num_nan / num_observations)
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov = model.get_latent_dists(
        x, y_c, mask_c)

    # Required distributions.
    # if make_lazy:
    #     qf = MultivariateNormal(qf_mu, lazify(qf_cov))
    #     pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    # else:
    #     qf = MultivariateNormal(qf_mu, qf_cov)
    #     pf = MultivariateNormal(pf_mu, pf_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    pf_var = torch.stack([cov.diag() for cov in pf_cov])

    # Monte-Carlo estimate of ELBO gradient.
    for _ in range(num_samples):
        f = qf_mu + qf_var ** 0.5 * torch.randn_like(qf_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log q(f) term.
        # qf_term = qf.log_prob(f).sum()
        qf_term = gaussian_diagonal_ll(f, qf_mu, qf_var)
        estimator += - qf_term.sum()

        # log p(f) term.
        # pf_term = pf.log_prob(f).sum()
        pf_term = gaussian_diagonal_ll(f, pf_mu, pf_var)
        estimator += pf_term.sum()

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator
