import torch
import torch.nn as nn


def reconstruction_loss(reconstructed, data):
    criterion = nn.MSELoss(reduction='mean')
    reconstructed = reconstructed.view(data.shape[0], -1, data.shape[3])
    data = data.view(data.shape[0], -1, data.shape[3])
    recon_loss = criterion(reconstructed, data)
    return recon_loss


def reconstruction_loss_keypoint(reconstructed, data):
    criterion = nn.MSELoss(reduction='mean')
    total_recon_loss = 0.0
    for i in range(len(reconstructed)):
        reconstructed[i] = reconstructed[i].view(data[i].shape[0], -1, data[i].shape[3])
        data[i] = data[i].view(data[i].shape[0], -1, data[i].shape[3])
        total_recon_loss += criterion(reconstructed[i], data[i])
    return total_recon_loss


def kl_divergence_loss(mean, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return kl_loss
