import torch
import random

class BCEwithProj(torch.nn.Module):
    def __init__(self):
        super(BCEwithProj, self).__init__()
        self.alpha = torch.nn.Parameter(torch.tensor(1.0))
        self.beta = torch.nn.Parameter(torch.tensor(0.0))

    def forward(self, logits_per_image, gt):
        """
        logits_per_image: torch.Tensor, shape: (N, M)
        gt: torch.Tensor, shape: (N, M)
        """
        sigmoid_cosine_sim = torch.sigmoid(self.alpha * logits_per_image + self.beta)
        
        loss = torch.nn.functional.binary_cross_entropy(sigmoid_cosine_sim, gt)

        if random.random() < 0.01:
            print("BCEwithProj alpha: {}, beta: {}".format(self.alpha.item(), self.beta.item()))
        return loss
