import torch
import torch.nn as nn
import torch.nn.functional as F


class BCEWithLogitsLoss(nn.BCEWithLogitsLoss):
    def __init__(self, label_smoothing=0.0, **kwargs):
        super(BCEWithLogitsLoss, self).__init__(**kwargs)
        self.label_smoothing = label_smoothing

    def forward(self, input, target):
        if self.label_smoothing:
            target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
        return super(BCEWithLogitsLoss, self).forward(input, target)


class SigmoidFocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, label_smoothing=0.0, reduction="mean"):
        """
        Args:
            alpha (float): Weighting factor in range (0,1) to balance positive vs negative examples.
            gamma (float): Focusing parameter to reduce the relative loss for well-classified examples.
            label_smoothing (float): Label smoothing factor to reduce the confidence of the true label.
            reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
                             'none': no reduction will be applied,
                             'mean': the sum of the output will be divided by the number of elements in the output,
                             'sum': the output will be summed.
        """
        super(SigmoidFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing
        self.reduction = reduction

    def forward(self, input, target):
        """
        Args:
            input (Tensor): Predicted logits for each example.
            target (Tensor): Ground truth binary labels (0 or 1) for each example.
        """
        if self.label_smoothing:
            target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing

        p = torch.sigmoid(input)

        ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction="none")
        p_t = p * target + (1 - p) * (1 - target)
        loss = ce_loss * ((1 - p_t) ** self.gamma)

        if self.alpha >= 0:
            alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
            loss = alpha_t * loss

        # Check reduction option and return loss accordingly
        if self.reduction == "none":
            pass
        elif self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()
        else:
            raise ValueError(
                f"Invalid Value for arg 'reduction': '{self.reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
            )
        return loss
