import sys
import torch

import torch
import torch.nn.functional as F





def compute_joint(x_out, x_tf_out):
    # 计算视图间的联合分布
    bn, k = x_out.size()
    assert (x_tf_out.size(0) == bn and x_tf_out.size(1) == k)

    p_i_j = x_out.unsqueeze(2) * x_tf_out.unsqueeze(1)  # bn, k, k
    p_i_j = p_i_j.sum(dim=0)  # k, k
    p_i_j = (p_i_j + p_i_j.t()) / 2.  # 对称化
    p_i_j = p_i_j / p_i_j.sum()  # 归一化

    return p_i_j

def compute_feature_redundancy(x_out, x_tf_out, EPS=sys.float_info.epsilon):
    """特征间冗余互信息损失"""
    bn, k = x_out.size()
    H_cat = torch.cat([x_out, x_tf_out], dim=0)  # 合并两个视图的样本
    H_cat_norm = torch.nn.functional.normalize(H_cat, p=2, dim=1)

    P_feat_joint = torch.mm(H_cat_norm.T, H_cat_norm) / (2 * bn)

    # 非对角线部分
    mask = ~torch.eye(k, dtype=torch.bool, device=x_out.device)
    P_dd = P_feat_joint[mask].view(k, k - 1)
    P_d = torch.diag(P_feat_joint).view(k, 1)

    P_dd = torch.where(P_dd < EPS, torch.tensor([EPS], device=P_dd.device), P_dd)

    # 匹配维度的外积（也去除对角线）
    P_d_outer = P_d @ P_d.T  # [k, k]
    P_d_outer_no_diag = P_d_outer[mask].view(k, k - 1)

    loss_feature_redundancy = (P_dd * torch.log(P_dd / (P_d_outer_no_diag + EPS))).mean()
    return loss_feature_redundancy

def instance_contrastive_Loss(x_out, x_tf_out, lamb=10, lamb_redundancy=0.00001, EPS=sys.float_info.epsilon):
    """最大化视图一致性 + 最小化特征冗余"""
    _, k = x_out.size()
    p_i_j = compute_joint(x_out, x_tf_out)

    assert (p_i_j.size() == (k, k))
    p_i = p_i_j.sum(dim=1).view(k, 1).expand(k, k)
    p_j = p_i_j.sum(dim=0).view(1, k).expand(k, k)

    p_i_j = torch.where(p_i_j < EPS, torch.tensor([EPS], device=p_i_j.device), p_i_j)
    p_j = torch.where(p_j < EPS, torch.tensor([EPS], device=p_j.device), p_j)
    p_i = torch.where(p_i < EPS, torch.tensor([EPS], device=p_i.device), p_i)

    # 视图一致性损失
    loss_consistency = - p_i_j * (torch.log(p_i_j) \
                                 - lamb * torch.log(p_j) \
                                 - lamb * torch.log(p_i))
    loss_consistency = loss_consistency.sum()

    # 特征冗余损失
    loss_feature_redundancy = compute_feature_redundancy(x_out, x_tf_out, EPS)

    # 主损失函数
    loss = loss_consistency + lamb_redundancy * loss_feature_redundancy
    return loss

def contrastive_loss(h_i, h_j,weight=None):
    """
    计算对比损失（Contrastive Loss）。

    参数：
        h_i: Tensor, 视图 i 的表示 (batch_size, dim)
        h_j: Tensor, 视图 j 的表示 (batch_size, dim)
        temperature: float, 温度参数，用于缩放相似度
        weight: Tensor (可选), 权重系数

    返回：
        loss: 对比损失值
    """
    batch_size = h_i.shape[0]
    # 计算相似度矩阵 (cosine similarity)
    similarity_matrix = torch.matmul(h_i, h_j.T)   # (N, N)
    # 取出对角线元素（正样本相似度）
    positives = torch.diag(similarity_matrix)
    # 生成 mask，去掉自身相似度
    mask = torch.ones((batch_size, batch_size), device=h_i.device)
    mask.fill_diagonal_(0)
    # 计算分子 (正样本相似度)
    nominator = torch.exp(positives)  # (N,)
    # 计算分母 (所有负样本相似度的和)
    denominator = mask * torch.exp(similarity_matrix)  # (N, N)
    # 计算对比损失
    loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))  # (N,)
    loss = torch.sum(loss_partial) / batch_size  # 平均损失
    # 乘以权重（如果提供了）
    if weight is not None:
        loss *= weight
    return loss
