import torch
import torch.nn.functional as F


def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
    """
    Focal Loss 用于解决类别不平衡问题
    :param logits: 预测的logits (batch_size, num_classes)
    :param targets: 真实标签 (batch_size)
    :param alpha: 类别平衡因子
    :param gamma: 焦点调节因子
    :return: 计算后的 Focal Loss
    """
    # 计算交叉熵损失
    ce_loss = F.cross_entropy(logits, targets, reduction='none')

    # 获取预测类别的概率
    prob = torch.exp(-ce_loss)

    # 计算Focal Loss
    focal_loss = alpha * (1 - prob) ** gamma * ce_loss
    return focal_loss.mean()
