import torch


def logsumexp_mse(predictions: torch.Tensor, labels: torch.Tensor, logsumexp_dim=0) -> torch.Tensor:
    """
    Computes the logsumexp of the mean squared error of the predicted positions given the labels.
    """
    loss = torch.mean((predictions - labels) ** 2, dim=tuple(dim for dim in range(predictions.dim()) if dim != logsumexp_dim))
    loss = torch.logsumexp(loss, dim=0)
    return loss
