import torch
import torch.nn.functional as F

class CLIPLoss(torch.nn.Module):
    def __init__(self, n_views, temperature):
        super().__init__()
        self.n_views = n_views
        self.temperature = torch.tensor(1 / temperature, requires_grad=False)

    def forward(self, target_feats, source_feats):
        device = target_feats.device
        # normalized features
        source_feats = source_feats / source_feats.norm(dim=1, keepdim=True)
        target_feats = target_feats / target_feats.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.temperature
        logits_per_source = logit_scale * source_feats @ target_feats.t()
        logits_per_target = logits_per_source.t()

        labels = torch.arange(logits_per_target.size(0)).to(device)
        loss_source = F.cross_entropy(logits_per_source, labels)
        loss_target = F.cross_entropy(logits_per_source, labels)
        loss_val = (loss_source + loss_target) / 2
        # shape = [global_batch_size, global_batch_size]
        
        return logits_per_source, logits_per_target, labels, loss_val
