import torch
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import torch.nn.functional as F
import torch.distributed as dist
from pcdet.config import cfg
from pcdet.models.model_utils.model_nms_utils import class_agnostic_nms

class CrossDomainPrototypeManager(nn.Module):
    """
    跨域共享原型管理模块  
      - 使用一组共享原型（每个类别有若干原型，不区分源域和目标域）  
      - 原型以不可学习变量形式存储（requires_grad=False）  
      - 初始阶段采用 KMeans 初始化；后续利用 GMM 拟合更新，
        得到各聚类的均值、协方差和权重，然后采用 EMA 更新原型  
      - 同时在 forward 中计算原型相关 loss，包括特征–原型对齐 loss、
        类内原型 loss 和类间原型 loss
    """
    def __init__(self, prototype_cfg, feature_dim):
        """
        参数:
          num_classes: 类别数量
          prototype_nums: dict，键为类别 id（例如 1～num_classes），值为该类别的原型个数，如 {1: 3, 2: 2, 3: 4}
          feature_dim: 特征维度
          ema_decay: EMA 更新衰减因子（一般取 0.9～0.99）
          gmm_config: dict，GMM 拟合相关配置（可选），例如 {'covariance_type': 'diag', 'max_iter': 100, 'reg_covar': 1e-6}
          margin_intra: 类内原型距离下界（用于 loss 计算）
          margin_inter: 类间原型距离下界
          lambda_intra: 类内 loss 的权重
          lambda_inter: 类间 loss 的权重
        """
        super(CrossDomainPrototypeManager, self).__init__()
        self.prototype_cfg = prototype_cfg
        self.use_backgroud = prototype_cfg.get('use_backgroud', False)
        self.use_multi_backgroud = prototype_cfg.get('use_multi_backgroud', False)
        prototype_nums_tmp = prototype_cfg.PROTOTYPE_NUMS
        prototype_nums = {str(i + 1): prototype_nums_tmp[i] for i in range(len(prototype_nums_tmp))}
        if self.use_backgroud:
            if not self.use_multi_backgroud:
                prototype_nums[str(len(prototype_nums) + 1)] = 1
            else:
                prototype_nums[str(len(prototype_nums) + 1)] = len(prototype_nums)
        self.num_classes = len(prototype_nums)
        
        gmm_config = {'covariance_type': 'diag', 'max_iter': prototype_cfg.get('max_iter', 100), 'reg_covar': 1e-6}
        lambda_config = prototype_cfg.get('lambda', None)

        margin_feat_intra = prototype_cfg.get('margin_feat_intra', 0.1)
        margin_feat_inter = prototype_cfg.get('margin_feat_inter', 0.1)
        margin_align_intra = prototype_cfg.get('margin_align_intra', 0.1)
        margin_align_inter = prototype_cfg.get('margin_align_inter', 0.1)

        lambda_feat = lambda_config.get('lambda_feat', 1.0) if lambda_config is not None else 1.0
        lambda_feat_intra = lambda_config.get('lambda_feat_intra', 1.0) if lambda_config is not None else 1.0
        lambda_feat_inter = lambda_config.get('lambda_feat_inter', 1.0) if lambda_config is not None else 1.0
        lambda_align = lambda_config.get('lambda_align', 1.0) if lambda_config is not None else 1.0
        lambda_align_intra = lambda_config.get('lambda_align_intra', 1.0) if lambda_config is not None else 1.0
        lambda_align_inter = lambda_config.get('lambda_align_inter', 1.0) if lambda_config is not None else 1.0
        lambda_bg = lambda_config.get('lambda_bg', 1.0) if lambda_config is not None else 1.0
        coff = prototype_cfg.get('coff', 1.0)


        self.prototype_nums = prototype_nums  # 每个类别的原型数目
        self.feature_dim = feature_dim
        # self.ema_decay = ema_decay
        self.coff = coff  # EMA 更新时的平滑系数
        self.gmm_config = gmm_config if gmm_config is not None else {
            'covariance_type': 'diag', 'max_iter': 100, 'reg_covar': 1e-6
        }
        # Loss 超参数
        
        self.lambda_feat = lambda_feat
        self.lambda_feat_intra = lambda_feat_intra
        self.lambda_feat_inter = lambda_feat_inter
        self.lambda_align = lambda_align
        self.lambda_align_intra = lambda_align_intra
        self.lambda_align_inter = lambda_align_inter
        self.lambda_bg = lambda_bg

        self.margin_feat_intra = margin_feat_intra
        self.margin_feat_inter = margin_feat_inter
        self.margin_align_intra = margin_align_intra
        self.margin_align_inter = margin_align_inter

        # 可学习的原型参数
        self.learnable_prototypes = nn.ParameterDict()
        for cls in range(1, self.num_classes + 1):
            cls = str(cls)
            # if self.use_backgroud and cls == str(self.num_classes):
            #     continue
            num_proto = self.prototype_nums[cls]
            self.learnable_prototypes[str(cls)] = nn.Parameter(torch.randn(num_proto, feature_dim))
        
        # GMM统计量（不可学习）
        self.gmm_prototypes = nn.ParameterDict()
        self.gmm_covariances = nn.ParameterDict()
        self.gmm_weights = nn.ParameterDict()
        self.gmm_precisions_cholesky_ = nn.ParameterDict()
        for cls in range(1, self.num_classes + 1):
            cls_str = str(cls)
            num_proto = prototype_nums[cls_str]
            self.gmm_prototypes[cls_str] = nn.Parameter(torch.zeros(num_proto, feature_dim), requires_grad=False)
            self.gmm_covariances[cls_str] = nn.Parameter(torch.zeros(num_proto, feature_dim), requires_grad=False)
            self.gmm_weights[cls_str] = nn.Parameter(torch.zeros(num_proto), requires_grad=False)
            self.gmm_precisions_cholesky_[cls_str] = nn.Parameter(torch.zeros(num_proto, feature_dim), requires_grad=False)


        # 记录每个原型的特征数量，源域和目标域的特征分别记录
        self.feature_counts = {domain: {str(cls): [0]*prototype_nums[str(cls)] for cls in range(1, self.num_classes + 1)} for domain in ['source', 'target']}
        
        self.domain = None
        self.init = False

        self.init_prototypes(weight_init='xavier')

        self.use_gmm = prototype_cfg.get('use_gmm', True)
        self.use_cur_gmm = prototype_cfg.get('use_cur_gmm', True)

        self.use_kmeans = prototype_cfg.get('use_kmeans', False)
        

    def init_prototypes(self, weight_init='xavier'):
        if weight_init == 'kaiming':
            init_func = nn.init.kaiming_normal_
        elif weight_init == 'xavier':
            init_func = nn.init.xavier_normal_
        elif weight_init == 'normal':
            init_func = nn.init.normal_
        else:
            raise NotImplementedError
        for cls in range(1, self.num_classes + 1):
            cls = str(cls)
            # if self.use_backgroud and cls == str(self.num_classes):
            #     continue
            init_func(self.learnable_prototypes[cls])

    @torch.no_grad()
    def gmm_update(self, features_dict):
        """
        完整 GMM 更新流程：  
          1. 将输入特征转换为 numpy 数组  
          2. 用当前原型（若非全 0）作为初始化拟合 GMM  
          3. 根据 GMM 预测的聚类标签更新原型  
        参数:
           features_dict: dict，键为域名，值为该域下的特征
        """
        for domain in features_dict:
            for cls, feats in features_dict[domain].items():
                if feats.size(0) == 0:
                    continue
                self.gmm_update_cls(feats, cls, domain=domain)
        self.init = True

    @torch.no_grad()
    def gmm_update_cls(self, features, cls, update=True, domain=None):
        """
        完整 GMM 更新流程：  
          1. 将输入特征转换为 numpy 数组  
          2. 用当前原型（若非全 0）作为初始化拟合 GMM  
          3. 根据 GMM 预测的聚类标签更新原型  
        参数:
           features: Tensor，形状 (N, feature_dim)
           cls: str，类别 id
        返回:
           labels: numpy 数组，预测的聚类标签
        """
        covariance_type = self.gmm_config.get('covariance_type', 'diag')
        num_components = self.prototype_nums[cls]
        if isinstance(features, np.ndarray):
            features_np = features
        else:
            features_np = features.detach().cpu().numpy().astype(np.float64)
        gmm = GaussianMixture(
            n_components=num_components,
            covariance_type=covariance_type,
            max_iter=self.gmm_config.get('max_iter', 100),
            reg_covar=self.gmm_config.get('reg_covar', 1e-6),
            random_state=0
        )

        current_proto = self.gmm_prototypes[cls].detach().cpu().numpy().astype(np.float64)
        current_cov = self.gmm_covariances[cls].detach().cpu().numpy().astype(np.float64)
        # current_precisions = self.gmm_covariances[cls].detach().cpu().numpy().astype(np.float64)
        # precisions_init = 1.0 / (current_precisions + 1e-6)
        current_precisions_cholesky = self.gmm_precisions_cholesky_[cls].detach().cpu().numpy().astype(np.float64)

        weights_np = self.gmm_weights[cls].detach().cpu().numpy().astype(np.float64)
        # if update:
            # if not np.allclose(current_proto, 0):
            #     gmm.means_init = current_proto
            # if not np.allclose(current_cov, 0):
            #     gmm.weights_init = weights_np
            # if not np.allclose(current_cov, 0):
            #     gmm.precisions_init = current_precisions
            # if not np.allclose(current_precisions_cholesky, 0):
            #     gmm.precisions_cholesky_init = current_precisions_cholesky
        # else:
        if self.init: # 如果 init 过了，就直接使用之前的统计量
            gmm.means_ = current_proto
            gmm.covariances_ = current_cov
            gmm.weights_ = weights_np
            # if not np.allclose(current_precisions_cholesky, 0):
            gmm.precisions_cholesky_ = current_precisions_cholesky
            # else:
            #     gmm.precisions_cholesky_init = np.diag(1 / np.sqrt(current_cov + 1e-6))
        # if not np.allclose(current_proto, 0):
        #     gmm.means_ = current_proto
        # if not np.allclose(current_cov, 0):
        #     gmm.covariances_ = current_cov
        # if not np.allclose(weights_np, 0):
        #     weights_np = np.round(weights_np, decimals=3)
        #     weights_np = weights_np / np.sum(weights_np)  # 明确归一化
        #     weights_np[-1] = 1 - np.sum(weights_np[:-1])  # 确保权重和为 1
        #     gmm.weights_ = weights_np
        # if not np.allclose(current_precisions_cholesky, 0):
        #     gmm.precisions_cholesky_ = current_precisions_cholesky
        # elif not np.allclose(current_cov, 0):
        #     gmm.precisions_cholesky_ = np.diag(1 / np.sqrt(current_cov + 1e-6))
        
        if self.use_gmm:
            try:
                if update:
                    if features_np.shape[0] > 1:
                        gmm.fit(features_np)
                        labels = gmm.predict(features_np)
                    else:
                        labels = gmm.predict(features_np)
                else:
                    labels = gmm.predict(features_np)
            except Exception as e:
                print(e)
                dist_matrix = torch.cdist(features, self.gmm_prototypes[cls], p=2)
                labels = torch.argmin(dist_matrix, dim=1).cpu().numpy()
        elif self.use_kmeans and update:
            kmeans = KMeans(n_clusters=num_components, init='k-means++', max_iter=300, n_init=10, random_state=0)
            kmeans.fit(features_np)
            labels = kmeans.labels_
        else:
            features = torch.tensor(features_np).float().to(self.gmm_prototypes[cls].device)
            dist_matrix = torch.cdist(features, self.gmm_prototypes[cls], p=2)
            labels = torch.argmin(dist_matrix, dim=1).cpu().numpy()
            update = False


        # 确保每个聚类至少有一个样本
        labels = self.ensure_non_empty_clusters(features_np, labels, num_components, cls).cpu().numpy().astype(np.int32)
        if update:
            self.update_GMM(features, labels, cls, init=False, domain=domain, means_=gmm.means_, covariances_=gmm.covariances_, weights_=gmm.weights_, precisions_cholesky_=gmm.precisions_cholesky_)
        return labels


    @torch.no_grad()
    def update_GMM(self, features, labels, cls, init=False, domain=None, means_=None, covariances_=None, weights_=None, precisions_cholesky_=None):
        """
        根据 GMM 结果更新指定类别的原型（均值）、协方差和权重  
        参数:
           features: Tensor 或 numpy 数组，形状 (N, feature_dim)
           labels: numpy 数组，形状 (N,)  —— 聚类标签
           cls: str，类别 id
           init: bool，是否为初始化阶段（初始化时直接赋值，不采用 EMA）
        """
        if isinstance(features, np.ndarray):
            features = torch.tensor(features, dtype=self.gmm_prototypes[cls].dtype).to(self.gmm_prototypes[cls].device)
        else:
            features = features.to(self.gmm_prototypes[cls].device)
        num_components = self.prototype_nums[cls]
        device = features.device
        new_mu = torch.zeros(num_components, self.feature_dim, device=device)
        new_cov = torch.zeros(num_components, self.feature_dim, device=device)
        new_weight = torch.zeros(num_components, device=device)
        new_chol = torch.zeros(num_components, self.feature_dim, device=device)
        for comp in range(num_components):
            mask = (labels == comp)
            if mask.sum() == 0:
                new_mu[comp] = self.gmm_prototypes[cls][comp]
                new_cov[comp] = self.gmm_covariances[cls][comp]
                new_weight[comp] = self.gmm_weights[cls][comp]
            else:
                if not self.use_cur_gmm:
                    comp_features = features[mask]
                    comp_mean = comp_features.mean(dim=0)
                    if comp_features.size(0) > 1:
                        comp_features = comp_features - comp_mean
                        comp_var = torch.var(comp_features, dim=0, unbiased=True)
                        comp_var = torch.clamp(comp_var, min=1e-6)
                    else:
                        comp_var = self.gmm_covariances[cls][comp]
                    comp_weight = torch.tensor(mask.sum()).float() / features.size(0)
                    # cholesky_ = torch.cholesky(torch.inverse(torch.diag(comp_var)))
                    # np.diag(1 / np.sqrt(current_cov + 1e-6))
                    cholesky_ = torch.diag(1 / torch.sqrt(comp_var + 1e-6))
                    new_mu[comp] = comp_mean
                    new_cov[comp] = comp_var
                    new_weight[comp] = comp_weight
                    new_chol[comp] = cholesky_
                else:
                    new_mu[comp] = torch.tensor(means_[comp], device=device)
                    new_cov[comp] = torch.tensor(covariances_[comp], device=device)
                    new_weight[comp] = torch.tensor(weights_[comp], device=device)
                    new_chol[comp] = torch.tensor(precisions_cholesky_[comp], device=device)

            if not self.init:
                self.feature_counts[domain][cls][comp] = mask.sum().item() if mask.sum() > 0 else 0
            else:
                self.feature_counts[domain][cls][comp] += mask.sum().item() if mask.sum() > 0 else 0
        if not self.init:
            self.gmm_prototypes[cls].data.copy_(new_mu)
            self.gmm_covariances[cls].data.copy_(new_cov)
            new_weight = new_weight.cpu().numpy()
            new_weight = new_weight / np.sum(new_weight)
            new_weight = np.round(new_weight, decimals=3)
            new_weight[-1] = 1 - np.sum(new_weight[:-1])
            new_weight = torch.tensor(new_weight, device=device)
            self.gmm_weights[cls].data.copy_(new_weight)
            self.gmm_precisions_cholesky_[cls].data.copy_(new_chol)
        else:
            updated_mu = (1 - self.coff) * self.gmm_prototypes[cls] + self.coff * new_mu
            updated_cov = (1 - self.coff) * self.gmm_covariances[cls] + self.coff * new_cov
            updated_weight = (1 - self.coff) * self.gmm_weights[cls] + self.coff * new_weight
            self.gmm_prototypes[cls].data.copy_(updated_mu)
            self.gmm_covariances[cls].data.copy_(updated_cov)
            updated_weight = updated_weight.cpu().numpy()
            updated_weight = updated_weight / np.sum(updated_weight)
            updated_weight = np.round(updated_weight, decimals=3)
            updated_weight[-1] = 1 - np.sum(updated_weight[:-1])
            updated_weight = torch.tensor(updated_weight, device=device)
            self.gmm_weights[cls].data.copy_(updated_weight)
            updated_chol = (1 - self.coff) * self.gmm_precisions_cholesky_[cls] + self.coff * new_chol
            self.gmm_precisions_cholesky_[cls].data.copy_(updated_chol)

    @torch.no_grad()
    def ensure_non_empty_clusters(self, data, labels, n_clusters, cls):
        """
        如果某个聚类为空，则重新分配样本保证每个聚类至少有一个样本
        参数:
           data: numpy 数组，形状 (N, feature_dim)
           labels: numpy 数组，形状 (N,)
           n_clusters: 聚类个数
           cls: 当前类别 id，用于取出对应原型
        返回:
           labels: 修正后的 labels（tensor形式）
        """
        data_tensor = torch.tensor(data).float().to(self.gmm_prototypes[cls].device)
        labels_tensor = torch.tensor(labels).long().to(self.gmm_prototypes[cls].device)
        prototypes_tensor = self.gmm_prototypes[cls]  # (n_clusters, feature_dim)
        dist_matrix = torch.cdist(prototypes_tensor, data_tensor, p=2)  # (n_clusters, N)
        for cluster in range(n_clusters):
            if (labels_tensor == cluster).sum() == 0:
                min_val, min_idx = torch.min(dist_matrix[cluster], dim=0)
                labels_tensor[min_idx] = cluster
        return labels_tensor

    def orthogonality_regularization(self, embeddings):
        similarity_matrix = torch.mm(embeddings, embeddings.T)  # (N, N)
        mask = torch.eye(similarity_matrix.size(0), device=embeddings.device).bool()
        off_diagonal = similarity_matrix.masked_fill(mask, 0)
        loss = off_diagonal.pow(2).mean()
        return loss

    def pairwise_dissimilarity_loss(self, embeddings, margin=1.0):
        embeddings = F.normalize(embeddings, p=2, dim=1)
        similarity_matrix = torch.mm(embeddings, embeddings.T)
        N = embeddings.size(0)
        mask = torch.eye(N, device=embeddings.device).bool()
        similarity_matrix = similarity_matrix.masked_fill(mask, 0)
        reg_loss = self.orthogonality_regularization(embeddings)
        loss = similarity_matrix.mean()
        if loss.item() < margin:
            use_loss = False
        else:
            use_loss = True
        return loss + reg_loss * 0.1, use_loss
        # 三元组替换

    def inter_pairwise_dissimilarity_loss(self, embeddings1, embeddings2, margin=1.0):
        # 对两个输入分别归一化
        embeddings1 = F.normalize(embeddings1, p=2, dim=1)
        embeddings2 = F.normalize(embeddings2, p=2, dim=1)
        
        # 计算两个集合之间的余弦相似度矩阵
        # 注意：这里直接计算所有 pair，无需 mask，因为两个集合本身不重叠
        similarity_matrix = torch.mm(embeddings1, embeddings2.T)
        
        # 计算所有 pair 的平均相似度作为损失
        loss = similarity_matrix.mean()
        
        # 为了与 pairwise_dissimilarity_loss 风格保持一致，加入正交性正则
        reg_loss = similarity_matrix.pow(2).mean()
        
        total_loss = loss + reg_loss * 0.1
        
        # 当损失低于 margin 时认为不需要施加该损失
        if loss.item() < margin:
            use_loss = False
        else:
            use_loss = True
        
        return total_loss, use_loss

    def align_pairwise_similarity_loss(self, embeddings1, embeddings2):
        embeddings1 = F.normalize(embeddings1, p=2, dim=1)
        embeddings2 = F.normalize(embeddings2, p=2, dim=1)
        similarity_matrix = torch.mm(embeddings1, embeddings2.T)
        avg_sim = similarity_matrix.mean()
        loss = (1 - avg_sim).mean()

        return loss

    def pairwise_similarity_loss(self, embeddings):
        N = embeddings.size(0)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        similarity_matrix = torch.mm(embeddings, embeddings.T)  # (N, N)
        mask = torch.eye(similarity_matrix.size(0), device=embeddings.device).bool()
        off_diagonal = similarity_matrix.masked_fill(mask, 0)  # 排除对角线
        loss = off_diagonal.pow(2).mean()  # 惩罚 off-diagonal 项
        l1_reg = embeddings.abs().mean()
        avg_sim = similarity_matrix.mean()
        
        if avg_sim.item() > 0.4: 
            use_loss = False
        else:
            use_loss = True
        sim_loss = -off_diagonal.mean()
        return sim_loss + loss + l1_reg * 0.01, use_loss

    def pairwise_cross_loss(self, feats, current_cls, margin=0.5):
        """
        计算跨类别特征 loss：
          对于当前类别的样本 feats，
          收集除 current_cls 外的所有类别原型，归一化后计算余弦相似度，
          转换为距离（1 - sim），并取平均。loss 定义为负的平均距离，
          当平均距离小于 margin 时施加该 loss。
        返回：(loss, use_loss)
        """
        # feats: (N, D)
        norm_feats = F.normalize(feats, p=2, dim=1)
        neg_protos = []
        for other_cls in self.prototypes.keys():
            if int(other_cls) != int(current_cls):
                neg_protos.append(self.prototypes[other_cls])
        if len(neg_protos) == 0:
            return torch.tensor(0.0, device=feats.device), False
        neg_protos = torch.cat(neg_protos, dim=0)  # (P_neg, D)
        norm_neg_protos = F.normalize(neg_protos, p=2, dim=1)
        # 计算样本与负原型之间的余弦相似度，并转换为距离
        sim_matrix = torch.mm(norm_feats, norm_neg_protos.T)  # (N, P_neg)
        # dists = 1 - sim_matrix  # 越大表示样本与其他类别原型距离越大 
        # avg_dist = dists.mean()
        # 定义 loss 为负的平均距离，从而最大化 avg_dist
        # loss = -avg_dist
        # 当 avg_dist 小于 margin 时，说明样本与其他类别原型距离较近，需施加损失
        # use_loss = avg_dist.item() < margin
        avg_sim = sim_matrix.mean()
        if avg_sim.item() < margin:
            use_loss = False
        else:
            use_loss = True
        # use_loss = avg_sim.item() < margin
        return avg_sim, use_loss

    def compute_loss(self, features_dict, group_labels):
        """
        计算原型相关 loss（基于余弦相似度）：
        1. 特征–原型对齐 loss：对于每个域和每个类别，
            根据 group_labels，将每个样本与其所属组（对应原型）的归一化向量计算余弦相似度，
            换算为距离（1 - sim），并取平均；（该部分用来拉近同类特征和对应原型）
        2. 类内原型 loss（hinge loss）：对于同一类别内不同原型，
            先归一化后计算余弦距离（1 - sim）；若小于 margin_intra 则施加惩罚；
        3. 类间原型 loss（最大化损失）：对于不同类别的原型，
            先归一化后计算余弦距离，并取负的平均值（即鼓励不同类别的原型之间距离越大越好）；
        4. 跨类别特征 loss（最大化损失）：对于每个域中每个类别的样本，
            计算其与其他类别所有原型之间的余弦距离，并取负的平均值。
        
        总 loss = λ_feat * loss_feat + λ_intra * loss_intra + λ_inter * loss_inter + λ_cross * loss_cross
        
        返回:
        total_loss, loss_feat, loss_intra, loss_inter, loss_cross
        """
        # --------------------
        # 计算类别权重（逆频率加权）
        # freq_dict = {}
        # for domain in features_dict:
        #     for cls, feats in features_dict[domain].items():
        #         if feats.size(0) == 0: continue
        #         if cls not in freq_dict:
        #             freq_dict[cls] = len(feats)
        #         else:
        #             freq_dict[cls] = freq_dict[cls] + len(feats)

        # # 计算逆频率权重（样本越少权重越高）
        # sum_freq = sum(freq_dict.values())
        # class_weight = {
        #     cls: sum_freq / (freq + 1e-6)  # 逆频率，避免除零
        #     for cls, freq in freq_dict.items()
        # }
        class_weight = {}
        for cls in self.prototype_nums.keys():
            if cls not in class_weight:
                class_weight[cls] = 1.0

        if self.use_backgroud:
            class_weight[str(self.num_classes)] = self.lambda_bg
        # ------------------------------------------------------------
        # 1.特征-原型对齐 loss
        # --------------------
        # 1.1 对应特征–可学习原型对齐 loss（使用余弦距离 1 - sim）
        loss_feat = 0.0
        count_feat = 0
        for domain in ['source', 'target']:
            if domain in features_dict:
                for cls, feats in features_dict[domain].items():
                    if feats.size(0) == 0:
                        continue
                    # if self.use_backgroud and cls == str(self.num_classes):
                    #     continue
                    # 获取对应原型
                    prototypes = self.learnable_prototypes[cls]
                    assigned = group_labels[domain][cls]
                    
                    # 计算每个样本与其分配原型的距离
                    norm_feats = F.normalize(feats, p=2, dim=1)
                    norm_proto = F.normalize(prototypes, p=2, dim=1)
                    sim = torch.mm(norm_feats, norm_proto.T)  # (N, P)

                    # 对于每个样本，获取它与对应原型的相似度
                    sim = sim[range(len(sim)), assigned]

                    # 计算与分配原型的相似度（距离）
                    loss_feat += (1 - sim).mean() * class_weight[cls]
                    count_feat += 1


                    # gmm 均值与特征之间的损失：
                    # norm_feats = F.normalize(feats, p=2, dim=1)
                    # norm_proto = F.normalize(self.gmm_prototypes[cls], p=2, dim=1)
                    # sim = (norm_feats * norm_proto[assigned]).sum(dim=1)
                    # loss_feat += (1 - sim).mean()
                    # count_feat += 1
        loss_feat = loss_feat / count_feat if count_feat > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
        # --------------------
        # 1.2 同类别特征-跨原型特征 loss（最大化样本与同类别其他原型之间的距离）
        loss_feat_intra = 0.0
        count_feat_intra = 0
        for domain in ['source', 'target']:
            if domain in features_dict:
                for cls, feats in features_dict[domain].items():
                    if feats.size(0) == 0:
                        continue
                    # if self.use_backgroud and cls == str(self.num_classes):
                    #     continue

                    # 获取当前类别的所有原型
                    prototypes = self.learnable_prototypes[cls]  # (P, D)
                    if prototypes.size(0) < 2:
                        continue  # 至少需要两个原型才计算类内损失
                    norm_prototypes = F.normalize(prototypes, p=2, dim=1)  # (P, D)

                    # 获取当前批次的特征
                    norm_feats = F.normalize(feats, p=2, dim=1)  # (N, D)

                    # 计算当前特征与同类别所有原型的相似度（余弦相似度）
                    sim_matrix = torch.mm(norm_feats, norm_prototypes.T)  # (N, P)

                    # 获取当前特征所属原型的索引
                    assigned_proto_idx = group_labels[domain][cls]  # (N,)

                    # 对每个特征，将其所属原型的相似度设为负无穷，避免它与自己的原型计算损失
                    sim_matrix[range(len(sim_matrix)), assigned_proto_idx] = -1

                    sim_matrix = sim_matrix[sim_matrix != -1].view(-1, len(prototypes) - 1)  # (N, P-1)

                    penalties = sim_matrix[sim_matrix > self.margin_feat_intra] - self.margin_feat_intra
                    
                    if penalties.size(0) > 0:
                        # 计算损失
                        loss_feat_intra += penalties.mean() * class_weight[cls]
                        count_feat_intra += 1

        # 平均损失
        loss_feat_intra = loss_feat_intra / count_feat_intra if count_feat_intra > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
                    
        
        # --------------------
        # 1.3 跨类别特征原型 loss（最大化样本与其他类别原型之间的距离）
        loss_feat_inter = 0.0
        count_cross = 0
        loss_feat_inter_bg = 0.0
        for domain in ['source', 'target']:
            if domain in features_dict:
                for cls, feats in features_dict[domain].items():
                    if feats.size(0) == 0:
                        continue
                    # if self.use_backgroud and cls == str(self.num_classes):
                    #     continue
                    # 收集其他类别所有原型
                    other_protos = []
                    for other_cls in self.learnable_prototypes:
                        if other_cls != cls:
                            other_protos.append(self.learnable_prototypes[other_cls])
                            
                    if len(other_protos) == 0: continue
                    other_protos = torch.cat(other_protos, dim=0)
                    
                    # 计算特征与异类原型的相似度
                    norm_feats = F.normalize(feats, p=2, dim=1)
                    norm_other = F.normalize(other_protos, p=2, dim=1)
                    sim_matrix = torch.mm(norm_feats, norm_other.T)
                    
                    # 计算hinge loss
                    # penalties = F.relu(sim_matrix - self.margin_feat_inter)
                    penalties = sim_matrix[sim_matrix > self.margin_feat_inter] - self.margin_feat_inter

                    if penalties.size(0) > 0:
                        loss_feat_inter += penalties.mean() * class_weight[cls]
                        count_cross += 1

                    if self.use_backgroud and cls == str(self.num_classes):
                        loss_feat_inter_bg = penalties.mean()

        loss_feat_inter = loss_feat_inter / count_cross if loss_feat_inter > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
        loss_feat_inter_bg = loss_feat_inter_bg if loss_feat_inter_bg > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
        
        # if self.use_backgroud:
        #     loss_feat_inter_bg = 0.0
        #     count_cross_bg = 0
        #     for domain in ['source', 'target']:
        #         if domain in features_dict:
        #             feats = features_dict[domain][str(self.num_classes)]
        #             if feats.size(0) == 0:
        #                 continue
        #             # 收集其他类别所有原型
        #             other_protos = []
        #             for other_cls in self.learnable_prototypes:
        #                 if other_cls != str(self.num_classes):
        #                     other_protos.append(self.learnable_prototypes[other_cls])
                            
        #             if len(other_protos) == 0: continue
        #             other_protos = torch.cat(other_protos, dim=0)
                    
        #             # 计算特征与异类原型的相似度
        #             norm_feats = F.normalize(feats, p=2, dim=1)
        #             norm_other = F.normalize(other_protos, p=2, dim=1)
        #             sim_matrix = torch.mm(norm_feats, norm_other.T)
                    
        #             # 计算hinge loss
        #             # penalties = F.relu(sim_matrix - self.margin_feat_inter)
        #             penalties = sim_matrix[sim_matrix > self.margin_feat_inter] - self.margin_feat_inter

        #             if penalties.size(0) > 0:
        #                 loss_feat_inter_bg += penalties.mean() * class_weight[str(self.num_classes)]
        #                 count_cross_bg += 1
            
        #     loss_feat_inter_bg = loss_feat_inter_bg / count_cross_bg if count_cross_bg > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
            # loss_feat_inter = loss_feat_inter + self.lambda_bg * loss_feat_inter_bg
        # ------------------------------------------------------------
        # 2. 原型之间loss
        # --------------------
        # 2.1 可学习原型与GMM原型的对齐损失
        loss_align = 0.0
        count_align = 0
        for cls in self.learnable_prototypes.keys():
            # if self.use_backgroud and cls == str(self.num_classes):
            #     continue
            learnable_proto = self.learnable_prototypes[cls]
            gmm_proto = self.gmm_prototypes[cls]
            assert len(learnable_proto) == len(gmm_proto), "Prototype counts must match"
            # 建议在原型初始化代码中添加检查
            norm_learnable_proto = F.normalize(learnable_proto, p=2, dim=1)
            norm_gmm_proto = F.normalize(gmm_proto, p=2, dim=1)
            # if self.use_gmm:
            # 计算对应原型的余弦相似度，对角线
            sim = torch.mm(norm_learnable_proto, norm_gmm_proto.T).diag()
            loss_align += (1 - sim).mean() * class_weight[cls]
            # else:
            #     # 匈牙利匹配 learnable 原型与 GMM 原型
            #     # 1. 计算余弦相似度矩阵
            #     sim_matrix = torch.mm(norm_learnable_proto, norm_gmm_proto.T)
            #     # 2. 匈牙利匹配
            #     from scipy.optimize import linear_sum_assignment
            #     row_ind, col_ind = linear_sum_assignment(-sim_matrix.cpu().numpy())
            #     # 3. 计算匹配的损失
            #     loss_align += 1 - sim_matrix[row_ind, col_ind].mean()
                
            # sim = (norm_learnable_proto * norm_gmm_proto).sum(dim=1)
            # loss_align += (1 - sim).mean()
            count_align += 1

        loss_align = loss_align / count_align if count_align > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)

        # --------------------
        # 2.2 类内可学习原型 loss
        loss_align_intra = 0.0
        count_intra = 0
        for cls in self.learnable_prototypes.keys():
            # if self.use_backgroud and cls == str(self.num_classes):
            #     continue
            prototypes = self.learnable_prototypes[cls]  # (P, D)
            if prototypes.size(0) < 2:
                continue
            norm_prototypes = F.normalize(prototypes, p=2, dim=1)
            sim_matrix = torch.mm(norm_prototypes, norm_prototypes.T)

            # 对角线设为 -1，避免计算自己与自己的损失
            sim_matrix = sim_matrix.masked_fill(torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool(), -1)
            sim_matrix = sim_matrix[sim_matrix != -1].view(-1, len(prototypes) - 1)  # (N, P-1)
            penalties = sim_matrix[sim_matrix > self.margin_align_intra] - self.margin_align_intra
            if penalties.size(0) > 0:
                loss_align_intra += penalties.mean() * class_weight[cls]
                count_intra += 1

            # loss_intra_cur, use_loss = self.pairwise_dissimilarity_loss(proto, margin=self.margin_intra)
            # if use_loss:
            #     loss_intra += loss_intra_cur
            #     count_intra += 1
            # norm_proto = F.normalize(proto, p=2, dim=1)  # (P, D)
            # 计算 pairwise cosine similarity矩阵 (P, P)
            # sim_matrix = torch.mm(norm_proto, norm_proto.T)
            # cosine_distance = 1 - sim_matrix  # (P, P)
            # # 排除对角线
            # mask = torch.eye(cosine_distance.size(0), device=cosine_distance.device).bool()
            # cosine_distance = cosine_distance.masked_fill(mask, 1.0)
            # # 对于每对原型，如果距离小于 margin_intra，则施加惩罚
            # penalties = F.relu(self.margin_intra - cosine_distance)
            # loss_intra += torch.mean(penalties)
            # count_intra += 1
        loss_align_intra = loss_align_intra / count_intra if count_intra > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
        
        # # --------------------
        # # 2.3 类间可学习原型 loss（最大化距离）
        loss_align_inter = 0.0
        count_inter = 0
        classes = list(self.learnable_prototypes.keys())
        for i in range(len(classes)):
            for j in range(i+1, len(classes)):
                # if self.use_backgroud and (classes[i] == str(self.num_classes) or classes[j] == str(self.num_classes)):
                #     continue
                proto_i = self.learnable_prototypes[classes[i]]  # (P_i, D)
                proto_j = self.learnable_prototypes[classes[j]]  # (P_j, D)
                norm_proto_i = F.normalize(proto_i, p=2, dim=1)
                norm_proto_j = F.normalize(proto_j, p=2, dim=1)
                sim_matrix = torch.mm(norm_proto_i, norm_proto_j.T)  # (P_i, P_j)
                penalties = sim_matrix[sim_matrix > self.margin_align_inter] - self.margin_align_inter
                if penalties.size(0) > 0:
                    loss_align_inter += penalties.mean() * class_weight[classes[i]]
                    count_inter += 1


        loss_align_inter = loss_align_inter / count_inter if count_inter > 0 else torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
        #         cosine_distance = 1 - sim_matrix  # (P_i, P_j)
        #         inter_losses.append(torch.mean(cosine_distance))
        #         count_inter += 1

        # if inter_losses:
        #     # loss_inter = - torch.mean(torch.stack(inter_losses))  # 负号：鼓励距离最大化
        #     loss_inter = torch.mean(torch.stack(inter_losses))
        #     loss_inter = loss_inter / count_inter
        # else:
        #     loss_inter = torch.tensor(0.0, device=list(self.prototypes.values())[0].device)
        
        # 总 loss 加权组合，各项权重可根据需要调整
        total_loss = self.lambda_feat * loss_feat + self.lambda_feat_intra * loss_feat_intra + self.lambda_align * loss_align + self.lambda_feat_inter * loss_feat_inter + self.lambda_align_intra * loss_align_intra + self.lambda_align_inter * loss_align_inter
        # if self.use_backgroud:
        #     total_loss = total_loss + self.lambda_bg * loss_feat_inter_bg
        # else:
        #     loss_feat_inter_bg = torch.tensor(0.0, device=list(self.learnable_prototypes.values())[0].device)
        return total_loss, loss_feat, loss_feat_intra, loss_feat_inter, loss_align, loss_align_intra, loss_align_inter, loss_feat_inter_bg



    def filter_pseudo_label_by_score(self, rois, roi_scores, roi_labels, roi_head_features):
        """
        Filter pseudo label by score threshold and nms.
        Args:
            rois: pseudo label boxes
            roi_scores: pseudo label scores
            roi_labels: pseudo label labels
        """

        roi_scores = torch.sigmoid(roi_scores)
        batch_size = rois.shape[0]
        pred_dicts = []
        nms_config = cfg.MODEL.ROI_HEAD.NMS_CONFIG.TEST
        # score_thresh = np.array(cfg.SELF_TRAIN.SCORE_THRESH)
        # score_thresh = torch.tensor(score_thresh).to(roi_scores.device)
        neg_thresh = np.array(cfg.SELF_TRAIN.NEG_THRESH)
        neg_thresh = torch.tensor(neg_thresh).to(roi_scores.device)
        for index in range(batch_size):
            batch_mask = index
            mask = roi_scores[batch_mask] > neg_thresh[roi_labels[batch_mask] - 1]
            box_preds = rois[batch_mask][mask]
            cur_roi_scores = roi_scores[batch_mask][mask]
            cur_roi_labels = roi_labels[batch_mask][mask]
            cur_roi_features = roi_head_features[batch_mask][mask]
            selected, selected_scores = class_agnostic_nms(
                box_scores=cur_roi_scores, box_preds=box_preds, nms_config=nms_config
            )
            if self.use_backgroud:
                background_mask = roi_scores[batch_mask] <= neg_thresh[roi_labels[batch_mask] - 1]
                neg_features = roi_head_features[batch_mask][background_mask]
                neg_labels = roi_labels[batch_mask][background_mask]
                # neg_features_return = None
                neg_features_return = torch.tensor([]).to(neg_features.device)
                # neg_features 按类别随机采样与正样本数量相同
                if neg_features.size(0) > 0:
                    for cls in range(1, self.num_classes + 1):
                        # mask = (roi_labels[batch_mask] == cls)
                        mask = (cur_roi_labels[selected] == cls)
                        if mask.sum() == 0:
                            continue
                        neg_mask = (neg_labels == cls)
                        if neg_mask.sum() == 0:
                            continue
                        neg_mask = neg_mask.nonzero().squeeze(1)
                        neg_mask = neg_mask[torch.randperm(neg_mask.size(0))[:mask.sum()]]
                        neg_features_return = neg_features[neg_mask] if neg_features_return.size(0) == 0 else torch.cat([neg_features_return, neg_features[neg_mask]], dim=0)
                        

                pred_dicts.append({
                    'pred_boxes': box_preds[selected],
                    'pred_scores': cur_roi_scores[selected],
                    'pred_labels': cur_roi_labels[selected],
                    'pred_head_features': cur_roi_features[selected],
                    'neg_features': neg_features_return
                })
            else:
                pred_dicts.append({
                    'pred_boxes': box_preds[selected],
                    'pred_scores': cur_roi_scores[selected],
                    'pred_labels': cur_roi_labels[selected],
                    'pred_head_features': cur_roi_features[selected],
                    'neg_features': torch.tensor([]).to(cur_roi_features.device)
                })

        rois = torch.cat([pred_dict['pred_boxes'] for pred_dict in pred_dicts], dim=0)
        roi_scores = torch.cat([pred_dict['pred_scores'] for pred_dict in pred_dicts], dim=0)
        roi_labels = torch.cat([pred_dict['pred_labels'] for pred_dict in pred_dicts], dim=0)
        roi_head_features = torch.cat([pred_dict['pred_head_features'] for pred_dict in pred_dicts], dim=0)
        neg_features = torch.cat([pred_dict['neg_features'] for pred_dict in pred_dicts], dim=0)

        return rois, roi_scores, roi_labels, roi_head_features, neg_features


    def forward(self, batch_dict):
        """
        前向传播：
          1. 根据 batch_dict 中的 ROI 特征和 ROI 标签，对每个类别调用 gmm_update 更新原型
          2. 利用 batch_dict 中的特征（可由外部构造成字典 loss_features_dict）计算原型 loss
        batch_dict 要求包含：
          'roi_head_features': Tensor, shape (N, D)
          'roi_labels': Tensor, shape (N,), 类别标签（1～num_classes）
          可选：'domain' 表示该 batch 来自 'source' 或 'target'
          可选：'loss_features_dict': dict, 格式同 compute_loss() 中要求的
        返回:
          loss_tuple: (total_loss, loss_feat, loss_intra, loss_inter)
        """
        # if self.rank_cons:
        #     print('==> Check prototype manager parameters')
        #     import torch.distributed as dist
        #     print('==> rank: {}'.format(dist.get_rank()))
        #     for key in self.learnable_prototypes.keys():
        #         print('==> {} protos: {}'.format(key, self.learnable_prototypes[key][0][:5]))
        #         break

        features = batch_dict['roi_head_features_mt']
        labels = batch_dict['roi_labels_mt']
        scores = batch_dict['roi_scores_mt']
        rois = batch_dict['rois_mt']

        rois, scores, labels, features, neg_features = self.filter_pseudo_label_by_score(rois, scores, labels, features)
        if self.use_backgroud:
            features = torch.cat([features, neg_features], dim=0)
            labels = torch.cat([labels, torch.tensor([self.num_classes] * neg_features.size(0)).to(labels.device)], dim=0)


        domain = batch_dict.get('domain', 'source')
        self.domain = domain
        group_labels = {domain: {}}
 

        # 对每个类别更新原型
        for cls in range(1, self.num_classes + 1):
            mask = (labels == cls)
            if mask.sum() == 0:
                continue
            cls = str(cls)
            cls_features = features[mask]

            # 若处于分布式模式，收集全局特征前释放无用内存
            with torch.cuda.device(features.device):
                torch.cuda.empty_cache()

            group_labels_cur = self.gmm_update_cls(cls_features, cls, update=False)

            group_labels_cur = torch.from_numpy(group_labels_cur).long().to(features.device)
            group_labels[domain][cls] = group_labels_cur

        loss_features_dict = {domain: {}}
        for cls in range(1, self.num_classes + 1):
            mask = (labels == cls)
            cls = str(cls)
            if mask.sum() > 0:
                loss_features_dict[domain][cls] = features[mask]
        # 计算 loss
        loss_tuple = self.compute_loss(loss_features_dict, group_labels)
        return loss_tuple
