import torch
import torch.nn as nn
import torch.nn.functional as F


class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, image_features, tabular_features):
        # Normalize to unit vectors
        image_features = F.normalize(image_features, dim=-1)
        tabular_features = F.normalize(tabular_features, dim=-1)

        # Cosine similarity logits
        logits_per_image = torch.matmul(image_features, tabular_features.T) / self.temperature

        # Labels for contrastive learning
        batch_size = image_features.shape[0]
        labels = torch.arange(batch_size, device=image_features.device)

        # Compute loss both directions
        loss_image_to_tabular = self.cross_entropy_loss(logits_per_image, labels)
        loss_tabular_to_image = self.cross_entropy_loss(logits_per_image.T, labels)
        loss = (loss_image_to_tabular + loss_tabular_to_image) / 2

        return loss


class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.5, size_average='mean'):
        super(FocalLoss, self).__init__()
        # gamma: Focusing parameter to adjust the rate at which easy examples are down-weighted
        self.gamma = gamma
        # alpha: Balancing factor for the classes, used to adjust the weight for each class
        self.alpha = alpha
        # size_average: How to reduce the loss across the batch. Can be 'mean' or 'sum'
        self.size_average = size_average

    def forward(self, input, target):
        bce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')

        p = torch.sigmoid(input)
        p = torch.where(target >= 0.5, p, 1 - p)
        modulating_factor = (1 - p) ** self.gamma
        alpha = self.alpha * target + (1 - self.alpha) * (1 - target)

        focal_loss = alpha * modulating_factor * bce_loss

        if self.size_average == 'mean':
            return focal_loss.mean()
        else:
            return focal_loss.sum()