import torch


def batch_l2_distance(a: torch.Tensor, b: torch.Tensor, last_dimensions=1):
    """

    Args:
        a:
        b:
        last_dimensions:

    Returns:

    """
    vectors_ab = (a.flatten(-last_dimensions) - b.flatten(-last_dimensions))
    distances_ab = vectors_ab.pow(2).sum(dim=-1).sqrt()
    return distances_ab
