import os
import numpy as np
import torch
from maskrcnn_benchmark.modeling import registry
from torch import nn
from torch.nn import functional as F
import random

from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.modeling.make_layers import make_fc
from .utils_motifs import rel_vectors, obj_edge_vectors, to_onehot, nms_overlaps, encode_box_info
from .utils_relation import layer_init
from maskrcnn_benchmark.data import get_dataset_statistics
from maskrcnn_benchmark.data.datasets.inverse_augmentation import InverseAugmentationConfig
from .inverse_relation_utils import InverseRelationMapper

from .model_msg_passing import IMPContext
from .model_motifs import LSTMContext, FrequencyBias
from .model_vctree import VCTreeLSTMContext

from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip import clip
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.simple_tokenizer import SimpleTokenizer as clip_tokenizer
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.text_encoder import TextEncoder
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip.prompt import Clip_PromptLearner


@registry.ROI_RELATION_PREDICTOR.register("PeNet_HP_Predictor")
class PeNet_HP_Predictor(nn.Module):
    def __init__(self, config, in_channels):
        super(PeNet_HP_Predictor, self).__init__()

        self.num_obj_cls = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
        self.num_att_cls = config.MODEL.ROI_ATTRIBUTE_HEAD.NUM_ATTRIBUTES
        self.num_rel_cls = config.MODEL.ROI_RELATION_HEAD.NUM_CLASSES
        self.cfg = config

        assert in_channels is not None
        self.in_channels = in_channels
        self.obj_dim = in_channels

        self.use_vision = config.MODEL.ROI_RELATION_HEAD.PREDICT_USE_VISION
        statistics = get_dataset_statistics(config)

        self.fg_matrix = statistics['fg_matrix']
        if self.fg_matrix is not None:
            self.fg_matrix[:, :, 0] = 0
        self.fg_matrix = self.fg_matrix.cuda()

        obj_classes, rel_classes, att_classes = statistics['obj_classes'], statistics['rel_classes'], statistics[
            'att_classes']
        assert self.num_obj_cls == len(obj_classes)
        assert self.num_att_cls == len(att_classes)
        assert self.num_rel_cls == len(rel_classes)
        self.obj_classes = obj_classes
        self.rel_classes = rel_classes
        self.num_obj_classes = len(obj_classes)

        # 初始化逆关系数据增强
        self.inverse_aug = InverseAugmentationConfig(
            config.DATASETS.INVERSE_AUG, rel_classes
        )

        self.hidden_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_HIDDEN_DIM
        self.pooling_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_POOLING_DIM

        self.mlp_dim = 2048  # config.MODEL.ROI_RELATION_HEAD.PENET_MLP_DIM
        self.post_emb = nn.Linear(self.obj_dim, self.mlp_dim * 2)

        self.embed_dim = 300  # config.MODEL.ROI_RELATION_HEAD.PENET_EMBED_DIM
        dropout_p = 0.2  # config.MODEL.ROI_RELATION_HEAD.PENET_DROPOUT

        obj_embed_vecs = obj_edge_vectors(obj_classes, wv_dir=self.cfg.GLOVE_DIR,
                                          wv_dim=self.embed_dim)  # load Glove for objects
        rel_embed_vecs = rel_vectors(rel_classes, wv_dir=config.GLOVE_DIR,
                                     wv_dim=self.embed_dim)  # load Glove for predicates
        self.obj_embed = nn.Embedding(self.num_obj_cls, self.embed_dim)
        self.rel_embed = nn.Embedding(self.num_rel_cls, self.embed_dim)
        with torch.no_grad():
            self.obj_embed.weight.copy_(obj_embed_vecs, non_blocking=True)
            self.rel_embed.weight.copy_(rel_embed_vecs, non_blocking=True)

        with torch.no_grad():
            clip_model, _ = clip.load('RN101', device='cpu')

            self.text_encoder = TextEncoder(clip_model)
            relations = torch.arange(len(rel_classes))
            self.relation_text = clip.tokenize([rel_classes[i.data] for i in relations])
            tokenized_prompts = self.relation_text
            self.text_features = self.text_encoder(tokenized_prompts)

        self.prompt_learner_kl = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model)
        self.prompt_learner_kl_harder = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model,)
        self.prompt_learner = Clip_PromptLearner(config, rel_classes, clip_tokenizer(), clip_model)

        self.W_sub = MLP(self.embed_dim, self.mlp_dim // 2, self.mlp_dim, 2)
        self.W_obj = MLP(self.embed_dim, self.mlp_dim // 2, self.mlp_dim, 2)
        self.W_pred = MLP(512, self.mlp_dim // 2, self.mlp_dim, 2)
        self.w_pvtp = MLP(512, self.mlp_dim // 2, self.mlp_dim, 2)

        self.gate_sub = nn.Linear(self.mlp_dim * 2, self.mlp_dim)
        self.gate_obj = nn.Linear(self.mlp_dim * 2, self.mlp_dim)
        self.gate_pred = nn.Linear(self.mlp_dim * 2, self.mlp_dim)

        self.vis2sem = nn.Sequential(*[
            nn.Linear(self.mlp_dim, self.mlp_dim * 2), nn.ReLU(True),
            nn.Dropout(dropout_p), nn.Linear(self.mlp_dim * 2, self.mlp_dim)
        ])

        self.project_head = MLP(self.mlp_dim, self.mlp_dim, self.mlp_dim * 2, 2)

        self.linear_sub = nn.Linear(self.mlp_dim, self.mlp_dim)
        self.linear_obj = nn.Linear(self.mlp_dim, self.mlp_dim)
        self.linear_pred = nn.Linear(self.mlp_dim, self.mlp_dim)
        self.linear_rel_rep = nn.Linear(self.mlp_dim, self.mlp_dim)

        self.norm_sub = nn.LayerNorm(self.mlp_dim)
        self.norm_obj = nn.LayerNorm(self.mlp_dim)
        self.norm_rel_rep = nn.LayerNorm(self.mlp_dim)

        self.dropout_sub = nn.Dropout(dropout_p)
        self.dropout_obj = nn.Dropout(dropout_p)
        self.dropout_rel_rep = nn.Dropout(dropout_p)

        self.dropout_rel = nn.Dropout(dropout_p)
        self.dropout_pred = nn.Dropout(dropout_p)

        self.down_samp = MLP(self.pooling_dim, self.mlp_dim, self.mlp_dim, 2)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        ##### refine object labels
        self.pos_embed = nn.Sequential(*[
            nn.Linear(9, 32), nn.BatchNorm1d(32, momentum=0.001),
            nn.Linear(32, 128), nn.ReLU(inplace=True),
        ])

        self.obj_embed1 = nn.Embedding(self.num_obj_classes, self.embed_dim)
        with torch.no_grad():
            self.obj_embed1.weight.copy_(obj_embed_vecs, non_blocking=True)

        self.obj_dim = in_channels
        self.out_obj = make_fc(self.hidden_dim, self.num_obj_classes)
        self.lin_obj_cyx = make_fc(self.obj_dim + self.embed_dim + 128, self.hidden_dim)

        if self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_BOX:
            if self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL:
                self.mode = 'predcls'
            else:
                self.mode = 'sgcls'
        else:
            self.mode = 'sgdet'

        self.nms_thresh = self.cfg.TEST.RELATION.LATER_NMS_PREDICTION_THRES

        self.freq_bias = FrequencyBias(config, statistics)

        # =============================================
        # 逆关系测试配置
        # =============================================
        test_cfg = getattr(config, "TEST", None)
        inverse_cfg = getattr(test_cfg, "INVERSE_RELATION", None) if test_cfg else None
        
        self.inverse_test_enabled = getattr(inverse_cfg, "ENABLED", False) if inverse_cfg else False
        self.inverse_alpha = getattr(inverse_cfg, "ALPHA", 0.5) if inverse_cfg else 0.5
        inverse_csv_path = getattr(inverse_cfg, "CSV_PATH", "") if inverse_cfg else ""
        
        # 创建逆关系映射器
        self.inverse_mapper = InverseRelationMapper(rel_classes, inverse_csv_path)
        self._inverse_debug_logged = 0
        self._error_debug_logged = 0
        
        if self.inverse_test_enabled:
            print(f"[PeNet_HP] inverse_test_enabled={self.inverse_test_enabled}, alpha={self.inverse_alpha}")
            self.inverse_mapper.log_mapping_info()

        self.frozen()

    def frozen(self):
        for name, param in self.text_encoder.named_parameters():
            param.requires_grad_(False)

    def maybe_augment_relations(self, rel_pair_idxs, rel_labels, rel_binarys=None):
        """逆关系数据增强入口，支持 ewai/tihuan/both 三种策略"""
        if self.training and self.inverse_aug.enabled:
            return self.inverse_aug.augment_sampled_pairs(rel_pair_idxs, rel_labels, rel_binarys)
        # 非训练或未启用时，返回原始数据和空的 inverse_flags
        inverse_flags = []
        for labels in rel_labels:
            if labels is None:
                inverse_flags.append(None)
            else:
                inverse_flags.append(labels.new_zeros(labels.size(0), dtype=torch.bool, device=labels.device))
        return rel_pair_idxs, rel_labels, rel_binarys, inverse_flags

    def get_rel_dist_vtp(self, prod_rep, prompt_learner, logit_scale):
        tokenized_prompts = self.relation_text.to(prod_rep.device)
        text_features = self.w_pvtp(self.text_encoder(prompt_learner, tokenized_prompts).float())
        text_features = self.project_head(self.dropout_pred(torch.relu(text_features)))
        image_features = prod_rep
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        rel_dists = logit_scale.exp() * image_features @ text_features.T
        return rel_dists

    def forward(self, proposals, rel_pair_idxs, rel_labels, rel_binarys, rel_inverse_flags, roi_features, union_features, logger=None):

        add_losses = {}
        add_data = {}

        # refine object labels
        entity_dists, entity_preds = self.refine_obj_labels(roi_features, proposals)
        #####

        entity_rep = self.post_emb(roi_features)  # using the roi features obtained from the faster rcnn
        entity_rep = entity_rep.view(entity_rep.size(0), 2, self.mlp_dim)

        sub_rep = entity_rep[:, 1].contiguous().view(-1, self.mlp_dim)  # xs
        obj_rep = entity_rep[:, 0].contiguous().view(-1, self.mlp_dim)  # xo

        entity_embeds = self.obj_embed(entity_preds)  # obtaining the word embedding of entities with GloVe

        num_rels = [r.shape[0] for r in rel_pair_idxs]
        num_objs = [len(b) for b in proposals]
        assert len(num_rels) == len(num_objs)

        sub_reps = sub_rep.split(num_objs, dim=0)
        obj_reps = obj_rep.split(num_objs, dim=0)
        entity_preds = entity_preds.split(num_objs, dim=0)
        entity_embeds = entity_embeds.split(num_objs, dim=0)

        fusion_so = []
        fusion_so_inv = []  # [新增] 逆关系：交换主客体
        pair_preds = []

        for pair_idx, sub_rep, obj_rep, entity_pred, entity_embed, proposal in zip(rel_pair_idxs, sub_reps, obj_reps,
                                                                                   entity_preds, entity_embeds, proposals):
            # ========== 原始方向：(subj, obj) ==========
            s_embed = self.W_sub(entity_embed[pair_idx[:, 0]])  # Ws x ts
            o_embed = self.W_obj(entity_embed[pair_idx[:, 1]])  # Wo x to

            sem_sub = self.vis2sem(sub_rep[pair_idx[:, 0]])  # h(xs)
            sem_obj = self.vis2sem(obj_rep[pair_idx[:, 1]])  # h(xo)

            gate_sem_sub = torch.sigmoid(self.gate_sub(cat((s_embed, sem_sub), dim=-1)))  # gs
            gate_sem_obj = torch.sigmoid(self.gate_obj(cat((o_embed, sem_obj), dim=-1)))  # go

            sub = s_embed + sem_sub * gate_sem_sub  # s = Ws x ts + gs · h(xs)  i.e., s = Ws x ts + vs
            obj = o_embed + sem_obj * gate_sem_obj  # o = Wo x to + go · h(xo)  i.e., o = Wo x to + vo

            ##### for the model convergence
            sub = self.norm_sub(self.dropout_sub(torch.relu(self.linear_sub(sub))) + sub)
            obj = self.norm_obj(self.dropout_obj(torch.relu(self.linear_obj(obj))) + obj)
            #####

            fusion_so.append(fusion_func(sub, obj))  # F(s, o)
            pair_preds.append(torch.stack((entity_pred[pair_idx[:, 0]], entity_pred[pair_idx[:, 1]]), dim=1))

            # ========== [新增] 逆关系方向：(obj, subj) - 交换主客体 ==========
            if (not self.training) and self.inverse_test_enabled:
                # 交换主客体的索引
                s_embed_inv = self.W_sub(entity_embed[pair_idx[:, 1]])  # 原来的 obj 作为 subj
                o_embed_inv = self.W_obj(entity_embed[pair_idx[:, 0]])  # 原来的 subj 作为 obj

                sem_sub_inv = self.vis2sem(sub_rep[pair_idx[:, 1]])
                sem_obj_inv = self.vis2sem(obj_rep[pair_idx[:, 0]])

                gate_sem_sub_inv = torch.sigmoid(self.gate_sub(cat((s_embed_inv, sem_sub_inv), dim=-1)))
                gate_sem_obj_inv = torch.sigmoid(self.gate_obj(cat((o_embed_inv, sem_obj_inv), dim=-1)))

                sub_inv = s_embed_inv + sem_sub_inv * gate_sem_sub_inv
                obj_inv = o_embed_inv + sem_obj_inv * gate_sem_obj_inv

                sub_inv = self.norm_sub(self.dropout_sub(torch.relu(self.linear_sub(sub_inv))) + sub_inv)
                obj_inv = self.norm_obj(self.dropout_obj(torch.relu(self.linear_obj(obj_inv))) + obj_inv)

                fusion_so_inv.append(fusion_func(sub_inv, obj_inv))

        fusion_so = cat(fusion_so, dim=0)
        pair_pred = cat(pair_preds, dim=0)

        sem_pred = self.vis2sem(self.down_samp(union_features))  # h(xu)
        gate_sem_pred = torch.sigmoid(self.gate_pred(cat((fusion_so, sem_pred), dim=-1)))  # gp

        rel_rep = fusion_so - sem_pred * gate_sem_pred  # F(s,o) - gp · h(xu)   i.e., r = F(s,o) - up
        tokenized_prompts = self.relation_text.to(rel_rep.device)
        predicate_proto = self.w_pvtp(self.text_encoder(self.prompt_learner(), tokenized_prompts).float())

        ##### for the model convergence
        rel_rep = self.norm_rel_rep(self.dropout_rel_rep(torch.relu(self.linear_rel_rep(rel_rep))) + rel_rep)

        rel_rep = self.project_head(self.dropout_rel(torch.relu(rel_rep)))
        predicate_proto = self.project_head(self.dropout_pred(torch.relu(predicate_proto)))
        ######

        rel_rep_norm = rel_rep / rel_rep.norm(dim=1, keepdim=True)  # r_norm
        predicate_proto_norm = predicate_proto / predicate_proto.norm(dim=1, keepdim=True)  # c_norm

        ### (Prototype-based Learning  ---- cosine similarity) & (Relation Prediction)
        rel_dists = rel_rep_norm @ predicate_proto_norm.t() * self.logit_scale.exp()  # <r_norm, c_norm> / τ
        # the rel_dists will be used to calculate the Le_sim with the ce_loss

        # =============================================
        # [新增] 测试阶段的逆关系融合
        # =============================================
        if (not self.training) and self.inverse_test_enabled and len(fusion_so_inv) > 0:
            fusion_so_inv = cat(fusion_so_inv, dim=0)
            
            # 计算逆关系的 rel_rep
            sem_pred_inv = self.vis2sem(self.down_samp(union_features))
            gate_sem_pred_inv = torch.sigmoid(self.gate_pred(cat((fusion_so_inv, sem_pred_inv), dim=-1)))
            rel_rep_inv = fusion_so_inv - sem_pred_inv * gate_sem_pred_inv
            
            rel_rep_inv = self.norm_rel_rep(self.dropout_rel_rep(torch.relu(self.linear_rel_rep(rel_rep_inv))) + rel_rep_inv)
            rel_rep_inv = self.project_head(self.dropout_rel(torch.relu(rel_rep_inv)))
            
            rel_rep_inv_norm = rel_rep_inv / rel_rep_inv.norm(dim=1, keepdim=True)
            rel_dists_inv_raw = rel_rep_inv_norm @ predicate_proto_norm.t() * self.logit_scale.exp()
            
            # 将逆关系预测分数重排到对应的原关系位置
            rel_dists_inv = self.inverse_mapper.remap_scores(rel_dists_inv_raw)
            
            # 保存原始预测用于调试
            rel_dists_orig = rel_dists.clone()
            
            # 融合原始预测和逆关系预测（排除 background）
            rel_dists[:, 1:] = (1 - self.inverse_alpha) * rel_dists[:, 1:] + self.inverse_alpha * rel_dists_inv[:, 1:]
            
            # Debug 输出 - 基础 top3 信息
            if self._inverse_debug_logged < 5 and rel_dists.size(0) > 0:
                num_log = min(5 - self._inverse_debug_logged, rel_dists.size(0))
                for idx in range(num_log):
                    orig_vals, orig_inds = torch.topk(rel_dists_orig[idx, 1:], k=min(3, rel_dists.size(1) - 1))
                    fused_vals, fused_inds = torch.topk(rel_dists[idx, 1:], k=min(3, rel_dists.size(1) - 1))
                    inv_vals, inv_inds = torch.topk(rel_dists_inv[idx, 1:], k=min(3, rel_dists.size(1) - 1))
                    print(f"[PeNet_HP Inverse Debug #{self._inverse_debug_logged + idx}] "
                          f"orig_top3: {[self.rel_classes[i+1] for i in orig_inds.tolist()]} | "
                          f"inv_top3: {[self.rel_classes[i+1] for i in inv_inds.tolist()]} | "
                          f"fused_top3: {[self.rel_classes[i+1] for i in fused_inds.tolist()]}")
                self._inverse_debug_logged += num_log
            
            # =============================================
            # [新增] 针对特定关系词的错误预测详细分析
            # 关注: holding, walking on, sitting on, lying on, wearing, with
            # =============================================
            if rel_labels is not None and len(rel_labels) > 0:
                rel_labels_cat = cat(rel_labels, dim=0) if isinstance(rel_labels, list) else rel_labels
                
                # 要关注的关系词列表
                target_relations = ['holding', 'walking on', 'sitting on', 'lying on', 'wearing', 'with']
                target_rel_ids = []
                for rel_name in target_relations:
                    for idx, cls_name in enumerate(self.rel_classes):
                        if cls_name.lower() == rel_name.lower():
                            target_rel_ids.append(idx)
                            break
                
                if self._error_debug_logged < 30:  # 打印最多30条
                    # 获取预测结果
                    pred_labels = rel_dists[:, 1:].argmax(dim=1) + 1  # +1 因为跳过了 background
                    
                    for idx in range(min(rel_labels_cat.size(0), rel_dists.size(0))):
                        gt_label = rel_labels_cat[idx].item()
                        pred_label = pred_labels[idx].item()
                        
                        # 只关注特定关系词且预测错误的情况
                        if gt_label in target_rel_ids and gt_label != pred_label:
                            gt_name = self.rel_classes[gt_label]
                            pred_name = self.rel_classes[pred_label]
                            
                            # 获取 top5 预测
                            top5_vals, top5_inds = torch.topk(rel_dists[idx, 1:], k=min(5, rel_dists.size(1) - 1))
                            top5_names = [self.rel_classes[i+1] for i in top5_inds.tolist()]
                            top5_scores = [f"{v:.3f}" for v in top5_vals.tolist()]
                            
                            # 原始预测 top3
                            orig_top3_vals, orig_top3_inds = torch.topk(rel_dists_orig[idx, 1:], k=min(3, rel_dists.size(1) - 1))
                            orig_top3_names = [self.rel_classes[i+1] for i in orig_top3_inds.tolist()]
                            
                            # 逆关系预测 top3
                            inv_top3_vals, inv_top3_inds = torch.topk(rel_dists_inv[idx, 1:], k=min(3, rel_dists.size(1) - 1))
                            inv_top3_names = [self.rel_classes[i+1] for i in inv_top3_inds.tolist()]
                            
                            # 获取主客体信息
                            subj_label = pair_pred[idx, 0].item()
                            obj_label = pair_pred[idx, 1].item()
                            subj_name = self.obj_classes[subj_label] if subj_label < len(self.obj_classes) else f"unk_{subj_label}"
                            obj_name = self.obj_classes[obj_label] if obj_label < len(self.obj_classes) else f"unk_{obj_label}"
                            
                            print(f"[Error Analysis #{self._error_debug_logged}] "
                                  f"GT: '{gt_name}' | Pred: '{pred_name}' | "
                                  f"Triple: ({subj_name}, {gt_name}, {obj_name}) | "
                                  f"Fused_top5: {list(zip(top5_names, top5_scores))} | "
                                  f"Orig_top3: {orig_top3_names} | Inv_top3: {inv_top3_names}")
                            
                            self._error_debug_logged += 1
                            if self._error_debug_logged >= 30:
                                print("[Error Analysis] 已打印30条错误分析，停止输出。")
                                break

        if self.training:
            ### Prototype Regularization  ---- cosine similarity
            target_predicate_proto_norm = predicate_proto_norm.clone().detach()
            simil_mat = predicate_proto_norm @ target_predicate_proto_norm.t()  # Semantic Matrix S = C_norm @ C_norm.T
            l21 = torch.norm(torch.norm(simil_mat, p=2, dim=1), p=1) / (self.num_rel_cls * self.num_rel_cls)
            add_losses.update({"l21_loss": l21})  # Le_sim = ||S||_{2,1}
            ### end

            ### Prototype Regularization  ---- Euclidean distance
            gamma2 = 7.0
            predicate_proto_a = predicate_proto.unsqueeze(dim=1).expand(-1, self.num_rel_cls, -1)
            predicate_proto_b = predicate_proto.detach().unsqueeze(dim=0).expand(self.num_rel_cls, -1, -1)
            proto_dis_mat = (predicate_proto_a - predicate_proto_b).norm(dim=2) ** 2  # Distance Matrix D, dij = ||ci - cj||_2^2
            sorted_proto_dis_mat, _ = torch.sort(proto_dis_mat, dim=1)
            topK_proto_dis = sorted_proto_dis_mat[:, :2].sum(dim=1) / 1  # obtain d-, where k2 = 1
            dist_loss = torch.max(torch.zeros(self.num_rel_cls).cuda(), -topK_proto_dis + gamma2).mean()  # Lr_euc = max(0, -(d-) + gamma2)
            add_losses.update({"dist_loss2": dist_loss})
            ### end

            ### Prototype-based Learning  ---- Euclidean distance
            rel_labels_cat = cat(rel_labels, dim=0)
            gamma1 = 1.0
            rel_rep_expand = rel_rep.unsqueeze(dim=1).expand(-1, self.num_rel_cls, -1)  # r
            predicate_proto_expand = predicate_proto.unsqueeze(dim=0).expand(rel_labels_cat.size(0), -1, -1)  # ci
            distance_set = (rel_rep_expand - predicate_proto_expand).norm(dim=2) ** 2  # Distance Set G, gi = ||r-ci||_2^2
            mask_neg = torch.ones(rel_labels_cat.size(0), self.num_rel_cls).cuda()
            mask_neg[torch.arange(rel_labels_cat.size(0)), rel_labels_cat] = 0
            distance_set_neg = distance_set * mask_neg
            distance_set_pos = distance_set[torch.arange(rel_labels_cat.size(0)), rel_labels_cat]  # gt i.e., g+
            sorted_distance_set_neg, _ = torch.sort(distance_set_neg, dim=1)
            topK_sorted_distance_set_neg = sorted_distance_set_neg[:, :11].sum(dim=1) / 10  # obtaining g-, where k1 = 10
            loss_sum = torch.max(torch.zeros(rel_labels_cat.size(0)).cuda(), distance_set_pos - topK_sorted_distance_set_neg + gamma1).mean()
            add_losses.update({"loss_dis": loss_sum})  # Le_euc = max(0, (g+) - (g-) + gamma1)
            ### end

        # =============================================
        # [RWT] 保存 rel_dists_2 用于粗粒度分支损失计算
        # 在 split 之前保存完整的 rel_dists 张量
        # =============================================
        add_data['rel_dists_2'] = rel_dists.clone()

        if type(rel_dists) != list:
            rel_dists = rel_dists.split(num_rels, dim=0)
        entity_dists = entity_dists.split(num_objs, dim=0)

        return entity_dists, rel_dists, add_losses, add_data

    def refine_obj_labels(self, roi_features, proposals):
        use_gt_label = self.training or self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL
        obj_labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0) if use_gt_label else None
        pos_embed = self.pos_embed(encode_box_info(proposals))

        # label/logits embedding will be used as input
        if self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL:
            obj_labels = obj_labels.long()
            obj_embed = self.obj_embed1(obj_labels)
        else:
            obj_logits = cat([proposal.get_field("predict_logits") for proposal in proposals], dim=0).detach()
            obj_embed = F.softmax(obj_logits, dim=1) @ self.obj_embed1.weight

        assert proposals[0].mode == 'xyxy'

        pos_embed = self.pos_embed(encode_box_info(proposals))
        num_objs = [len(p) for p in proposals]
        obj_pre_rep_for_pred = self.lin_obj_cyx(cat([roi_features, obj_embed, pos_embed], -1))

        if self.mode == 'predcls':
            obj_labels = obj_labels.long()
            obj_preds = obj_labels
            obj_dists = to_onehot(obj_preds, self.num_obj_classes)
        else:
            obj_dists = self.out_obj(obj_pre_rep_for_pred)  # 512 -> 151
            use_decoder_nms = self.mode == 'sgdet' and not self.training
            if use_decoder_nms:
                boxes_per_cls = [proposal.get_field('boxes_per_cls') for proposal in proposals]
                obj_preds = self.nms_per_cls(obj_dists, boxes_per_cls, num_objs).long()
            else:
                obj_preds = (obj_dists[:, 1:].max(1)[1] + 1).long()

        return obj_dists, obj_preds

    def nms_per_cls(self, obj_dists, boxes_per_cls, num_objs):
        obj_dists = obj_dists.split(num_objs, dim=0)
        obj_preds = []
        for i in range(len(num_objs)):
            is_overlap = nms_overlaps(boxes_per_cls[i]).cpu().numpy() >= self.nms_thresh  # (#box, #box, #class)

            out_dists_sampled = F.softmax(obj_dists[i], -1).cpu().numpy()
            out_dists_sampled[:, 0] = -1

            out_label = obj_dists[i].new(num_objs[i]).fill_(0)

            for i in range(num_objs[i]):
                box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(), out_dists_sampled.shape)
                out_label[int(box_ind)] = int(cls_ind)
                out_dists_sampled[is_overlap[box_ind, :, cls_ind], cls_ind] = 0.0
                out_dists_sampled[box_ind] = -1.0  # This way we won't re-sample

            obj_preds.append(out_label.long())
        obj_preds = torch.cat(obj_preds, dim=0)
        return obj_preds


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


def fusion_func(x, y):
    return F.relu(x + y) - (x - y) ** 2

