import torch


class DistanceLoss(torch.nn.Module):

    def __init__(self, upper_bound=10.0, lower_bound=5.0):
        super().__init__()
        self.upper_bound = upper_bound
        self.lower_bound = lower_bound

    def forward(self, activations):
        pair_distances = torch.cdist(activations, activations)
        mask = ~torch.eye(activations.shape[0]).bool()
        pair_distances = pair_distances.masked_select(mask)
        pos_loss = pair_distances[pair_distances >= self.upper_bound].mean()
        pos_loss = torch.nan_to_num(pos_loss, nan=0.0)
        neg_loss = pair_distances[pair_distances < self.lower_bound].mean()
        neg_loss = torch.nan_to_num(neg_loss, nan=0.0)
        return pos_loss - neg_loss