import torch.nn.functional as F


# 计算对称KL散度损失
def symmetric_kl_divergence_loss(output_global, output_local):
    """
    计算对称KL散度
    :param output_global: 全局分支输出 logits
    :param output_local: 局部分支输出 logits
    :return: 对称KL散度
    """
    # KL散度：P || Q
    log_output_global = F.log_softmax(output_global, dim=1)
    output_local_softmax = F.softmax(output_local, dim=1)
    kl_loss_forward = F.kl_div(log_output_global, output_local_softmax, reduction='batchmean')

    # KL散度：Q || P
    log_output_local = F.log_softmax(output_local, dim=1)
    output_global_softmax = F.softmax(output_global, dim=1)
    kl_loss_backward = F.kl_div(log_output_local, output_global_softmax, reduction='batchmean')

    # 计算对称KL散度
    symmetric_kl_loss = (kl_loss_forward + kl_loss_backward) / 2.0
    return symmetric_kl_loss
