from __future__ import print_function
import torch
import torch.utils.data
from torch import nn

CELloss = nn.CrossEntropyLoss(reduction='sum')


def loss_functionM(M, target, mu, logvar):
    BCE = CELloss(M, target.long())
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE, KLD


def loss_functionL(L, target, mu, logvar):
    NGLL = torch.mean((target - L).pow(2))
    KLD = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
    return NGLL, KLD


def loss_functionR(R, target, mu, logvar):
    NGLL = torch.mean((target - R).pow(2))
    KLD = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
    return NGLL, KLD


