import torch

#https://github.com/facebookresearch/spreadingvectors/blob/main/train.py#L25
def pairwise_NNs_inner(x):
    """
    Pairwise nearest neighbors for L2-normalized vectors.
    Uses Torch rather than Faiss to remain on GPU.
    """
    # parwise dot products (= inverse distance)
    dots = torch.mm(x, x.t())
    n = x.shape[0]
    dots.view(-1)[::(n+1)].fill_(-1)  # Trick to fill diagonal with -1
    _, I = torch.max(dots, 1)  # max inner prod -> min distance
    return I

#Infer the detector on batch_size non marked images, and ensure the vectors, and so the pfas, are well-behaved on the hypersphere, meaning uniformly distributed 
####detector
pdist = nn.PairwiseDistance(2)
preds = detector_batched(imgs_batch)

for ins_vector in preds:
    #ins_vector is the anchor
    #ins_vector = detector(im) ###careful about normalization
    nearest_neighbour_ins = pairwise_NNs_inner(ins_vector)
    distances = pdist(ins_vector, nearest_neighbour_ins)
    loss_uniform = - torch.log(n * distances).mean() #n is the number of samples in batch
