import torch
import torch.nn.functional as F


def siglip_loss(logits):
    """
      Implement the SigmoidLIP loss proposed in the following paper
      https://arxiv.org/pdf/2303.15343
      (Sigmoid Loss for Language Image Pre-Training)
    """
    batch_size = logits.size(0)
    # -1 for off-diagonals and 1 for diagonals
    labels = 2 * torch.eye(batch_size, device=logits.device) - 1
    # pairwise sigmoid loss
    return -torch.sum(F.logsigmoid(labels * logits))
