import torch.distributions as dist


def kld_logits(p_logits, q_logits):
    """
    KL divergence between some logits
    """
    return dist.kl_divergence(
        dist.Categorical(logits=p_logits), dist.Categorical(logits=q_logits)
    )


def kld_probs(p_probs, q_probs):
    """
    KL divergence between some probs
    """
    return dist.kl_divergence(
        dist.Categorical(probs=p_probs), dist.Categorical(probs=q_probs)
    )
