import torch

def calculate_embedding_distance(e1, e2):
    return torch.norm(e1 - e2, p=2)

def calculate_score_field_distance(s1, s2, sqrt_num_points):
    return torch.norm(s1 - s2, p=2) / sqrt_num_points

def calculate_token_level_distance(tl1, tl2, normalize=True):
    mask1 = tl1.norm(p=2, dim=1) > 0
    mask2 = tl2.norm(p=2, dim=1) > 0

    tl1 = tl1[mask1]
    tl2 = tl2[mask2]

    L1, L2 = tl1.shape[0], tl2.shape[0]
    D = tl1.shape[1]
    L = max(L1, L2)

    if L1 < L:
        pad = torch.zeros(L - L1, D, device=tl1.device, dtype = tl1.dtype)
        tl1 = torch.cat([tl1, pad], dim=0)
    if L2 < L:
        pad = torch.zeros(L - L2, D, device=tl2.device, dtype = tl2.dtype)
        tl2 = torch.cat([tl2, pad], dim=0)

    diff = tl1 - tl2
    dist = torch.norm(diff, p=2)

    if normalize:
        dist = dist / torch.sqrt(torch.tensor(L, device=dist.device, dtype=dist.dtype))
    return dist
    