import torch
import torch.nn.functional as F

from mow.common.tensor import split_views


def contrastive_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    temperature: float = 0.1,
    lambda_: float = 0.0,
) -> torch.Tensor:
    """
    Compute the loss between the logits and labels.
    """
    logits_1, logits_2 = split_views(logits, num_views=2)
    labels_1, labels_2 = split_views(labels, num_views=2)

    # (batch_size, batch_size)
    sims = F.cosine_similarity(
        logits_1.unsqueeze(1), logits_2.unsqueeze(0), dim=-1
    )
    sims = sims / temperature

    indices = torch.arange(logits_1.shape[0], device=logits.device)
    loss_1 = F.cross_entropy(sims, indices)
    loss_2 = F.cross_entropy(sims.transpose(0, 1), indices)

    loss_a = (loss_1 + loss_2) / 2.0

    pos_mask = labels_1.unsqueeze(1) == labels_2.unsqueeze(0)

    numerator = (sims.exp() * pos_mask).sum(dim=-1)
    denominator = sims.exp().sum(dim=-1)
    loss_1 = -torch.log(numerator / denominator).mean()

    numerator = (sims.exp() * pos_mask).sum(dim=0)
    denominator = sims.exp().sum(dim=0)
    loss_2 = -torch.log(numerator / denominator).mean()

    loss_b = (loss_1 + loss_2) / 2.0

    lambda_norm = lambda_ / (lambda_ + 1.0)
    loss = (1.0 - lambda_norm) * loss_a + lambda_norm * loss_b
    return loss
