import torch.nn.functional as func
import torch



def loss_function(recon_x, x):
    bce = func.mse_loss(recon_x, x, reduction='sum')
    return bce


def loss_function_sumless(recon_x, x):
    bce = torch.sum(func.mse_loss(recon_x, x, reduction='none'), dim=(1, 2, 3))
    return bce