import torch
import torch.nn as nn

class GeneralizedCrossEntropyLoss(nn.Module):
    """
    An implementation of the generalized cross entropy loss, adapted to work with binary classification.
    """
    def __init__(self, q=0.7):
        super(GeneralizedCrossEntropyLoss, self).__init__()
        assert q > 0 and q <= 1, "q should be in (0, 1]"
        self.q = q
    
    def forward(self, input, target):
        # Predictions are assumed to be probabilities (after sigmoid)
        probs = torch.clamp(input, min=1e-7, max=1 - 1e-7)  # Avoid numerical instability
        # Compute the GCE loss
        loss = (1 - (target * probs + (1 - target) * (1 - probs))**self.q) / self.q
        return torch.mean(loss)

class GeneralizedCrossEntropyLossSoftLabels(nn.Module):
    """
    Generalized Cross-Entropy Loss for probabilistic (soft) labels.
    Compatible with binary classification and existing GeneralizedCrossEntropyLoss.
    """
    def __init__(self, q=0.7):
        super(GeneralizedCrossEntropyLossSoftLabels, self).__init__()
        assert q > 0 and q <= 1, "q should be in (0, 1]"
        self.q = q

    def forward(self, input, target):
        """
        Args:
            input (torch.Tensor): Predicted probabilities (after sigmoid), shape (batch_size,).
            target (torch.Tensor): Soft labels (probabilities from the noisy labeler), shape (batch_size,).
        
        Returns:
            torch.Tensor: The average GCE loss for the batch.
        """
        # Predictions are assumed to be probabilities (after sigmoid)
        probs = torch.clamp(input, min=1e-7, max=1 - 1e-7)  # Avoid numerical instability
        # Compute the likelihood based on soft labels
        likelihood = target * probs + (1 - target) * (1 - probs)
        # Compute the GCE loss
        loss = (1 - likelihood**self.q) / self.q
        return torch.mean(loss)
