import torch

def calc_alignment(representation, label):
    n_samples = len(representation)

    representation = torch.nn.functional.normalize(representation, dim=1) 

    # Calculate distance of label
    ro = representation.reshape(1, n_samples, -1)
    rt = representation.reshape(n_samples, 1, -1)
    rep = (ro - rt).pow(2).sum(dim=2)

    # Calculate label similarity
    lo = label.reshape(1, n_samples, -1)
    lt = label.reshape(n_samples, 1, -1)
    s = (lo - lt).abs().mean(dim=2)
    del lo
    del lt
    M = s.max()
    m = s.min()
    width = M - m
    s = torch.ones_like(s) - (s - m)/width

    # Calculate mask
    mask = ~torch.eye(n_samples, device=representation.device).bool()
    masked_rep = rep.masked_select(mask).view(n_samples, -1)
    masked_s = s.masked_select(mask).view(n_samples, -1)

    # Calculate alignment
    numerator = (masked_s * masked_rep).sum(dim=1)
    denominator = masked_s.sum(dim=1)
    return (numerator / denominator).mean().item()

def calc_uniformity(representation):
    sq_pdist = torch.pdist(representation, p=2).pow(2)
    return sq_pdist.mul(-2).exp().mean().log().item()

def main():
    pass

if __name__ == "__main__":
    main()