import logging

import torch
import torch.distributions as tdist
from torch import Tensor

logger = logging.getLogger('custom')


def get_mc_estimate(p, q, samples) -> Tensor:
    """ Get monte-carlo estimate for kl-divergence KL(q||p).

    When samples are from q (as in a VAE):
        KL = E_q [ log q / p ]

    :param p: prior if in a VAE setting
    :param q: posterior if in a VAE setting
    :param samples: samples from q
    :return:
    """
    q = q.log_prob(samples)
    p = p.log_prob(samples)
    kl = q - p
    return kl


def reduce_kl(kl) -> torch.Tensor:
    """ Sums over dimensions.
    :return: one tensor that summarizes KL-regularization
        - shape (N,) or K x N
    """
    return sum_over_dims(kl)


def sum_over_dims(x):
    """ Recognizes shapes and then sums over dimensions (and not over batch
    dimensions).

    :param x: supported shapes
        - N x D
        - K x N x D
        - N x C x H x W
        - K x N x C x H x w
    :return: shape (N,) or K x N
    """
    length = len(x.size())
    if length == 2 or length == 3:
        # flat latent space: N x D or K x N x D
        x = x.sum(-1)
    elif length == 4 or length == 5:
        # spatial latent space with shape N x C x H x W
        # or K x N x C x H x W
        x = x.sum((-1, -2, -3))
    else:
        raise ValueError('Illegal shape.')
    return x


def accumulate_distributions(old: dict, new: dict):
    """ Iteratively collects values into nested dictionary structures.

    Assumes that both dictionaries have identical structure. This
    function is helpful for iteratively collecting forward pass values.
    """
    # first unpack old dictionary
    for k1, v1 in old.items():
        for k2, v2 in v1.items():
            # then accumulate values in "new" dictionary
            _unpack_multiple_distributions(k1, k2, v2, new)
    return new


def _unpack_multiple_distributions(k1, k2, v2, new):
    if not isinstance(v2, list):
        # reconstruction
        _unpack_single_distribution(v2, new[k1][k2])
    else:
        for i, v3 in enumerate(v2):
            if not v3:
                # unconditional prior
                continue
            # posterior, conditional priors, and ancestral samples
            _unpack_single_distribution(v3, new[k1][k2][i])


def _unpack_single_distribution(old, new):
    """ Concatenates leaf nodes on CPU. """
    # samples: concatenate along importance sampling dimension
    if old['samples'] is not None:
        # we do not use samples for categorical distributions
        c = torch.cat((old['samples'].cpu(), new['samples'].cpu()), dim=1)
        new['samples'] = c

    # distribution
    if old['dist'] is not None:
        if isinstance(old['dist'], tdist.Categorical):
            new['dist'] = _accumulate_categorical(old['dist'], new['dist'])
        elif isinstance(old['dist'], tdist.Normal):
            new['dist'] = _accumulate_normal(old['dist'], new['dist'])
        elif isinstance(old['dist'], tdist.Bernoulli):
            new['dist'] = _accumulate_bernoulli(old['dist'], new['dist'])
        elif isinstance(old['dist'], tdist.Laplace):
            new['dist'] = _accumulate_laplace(old['dist'], new['dist'])


def _accumulate_categorical(old_dist, new_dist):
    old_logits = old_dist.logits.cpu()
    new_logits = new_dist.logits.cpu()
    dim = _get_batch_dimension(old_logits)
    return tdist.Categorical(logits=torch.cat((old_logits, new_logits),
                                              dim=dim))


def _accumulate_normal(old_dist, new_dist):
    m, s = old_dist.mean.cpu(), old_dist.scale.cpu()
    mn, sn = new_dist.mean.cpu(), new_dist.scale.cpu()
    dim = _get_batch_dimension(m)
    return tdist.Normal(loc=torch.cat((m, mn), dim=dim),
                        scale=torch.cat((s, sn), dim=dim))


def _accumulate_bernoulli(old_dist, new_dist):
    old_logits = old_dist.logits.cpu()
    new_logits = new_dist.logits.cpu()
    dim = _get_batch_dimension(old_logits)
    return tdist.Bernoulli(logits=torch.cat((old_logits, new_logits),
                                            dim=dim))


def _accumulate_laplace(old_dist, new_dist):
    m, s = old_dist.mean.cpu(), old_dist.scale.cpu()
    mn, sn = new_dist.mean.cpu(), new_dist.scale.cpu()
    dim = _get_batch_dimension(m)
    return tdist.Laplace(loc=torch.cat((m, mn), dim=dim),
                         scale=torch.cat((s, sn), dim=dim))


def _get_batch_dimension(x):
    if len(x.size()) == 3 or len(x.size()) == 5:
        # importance sampling dimension
        dim = 1
    elif len(x.size()) == 2 or len(x.size()) == 4:
        # no importance sampling dimension
        dim = 0
    else:
        raise ValueError('Invalid input size.')
    return dim
