import torch
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import torch.nn.functional as F
import torch.distributed as dist

class CrossDomainPrototypeManager(nn.Module):
    """
    跨域共享原型管理模块  
      - 使用一组共享原型（每个类别有若干原型，不区分源域和目标域）  
      - 原型以不可学习变量形式存储（requires_grad=False）  
      - 初始阶段采用 KMeans 初始化；后续利用 GMM 拟合更新，
        得到各聚类的均值、协方差和权重，然后采用 EMA 更新原型  
      - 同时在 forward 中计算原型相关 loss，包括特征–原型对齐 loss、
        类内原型 loss 和类间原型 loss
    """
    def __init__(self, num_classes, prototype_nums, feature_dim, coff, gmm_config=None,
                 margin_intra=0.0, margin_inter=0.0, margin_cross=0.0, lambda_intra=1.0, lambda_inter=1.0, lambda_distin=1.0, lambda_cross=1.0):
        """
        参数:
          num_classes: 类别数量
          prototype_nums: dict，键为类别 id（例如 1～num_classes），值为该类别的原型个数，如 {1: 3, 2: 2, 3: 4}
          feature_dim: 特征维度
          ema_decay: EMA 更新衰减因子（一般取 0.9～0.99）
          gmm_config: dict，GMM 拟合相关配置（可选），例如 {'covariance_type': 'diag', 'max_iter': 100, 'reg_covar': 1e-6}
          margin_intra: 类内原型距离下界（用于 loss 计算）
          margin_inter: 类间原型距离下界
          lambda_intra: 类内 loss 的权重
          lambda_inter: 类间 loss 的权重
        """
        super(CrossDomainPrototypeManager, self).__init__()
        self.num_classes = num_classes
        self.prototype_nums = prototype_nums  # 每个类别的原型数目
        self.feature_dim = feature_dim
        # self.ema_decay = ema_decay
        self.coff = coff  # EMA 更新时的平滑系数
        self.gmm_config = gmm_config if gmm_config is not None else {
            'covariance_type': 'diag', 'max_iter': 100, 'reg_covar': 1e-6
        }
        # Loss 超参数
        self.margin_intra = margin_intra
        self.margin_inter = margin_inter
        self.margin_cross = margin_cross
        self.lambda_intra = lambda_intra
        self.lambda_inter = lambda_inter
        self.lambda_distin = lambda_distin
        self.lambda_cross = lambda_cross

        # 对于每个类别，存储共享原型（均值）、协方差（对角形式）和权重
        self.prototypes = nn.ParameterDict()
        self.covariances = nn.ParameterDict()
        self.weights = nn.ParameterDict()
        for cls in range(1, num_classes + 1):
            cls = str(cls)
            num_proto = self.prototype_nums[cls]
            self.prototypes[str(cls)] = nn.Parameter(torch.zeros(num_proto, feature_dim), requires_grad=False)
            self.covariances[str(cls)] = nn.Parameter(torch.ones(num_proto, feature_dim), requires_grad=False)
            self.weights[str(cls)] = nn.Parameter(torch.ones(num_proto) / num_proto, requires_grad=False)
        
        # 标志：初始阶段是否只收集特征（如果需要特征收集功能，可结合外部特征银行同步）
        # self.collect_features_only = False
        # 记录每个原型的特征数量，源域和目标域的特征分别记录
        self.domain = None
        self.feature_counts = {domain: {str(cls): [0] * self.prototype_nums[str(cls)] for cls in range(1, num_classes + 1)} for domain in ['source', 'target']}
        
        # self.rank_cons = True
    @torch.no_grad()
    def kmeans_initialize(self, features_dict):
        """
        对每个类别的跨域特征使用 KMeans 进行初始化  
        参数:
           features_dict: dict，格式 { cls: Tensor, ... }  
                          每个 Tensor 的形状为 (N, feature_dim)
        """
        for domain in ['source', 'target']:
            features_dict_cur = features_dict[domain]
            for cls, features in features_dict_cur.items():
                num_proto = self.prototype_nums[cls]
                features_np = features.detach().cpu().numpy()
                kmeans = KMeans(n_clusters=num_proto, init='k-means++', max_iter=300, n_init=10, random_state=0)
                kmeans.fit(features_np)
                # 使用 KMeans 得到的聚类中心作为初始均值，同时利用 kmeans.labels_ 更新参数
                self.update_GMM(features_np, kmeans.labels_, cls, init=True, domain=domain)

    @torch.no_grad()
    def update_GMM(self, features, labels, cls, init=False, domain=None):
        """
        根据 GMM 结果更新指定类别的原型（均值）、协方差和权重  
        参数:
           features: Tensor 或 numpy 数组，形状 (N, feature_dim)
           labels: numpy 数组，形状 (N,)  —— 聚类标签
           cls: str，类别 id
           init: bool，是否为初始化阶段（初始化时直接赋值，不采用 EMA）
        """
        if isinstance(features, np.ndarray):
            features = torch.tensor(features, dtype=self.prototypes[cls].dtype).to(self.prototypes[cls].device)
        num_components = self.prototype_nums[cls]
        device = features.device
        new_mu = torch.zeros(num_components, self.feature_dim, device=device)
        new_cov = torch.zeros(num_components, self.feature_dim, device=device)
        new_weight = torch.zeros(num_components, device=device)
        for comp in range(num_components):
            mask = (labels == comp)
            if mask.sum() == 0:
                new_mu[comp] = self.prototypes[cls][comp]
                new_cov[comp] = self.covariances[cls][comp]
                new_weight[comp] = self.weights[cls][comp]
            else:
                comp_features = features[mask]
                comp_mean = comp_features.mean(dim=0)
                if comp_features.size(0) > 1:
                    comp_features = comp_features - comp_mean
                    comp_var = torch.var(comp_features, dim=0, unbiased=True)
                    comp_var = torch.clamp(comp_var, min=1e-6)
                else:
                    comp_var = self.covariances[cls][comp]
                comp_weight = torch.tensor(mask.sum()).float() / features.size(0)
                new_mu[comp] = comp_mean
                new_cov[comp] = comp_var
                new_weight[comp] = comp_weight
            if init:
                self.feature_counts[domain][cls][comp] = mask.sum().item() if mask.sum() > 0 else 0
            else:
                self.feature_counts[self.domain][cls][comp] += mask.sum().item() if mask.sum() > 0 else 0
        if init:
            self.prototypes[cls].data.copy_(new_mu)
            self.covariances[cls].data.copy_(new_cov)
            new_weight = new_weight.cpu().numpy()
            new_weight = new_weight / np.sum(new_weight)
            new_weight = np.round(new_weight, decimals=3)
            new_weight[-1] = 1 - np.sum(new_weight[:-1])
            new_weight = torch.tensor(new_weight, device=device)
            self.weights[cls].data.copy_(new_weight)
        else:
            updated_mu = (1 - self.coff) * self.prototypes[cls] + self.coff * new_mu
            updated_cov = (1 - self.coff) * self.covariances[cls] + self.coff * new_cov
            updated_weight = (1 - self.coff) * self.weights[cls] + self.coff * new_weight
            self.prototypes[cls].data.copy_(updated_mu)
            self.covariances[cls].data.copy_(updated_cov)
            updated_weight = updated_weight.cpu().numpy()
            updated_weight = updated_weight / np.sum(updated_weight)
            updated_weight = np.round(updated_weight, decimals=3)
            updated_weight[-1] = 1 - np.sum(updated_weight[:-1])
            updated_weight = torch.tensor(updated_weight, device=device)
            self.weights[cls].data.copy_(updated_weight)

    @torch.no_grad()
    def gmm_update(self, features, cls, update=True, max_samples=3000):
        """
        完整 GMM 更新流程：  
          1. 将输入特征转换为 numpy 数组  
          2. 用当前原型（若非全 0）作为初始化拟合 GMM  
          3. 根据 GMM 预测的聚类标签更新原型  
        参数:
           features: Tensor，形状 (N, feature_dim)
           cls: str，类别 id
        返回:
           labels: numpy 数组，预测的聚类标签
        """
        if len(features) > max_samples:
            indices = np.random.choice(len(features), max_samples, replace=False)
            features = features[indices]

        covariance_type = self.gmm_config.get('covariance_type', 'diag')
        num_components = self.prototype_nums[cls]
        if isinstance(features, np.ndarray):
            features_np = features
        else:
            features_np = features.detach().cpu().numpy().astype(np.float64)
        gmm = GaussianMixture(
            n_components=num_components,
            covariance_type=covariance_type,
            max_iter=self.gmm_config.get('max_iter', 100),
            reg_covar=self.gmm_config.get('reg_covar', 1e-6),
            random_state=0
        )
        current_proto = self.prototypes[cls].detach().cpu().numpy().astype(np.float64)
        if not np.allclose(current_proto, 0):
            gmm.means_init = current_proto
        
        # weights_np = self.weights[cls].detach().cpu().numpy()
        # weights_np = weights_np.astype(np.float32)
        # weights_np = np.round(weights_np, decimals=3)
        # weights_np = weights_np / np.sum(weights_np)  # 明确归一化
        # weights_np[-1] = 1 - np.sum(weights_np[:-1])  # 确保权重和为 1
        # gmm.weights_init = weights_np
        # self.weights[cls].data.copy_(torch.tensor(weights_np, device=self.prototypes[cls].device))
        gmm.weights_init = self.weights[cls].detach().cpu().numpy()
        if covariance_type == 'full':
            precisions_init = np.array([np.diag(1.0 / (var + 1e-6))
                                        for var in self.covariances[cls].detach().cpu().numpy()])
            gmm.precisions_init = precisions_init
        else:
            precisions_init = 1.0 / (self.covariances[cls].detach().cpu().numpy() + 1e-6).astype(np.float64)
            gmm.precisions_init = precisions_init
        try:
            gmm.fit(features_np)
            labels = gmm.predict(features_np)
        except Exception as e:
            print(e)
            dist_matrix = torch.cdist(features, self.prototypes[cls], p=2)
            labels = torch.argmin(dist_matrix, dim=1).cpu().numpy()
        # 确保每个聚类至少有一个样本
        labels = self.ensure_non_empty_clusters(features_np, labels, num_components, cls).cpu().numpy().astype(np.int32)
        if update:
            self.update_GMM(features, labels, cls, init=False)
        return labels

    @torch.no_grad()
    def ensure_non_empty_clusters(self, data, labels, n_clusters, cls):
        """
        如果某个聚类为空，则重新分配样本保证每个聚类至少有一个样本
        参数:
           data: numpy 数组，形状 (N, feature_dim)
           labels: numpy 数组，形状 (N,)
           n_clusters: 聚类个数
           cls: 当前类别 id，用于取出对应原型
        返回:
           labels: 修正后的 labels（tensor形式）
        """
        data_tensor = torch.tensor(data).float().to(self.prototypes[cls].device)
        labels_tensor = torch.tensor(labels).long().to(self.prototypes[cls].device)
        prototypes_tensor = self.prototypes[cls]  # (n_clusters, feature_dim)
        dist_matrix = torch.cdist(prototypes_tensor, data_tensor, p=2)  # (n_clusters, N)
        for cluster in range(n_clusters):
            if (labels_tensor == cluster).sum() == 0:
                min_val, min_idx = torch.min(dist_matrix[cluster], dim=0)
                labels_tensor[min_idx] = cluster
        return labels_tensor

    def orthogonality_regularization(self, embeddings):
        similarity_matrix = torch.mm(embeddings, embeddings.T)  # (N, N)
        mask = torch.eye(similarity_matrix.size(0), device=embeddings.device).bool()
        off_diagonal = similarity_matrix.masked_fill(mask, 0)
        loss = off_diagonal.pow(2).mean()
        return loss

    def pairwise_dissimilarity_loss(self, embeddings, margin=1.0):
        embeddings = F.normalize(embeddings, p=2, dim=1)
        similarity_matrix = torch.mm(embeddings, embeddings.T)
        N = embeddings.size(0)
        mask = torch.eye(N, device=embeddings.device).bool()
        similarity_matrix = similarity_matrix.masked_fill(mask, 0)
        reg_loss = self.orthogonality_regularization(embeddings)
        loss = similarity_matrix.mean()
        if loss.item() < margin:
            use_loss = False
        else:
            use_loss = True
        return loss + reg_loss * 0.1, use_loss
        # 三元组替换

    def inter_pairwise_dissimilarity_loss(self, embeddings1, embeddings2, margin=1.0):
        # 对两个输入分别归一化
        embeddings1 = F.normalize(embeddings1, p=2, dim=1)
        embeddings2 = F.normalize(embeddings2, p=2, dim=1)
        
        # 计算两个集合之间的余弦相似度矩阵
        # 注意：这里直接计算所有 pair，无需 mask，因为两个集合本身不重叠
        similarity_matrix = torch.mm(embeddings1, embeddings2.T)
        
        # 计算所有 pair 的平均相似度作为损失
        loss = similarity_matrix.mean()
        
        # 为了与 pairwise_dissimilarity_loss 风格保持一致，加入正交性正则
        reg_loss = similarity_matrix.pow(2).mean()
        
        total_loss = loss + reg_loss * 0.1
        
        # 当损失低于 margin 时认为不需要施加该损失
        if loss.item() < margin:
            use_loss = False
        else:
            use_loss = True
        
        return total_loss, use_loss


    def pairwise_similarity_loss(self, embeddings):
        N = embeddings.size(0)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        similarity_matrix = torch.mm(embeddings, embeddings.T)  # (N, N)
        mask = torch.eye(similarity_matrix.size(0), device=embeddings.device).bool()
        off_diagonal = similarity_matrix.masked_fill(mask, 0)  # 排除对角线
        loss = off_diagonal.pow(2).mean()  # 惩罚 off-diagonal 项
        l1_reg = embeddings.abs().mean()
        avg_sim = similarity_matrix.mean()
        
        if avg_sim.item() > 0.4: 
            use_loss = False
        else:
            use_loss = True
        sim_loss = -off_diagonal.mean()
        return sim_loss + loss + l1_reg * 0.01, use_loss

    def pairwise_cross_loss(self, feats, current_cls, margin=0.5):
        """
        计算跨类别特征 loss：
          对于当前类别的样本 feats，
          收集除 current_cls 外的所有类别原型，归一化后计算余弦相似度，
          转换为距离（1 - sim），并取平均。loss 定义为负的平均距离，
          当平均距离小于 margin 时施加该 loss。
        返回：(loss, use_loss)
        """
        # feats: (N, D)
        norm_feats = F.normalize(feats, p=2, dim=1)
        neg_protos = []
        for other_cls in self.prototypes.keys():
            if int(other_cls) != int(current_cls):
                neg_protos.append(self.prototypes[other_cls])
        if len(neg_protos) == 0:
            return torch.tensor(0.0, device=feats.device), False
        neg_protos = torch.cat(neg_protos, dim=0)  # (P_neg, D)
        norm_neg_protos = F.normalize(neg_protos, p=2, dim=1)
        # 计算样本与负原型之间的余弦相似度，并转换为距离
        sim_matrix = torch.mm(norm_feats, norm_neg_protos.T)  # (N, P_neg)
        # dists = 1 - sim_matrix  # 越大表示样本与其他类别原型距离越大 
        # avg_dist = dists.mean()
        # 定义 loss 为负的平均距离，从而最大化 avg_dist
        # loss = -avg_dist
        # 当 avg_dist 小于 margin 时，说明样本与其他类别原型距离较近，需施加损失
        # use_loss = avg_dist.item() < margin
        avg_sim = sim_matrix.mean()
        if avg_sim.item() < margin:
            use_loss = False
        else:
            use_loss = True
        # use_loss = avg_sim.item() < margin
        return avg_sim, use_loss

    def compute_loss(self, features_dict, group_labels):
        """
        计算原型相关 loss（基于余弦相似度）：
        1. 特征–原型对齐 loss：对于每个域和每个类别，
            根据 group_labels，将每个样本与其所属组（对应原型）的归一化向量计算余弦相似度，
            换算为距离（1 - sim），并取平均；（该部分用来拉近同类特征和对应原型）
        2. 类内原型 loss（hinge loss）：对于同一类别内不同原型，
            先归一化后计算余弦距离（1 - sim）；若小于 margin_intra 则施加惩罚；
        3. 类间原型 loss（最大化损失）：对于不同类别的原型，
            先归一化后计算余弦距离，并取负的平均值（即鼓励不同类别的原型之间距离越大越好）；
        4. 跨类别特征 loss（最大化损失）：对于每个域中每个类别的样本，
            计算其与其他类别所有原型之间的余弦距离，并取负的平均值。
        
        总 loss = λ_feat * loss_feat + λ_intra * loss_intra + λ_inter * loss_inter + λ_cross * loss_cross
        
        返回:
        total_loss, loss_feat, loss_intra, loss_inter, loss_cross
        """
        # --------------------
        # 计算类别权重（基于各类别样本数）：
        # freq_dict = {}
        # for domain in features_dict:
        #     for cls, feats in features_dict[domain].items():
        #         if cls not in freq_dict:
        #             freq_dict[cls] = len(feats)
        #         else:
        #             freq_dict[cls] = freq_dict[cls] + len(feats)
        # freq_values = torch.tensor(list(freq_dict.values()), dtype=torch.float32)
        # sum_freq = freq_values.sum()
        # class_weight = {}
        # for cls in self.prototype_nums.keys():
        #     # 注意：这里 key 使用字符串，因为 self.prototypes 的 key 是字符串
        #     class_weight[cls] = freq_dict[cls] / (sum_freq + 1e-6)
        # --------------------
        # 1. 特征–原型对齐 loss（使用余弦距离 1 - sim）
        loss_feat = 0.0
        count_feat = 0
        for domain in ['source', 'target']:
            if domain in features_dict:
                for cls, feats in features_dict[domain].items():
                    if feats.size(0) == 0:
                        continue
                    # 获取对应原型
                    prototypes = self.prototypes[cls]
                    assigned = group_labels[domain][cls]
                    
                    # 计算每个样本与其分配原型的距离
                    norm_feats = F.normalize(feats, p=2, dim=1)
                    norm_proto = F.normalize(prototypes, p=2, dim=1)
                    sim = (norm_feats * norm_proto[assigned]).sum(dim=1)
                    # loss_feat += (1 - sim).mean() * (class_weight[cls] + 1e-6)
                    loss_feat += (1 - sim).mean()
        #             norm_feats = F.normalize(feats, p=2, dim=1)  # (N, D)
        #             proto = self.prototypes[cls]            # (P, D)
        #             norm_proto = F.normalize(proto, p=2, dim=1)    # (P, D)
        #             group_labels_cur = group_labels[domain][cls]   # (N,)
        #             cls_loss = 0.0
        #             group_count = 0
        #             # 对每个原型组计算 loss
        #             for group_idx in range(norm_proto.size(0)):
        #                 mask = (group_labels_cur == group_idx)
        #                 if mask.sum() == 0:
        #                     continue
        #                 group_feats = norm_feats[mask]  # (n, D)
        #                 # 余弦相似度：sim = dot(product)，因为已经归一化
        #                 sim = torch.einsum('nc,c->n', group_feats, norm_proto[group_idx])
        #                 # 将相似度转为距离：1 - sim
        #                 d = 1 - sim
        #                 cls_loss += torch.mean(d)
        #                 group_count += 1
        #             if group_count > 0:
        #                 cls_loss = cls_loss / group_count
        #                 loss_feat += class_weight[cls] * cls_loss
        #                 count_feat += 1
        # loss_feat = loss_feat / count_feat if count_feat > 0 else torch.tensor(0.0, device=list(self.prototypes.values())[0].device)
        
        # --------------------
        # 2. 类内原型 loss（使用 hinge loss 对 1 - sim）
        loss_intra = 0.0
        count_intra = 0
        for cls in self.prototypes.keys():
            proto = self.prototypes[cls]  # (P, D)
            if proto.size(0) < 2:
                continue
            loss_intra_cur, use_loss = self.pairwise_dissimilarity_loss(proto, margin=self.margin_intra)
            if use_loss:
                loss_intra += loss_intra_cur
                count_intra += 1
            # norm_proto = F.normalize(proto, p=2, dim=1)  # (P, D)
            # # 计算 pairwise cosine similarity矩阵 (P, P)
            # sim_matrix = torch.mm(norm_proto, norm_proto.T)
            # cosine_distance = 1 - sim_matrix  # (P, P)
            # # 排除对角线
            # mask = torch.eye(cosine_distance.size(0), device=cosine_distance.device).bool()
            # cosine_distance = cosine_distance.masked_fill(mask, 1.0)
            # # 对于每对原型，如果距离小于 margin_intra，则施加惩罚
            # penalties = F.relu(self.margin_intra - cosine_distance)
            # loss_intra += torch.mean(penalties)
            # count_intra += 1
        loss_intra = loss_intra / count_intra if count_intra > 0 else torch.tensor(0.0, device=list(self.prototypes.values())[0].device)
        
        # --------------------
        # 3. 类间原型 loss（最大化距离：使用负的平均余弦距离）
        inter_losses = []
        count_inter = 0
        classes = list(self.prototypes.keys())
        for i in range(len(classes)):
            for j in range(i+1, len(classes)):
                proto_i = self.prototypes[classes[i]]  # (P_i, D)
                proto_j = self.prototypes[classes[j]]  # (P_j, D)
                # norm_proto_i = F.normalize(proto_i, p=2, dim=1)
                # norm_proto_j = F.normalize(proto_j, p=2, dim=1)
                # sim_matrix = torch.mm(norm_proto_i, norm_proto_j.T)  # (P_i, P_j)
                # cosine_distance = 1 - sim_matrix  # (P_i, P_j)
                # inter_losses.append(torch.mean(cosine_distance))
                # count_inter += 1

                inter_loss, use_loss = self.inter_pairwise_dissimilarity_loss(proto_i, proto_j, margin=self.margin_inter)
                if use_loss:
                    inter_losses.append(inter_loss)
                    count_inter += 1
        if inter_losses:
            # loss_inter = - torch.mean(torch.stack(inter_losses))  # 负号：鼓励距离最大化
            loss_inter = torch.mean(torch.stack(inter_losses))
            loss_inter = loss_inter / count_inter
        else:
            loss_inter = torch.tensor(0.0, device=list(self.prototypes.values())[0].device)
        
        # --------------------
        # 4. 跨类别特征 loss（最大化样本与其他类别原型之间的距离）
        loss_cross = 0.0
        count_cross = 0
        for domain in ['source', 'target']:
            if domain in features_dict:
                for cls, feats in features_dict[domain].items():
                    if feats.size(0) == 0:
                        continue
                    # norm_feats = F.normalize(feats, p=2, dim=1)  # (N, D)
                    # # 构造除当前类别外的所有原型
                    # neg_protos = []
                    # for other_cls in self.prototypes.keys():
                    #     if int(other_cls) != int(cls):
                    #         neg_protos.append(self.prototypes[other_cls])
                    # if len(neg_protos) == 0:
                    #     continue
                    # neg_protos = torch.cat(neg_protos, dim=0)  # (P_neg, D)
                    # norm_neg_protos = F.normalize(neg_protos, p=2, dim=1)
                    # # 计算样本与所有负原型的余弦相似度，再转换为距离
                    # dists = 1 - torch.mm(norm_feats, norm_neg_protos.T)  # (N, P_neg)
                    # # 我们希望样本与其他类别原型的距离尽可能大，
                    # # 所以定义 loss 为负的平均距离（越大越好，损失越低）
                    # loss_cross += - torch.mean(dists)
                    # count_cross += 1
                    cross_loss, use_loss = self.pairwise_cross_loss(feats, cls, margin=self.margin_cross)
                    if use_loss:
                        loss_cross += cross_loss
                        count_cross += 1
        loss_cross = loss_cross / count_cross if count_cross > 0 else torch.tensor(0.0, device=loss_feat.device)
        
        # 总 loss 加权组合，各项权重可根据需要调整
        total_loss = self.lambda_distin * loss_feat + self.lambda_intra * loss_intra + self.lambda_inter * loss_inter + self.lambda_cross * loss_cross
        return total_loss, loss_feat, loss_intra, loss_inter, loss_cross

    # @torch.no_grad()
    # def gather_global_features(self, features):
    #     """
    #     在 DDP 下，收集所有进程的特征。
    #     此处利用 all_gather_object 收集 numpy 数组，再合并。
    #     """
    #     if dist.is_initialized():
    #         world_size = dist.get_world_size()
    #         features_list = [None for _ in range(world_size)]
    #         # 将 features 转为 numpy 数组（假设 features 为 cpu tensor 或可转为 cpu）
    #         features_np = features.detach().cpu().numpy()
    #         dist.all_gather_object(features_list, features_np)
    #         global_features_np = np.concatenate(features_list, axis=0)
    #         global_features = torch.tensor(global_features_np, dtype=features.dtype, device=features.device)
    #         return global_features
    #     else:
    #         return features


    @torch.no_grad()
    def gather_global_features(self, features):
        """
        在 DDP 下，收集所有进程的特征。
        采用 all_gather 方式：先将每个进程的 features padding 到相同长度，
        然后 all_gather 收集，再去除填充部分。
        参数:
            features: Tensor，形状 (N, D)
        返回:
            global_features: Tensor，形状 (sum_i N_i, D)
        """
        if dist.is_initialized():
            world_size = dist.get_world_size()
            # 获取当前进程 features 的数量（放在 GPU 上的 tensor）
            local_size = torch.tensor([features.shape[0]], device=features.device, dtype=torch.long)
            size_list = [torch.zeros(1, device=features.device, dtype=torch.long) for _ in range(world_size)]
            dist.all_gather(size_list, local_size)
            sizes = [int(s.item()) for s in size_list]
            max_size = max(sizes)

            # 对当前进程 features 进行 pad，使得第一维大小等于 max_size
            pad_size = max_size - features.shape[0]
            if pad_size > 0:
                padding = torch.zeros((pad_size, features.shape[1]), device=features.device, dtype=features.dtype)
                padded_features = torch.cat([features, padding], dim=0)
            else:
                padded_features = features

            # 准备一个列表用于收集所有进程的 padded_features
            gathered = [torch.zeros_like(padded_features) for _ in range(world_size)]
            dist.all_gather(gathered, padded_features)

            # 去除每个进程填充的部分
            feature_list = []
            for i, feat in enumerate(gathered):
                feature_list.append(feat[:sizes[i]])
            global_features = torch.cat(feature_list, dim=0)
            return global_features
        else:
            return features
    # @torch.no_grad()
    # def gather_global_features(self, features):
    #     """
    #     在 DDP 下，收集所有进程的特征（基于张量的 all_gather）
    #     """
    #     if dist.is_initialized():
    #         # 确保特征在 GPU 上
    #         features = features.contiguous().to(device=features.device)
            
    #         # 获取所有进程的特征张量列表
    #         world_size = dist.get_world_size()
    #         features_list = [torch.zeros_like(features) for _ in range(world_size)]
            
    #         # 使用 all_gather 收集张量
    #         dist.all_gather(features_list, features)
            
    #         # 合并特征
    #         global_features = torch.cat(features_list, dim=0)
    #         return global_features
    #     else:
    #         return features


    # @torch.no_grad()
    # def sync_parameters(self):
    #     """
    #     将 prototypes、covariances、weights 从 rank 0 广播到其它进程，
    #     以保证在 DDP 中各个进程参数一致。
    #     """
    #     if dist.is_initialized():
    #         for key in self.prototypes.keys():
    #             dist.broadcast(self.prototypes[key], src=0)
    #             dist.broadcast(self.covariances[key], src=0)
    #             dist.broadcast(self.weights[key], src=0)


    @torch.no_grad()
    def sync_parameters(self):
        """
        使用非阻塞广播同步参数
        """
        if dist.is_initialized():
            for key in self.prototypes.keys():
                # 将参数移到当前设备
                proto = self.prototypes[key].to(device=torch.cuda.current_device())
                cov = self.covariances[key].to(device=torch.cuda.current_device())
                weight = self.weights[key].to(device=torch.cuda.current_device())
                
                # 异步广播
                dist.broadcast(proto, src=0, async_op=False)
                dist.broadcast(cov, src=0, async_op=False)
                dist.broadcast(weight, src=0, async_op=False)
                
                # 更新参数并释放临时变量
                self.prototypes[key].data.copy_(proto)
                self.covariances[key].data.copy_(cov)
                self.weights[key].data.copy_(weight)
                del proto, cov, weight

    def forward(self, batch_dict):
        """
        前向传播：
          1. 根据 batch_dict 中的 ROI 特征和 ROI 标签，对每个类别调用 gmm_update 更新原型
          2. 利用 batch_dict 中的特征（可由外部构造成字典 loss_features_dict）计算原型 loss
        batch_dict 要求包含：
          'roi_head_features': Tensor, shape (N, D)
          'roi_labels': Tensor, shape (N,), 类别标签（1～num_classes）
          可选：'domain' 表示该 batch 来自 'source' 或 'target'
          可选：'loss_features_dict': dict, 格式同 compute_loss() 中要求的
        返回:
          loss_tuple: (total_loss, loss_feat, loss_intra, loss_inter)
        """
        # if self.rank_cons:
        #     print('==> Check prototype manager parameters')
        #     import torch.distributed as dist
        #     print('==> rank: {}'.format(dist.get_rank()))
        #     for key in self.prototypes.keys():
        #         print('==> {} protos: {}'.format(key, self.prototypes[key][0][:5]))
        #         break

        features = batch_dict['roi_head_features']
        labels = batch_dict['roi_labels']
        domain = batch_dict.get('domain', 'source')
        self.domain = domain
        group_labels = {domain: {}}

        # 获取当前进程的 rank（若未初始化分布式，则视为 rank 0）
        if dist.is_initialized():
            rank = dist.get_rank()
        else:
            rank = 0

        # 对每个类别更新原型
        for cls in range(1, self.num_classes + 1):
            mask = (labels == cls)
            if mask.sum() == 0:
                continue
            cls = str(cls)
            cls_features = features[mask]

            # 若处于分布式模式，收集全局特征前释放无用内存
            with torch.cuda.device(features.device):
                torch.cuda.empty_cache()

            # 若处于分布式模式，则先收集所有进程该类别的特征
            if dist.is_initialized():
                global_cls_features = self.gather_global_features(cls_features)
                if rank == 0:
                    # 将特征转移到 CPU 进行 GMM 计算
                    # global_cls_features_cpu = global_cls_features.cpu().float().numpy()
                    group_labels_cur = self.gmm_update(global_cls_features, cls)
                    # del global_cls_features_cpu  # 立即释放内存
                    # 仅由 rank 0 利用全局特征进行 GMM 更新
                    # group_labels_cur = self.gmm_update(global_cls_features, cls)
                # 等待 rank 0 完成更新后，同步参数
                dist.barrier()
                self.sync_parameters()
                del global_cls_features  # 释放全局特征内存
                # 每个进程均使用更新后的参数进行 GMM 预测（不再使用欧氏距离）
            #     group_labels_cur = self.gmm_update(cls_features, cls, update=False)
            # else:
            #     group_labels_cur = self.gmm_update(cls_features, cls)
                # 各进程使用本地特征预测标签（不更新原型）
                group_labels_cur = self.gmm_update(cls_features, cls, update=False)
            else:
                group_labels_cur = self.gmm_update(cls_features, cls)
            if dist.is_initialized():
                # 等待 rank 0 完成更新后，同步所有进程参数
                dist.barrier()
                self.sync_parameters()
            group_labels_cur = torch.from_numpy(group_labels_cur).long().to(features.device)
            group_labels[domain][cls] = group_labels_cur
            # group_labels_cur = self.gmm_update(cls_features, cls)
            # group_labels_cur = torch.from_numpy(group_labels_cur).long().to(features.device)
            # group_labels[domain][cls] = group_labels_cur
        # 构造 loss_features_dict，如果 batch_dict 中没有提供，则按当前域构造
        if 'loss_features_dict' in batch_dict:
            loss_features_dict = batch_dict['loss_features_dict']
        else:
            loss_features_dict = {domain: {}}
            for cls in range(1, self.num_classes + 1):
                mask = (labels == cls)
                cls = str(cls)
                if mask.sum() > 0:
                    loss_features_dict[domain][cls] = features[mask]
        # 计算 loss
        loss_tuple = self.compute_loss(loss_features_dict, group_labels)
        return loss_tuple
