import torch
import torch.nn as nn
import torch.nn.functional as F

class BSCE_GRA_Loss(nn.Module):
    """
    Implements the Uncertainty Weighted Gradients for Model Calibration (BSCE-GRA) loss.
    This loss function combines Cross-Entropy with a gradient weight derived from the
    generalized Brier Score (gBS) to improve model calibration.

    Args:
        gamma (float): The exponent hyperparameter for the gBS calculation.
                       Controls the sensitivity to prediction errors.
        beta (float): The norm order hyperparameter for the gBS calculation.
                      Typically set to 2 for L2 norm (standard Brier Score).
    """
    def __init__(self, gamma: float = 3.0, beta: float = 3.0, size_average=False):
        super(BSCE_GRA_Loss, self).__init__()
        if gamma < 0:
            raise ValueError("Gamma must be non-negative.")
        if beta <= 0:
            raise ValueError("Beta must be positive.")
        self.gamma = gamma
        self.beta = beta
        # Use 'none' reduction to get per-sample losses, which we will then weight.
        self.ce_loss_fn = nn.CrossEntropyLoss(reduction='none')
        self.size_average = size_average

    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculates the BSCE-GRA loss.

        Args:
            logits (torch.Tensor): The raw, unnormalized output from the model's final layer.
                                   Shape: (N, C) where N is batch size, C is number of classes.
            labels (torch.Tensor): The ground truth labels.
                                   Shape: (N,) where each value is an integer from 0 to C-1.

        Returns:
            torch.Tensor: A scalar tensor representing the mean loss for the batch.
        """
        # Step 1: Calculate the predicted probability distribution
        probs = F.softmax(logits, dim=1)
        num_classes = logits.shape[1]

        # Step 2: Create one-hot encoded labels for gBS calculation
        # Ensure labels are of type long for one_hot
        y_hot = F.one_hot(labels.to(torch.long), num_classes=num_classes).float()

        # Step 3: Calculate the sample-wise uncertainty metric (gBS)
        # This corresponds to Equation (11) in the paper.
        # We use torch.abs(probs - y_hot)**self.beta to compute the inner norm part.
        # The sum over dim=1 calculates the sum over all classes for each sample.
        # The final **(self.gamma / self.beta) is a stable way to compute (sum(|p-y|^beta))^(gamma/beta)
        # which is equivalent to (sum(||p-y||_beta^beta))^(gamma/beta)
        # When beta=gamma, this simplifies to sum(|p-y|^gamma).
        # The paper uses ||.||_beta^gamma, so we compute (sum(|.|^beta))^(gamma/beta)
        gbs_weight = torch.sum(torch.abs(probs - y_hot)**self.beta, dim=1)**(self.gamma / self.beta)

        # Step 4: Calculate the foundational per-sample Cross-Entropy loss
        ce_loss_per_sample = self.ce_loss_fn(logits, labels)

        # Step 5: Apply the uncertainty weight to the CE gradient
        # This is the core mechanism of the paper. We detach the weight to ensure
        # it acts as a pure scaling factor on the CE gradient, preventing the
        # "misalignment problem". This implements Equation (9).
        weighted_loss = gbs_weight.detach() * ce_loss_per_sample

        # Step 6: Return the mean loss for the batch
        if self.size_average: 
            return weighted_loss.mean()
        else:
            return weighted_loss.sum()
        

