
import torch
import torch.nn.functional as F
from collections import defaultdict


class ContrastiveLoss(torch.nn.Module):

    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, embeddings: torch.Tensor, pos_pairs: torch.Tensor, neg_pairs: torch.Tensor) -> torch.Tensor:

        anchor_pos = embeddings[pos_pairs[:, 0]]
        pos = embeddings[pos_pairs[:, 1]]
        sim_pos = torch.exp(F.cosine_similarity(anchor_pos, pos) / self.temperature)

        anchor_neg = embeddings[neg_pairs[:, 0]]
        neg = embeddings[neg_pairs[:, 1]]
        sim_neg = torch.exp(F.cosine_similarity(anchor_neg, neg) / self.temperature)

        pos_dict = defaultdict(list)
        for i, (a_idx, _) in enumerate(pos_pairs):
            pos_dict[a_idx.item()].append(sim_pos[i])

        neg_dict = defaultdict(list)
        for i, (a_idx, _) in enumerate(neg_pairs):
            neg_dict[a_idx.item()].append(sim_neg[i])

        losses = []
        for anchor_idx in pos_dict:

            pos_vals = torch.stack(pos_dict[anchor_idx])
            neg_vals = torch.stack(neg_dict.get(anchor_idx, [torch.tensor(0.0, device=embeddings.device)]))

            denom = pos_vals.sum() + neg_vals.sum() + 1e-8
            loss = - torch.log(pos_vals / denom)
            losses.append(loss.mean())

        return torch.stack(losses).mean()


def cosine_similarity_stats(embeddings, pos_pairs, neg_pairs):
    pos_sim = (F.cosine_similarity(embeddings[pos_pairs[:, 0]], embeddings[pos_pairs[:, 1]]))
    neg_sim = (F.cosine_similarity(embeddings[neg_pairs[:, 0]], embeddings[neg_pairs[:, 1]]))
    return pos_sim.mean().item(), pos_sim.std().item(), neg_sim.mean().item(), neg_sim.std().item()
