import torch


class KLDivergence:
    def __init__(self):
        pass

    def __call__(self, trained_model_logp, reference_model_logp):

        return (
            torch.exp(reference_model_logp - trained_model_logp)
            - (reference_model_logp - trained_model_logp)
            - 1
        )
