
import torch
import sklearn
import torch.nn.functional as F

def pairwise_distance(a, squared=False):
    '''
        Computes the pairwise distance matrix with numerical stability
        :param a: torch.Tensor (M, sz_embedding)
        :param squared: if True, will compute (euclidean_dist)^2
        :return pairwise_distances: distance torch.Tensor (M, M)
    '''

    a_norm = F.normalize(a, p=2, dim=-1)
    inner_prod = torch.mm(a_norm, torch.t(a_norm))

    pairwise_distances_squared = torch.add(
        a.pow(2).sum(dim=1, keepdim=True).expand(a.size(0), -1),
        torch.t(a).pow(2).sum(dim=0, keepdim=True).expand(a.size(0), -1)
    ) - 2 * (torch.mm(a, torch.t(a))) # compute euclidean distance in dot-product way, since ||x-y||^2 = x'x - 2x'y + y'y

    # Deal with numerical inaccuracies. Set small negatives to zero.
    pairwise_distances_squared = torch.clamp(
        pairwise_distances_squared, min=0.0
    )

    # Get the mask where the zero distances are at.
    error_mask = torch.le(pairwise_distances_squared, 0.0)

    # Optionally take the sqrt.
    if squared:
        pairwise_distances = pairwise_distances_squared
    else:
        pairwise_distances = torch.sqrt(
            pairwise_distances_squared + error_mask.float() * 1e-16
        )

    # Undo conditionally adding 1e-16.
    pairwise_distances = torch.mul(
        pairwise_distances,
        (error_mask == False).float()
    )

    # Explicitly set diagonals to zero since it is the distance to itself
    mask_offdiagonals = 1 - torch.eye(
        *pairwise_distances.size(),
        device=pairwise_distances.device
    )
    pairwise_distances = torch.mul(pairwise_distances, mask_offdiagonals)

    return pairwise_distances, inner_prod



