import torch


def mse(predictions: torch.Tensor, labels: torch.Tensor, ptr: torch.Tensor | None = None) -> torch.Tensor:
    """
    Computes the mean squared error of the predicted positions given the internal next mesh positions.
    """
    if ptr is None:
        loss = torch.mean((predictions - labels) ** 2)
    else:
        # use scatter mean to compute the mse for each batch element separately
        loss = torch.scatter(src=(predictions - labels) ** 2, dim=0, index=ptr, reduce="mean")
    return loss