import torch
import torch.nn as nn


def reconstruction_loss_mpjpe(reconstructed, data):
    mask = data.sum(dim=-1).mean(dim=-1) != 0
    num_elements = mask.sum()
    diff = reconstructed - data # (B, T, J, 3)
    joint_distances = torch.norm(diff, dim=-1) # (B, T, J)
    mean_joint_distances = torch.mean(joint_distances, dim=2)
    non_pad_mean_joint_distances = mean_joint_distances * mask
    sum_along_T = non_pad_mean_joint_distances.sum(dim=1)
    mpjpe = sum_along_T.sum(dim=0) / num_elements

    return mpjpe


def reconstruction_loss_keypoint(reconstructed, data):
    criterion = nn.MSELoss(reduction='mean')
    total_recon_loss = 0.0
    for i in range(len(reconstructed)):
        reconstructed[i] = reconstructed[i].view(data[i].shape[0], -1, data[i].shape[3])
        data[i] = data[i].view(data[i].shape[0], -1, data[i].shape[3])
        total_recon_loss += criterion(reconstructed[i], data[i])
    return total_recon_loss


def kl_divergence_loss(mean, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return kl_loss
