

import torch

def attn_mask_from_radius(x, y, thresh, unsqueeze_head_dim=True):
    """
    Compute a boolean attention mask based on pairwise Euclidean distance using torch.cdist.

    Args:
        x (torch.Tensor): Matrix of shape (..., n, d) where n is the number of vectors and d is the dimension.
        y (torch.Tensor): Matrix of shape (..., m, d) where m is the number of vectors and d is the dimension.
        thresh (float): Distance threshold.

    Returns:
        torch.BoolTensor: Boolean mask of shape (..., n, m) where True means distance <= thresh.
    """
    # Compute pairwise Euclidean distance
    dist = torch.cdist(x, y, p=2)  # (..., n, m)

    # Boolean mask: True if distance <= thresh, False otherwise
    mask = dist <= thresh

    if unsqueeze_head_dim:
        mask = mask.unsqueeze(1)

    return mask
