import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

class InfoNCE(nn.Module):
    def __init__(self, loss_function, device='cuda', world_size=1):
        super().__init__()
        self.loss_function = loss_function
        self.device = device
        self.world_size = world_size

    def forward(self, image_features1, image_features2, logit_scale):
        
        image_features1 = F.normalize(image_features1, dim=-1)
        image_features2 = F.normalize(image_features2, dim=-1)
        rank = dist.get_rank()

        gathered_image_features1 = [
            torch.zeros_like(image_features1) for _ in range(self.world_size)
        ]
        gathered_image_features2 = [
            torch.zeros_like(image_features2) for _ in range(self.world_size)
        ]
        dist.all_gather(gathered_image_features1, image_features1)
        dist.all_gather(gathered_image_features2, image_features2)

        all_image_features1 = torch.cat(
            [image_features1]
            + gathered_image_features1[:rank]
            + gathered_image_features1[rank + 1 :]
        )
        all_image_features2 = torch.cat(
            [image_features2]
            + gathered_image_features2[:rank]
            + gathered_image_features2[rank + 1 :]
        )

        
        logits_per_image1 = logit_scale * all_image_features1 @ all_image_features2.T
        logits_per_image2 = logits_per_image1.T

        
        labels = torch.arange(len(logits_per_image1), dtype=torch.long, device=self.device)
        
        
        loss = (self.loss_function(logits_per_image1, labels) + 
                self.loss_function(logits_per_image2, labels)) / 2

        return loss

