import torch


def kl_divergence(logp, logq):
    # p, q are probabilities according to p and q for samples taken from p
    return torch.mean(logp - logq)


def reverse_kl_divergence(logp, logq):
    # logp, logq are probabilities according to p and q for samples taken from q
    return torch.mean(-logp + logq)


def squared_hellinger_distance(logp, logq):
    # logp, logq are probabilities according to p and q for samples taken from q
    p = logp.exp()
    q = logq.exp()
    t = p / q
    return torch.mean(2 * (1 - torch.sqrt(t)))


def total_variation_distance(logp, logq):
    # logp, logq are probabilities according to p and q for samples taken from q
    p = logp.exp()
    q = logq.exp()
    t = p / q
    return torch.mean(0.5 * torch.abs(t - 1))


def chi_squared_divergence(logp, logq):
    # logp, logq are probabilities according to p and q for samples taken from q
    p = logp.exp()
    q = logq.exp()
    t = p / q
    return torch.mean(t ** 2 - 1)


def alpha_divergence(logp, logq, alpha=0.0):
    # logp, logq are probabilities according to p and q for samples taken from q
    if alpha == 1:
        raise NotImplementedError("Use kl_divergence instead.")
    elif alpha == -1:
        raise NotImplementedError("Use reverse_kl_divergence instead.")
    if not -1 <= alpha <= 1:
        raise NotImplementedError
    p = logp.exp()
    q = logq.exp()
    t = p / q
    return torch.mean(4 / (1 - alpha ** 2) * (1 - t ** ((1 - alpha) / 2)))


def jensen_shannon_divergence(logp, logq):
    # logp, logq are probabilities according to p and q for samples taken from q
    p = logp.exp()
    q = logq.exp()
    t = p / q
    return torch.mean(1 / 2 * ((t + 1) * torch.log(2 / (t + 1)) + t * torch.log(t)))


@torch.no_grad()
def ess(log_likelihood: torch.Tensor, log_prior: torch.Tensor, logq: torch.Tensor, beta=1.0):
    tmp = log_likelihood + log_prior - logq
    return torch.exp(2 * torch.logsumexp(beta * tmp, dim=0) - torch.logsumexp(2.0 * beta * tmp, dim=0))
