"""
@Description :   FocalLoss 的实现
@Author      :   tqychy 
@Time        :   2025/01/20 14:24:10
"""
import torch
import torch.nn as nn


class FocalLoss(nn.Module):

    def __init__(self, alpha=0.55, gamma=8, size_average=True):
        """
        focal_loss, -α(1-yi)**γ *ce_loss(xi,yi)

        """
        super(FocalLoss, self).__init__()
        self.size_average = size_average
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, preds, gt_mask, pad_mask_with_gt):
        """
        :param preds: predicted similarity matrix [bs, N, N]
        :param gt_mask: positive labels [N, N]
        :param pad_mask_with_gt: positive and padded labels [N, N]
        :return: positive and negative loss, positive loss
        """

        loss_p = - self.alpha * (1 - preds[gt_mask]) ** self.gamma * \
            torch.log(preds[gt_mask] + torch.tensor([1e-9]).cuda())

        loss_n = - (1 - self.alpha) * preds[~pad_mask_with_gt] ** self.gamma * torch.log(
            1 - preds[~pad_mask_with_gt] + torch.tensor([1e-9]).cuda())

        if self.size_average:
            loss_np = torch.cat((loss_p, loss_n), 0).mean()
            loss_p = loss_p.mean()
        else:
            loss_np = torch.cat((loss_p, loss_n), 0).sum()
            loss_p = loss_p.sum()

        return 400 * loss_np, loss_p
    
class FocalLoss2(nn.Module):
    def __init__(self, alpha=0.55, gamma=8, reduction='mean'):
        """
        Args:
            alpha (float): 平衡正负样本的权重，默认 0.25（适用于正样本稀疏的场景）。
            gamma (float): 调整难易样本权重的指数因子，默认 2。
            reduction (str): 输出的聚合方式（'mean', 'sum', 'none'）。
        """
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, probs, targets):
        """
        Args:
            probs (Tensor): 模型输出的归一化概率值，形状为 [N, *]。
            targets (Tensor): 真实标签的 0-1 矩阵，形状与 probs 相同。
        Returns:
            loss (Tensor): 计算得到的 Focal Loss。
        """
        # 确保概率值在 [0, 1] 范围内
        probs = torch.clamp(probs, min=1e-12, max=1-1e-12)  # 避免 log(0) 或 log(1) 导致的数值问题
        
        # 计算二元交叉熵的 log 概率
        ce_loss = - (targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs))
        
        # 计算 p_t：对于正样本取 p，对于负样本取 1-p
        p_t = probs * targets + (1 - probs) * (1 - targets)
        
        # 计算调制因子 (1 - p_t)^gamma
        modulating_factor = (1 - p_t).pow(self.gamma)
        
        # 计算 alpha 权重：正样本用 alpha，负样本用 1-alpha
        alpha_weight = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        
        # 组合得到 Focal Loss
        loss = alpha_weight * modulating_factor * ce_loss
        
        # 聚合损失
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return 400 * loss