import torch

def l2_distance(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:

    X_norm_sq = (X ** 2).sum(dim=1).unsqueeze(1)  # (N, 1)
    Y_norm_sq = (Y ** 2).sum(dim=1).unsqueeze(0)  # (1, M)
    cross_term = X @ Y.T  # (N, M)

    dist_sq = X_norm_sq + Y_norm_sq - 2 * cross_term
    dist_sq = torch.clamp(dist_sq, min=0.0)  
    return torch.sqrt(dist_sq)
