import torch
import torch.nn as nn
from torch.nn import functional as F
import copy
from itertools import chain
from method.base import BaseMethod
import torch.distributed as dist
from method.netvlad import NetVLAD
import time
import pdb
from method.clip import clip
import logging

class AlignLoss(nn.Module):
    def __init__(self, t_q=1, t_k=1):
        super().__init__()

        self.t_q = t_q
        self.t_k = t_k

        self.loss_fn = nn.MSELoss()

    def self_dist(self, q, k):
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        loss = self.loss_fn(q, k)

        return loss

    def forward(self, q_feats, k_feats):
        loss = torch.tensor(0).to(q_feats[0])
        for q_feat, k_feat in zip(q_feats, k_feats):
            loss += self.self_dist(q_feat, k_feat)

        return loss

# 基于part对比的对比损失
def contrastive_loss(q, k, temperature=0.07):
    """
    计算两个张量之间的对比损失
    q, k: 形状为(N, K, D)的张量
    temperature: 对比损失中的温度参数
    """
    # 获取batch的大小和簇数
    N, K, D = q.shape
    
    # 将张量转换为形状(N*K, D)以计算相似度
    q_flat = q.view(N * K, D)
    k_flat = k.view(N * K, D)

    # 计算所有样本之间的余弦相似度
    similarity_matrix = F.cosine_similarity(q_flat.unsqueeze(1), k_flat.unsqueeze(0), dim=-1)
    
    # 对角线上的相似度是正样本
    labels = torch.arange(N * K).to(q.device)
    
    # 计算对比损失
    sim_pos = similarity_matrix[labels, labels]  # 正样本相似度
    sim_neg = similarity_matrix - torch.eye(N * K).to(q.device) * 1e9  # 负样本相似度，屏蔽对角线

    # 计算logits
    logits = torch.cat([sim_pos.unsqueeze(1), sim_neg], dim=1)  # [N*K, N*K]的logits矩阵
    
    # 使用softmax计算损失
    labels = torch.zeros(N * K, dtype=torch.long).to(q.device)
    loss = F.cross_entropy(logits / temperature, labels)
    
    return loss


# 基于样本的对比损失
def contrastive_loss_sample(z1, z2, temperature=0.07):
    """
    计算对比损失，并返回总损失、正样本损失和负样本损失
    
    参数:
        z1 (torch.Tensor): 形状为 (N, D) 的张量
        z2 (torch.Tensor): 形状为 (N, D) 的张量
        temperature (float): 温度参数，默认0.07
        
    返回:
        total_loss (torch.Tensor): 总对比损失
        pos_loss (torch.Tensor): 正样本损失
        neg_loss (torch.Tensor): 负样本损失
    """
    # L2归一化
    z1_norm = F.normalize(z1, p=2, dim=1)
    z2_norm = F.normalize(z2, p=2, dim=1)
    
    # 计算相似度矩阵 (N, N)
    sim_matrix = torch.mm(z1_norm, z2_norm.T) / temperature
    
    # 提取正样本对的相似度（对角线）
    pos_sim = sim_matrix.diag()  # (N,)
    
    # 计算每个样本的分母的对数值（包括正样本）
    log_denominator = torch.logsumexp(sim_matrix, dim=1)  # (N,)
    
    # 分解损失项
    pos_loss = -pos_sim.mean()
    neg_loss = log_denominator.mean()
    total_loss = pos_loss + neg_loss
    
    return total_loss, pos_loss, neg_loss
    
class CLUE(BaseMethod):
    # CLUE using the momentum network, better performance

    def __init__(self, args):
        super().__init__(args)

        # ============== part clustering ============
        self.is_parts = args.is_parts
        self.part_method = args.part_method
        self.n_parts = args.n_parts
        if self.is_parts=="proto":
            temp_vector = generate_orthonormal_vectors(self.n_parts, 2048)
            self.part_proto = nn.Parameter(temp_vector, requires_grad=True)
        elif self.is_parts=="netvlad":
            self.net_vlad = NetVLAD(num_clusters=args.n_parts, dim=2048, alpha=1.0)
            if self.part_method=="_global_part":
                self.fc_q = nn.Sequential(
                    nn.Linear(self.n_parts*2048 + 2048, 2048),
                    nn.LayerNorm(2048),  # 使用LayerNorm替代BatchNorm1d
                    nn.ReLU(),
                    nn.Linear(2048, 2048),
                )
                # 取消冻结fc_q，使其可训练以配合单独优化器
                # for p in self.fc_q.parameters():
                #     p.requires_grad = False
                
                if self.with_texts=="sample_level":
                    self.fc_text = nn.Sequential(
                        nn.Linear(self.n_parts*2048 + 2048, 2048),
                        nn.BatchNorm1d(2048),
                        nn.ReLU(),
                        nn.Linear(2048, 512),
                    )
            # 仅仅part
            if self.part_method=="_part":
                self.fc_q = nn.Sequential(
                    nn.Linear(self.n_parts*2048, 2048),
                    nn.LayerNorm(2048),  # 使用LayerNorm替代BatchNorm1d
                    nn.ReLU(),
                    nn.Linear(2048, 2048),
                )
                # 取消冻结fc_q，使其可训练以配合单独优化器
                # for p in self.fc_q.parameters():
                #     p.requires_grad = False
                
                if self.with_texts=="sample_level":
                    self.fc_text = nn.Sequential(
                        nn.Linear(self.n_parts*2048, 2048),
                        nn.BatchNorm1d(2048),
                        nn.ReLU(),
                        nn.Linear(2048, 512),
                    )
            # 初始化 fc_q 以更好地捕获局部特征
            self._init_fc_q()
        self.sd_loss = AlignLoss()
        if self.is_recon=="recon":
            self.recon_projector = nn.Sequential(
                nn.Linear(512, 2048),
                nn.BatchNorm1d(2048),
                nn.ReLU(),
                nn.Linear(2048, 512),
                nn.LayerNorm(512),  # 一般在最后加LayerNorm
            )

        # --- Text whitening (去各向异性) ---
        self.txt_whiten = nn.SyncBatchNorm(512, affine=False)  # 若单卡可用 nn.BatchNorm1d
        self.txt_whiten_eps = 1e-5
        self.txt_whiten_momentum = 0.9
        # warm-up 步数（你也可以从 args 里传；这里给默认值）
        self.txt_whiten_warmup_steps = getattr(args, "txt_whiten_warmup_steps", 1000)
        # 统计步计数器（用 buffer 方便保存/恢复）
        self.register_buffer("txt_whiten_step", torch.tensor(0, dtype=torch.long))
        # Debug step counter for periodic monitoring
        self.register_buffer("debug_step", torch.tensor(0, dtype=torch.long))
        # # 轻量正则的权重（建议非常小）
        self.w_txt_var = 1e-3     # 文本方差正则
        self.w_txt_xcorr = 1e-2   # 轻量跨模态相关对齐（Barlow风格）
        
        # ============= momentum encoder ================
        self.momentum_encoder = copy.deepcopy(self.encoder)
        # 保留momentum_projector定义以保持兼容性，但实际不使用
        self.momentum_projector = copy.deepcopy(self.projector)
 
        # ============= momentum netvlad (only when using netvlad) ================
        if self.is_parts == "netvlad":
            self.momentum_net_vlad = copy.deepcopy(self.net_vlad)
            # 冻结momentum_net_vlad参数
            for param in self.momentum_net_vlad.parameters():
                param.requires_grad = False

            # ============= momentum fc_q for key branch ================
            self.momentum_fc_q = copy.deepcopy(self.fc_q)
            for param in self.momentum_fc_q.parameters():
                param.requires_grad = False
            # Note: there's no momentum_fc_k; key branch uses momentum_fc_q.
        else:
            self.momentum_net_vlad = None
            self.momentum_fc_q = None

        # 只冻结momentum_encoder参数，momentum_projector不参与训练
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
        for param in self.momentum_projector.parameters():
            param.requires_grad = False

        self.temp = args.temperature

        # # ================ feature attention =======================
        # self.attn = args.attn
        # if self.attn is True:
        #     self.attention = nn.Sequential(
        #         nn.Linear(2048, 2048),
        #         nn.ReLU(),
        #         nn.Linear(2048, 1),
        #         nn.Sigmoid()
        #     )

    def _init_fc_q(self):
        """
        Initialize fc_q for better local-feature extraction stability.
        - First Linear: Kaiming normal (for ReLU)
        - LayerNorm: gamma=1, beta=0 (standard initialization)
        - Last Linear: Orthogonal with small gain (0.01) to avoid dominating at start
        - In _global_part mode, bias initial contribution toward local features by
          down-weighting global block columns in the first Linear.
        """
        if not hasattr(self, 'fc_q'):
            return
        # Extract modules
        linear_layers = [m for m in self.fc_q if isinstance(m, nn.Linear)]
        ln_layers = [m for m in self.fc_q if isinstance(m, nn.LayerNorm)]

        if len(linear_layers) >= 1:
            first_linear = linear_layers[0]
            nn.init.kaiming_normal_(first_linear.weight, nonlinearity='relu')
            if first_linear.bias is not None:
                nn.init.zeros_(first_linear.bias)
            # Rebalance global/local contributions for global_part
            if getattr(self, 'part_method', None) == "_global_part":
                in_features = first_linear.weight.shape[1]
                global_dim = 2048
                # Input order is [global(2048), local(n_parts*2048)]
                if in_features >= global_dim:
                    with torch.no_grad():
                        first_linear.weight[:, :global_dim] *= 0.5  # down-weight global block
                        # keep local block unchanged

        if len(ln_layers) >= 1:
            ln = ln_layers[0]
            if hasattr(ln, 'weight') and ln.weight is not None:
                ln.weight.data.fill_(1.0)
            if hasattr(ln, 'bias') and ln.bias is not None:
                ln.bias.data.zero_()

        if len(linear_layers) >= 2:
            last_linear = linear_layers[1]
            nn.init.orthogonal_(last_linear.weight, gain=0.01)
            if last_linear.bias is not None:
                nn.init.zeros_(last_linear.bias)


    @torch.no_grad()
    def update_momentum_params(self, m):
        """
        Update of the momentum encoder, projector, and NetVLAD
        重写父类方法以支持momentum NetVLAD
        """
        # 更新momentum encoder
        for param_q, param_k in zip(self.encoder.parameters(), 
                                    self.momentum_encoder.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)
        
        # 保持兼容性，继续更新momentum_projector（虽然不使用）
        for param_q, param_k in zip(self.projector.parameters(),
                                    self.momentum_projector.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)
        
        # 更新momentum NetVLAD (仅当使用netvlad时)
        if self.momentum_net_vlad is not None:
            for param_q, param_k in zip(self.net_vlad.parameters(),
                                        self.momentum_net_vlad.parameters()):
                param_k.data = param_k.data * m + param_q.data * (1. - m)
        
        # 更新momentum fc_q (仅当使用netvlad时)
        if self.momentum_fc_q is not None:
            for param_q, param_k in zip(self.fc_q.parameters(),
                                        self.momentum_fc_q.parameters()):
                param_k.data = param_k.data * m + param_q.data * (1. - m)

        # 无 momentum_fc_k（key分支使用 momentum_fc_q）


    def forward(self, samples):
        # 将数据做切分为text和samples
        if len(samples)>0:
            # 注意要先提出来
            texts = samples[1]
            labels = samples[2]
            samples = samples[0]

        samples = [x.cuda(non_blocking=True) for x in samples]
        texts = torch.cat([clip.tokenize(t, truncate=True) for t in texts])
        if self.with_texts!=None:
            texts = [x.cuda(non_blocking=True) for x in texts]

        # 返回所有增强的结果，fp是feature map
        h, emb, fp = self.ForwardWrapper(samples, self.encoder, self.projector)
        # h_con = [concat_all_gather(x) for x in h]
        # h_con = h_con[0]
        h = h[0]
        with torch.no_grad():
            self.update_momentum_params(self.momentum)
            # 动量分支不使用projector，提高模型稳定性
            h_m, emb_m, fp_m = self.ForwardWrapper(samples[:2], self.momentum_encoder, use_projector=False)
            h_m_stand = h_m[1]

        emb_m = [concat_all_gather(x) for x in emb_m]
        h_m = [concat_all_gather(x) for x in h_m]
        h_m = h_m[0]      # 在global对比学习需要weak aug来得到H   TODO 之前跑的结果是standard aug得到的

        # --- Monitor: global features before assignment ---
        try:
            rank = dist.get_rank() if dist.is_initialized() else 0
        except Exception:
            rank = 0
        if not torch.isfinite(h).all() and rank == 0:
            logging.error(f"Monitor h non-finite: min={h.nan_to_num().min().item():.6f} max={h.nan_to_num().max().item():.6f} norm_min={h.norm(dim=1).min().item():.6f} norm_max={h.norm(dim=1).max().item():.6f}")
        if not torch.isfinite(h_m).all() and rank == 0:
            logging.error(f"Monitor h_m non-finite: min={h_m.nan_to_num().min().item():.6f} max={h_m.nan_to_num().max().item():.6f} norm_min={h_m.norm(dim=1).min().item():.6f} norm_max={h_m.norm(dim=1).max().item():.6f}")
        # 数值稳定：特征归一化后再做分配
        h_norm = F.normalize(h, dim=1)
        h_m_norm = F.normalize(h_m, dim=1)
        assign = self.sinkhorn_knopp(h_norm @ h_m_norm.T)
        if (not torch.isfinite(assign).all()) and rank == 0:
            asn = assign.nan_to_num()
            logging.error(f"Monitor assign non-finite: min={asn.min().item():.6f} max={asn.max().item():.6f} row_sum_min={asn.sum(1).min().item():.6f} row_sum_max={asn.sum(1).max().item():.6f}")
        # Periodic monitoring
        try:
            self.debug_step += 1
            step = int(self.debug_step.item())
        except Exception:
            step = 0
        if rank == 0 and (step % 50 == 0):
            try:
                logging.info(
                    f"[dbg {step}] h_norm ok, h_m_norm ok, assign row_sum range: "
                    f"[{assign.sum(1).min().item():.6f}, {assign.sum(1).max().item():.6f}]"
                )
            except Exception:
                pass
        # 若出现非有限值，退化为均匀分布，避免后续CE崩溃
        if not torch.isfinite(assign).all():
            assign = torch.ones_like(assign)
            assign = assign / assign.sum(dim=1, keepdim=True)
        # assign = self.sinkhorn_knopp(h @ h_con.T)

        # # 计算assign超过阈值的样本的数量
        # threshold = 0.9
        # # 阶梯式统计，用桶来计算每个阈值范围内的样本数量
        # buckets = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
        # bucket_counts = [0] * len(buckets)
        # for i in range(len(buckets) - 1):
        #     bucket_counts[i] = ((assign > buckets[i]) & (assign <= buckets[i + 1])).sum().item()
        # bucket_counts[-1] = (assign > buckets[-1]).sum().item()
        # 打印每个阈值范围内的样本数量
        # for i in range(len(buckets)):
        #     print(f"Number of samples with assign > {buckets[i]}: {bucket_counts[i]}")
        # # 打印总样本数量
        # print(f"Total number of samples: {assign.shape[0]}")

        # assign矩阵中，第i行样本的label为anchor，查找第j列对应的label和anchor是否相同，计算相同的数量
        # 计算每个样本的label
        # sample_labels = labels
        # # 初始化相同标签的数量
        # same_label_count = 0
        # # 遍历每个样本
        # for i in range(len(sample_labels)):
        #     # 获取第i行样本的label
        #     anchor_label = sample_labels[i]
        #     # 查找第i行第j列样本对应的label
        #     for j in range(len(sample_labels)):
        #         corresponding_label = sample_labels[j]
        #     # 检查是否相同
        #     if anchor_label == corresponding_label:
        #         same_label_count += 1
        # # 计算相同标签的比例
        # same_label_ratio = same_label_count / len(sample_labels)
        # # 打印结果
        # print(f"Same label ratio: {same_label_ratio}")

        # import time
        # time.sleep(1000)

        # ================part clustering==================
        if self.is_parts=='proto':
            ## query分支
            q_feat = self.layer4_features
            q_feat_global = h
            q_feat_part = self._get_part_feature(q_feat, self.part_proto)
            q_feat_global_part = torch.cat([q_feat_global, q_feat_part], dim=1)
            q_norm = F.normalize(q_feat_global, dim=1)

            ## key分支
            with torch.no_grad():
                k_feat = self.layer4_features_momentum_standard.detach()
                k_feat_global = h_m_standard
                k_feat_part = self._get_part_feature(k_feat, self.part_proto)
                k_feat_global_part = torch.cat([k_feat_global, k_feat_part], dim=1)
                k_norm = F.normalize(k_feat_global, dim=1)
            
            sd_loss = self.sd_loss(q_feat_global_part, k_feat_global_part)

        elif self.is_parts == "netvlad":
            ## query分支
            q_feat = fp[0]
            q_feat_part = self.net_vlad(q_feat)
            q_feat_part = q_feat_part.view(q_feat_part.size(0), -1)  # flatten
            if rank == 0 and (step % 50 == 0):
                try:
                    logging.info(
                        f"[dbg {step}] q_feat_part norm range: "
                        f"[{q_feat_part.norm(dim=1).min().item():.6f}, {q_feat_part.norm(dim=1).max().item():.6f}]"
                    )
                except Exception:
                    pass
            if self.part_method=="_global_part":
                # q_feat_global = F.normalize(h, p=2, dim=1) 
                q_feat_global_part = torch.cat([h, q_feat_part], dim=1)   # concat
                q_feat_global_part = F.normalize(q_feat_global_part, p=2, dim=1)  # L2 normalize
                q_feat_proj = self.fc_q(q_feat_global_part)
            elif self.part_method=="_part":
                q_feat_proj = self.fc_q(q_feat_part)
            q_feat_proj = F.normalize(q_feat_proj, dim=-1)
            
            ## key分支 - 使用momentum NetVLAD与momentum fc_q进行稳定投影
            # 使用momentum NetVLAD获取part特征（不需要梯度）
            with torch.no_grad():
                k_feat = fp_m[1]  # momentum特征
                k_feat = concat_all_gather(k_feat)    # 收集负样本
                k_feat_part = self.momentum_net_vlad(k_feat)  # 使用momentum NetVLAD
                k_feat_part = k_feat_part.view(k_feat_part.size(0), -1)  # flatten
                if self.part_method=="_global_part":
                    k_feat_global = F.normalize(h_m_stand, p=2, dim=1)  # momentum全局特征
                    k_feat_global = concat_all_gather(k_feat_global)
                    # Monitor momentum local features
                    if (not torch.isfinite(k_feat_part).all()) and rank == 0:
                        kfp = k_feat_part.nan_to_num()
                        logging.error(f"Monitor k_feat_part non-finite: min={kfp.min().item():.6f} max={kfp.max().item():.6f} norm_min={kfp.norm(dim=1).min().item():.6f} norm_max={kfp.norm(dim=1).max().item():.6f}")
                    if rank == 0 and (step % 50 == 0):
                        try:
                            logging.info(
                                f"[dbg {step}] k_feat_part norm range: "
                                f"[{k_feat_part.norm(dim=1).min().item():.6f}, {k_feat_part.norm(dim=1).max().item():.6f}]"
                            )
                        except Exception:
                            pass
                    k_feat_part_norm = F.normalize(k_feat_part, p=2, dim=1)  # 归一化part特征
                    k_feat_global_part = torch.cat([k_feat_global, k_feat_part_norm], dim=1)
            
            # 使用momentum fc_q进行稳定的负样本投影
            with torch.no_grad():
                if self.part_method=="_global_part":
                    k_feat_proj = self.momentum_fc_q(k_feat_global_part)
                elif self.part_method=="_part":
                    k_feat_proj = self.momentum_fc_q(k_feat_part)
                if (not torch.isfinite(k_feat_proj).all()) and rank == 0:
                    kfp = k_feat_proj.nan_to_num()
                    logging.error(f"Monitor k_feat_proj non-finite: min={kfp.min().item():.6f} max={kfp.max().item():.6f} norm_min={kfp.norm(dim=1).min().item():.6f} norm_max={kfp.norm(dim=1).max().item():.6f}")
                k_feat_proj = F.normalize(k_feat_proj, dim=-1)
                if rank == 0 and (step % 50 == 0):
                    try:
                        logging.info(
                            f"[dbg {step}] q_proj/k_proj norm range: "
                            f"q:[{q_feat_proj.norm(dim=1).min().item():.6f},{q_feat_proj.norm(dim=1).max().item():.6f}] "
                            f"k:[{k_feat_proj.norm(dim=1).min().item():.6f},{k_feat_proj.norm(dim=1).max().item():.6f}]"
                        )
                    except Exception:
                        pass


            # sd_loss, _, _ = contrastive_loss_sample(q_feat_proj, k_feat_proj)
            part_sim = (q_feat_proj @ k_feat_proj.T) / (self.temp + 1e-12)
            if (not torch.isfinite(part_sim).all()) and rank == 0:
                psm = part_sim.nan_to_num()
                logging.error(f"Monitor part_sim non-finite: min={psm.min().item():.6f} max={psm.max().item():.6f}")
            elif rank == 0 and (step % 50 == 0):
                try:
                    logging.info(
                        f"[dbg {step}] part_sim stats: mean={part_sim.mean().item():.6f} "
                        f"std={part_sim.std().item():.6f} min={part_sim.min().item():.6f} max={part_sim.max().item():.6f}"
                    )
                except Exception:
                    pass
            sd_loss = self.cross_entropy(part_sim, assign)

        elif self.is_parts == None:
            sd_loss = torch.tensor(0).cuda()
        else:
            assert("The part clustering is not supported!!!")
            
        # ================text loss==================
        # 图片和自己的描述是正样本对，图片和其他的描述是负样本对
        # if self.with_texts=="sample_level":
        #     if self.part_method=="_global_part":
        #         q_feat_text = self.fc_text(q_feat_global_part)
        #     elif self.part_method=="_part":
        #         q_feat_text = self.fc_text(q_feat_part)
        #     q_feat_text = q_feat_text / q_feat_text.norm(dim=-1, keepdim=True)
            
            
        if self.with_texts == "sample_level":
            # ---------- 图像侧（用于与文本对齐/重构的分支） ----------
            if self.part_method == "_global_part":
                q_feat_text = self.fc_text(q_feat_global_part)
            elif self.part_method == "_part":
                q_feat_text = self.fc_text(q_feat_part)
            q_feat_text = F.normalize(q_feat_text, dim=-1)  # [N, 512]

            # ---------- 文本侧：编码 -> （跨卡）聚合 ----------
            texts_batch = torch.stack(texts, dim=0)      # [N, 77]
            with torch.no_grad():
                t_local = self.clip_model.encode_text(texts_batch)  # [N, 512]
            t_all = concat_all_gather(t_local)             # [world*N, 512]

            # ---------- Warm-up 统计 & 固定白化 ----------
            # 1) warm-up 阶段：self.txt_whiten.train()，更新running stats
            # 2) 结束后：self.txt_whiten.eval()，固定住统计量，把白化当固定预处理
            if self.training and (self.txt_whiten_step.item() < self.txt_whiten_warmup_steps):
                # 处于 warm-up：更新 BN 的 running mean/var
                self.txt_whiten.train()
                t_all = self.txt_whiten(t_all)            # 不需要梯度也会更新统计
                self.txt_whiten_step += 1                 # 记一次“统计步”
            else:
                # 固定住统计：白化当作“固定标准化”
                self.txt_whiten.eval()
                with torch.no_grad():
                    t_all = self.txt_whiten(t_all)

            # 白化后统一投到球面上，并阻断梯度（保持公平）
            t_all = F.normalize(t_all, dim=-1)
            t_all = t_all.detach()

            # 切回本卡对应的文本向量（保持正样本对齐顺序）
            rank = dist.get_rank()
            world = dist.get_world_size()
            N_local = t_local.size(0)
            start = rank * N_local
            end = start + N_local
            t_local_whiten = t_all[start:end]             # [N, 512]

            # ---------- 文本分支损失：对比 or 纯重构（A 方案） ----------
            if self.is_recon == "contrastive":
                # 公平的 ITC（使用固定白化后的文本；列用全局 t_all 提供更多负样本）
                logits = (q_feat_text @ t_all.T) / self.temp         # [N, world*N]
                labels = torch.arange(N_local, device=logits.device) + start
                text_loss = F.cross_entropy(logits, labels)

            elif self.is_recon == "recon":
                # 纯重构：投影 -> 归一化 -> SmoothL1（稳健，且与“固定文本目标”一致）
                q_proj = self.recon_projector(q_feat_text)           # [N, 512]
                q_proj = F.normalize(q_proj, dim=1)

                # 两侧均为单位范数，SmoothL1 比 (1 - cos) 更稳，CIFAR100 上通常更不易炸
                text_loss = F.smooth_l1_loss(q_proj, t_local_whiten, beta=0.5)

            else:
                raise ValueError("The text fusion mode is not supported!")


        elif self.with_texts == "sample_level_tmp":
            # ---------- 图像侧（用于与文本对齐/重构的分支） ----------
            if self.part_method == "_global_part":
                q_feat_text = self.fc_text(q_feat_global_part)
            elif self.part_method == "_part":
                q_feat_text = self.fc_text(q_feat_part)
            q_feat_text = F.normalize(q_feat_text, dim=-1)  # [N, 512]

            # ---------- 文本侧：编码 -> （跨卡）聚合 -> 白化 -> L2 -> detach ----------
            texts_batch = torch.stack(texts, dim=0)  # [N, 77]
            with torch.no_grad():
                t_local = self.clip_model.encode_text(texts_batch)  # [N, 512]

            # 为了更稳的BN统计，先把本卡的文本向量跨卡拼接，进入白化
            t_all = concat_all_gather(t_local)               # [world*N, 512]
            t_all = self.txt_whiten(t_all)                   # BN白化（affine=False），更新running stats
            t_all = F.normalize(t_all, dim=-1)               # 再投到球面上
            t_all = t_all.detach()                           # 文本侧停梯度

            # 切回本卡对应的文本向量（保持正样本对齐顺序）
            rank = dist.get_rank()
            world = dist.get_world_size()
            N_local = t_local.size(0)
            start = rank * N_local
            end = start + N_local
            t_local_whiten = t_all[start:end]                # [N, 512]

            # ---------- 轻量正则（函数内联，零依赖） ----------
            def _variance_loss(x: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
                # 约束每一维的batch标准差不小于1，抑制各向异性塌陷
                std = torch.sqrt(x.var(dim=0, unbiased=False) + eps)
                return torch.mean(F.relu(1.0 - std) ** 2)

            def _xcorr_loss_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                # Barlow Twins式跨模态相关对齐（很小权重）
                B, D = x.size()
                x_zm = x - x.mean(dim=0, keepdim=True)
                y_zm = y - y.mean(dim=0, keepdim=True)
                C = (x_zm.T @ y_zm) / (B - 1 + 1e-9)        # [D, D]
                I = torch.eye(D, device=C.device)
                return ((C - I) ** 2).sum()

            # ---------- 文本分支损失：对比 or 重构 ----------
            if self.is_recon == "contrastive":
                # 采用“广义负样本”的 InfoNCE：行是本卡N个图像向量，列是全局(all-gather后)的文本向量
                logits = (q_feat_text @ t_all.T) / self.temp             # [N, world*N]
                labels = torch.arange(N_local, device=logits.device) + start  # 每行的正样本列索引
                text_loss = F.cross_entropy(logits, labels)

                # 轻量正则：对白化后的本卡文本 + 跨模态相关（与图像侧q_feat_text）
                txt_var_loss = _variance_loss(t_local_whiten)
                img_txt_xcorr = _xcorr_loss_fn(q_feat_text, t_local_whiten)
                text_loss = text_loss + self.w_txt_var * txt_var_loss + self.w_txt_xcorr * img_txt_xcorr

            elif self.is_recon == "recon":
                # 重构分支：投影头 -> 归一化 -> 1-cos + 范数正则
                q_proj = self.recon_projector(q_feat_text)               # [N, 512]
                q_proj = F.normalize(q_proj, dim=1)

                recon_loss = (1.0 - F.cosine_similarity(q_proj, t_local_whiten, dim=1)).mean()
                reg_loss = ((q_proj.norm(p=2, dim=1) - 1.0) ** 2).mean()
                text_loss = recon_loss + 1.0 * reg_loss

                # 轻量正则：对白化后的本卡文本 + 跨模态相关（与图像侧q_proj）
                txt_var_loss = _variance_loss(t_local_whiten)
                img_txt_xcorr = _xcorr_loss_fn(q_proj, t_local_whiten)
                text_loss = text_loss + self.w_txt_var * txt_var_loss + self.w_txt_xcorr * img_txt_xcorr

            else:
                raise ValueError("The text fusion mode is not supported!")

        elif self.with_texts is None:
            text_loss = torch.tensor(0, device=h.device)
        else:
            raise ValueError("The text fusion is not supported!!!")


        # ================global loss==================
        # 在计算loss时收集负样本
        # if self.attn is True:
        #     attn_weights = self.attention(h)  # [N, 1]
        #     h = h * attn_weights
            # attn_weights_m = self.attention(h_m[0])  # [N, 1]
            # h_m[0] = h_m[0] * attn_weights_m
        
        # emb_m = [concat_all_gather(x) for x in emb_m]
        # h_m = [concat_all_gather(x) for x in h_m]
        # h_m = h_m[0]      # 在global对比学习需要weak aug来得到H   TODO 之前跑的结果是standard aug得到的
        # assign = self.sinkhorn_knopp(h @ h_m.T)

        # attn_weights = self.attention(h)  # [N, 1]
        # weighted_h = h * attn_weights
        # attn_weights_m = self.attention(h_m)  # [4N, 1]
        # weighted_h = h * attn_weights_m
        # sim_matrix = weighted_h @ weighted_h_m.T  # 加权后的相似度矩阵
        # print("assign====>",assign)
        # # print("assign max:", max(assign))
        # # print("assign min:", min(assign))
        # assign = torch.sigmoid(assign)
        # print("assign norm====>",assign)
        # time.sleep(1000)

        total_loss = 0
        n_loss_terms = 0
        identity_matrix = generate_hard_assign(assign, dist.get_rank())
        for q in range(len(emb)):
            for v in range(len(emb_m)):
                if v == q:
                    continue
                emb_sim = (emb[q] @ emb_m[v].T) / (self.temp + 1e-12)
                
                # print(f"identity_matrix shape: {identity_matrix.shape}")
                # print(f"emb_sim shape: {emb_sim.shape}")
                # print(f"assign shape: {assign.shape}")
                # time.sleep(5000)
                # 切换是否使用软标签对齐
                total_loss += self.cross_entropy(emb_sim, assign)
                # total_loss += self.cross_entropy(emb_sim, identity_matrix)
                n_loss_terms += 1

        # total_loss = torch.tensor(0).cuda()

        return total_loss / n_loss_terms, sd_loss, text_loss


@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
    dist.all_gather(tensors_gather, tensor, async_op=False)
    output = torch.cat(tensors_gather, dim=0)
    return output

def generate_orthonormal_vectors(n, dim):
    A = torch.randn(dim, n)
    U, S, Vt = torch.svd(A)
    return U.T


@torch.no_grad()
def generate_hard_assign(tensor, device_id):
    """
    Generate a hard assignment matrix of shape (N, world_size*N), where:
    - Each device's identity matrix is placed in a different block.
    - Other devices' blocks are filled with zeros.
    
    Args:
        tensor: Tensor of shape (N, D) where N is the total number of samples.
        device_id: The device ID for the current machine (used for identifying the current device's samples).
    
    Returns:
        assign: A hard assignment matrix of shape (N, world_size*N).
    """
    N = tensor.size(0)  # Get the number of samples
    world_size = dist.get_world_size()  # Get the number of devices (world size)

    # 1. Create a full zero matrix of shape (N, world_size*N)
    assign = torch.zeros(N, world_size * N, device=tensor.device)

    # 2. Calculate the start and end indices for the current device's block
    start_idx = device_id * N
    end_idx = (device_id + 1) * N
    
    # 3. Place the identity matrix in the corresponding block for the current device
    assign[:, start_idx:end_idx] = torch.eye(N, device=tensor.device)
    
    return assign