import torch
import torch.nn.functional as F


def contrastive_loss(logits, targets):
    """
    计算 CLIP 风格的对比损失（InfoNCE Loss）
    :param logits: 形状 (batch_size, num_texts)，包含图像-文本相似度
    :param targets: 形状 (batch_size)，每个图像对应一个文本的索引（可以是自定义的目标）
    :return: 计算得到的 InfoNCE 损失
    """
    temperature = 0.07

    # 计算 softmax 归一化
    logits = logits / temperature

    # 计算 InfoNCE Loss
    loss_img_to_text = F.cross_entropy(logits, targets)  # 图像 -> 文本

    # 修复文本 -> 图像的损失：直接使用 logits，而不是转置
    loss_text_to_img = F.cross_entropy(logits, targets)  # 文本 -> 图像（没有转置）

    loss = (loss_img_to_text + loss_text_to_img) / 2  # 取平均
    return loss


def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
    """
    Focal Loss 用于解决类别不平衡问题
    :param logits: 预测的logits (batch_size, num_classes)
    :param targets: 真实标签 (batch_size)
    :param alpha: 类别平衡因子
    :param gamma: 焦点调节因子
    :return: 计算后的 Focal Loss
    """
    # 计算交叉熵损失
    ce_loss = F.cross_entropy(logits, targets, reduction='none')

    # 获取预测类别的概率
    prob = torch.exp(-ce_loss)

    # 计算Focal Loss
    focal_loss = alpha * (1 - prob) ** gamma * ce_loss
    return focal_loss.mean()


# 计算对称KL散度损失
def symmetric_kl_divergence_loss(output_global, output_local):
    """
    计算对称KL散度
    :param output_global: 全局分支输出 logits
    :param output_local: 局部分支输出 logits
    :return: 对称KL散度
    """
    # KL散度：P || Q
    log_output_global = F.log_softmax(output_global, dim=1)
    output_local_softmax = F.softmax(output_local, dim=1)
    kl_loss_forward = F.kl_div(log_output_global, output_local_softmax, reduction='batchmean')

    # KL散度：Q || P
    log_output_local = F.log_softmax(output_local, dim=1)
    output_global_softmax = F.softmax(output_global, dim=1)
    kl_loss_backward = F.kl_div(log_output_local, output_global_softmax, reduction='batchmean')

    # 计算对称KL散度
    symmetric_kl_loss = (kl_loss_forward + kl_loss_backward) / 2.0
    return symmetric_kl_loss


# 计算总损失函数
def compute_loss(output_global, output_local, output_sum, targets, alpha=0.25, gamma=2.0, alpha_focal_loss = 0.5, alpha_kl_loss = 0.5):
    """
    计算每个batch的总损失，包括对比损失、Focal Loss 和 KL散度
    :param output_global: 全局分支输出
    :param output_local: 局部分支输出
    :param output_sum: 融合分支输出
    :param targets: 真实标签
    :param conloss: 对比损失函数
    :param alpha: Focal Loss中的类别平衡因子
    :param gamma: Focal Loss中的焦点调节因子
    :return: 返回总损失、各个部分损失
    """
    # 对比损失
    global_loss = contrastive_loss(output_global, targets)  # 全局对比损失
    local_loss = contrastive_loss(output_local, targets)  # 局部对比损失

    # 分类损失，使用Focal Loss代替交叉熵损失
    class_loss = focal_loss(output_sum, targets, alpha=alpha, gamma=gamma)

    # 计算全局和局部之间的KL散度
    kl_loss = symmetric_kl_divergence_loss(output_global, output_local)

    # 总损失是对比损失 + 分类损失 + KL散度损失
    total_loss = global_loss + local_loss + alpha_focal_loss * class_loss + alpha_kl_loss * kl_loss
    return total_loss, global_loss, local_loss, class_loss, kl_loss
