import torch.nn.functional as F
import torch
# For simplicity we also use uniqueness and synergy from other modalities as negative samples
# The simplification introduces only a marginal performance difference.
def multi_view_contrastive_loss_diverse(feats, modality_num=2, temperature=0.2):
    """
    Compute contrastive loss using all modality pairs among the first `modality_num` features.
    Remaining features are treated as negatives.

    Args:
        feats: list of [n, d] projected, normalized embeddings
        modality_num: number of modalities for which we define positive pairs
        temperature: temperature for contrastive scaling

    Returns:
        scalar contrastive loss
    """
    if modality_num < 2:
        return 0.0
    device = feats[0].device
    n = feats[0].shape[0]

    negatives = torch.stack(feats[modality_num:], dim=1)  # shape: [n, num_neg, d]

    loss = 0.0

    for i in range(modality_num):
        for j in range(modality_num):
            if i == j:
                continue
            anchor = feats[i]
            positive = feats[j]

            # Positive logits: dot product between anchor and positive
            pos_logits = torch.sum(anchor * positive, dim=1, keepdim=True)  # [n, 1]

            # Negative logits: dot anchor with all negatives
            neg_logits = torch.bmm(negatives, anchor.unsqueeze(2)).squeeze(2)  # [n, m]

            logits = torch.cat([pos_logits, neg_logits], dim=1) / temperature
            labels = torch.zeros(n, dtype=torch.long, device=device)  # pos at index 0

            loss += F.cross_entropy(logits, labels)

    return loss / (modality_num * (modality_num - 1))


# for some dataset, using raw contrastive loss could be better
def multi_view_contrastive_loss(embeddings, modality_num=2, temperature=0.1):
    
    """
    Compute contrastive loss
    
    Args:
        embeddings: list of projected, L2-normalized view tensors [n, d_e]
        modality_num: number of modalities for which we define positive pairs
        temperature: scalar or learnable torch.tensor
    Returns:
        scalar loss
    """
    if modality_num < 2:
        return 0.0
    embeddings = embeddings[:modality_num]
    n = embeddings[0].shape[0]
    device = embeddings[0].device
    labels = torch.arange(n, device=device)

    loss = 0.0
    for i in range(len(embeddings)):
        for j in range(i + 1, len(embeddings)):
            logits_ij = embeddings[i] @ embeddings[j].T / temperature
            logits_ji = embeddings[j] @ embeddings[i].T / temperature
            loss += F.cross_entropy(logits_ij, labels)
            loss += F.cross_entropy(logits_ji, labels)

    return loss / (modality_num * (modality_num - 1))