import torch

def kl_divergence(mu, logsigma, reduction='mean'):
    kl = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp())
    if reduction == 'sum':
        return kl.sum()
    else:
        return kl.sum(dim=-1).mean()