import torch

from ..utils.gaussian_utils import gaussian_diagonal_ll

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.lazy import lazify

__all__ = ['ds_sf_estimator', 'sa_sf_estimator', 'ds_pd_estimator',
           'sa_pd_estimator']


def ds_sf_estimator(model, x, y, mask=None, num_samples=1,
                    decoder_scale=None, make_lazy=True):
    """Estimates the negative ELBO using the weighted score function estimator
    for the GPVAE .

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(mask)
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. - num_nan / num_observations
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = model.get_latent_dists(
        x, y, mask)

    sum_cov = qf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        #         pf = MultivariateNormal(pf_mu, lazify(pf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        #         pf = MultivariateNormal(pf_mu, pf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    #     pf_cov_inv = torch.inverse(pf_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.sample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        estimator += py_f_term.sum() * decoder_scale

        # log l(f|y) term.
        lf_y_term = gaussian_diagonal_ll(f.T, lf_y_mu.T, lf_y_var.T)

        # log q(f) term.
        qf_term = gaussian_diagonal_ll(f.T, qf_mu.T, qf_var.T)
        estimator += (
                    qf_term * (py_f_term.detach() - lf_y_term.detach())).sum()

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log p(f) term.
    #     pf_term = 0
    #     for _ in range(10):
    #         f = qf.sample()
    #         pf_term += pf.log_prob(f).sum()

    #     pf_term /= 10
    #     estimator += pf_term

    # E[\nabla log p(f)] = \nabla log Zq - E[\nabla log l(f|y)]
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    lf_y_term_ = (-0.5 * lf_y_var.log()
                  - (0.5 / lf_y_var) * (
                              qf_var + (qf_mu - lf_y_mu) ** 2).detach()
                  + (1 / lf_y_var.detach()) * lf_y_mu * (
                              qf_mu - lf_y_mu).detach())
    lf_y_term_ = lf_y_term_.sum()

    pf_term = zq_term - lf_y_term_
    estimator += pf_term

    # log p(f) term - doesn't work.
    #     pf_logdet_term = 0.5 * (pf_cov_inv.detach().matmul(
    #     pf_cov)).diagonal(dim1=-2, dim2=-1).sum(-1)
    #     pf_exp_term_1 = 0.5 * (qf_cov.detach().matmul(
    #     pf_cov_inv)).diagonal(dim1=-2, dim2=-1).sum(-1)
    #     pf_exp_term_2 =  0.5 * qf_mu.detach().unsqueeze(2).transpose(-1,
    #     -2).matmul(pf_cov_inv.matmul(qf_mu.detach().unsqueeze(2)))
    #     pf_term = - pf_logdet_term.sum() - pf_exp_term_1.sum() -
    #     pf_exp_term_2.sum()
    #     estimator += pf_term

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def sa_sf_estimator(model, x, y, mask=None, num_samples=1,
                    decoder_scale=None, make_lazy=True):
    """Estimates the negative ELBO using the semi-analytic weighted score
    function estimator for the GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(mask)
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. - num_nan / num_observations
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = model.get_latent_dists(
        x, y, mask)

    sum_cov = pf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.sample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        estimator += py_f_term.sum() * decoder_scale

        # log q(f) term.
        qf_term = gaussian_diagonal_ll(f.T, qf_mu.T, qf_var.T)
        estimator += (qf_term * py_f_term.detach()).sum()

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log l(f|y) term.
    lf_y_term = gaussian_diagonal_ll(qf_mu, lf_y_mu, lf_y_var).sum()
    lf_y_term += - 0.5 * (qf_var / lf_y_var).sum()
    estimator += - lf_y_term

    # log Zq term.
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    estimator += zq_term

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def ds_pd_estimator(model, x, y, mask=None, num_samples=1,
                      decoder_scale=None, make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the full path
    derivative estimator for the GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(mask)
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. - num_nan / num_observations
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    sum_cov = qf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        #         pf = MultivariateNormal(pf_mu, lazify(pf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        #         pf = MultivariateNormal(pf_mu, pf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

        # log l(f|y) term.
        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu.detach(),
                                         lf_y_var.detach())
        lf_y_term = lf_y_term.sum()
        estimator += - lf_y_term

    # log p(f) term.
    #     pf_term = 0
    #     for _ in range(10):
    #         f = qf.sample()
    #         pf_term += pf.log_prob(f).sum()

    #     pf_term /= 10
    #     estimator += pf_term

    # E[\nabla log p(f)] = \nabla log Zq - E[\nabla log l(f|y)]
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    lf_y_term_ = (-0.5 * lf_y_var.log()
                  - (0.5 / lf_y_var) * (
                              qf_var + (qf_mu - lf_y_mu) ** 2).detach()
                  + (1 / lf_y_var.detach()) * lf_y_mu * (
                              qf_mu - lf_y_mu).detach())
    lf_y_term_ = lf_y_term_.sum()

    pf_term = zq_term - lf_y_term_
    estimator += pf_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator


def sa_pd_estimator(model, x, y, mask=None, num_samples=1,
                    decoder_scale=None, make_lazy=True, mf=False, idx=None):
    """Estimates the gradient of the negative ELBO using the semi-analytic
    path derivative estimator for the GPVAE model.

    :param model: A nn.Module, the model to evaluate on.
    :param x: A torch.Tensor, the input data.
    :param y: A torch.Tensor, the output data.
    :param mask: A torch.Tensor, the mask to apply to the output data.
    :param num_samples: An int, the number of samples to estimate the ELBO
    gradient with.
    :param decoder_scale: None or a float, the amount by which to scale the
    decoder term, p(y|f), by. Relevant in the presence of missing values.
    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal
    class for handling multivariate Gaussians.
    :param mf: A bool, whether to model uses mean-field variational
    inference or not.
    :param idx: A torch.Tensor, the data indeces.
    """
    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(mask)
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. - num_nan / num_observations
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent distributions.
    if mf:
        # Pass mean-field models the data indeces.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, idx=idx)
    else:
        # Pass amortisation models the observation data.
        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \
            model.get_latent_dists(x, y, mask)

    sum_cov = pf_cov + lf_y_cov

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        qf = MultivariateNormal(qf_mu, lazify(qf_cov))
        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))
    else:
        qf = MultivariateNormal(qf_mu, qf_cov)
        zq = MultivariateNormal(lf_y_mu, sum_cov)

    qf_var = torch.stack([cov.diag() for cov in qf_cov])
    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])

    # Monte-Carlo estimate of ELBO gradient.
    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.
    for _ in range(num_samples):
        f = qf.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = decoder_scale * py_f_term.sum()
        estimator += py_f_term

    # Inner summation over samples from q(f).
    estimator /= num_samples

    # log l(f|y) term.
    lf_y_term = gaussian_diagonal_ll(qf_mu, lf_y_mu, lf_y_var).sum()
    lf_y_term += - 0.5 * (qf_var / lf_y_var).sum()
    estimator += - lf_y_term

    # log Zq term.
    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()
    estimator += zq_term

    # Outer summation over batch.
    estimator /= x.shape[0]

    return - estimator
