import torch.nn as nn

from . import functional as F

__all__ = ["KLLoss"]


class KLLoss(nn.Module):
    def forward(self, x, y):
        return F.kl_loss(x, y)
