import torch

from ltsgns_mp.architectures.util.chamfer_distance import padded_chamfer_distance, \
    reshape_tensor_for_padded_chamfer_distance


def log_likelihood_per_time_step(predictions: torch.Tensor, gth: torch.Tensor, likelihood_std: float, gth_type: str):
    if gth_type == "mesh":
        log_likelihood_per_node = - 0.5 * torch.sum((predictions - gth) ** 2 / likelihood_std ** 2, dim=-1)
        # mean over nodes
        log_likelihood_per_time_step = torch.mean(log_likelihood_per_node, dim=-1)
    elif gth_type == "point_cloud":
        # reshape to single batch dim of the padded point cloud
        predictions, gth, batch_shape = reshape_tensor_for_padded_chamfer_distance(predictions, gth)
        chamfer_loss = padded_chamfer_distance(predictions, gth, density_aware=False, forward_only=True,
                                               point_reduction="mean")
        # undo the reshape
        chamfer_loss = chamfer_loss.reshape(chamfer_loss.shape[0], *batch_shape)
        # make a likelihood out of the distance measurement
        log_likelihood_per_time_step = - 0.5 * chamfer_loss / (likelihood_std ** 2)
    else:
        raise NotImplementedError(f"Unknown ground truth type: {gth_type}")

    return log_likelihood_per_time_step
