import torch
import torch.nn.functional as F
import torch.nn as nn

class InfoNCELoss(torch.nn.Module): # positive-negative contrastive loss
    def __init__(self, temperature=0.05):
        super().__init__()
        self.temperature = temperature

    def forward(self, anchor_emb, pos_emb, neg_emb):
        B = anchor_emb.size(0)
        P = pos_emb.size(0) // B
        N = neg_emb.size(0) // B
        
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)        
        pos_emb = F.normalize(pos_emb, p=2, dim=1)
        neg_emb = F.normalize(neg_emb, p=2, dim=1)
        
        all_negatives = torch.cat([pos_emb, neg_emb, anchor_emb], dim=0)  # [B, 2N + B]

        sim_matrix = torch.matmul(anchor_emb, all_negatives.t()) / self.temperature  # [B, 2N + B]  
        pos_mask = torch.zeros(B, sim_matrix.size(1), dtype=torch.bool, device=anchor_emb.device)
        
        for i in range(B):
            start = i * P
            end = start + P
            pos_mask[i, start:end] = True
            
        logits = sim_matrix - torch.max(sim_matrix, dim=1, keepdim=True)[0].detach()
        exp_logits = torch.exp(logits)        
        sum_exp_logits = torch.sum(exp_logits, dim=1, keepdim=True)
        
        pos_logits = logits[pos_mask].view(B, -1)
        log_probs = pos_logits - torch.log(sum_exp_logits)
        loss = -torch.mean(log_probs)
        return loss    

class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        B = anchor.size(0)
        P = positive.size(0) // B
        N = negative.size(0) // B

        # Reshape and average over P/N samples
        positive = positive.view(B, P, -1).mean(dim=1)
        negative = negative.view(B, N, -1).mean(dim=1)

        # Normalize
        anchor = F.normalize(anchor, p=2, dim=1)
        positive = F.normalize(positive, p=2, dim=1)
        negative = F.normalize(negative, p=2, dim=1)

        # Compute loss
        pos_dist = (anchor - positive).pow(2).sum(dim=1)
        neg_dist = (anchor - negative).pow(2).sum(dim=1)
        loss = F.relu(pos_dist - neg_dist + self.margin).mean()
        return loss