import torch

from ..utils.gaussian_utils import gaussian_diagonal_ll

from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.lazy import lazify

__all__ = ['analytical_estimator', 'td_estimator', 'elbo_estimator',
           'iwae_estimator']


def analytical_estimator(model, x, y, mask=None, num_samples=None,
                         num_s_samples=1, num_f_samples=1, decoder_scale=None,
                         debug=False, make_lazy=True):
    # For debugging.
    terms = {'py_f_terms': [],
             'lf_sy_terms': [],
             'norm_terms': [],
             'rs_yf_terms': [],
             'qs_terms': []
             }

    if num_samples is not None:
        # Overwrite num_s_samples and num_f_samples.
        num_s_samples = num_samples
        num_f_samples = 1

    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(mask)
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. - num_nan / num_observations
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent GP prior.
    pf_mu, pf_cov = model.get_latent_prior(x)

    # Auxiliary distributions.
    qs_mu, qs_cov = model.get_auxiliary_dists(x, y, mask)
    qs_var = torch.stack([cov.diag() for cov in qs_cov])

    # Monte-Carlo estimate of hierarchical ELBO gradient.
    # See Spatio-Temporal VAEs: Hierarchical VI.
    for _ in range(num_s_samples):
        s = qs_mu + qs_var ** 0.5 * torch.randn_like(qs_mu)

        # Latent distributions.
        qf_s_mu, qf_s_cov, pf_mu, pf_cov, lf_sy_mu, lf_sy_cov = \
            model.get_latent_dists(x, y, s, mask, kff=pf_cov)
        sum_cov = pf_cov + lf_sy_cov

        # Required distributions.
        if make_lazy:
            # Use GPyTorch MultivariateNormal class for sampling.
            zq = MultivariateNormal(lf_sy_mu, lazify(sum_cov))
        else:
            zq = MultivariateNormal(lf_sy_mu, sum_cov)

        lf_sy_var = torch.stack([cov.diag() for cov in lf_sy_cov])
        qf_s_var = torch.stack([cov.diag() for cov in qf_s_cov])

        # log Zq(s) term.
        zq_term = zq.log_prob(torch.zeros_like(lf_sy_mu)).sum()
        estimator += zq_term
        terms['norm_terms'].append(zq_term.item())

        # log q(s) term.
        qs_term = gaussian_diagonal_ll(s, qs_mu, qs_var)
        qs_term = qs_term.sum()
        estimator += - qs_term
        terms['qs_terms'].append(qs_term.item())

        # Inner loop over samples from q(f|s)
        inner_estimator = 0
        for _ in range(num_f_samples):
            f = qf_s_mu + qf_s_var ** 0.5 * torch.randn_like(qf_s_mu)

            # log p(y|f) term.
            py_f_mu, py_f_sigma = model.latent_decoder(f.transpose(0, 1))
            py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2),
                                             mask)
            py_f_term = decoder_scale * py_f_term.sum()
            estimator += py_f_term
            terms['py_f_terms'].append(py_f_term.item())

            # log r(s|y,f) term.
            y_ = torch.cat([y, f.transpose(0, 1)], dim=1)
            mask_ = torch.cat([mask, torch.ones_like(f.transpose(0, 1))],
                              dim=1)
            rs_yf_mu, rs_yf_sigma = model.auxiliary_decoder(y_, mask_)
            rs_yf_term = gaussian_diagonal_ll(s, rs_yf_mu, rs_yf_sigma.pow(2))
            rs_yf_term = rs_yf_term.sum()
            estimator += rs_yf_term
            terms['rs_yf_terms'].append(rs_yf_term.item())

            # log l(f|s,y) term.
            lf_sy_term = gaussian_diagonal_ll(f, lf_sy_mu, lf_sy_var)
            lf_sy_term = lf_sy_term.sum()
            estimator += - lf_sy_term
            terms['lf_sy_terms'].append(lf_sy_term.item())

        # Summation over samples from q(f|s)
        inner_estimator /= num_f_samples

        estimator += inner_estimator

    # Summation over samples from q(s)
    estimator /= num_s_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    if debug:
        return -estimator, terms
    else:
        return - estimator


def td_estimator(model, x, y, mask=None, num_samples=None, num_s_samples=1,
                 num_f_samples=1, decoder_scale=None, debug=False,
                 make_lazy=True):
    # For debugging.
    terms = {'py_f_terms': [],
             'lf_sy_terms': [],
             'norm_terms': [],
             'rs_yf_terms': [],
             'qs_terms': []
             }

    if num_samples is not None:
        # Overwrite num_s_samples and num_f_samples.
        num_s_samples = num_samples
        num_f_samples = 1

    if mask is not None:
        # Scale decoder terms by the reciprocal of the proportion of missing
        # observations.
        if decoder_scale is None:
            num_nan = 1. * torch.sum(mask)
            num_observations = y.shape[0] * y.shape[1]
            decoder_scale = 1. - num_nan / num_observations
    else:
        decoder_scale = 1.

    estimator = 0

    # Latent GP prior.
    pf_mu, pf_cov = model.get_latent_prior(x)

    # Auxiliary distributions.
    qs_mu, qs_cov = model.get_auxiliary_dists(x, y, mask)

    # Required distribution.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        pf = MultivariateNormal(pf_mu, pf_cov)

    qs_var = torch.stack([cov.diag() for cov in qs_cov])

    # Monte-Carlo estimate of hierarchical ELBO gradient.
    for i in range(num_s_samples):
        s = qs_mu + qs_var ** 0.5 * torch.randn_like(qs_mu)

        # Latent distributions.
        qf_s_mu, qf_s_cov, pf_mu, pf_cov, lf_sy_mu, lf_sy_cov = \
            model.get_latent_dists(x, y, s, mask, kff=pf_cov)

        # Required distributions.
        if make_lazy:
            # Use GPyTorch MultivariateNormal class for sampling.
            qf_s = MultivariateNormal(qf_s_mu, lazify(qf_s_cov))
        else:
            qf_s = MultivariateNormal(qf_s_mu, qf_s_cov)

        lf_sy_var = torch.stack([cov.diag() for cov in lf_sy_cov])

        # log q(s) term.
        qs_term = gaussian_diagonal_ll(s, qs_mu, qs_var)
        qs_term = qs_term.sum()
        estimator += - qs_term
        terms['qs_terms'].append(qs_term.item())

        # Inner loop over samples from q(f|s)
        inner_estimator = 0
        for _ in range(num_f_samples):
            f = qf_s.rsample()

            # log p(y|f) term.
            py_f_mu, py_f_sigma = model.latent_decoder(f.transpose(0, 1))
            py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2),
                                             mask)
            py_f_term = decoder_scale * py_f_term.sum()
            estimator += py_f_term
            terms['py_f_terms'].append(py_f_term.item())

            # log r(s|y,f) term.
            y_ = torch.cat([y, f.transpose(0, 1)], dim=1)
            mask_ = torch.cat([mask, torch.ones_like(f.transpose(0, 1))],
                              dim=1)
            rs_yf_mu, rs_yf_sigma = model.auxiliary_decoder(y_, mask_)
            rs_yf_term = gaussian_diagonal_ll(s, rs_yf_mu, rs_yf_sigma.pow(2))
            rs_yf_term = rs_yf_term.sum()
            estimator += rs_yf_term
            terms['rs_yf_terms'].append(rs_yf_term.item())

            # log l(f|s,y) term.
            lf_sy_term = gaussian_diagonal_ll(f, lf_sy_mu.detach(),
                                              lf_sy_var.detach())
            lf_sy_term = lf_sy_term.sum()
            estimator += - lf_sy_term
            terms['lf_sy_terms'].append(lf_sy_term.item())

            # log p(f) terms.
            pf_term = pf.log_prob(f.detach()).sum()
            estimator += pf_term
            terms['norm_terms'].append(pf_term.item())

        # Summation over samples from q(f|s).
        inner_estimator /= num_f_samples

        estimator += inner_estimator

    # Summation over samples from q(s).
    estimator /= num_s_samples

    # Outer summation over batch.
    estimator /= x.shape[0]

    if debug:
        return -estimator, terms
    else:
        return - estimator


def elbo_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True):
    elbo = 0

    # Latent GP prior.
    pf_mu, pf_cov = model.get_latent_prior(x)

    # Auxiliary distributions.
    qs_mu, qs_cov = model.get_auxiliary_dists(x, y, mask)
    qs_var = torch.stack([cov.diag() for cov in qs_cov])

    # Monte-Carlo estimate of ELBO.
    # See Spatio-Temporal VAEs: Hierarchical VI.
    for _ in range(num_samples):
        s = qs_mu + qs_var ** 0.5 * torch.randn_like(qs_mu)

        # Latent distributions.
        qf_s_mu, qf_s_cov, pf_mu, pf_cov, lf_sy_mu, lf_sy_cov = \
            model.get_latent_dists(x, y, s, mask, kff=pf_cov)
        sum_cov = pf_cov + lf_sy_cov

        # Required distributions.
        if make_lazy:
            # Use GPyTorch MultivariateNormal class for sampling.
            zq = MultivariateNormal(lf_sy_mu, lazify(sum_cov))
        else:
            zq = MultivariateNormal(lf_sy_mu, sum_cov)

        lf_sy_var = torch.stack([cov.diag() for cov in lf_sy_cov])
        qf_s_var = torch.stack([cov.diag() for cov in qf_s_cov])

        # log Zq(s) term.
        zq_term = zq.log_prob(torch.zeros_like(lf_sy_mu)).sum()
        elbo += zq_term

        # log q(s) term.
        qs_term = gaussian_diagonal_ll(s, qs_mu, qs_var)
        qs_term = qs_term.sum()
        elbo += - qs_term

        # Sample from latent posterior distribution.
        f = qf_s_mu + qf_s_var ** 0.5 * torch.randn_like(qf_s_mu)

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.latent_decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()
        elbo += py_f_term

        # log r(s|y,f) term.
        y_ = torch.cat([y, f.transpose(0, 1)], dim=1)
        mask_ = torch.cat([mask, torch.ones_like(f.transpose(0, 1))],
                          dim=1)
        rs_yf_mu, rs_yf_sigma = model.auxiliary_decoder(y_, mask_)
        rs_yf_term = gaussian_diagonal_ll(s, rs_yf_mu, rs_yf_sigma.pow(2))
        rs_yf_term = rs_yf_term.sum()
        elbo += rs_yf_term

        # log l(f|s,y) term.
        lf_sy_term = gaussian_diagonal_ll(f, lf_sy_mu, lf_sy_var)
        lf_sy_term = lf_sy_term.sum()
        elbo += - lf_sy_term

    # Summation over samples from q(f,s|y)
    elbo /= num_samples

    return elbo


def iwae_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True):
    # Latent GP prior.
    pf_mu, pf_cov = model.get_latent_prior(x)

    # Auxiliary distributions.
    qs_mu, qs_cov = model.get_auxiliary_dists(x, y, mask)
    qs_var = torch.stack([cov.diag() for cov in qs_cov])

    # Required distributions.
    if make_lazy:
        # Use GPyTorch MultivariateNormal class for sampling.
        pf = MultivariateNormal(pf_mu, lazify(pf_cov))
    else:
        pf = MultivariateNormal(pf_mu, pf_cov)

    weights = torch.ones(num_samples)

    # Importance weighted Monte-Carlo estimate of ELBO.
    for i in range(num_samples):
        s = qs_mu + qs_var ** 0.5 * torch.randn_like(qs_mu)

        # Latent distributions.
        qf_s_mu, qf_s_cov, pf_mu, pf_cov, lf_sy_mu, lf_sy_cov = \
            model.get_latent_dists(x, y, s, mask, kff=pf_cov)

        # Required distributions.
        if make_lazy:
            # Use GPyTorch MultivariateNormal class for sampling.
            qf_s = MultivariateNormal(qf_s_mu, lazify(qf_s_cov))
        else:
            qf_s = MultivariateNormal(qf_s_mu, qf_s_cov)

        # Sample from latent posterior distribution.
        f = qf_s.rsample()

        # log p(y|f) term.
        py_f_mu, py_f_sigma = model.latent_decoder(f.transpose(0, 1))
        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)
        py_f_term = py_f_term.sum()

        # log r(s|y,f) term.
        y_ = torch.cat([y, f.transpose(0, 1)], dim=1)
        mask_ = torch.cat([mask, torch.ones_like(f.transpose(0, 1))],
                          dim=1)
        rs_yf_mu, rs_yf_sigma = model.auxiliary_decoder(y_, mask_)
        rs_yf_term = gaussian_diagonal_ll(s, rs_yf_mu, rs_yf_sigma.pow(2))
        rs_yf_term = rs_yf_term.sum()

        # log p(f) term.
        pf_term = pf.log_prob(f).sum()

        # log q(s) term.
        qs_term = gaussian_diagonal_ll(s, qs_mu, qs_var)
        qs_term = qs_term.sum()

        # log q(f|s) term.
        qf_s_term = qf_s.log_prob(f).sum()

        weights[i] = pf_term + py_f_term + rs_yf_term - qs_term - qf_s_term

    iwae = torch.logsumexp(weights, dim=0)

    return iwae
