import torch


def kl_div(q: torch.distributions.Distribution, p: torch.distributions.Distribution) -> torch.Tensor:
    return torch.distributions.kl_divergence(q, p).mean()
