import torch
import torch.nn.functional as F

def js_divergence(P, Q):
    """
    计算两个分布 P 和 Q 的 Jensen-Shannon 散度。
    :param P: Tensor, 第一个分布 (归一化为概率分布)
    :param Q: Tensor, 第二个分布 (归一化为概率分布)
    :return: Tensor, JS 散度
    """
    M = 0.5 * (P + Q)
    js_div = 0.5 * (F.kl_div(P.log(), M, reduction='batchmean') + F.kl_div(Q.log(), M, reduction='batchmean'))
    return js_div

def compute_mmd(X, Y, sigma=1.0):
    """
    计算两个嵌入集合的 MMD。
    :param X: Tensor, 第一个嵌入集合 (N x D)
    :param Y: Tensor, 第二个嵌入集合 (M x D)
    :param sigma: float, 高斯核的带宽参数
    :return: Tensor, MMD^2
    """
    def rbf_kernel(X, Y, sigma=1.0):
        """
        计算高斯核矩阵。
        :param X: Tensor, 数据集合1 (N x D)
        :param Y: Tensor, 数据集合2 (M x D)
        :param sigma: float, 高斯核的带宽参数
        :return: Tensor, 核矩阵 (N x M)
        """
        XX = torch.sum(X**2, dim=1, keepdim=True)
        YY = torch.sum(Y**2, dim=1, keepdim=True)
        XY = torch.mm(X, Y.T)
        pairwise_dist = XX - 2 * XY + YY.T
        kernel = torch.exp(-pairwise_dist / (2 * sigma**2))
        return kernel
    
    K_XX = rbf_kernel(X, X, sigma)
    K_YY = rbf_kernel(Y, Y, sigma)
    K_XY = rbf_kernel(X, Y, sigma)
    mmd = K_XX.mean() + K_YY.mean() - 2 * K_XY.mean()
    return mmd

# 假设P,Q是两个高斯分布，计算KL散度
def compute_kl_loss(P, Q):
    """
    计算两个高斯分布的 KL 散度。
    :param P: Tensor, 真实高斯分布
    :param Q: Tensor, 预测高斯分布
    :return: Tensor, KL 散度
    """
    mean_p, std_p = P.mean(), P.std()
    mean_q, std_q = Q.mean(), Q.std()
    kl = torch.log(std_q / std_p) + (std_p**2 + (mean_p - mean_q)**2) / (2 * std_q**2) - 0.5
    return kl
