"""
支持逆关系测试的 Motifs 和 VCTree 预测器

在测试阶段（PredCls）对非 CLIP 模型进行逆关系性能指标测试。
核心实现：
1. 交换主客体的视觉特征 (head_rep <-> tail_rep)
2. 使用逆关系映射表将预测分数重排
3. 融合原始预测和逆关系预测
"""
import torch
from torch import nn
from torch.nn import functional as F

from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.data import get_dataset_statistics

from .utils_relation import layer_init
from .model_motifs import LSTMContext, FrequencyBias, AttributeLSTMContext
from .model_vctree import VCTreeLSTMContext
from .inverse_relation_utils import create_inverse_mapper


@registry.ROI_RELATION_PREDICTOR.register("MotifPredictor_Inverse")
class MotifPredictor_Inverse(nn.Module):
    """
    带逆关系测试支持的 Motif Predictor
    
    在测试阶段，额外计算逆关系预测并与原始预测融合。
    """
    
    def __init__(self, config, in_channels):
        super(MotifPredictor_Inverse, self).__init__()
        self.attribute_on = config.MODEL.ATTRIBUTE_ON
        self.num_obj_cls = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
        self.num_att_cls = config.MODEL.ROI_ATTRIBUTE_HEAD.NUM_ATTRIBUTES
        self.num_rel_cls = config.MODEL.ROI_RELATION_HEAD.NUM_CLASSES
        
        assert in_channels is not None
        self.use_vision = config.MODEL.ROI_RELATION_HEAD.PREDICT_USE_VISION
        self.use_bias = config.MODEL.ROI_RELATION_HEAD.PREDICT_USE_BIAS

        # load class dict
        statistics = get_dataset_statistics(config)
        obj_classes, rel_classes, att_classes = (
            statistics['obj_classes'], 
            statistics['rel_classes'], 
            statistics['att_classes']
        )
        assert self.num_obj_cls == len(obj_classes)
        assert self.num_att_cls == len(att_classes)
        assert self.num_rel_cls == len(rel_classes)
        
        # init contextual lstm encoding
        if self.attribute_on:
            self.context_layer = AttributeLSTMContext(config, obj_classes, att_classes, rel_classes, in_channels)
        else:
            self.context_layer = LSTMContext(config, obj_classes, rel_classes, in_channels)

        # post decoding
        self.hidden_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_HIDDEN_DIM
        self.pooling_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_POOLING_DIM
        self.post_emb = nn.Linear(self.hidden_dim, self.hidden_dim * 2)
        self.post_cat = nn.Linear(self.hidden_dim * 2, self.pooling_dim)
        self.rel_compress = nn.Linear(self.pooling_dim, self.num_rel_cls, bias=True)

        # initialize layer parameters 
        layer_init(self.post_emb, 10.0 * (1.0 / self.hidden_dim) ** 0.5, normal=True)
        layer_init(self.post_cat, xavier=True)
        layer_init(self.rel_compress, xavier=True)
        
        if self.pooling_dim != config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM:
            self.union_single_not_match = True
            self.up_dim = nn.Linear(config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM, self.pooling_dim)
            layer_init(self.up_dim, xavier=True)
        else:
            self.union_single_not_match = False

        if self.use_bias:
            self.freq_bias = FrequencyBias(config, statistics)

        # [新增] 逆关系测试支持
        hp_cfg = getattr(config.MODEL.ROI_RELATION_HEAD, "HP", None)
        self.inverse_alpha = getattr(hp_cfg, "INVERSE_ALPHA", 0.0) if hp_cfg else 0.0
        self.inverse_mapper = create_inverse_mapper(config, rel_classes)
        self._inverse_debug_logged = 0
        
        print(f"[MotifPredictor_Inverse] inverse_alpha = {self.inverse_alpha}")

    def maybe_augment_relations(self, rel_pair_idxs, rel_labels, rel_binarys=None):
        """逗关系数据增强方法（不使用增强）"""
        inverse_flags = []
        for labels in rel_labels:
            if labels is None:
                inverse_flags.append(None)
            else:
                inverse_flags.append(labels.new_zeros(labels.size(0), dtype=torch.bool, device=labels.device))
        return rel_pair_idxs, rel_labels, rel_binarys, inverse_flags

    def forward(self, proposals, rel_pair_idxs, rel_labels, rel_binarys, rel_inverse_flags=None, roi_features=None, union_features=None, logger=None):
        """
        Returns:
            obj_dists (list[Tensor]): logits of object label distribution
            rel_dists (list[Tensor])
            rel_pair_idxs (list[Tensor]): (num_rel, 2) index of subject and object
            union_features (Tensor): (batch_num_rel, context_pooling_dim): visual union feature of each pair
        """
        add_losses = {}

        # encode context information
        if self.attribute_on:
            obj_dists, obj_preds, att_dists, edge_ctx = self.context_layer(roi_features, proposals, logger)
        else:
            obj_dists, obj_preds, edge_ctx, _ = self.context_layer(roi_features, proposals, logger)

        # post decode
        edge_rep = self.post_emb(edge_ctx)
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim)
        head_rep = edge_rep[:, 0].contiguous().view(-1, self.hidden_dim)
        tail_rep = edge_rep[:, 1].contiguous().view(-1, self.hidden_dim)

        num_rels = [r.shape[0] for r in rel_pair_idxs]
        num_objs = [len(b) for b in proposals]
        assert len(num_rels) == len(num_objs)

        head_reps = head_rep.split(num_objs, dim=0)
        tail_reps = tail_rep.split(num_objs, dim=0)
        obj_preds = obj_preds.split(num_objs, dim=0)
        
        prod_reps = []
        prod_reps_inv = []  # [新增] 逆关系视觉特征
        pair_preds = []
        
        for pair_idx, head_rep, tail_rep, obj_pred in zip(rel_pair_idxs, head_reps, tail_reps, obj_preds):
            # 原始：(subj, obj)
            prod_reps.append(torch.cat((head_rep[pair_idx[:, 0]], tail_rep[pair_idx[:, 1]]), dim=-1))
            # [新增] 逆关系：(obj, subj) - 交换主客体
            prod_reps_inv.append(torch.cat((head_rep[pair_idx[:, 1]], tail_rep[pair_idx[:, 0]]), dim=-1))
            pair_preds.append(torch.stack((obj_pred[pair_idx[:, 0]], obj_pred[pair_idx[:, 1]]), dim=1))
        
        prod_rep = cat(prod_reps, dim=0)
        prod_rep_inv = cat(prod_reps_inv, dim=0)  # [新增]
        pair_pred = cat(pair_preds, dim=0)
        pair_pred_inv = pair_pred[:, [1, 0]]  # [新增] 交换主客体预测

        prod_rep = self.post_cat(prod_rep)
        prod_rep_inv = self.post_cat(prod_rep_inv)  # [新增]

        if self.use_vision:
            if self.union_single_not_match:
                prod_rep = prod_rep * self.up_dim(union_features)
                prod_rep_inv = prod_rep_inv * self.up_dim(union_features)  # [新增]
            else:
                prod_rep = prod_rep * union_features
                prod_rep_inv = prod_rep_inv * union_features  # [新增]

        # 原始关系预测
        rel_dists = self.rel_compress(prod_rep)
        
        if self.use_bias:
            rel_dists = rel_dists + self.freq_bias.index_with_labels(pair_pred.long())

        # [新增] 测试阶段的逆关系融合
        if (not self.training) and self.inverse_alpha > 0:
            self.inverse_mapper.log_mapping_info()
            
            # 计算逆关系预测
            rel_dists_inv_raw = self.rel_compress(prod_rep_inv)
            
            if self.use_bias:
                rel_dists_inv_raw = rel_dists_inv_raw + self.freq_bias.index_with_labels(pair_pred_inv.long())
            
            # 将逆关系预测分数重排到对应的原关系位置
            # 例如：逆关系预测中 "below" 的分数，应该对应到 "above" 位置
            rel_dists_inv = self.inverse_mapper.remap_scores(rel_dists_inv_raw)
            
            # 融合原始预测和逆关系预测（排除 background）
            rel_dists[:, 1:] = (1 - self.inverse_alpha) * rel_dists[:, 1:] + self.inverse_alpha * rel_dists_inv[:, 1:]
            
            # Debug 输出
            if self._inverse_debug_logged < 5 and rel_dists.size(0) > 0:
                num_log = min(5 - self._inverse_debug_logged, rel_dists.size(0))
                rel_classes = self.inverse_mapper.rel_classes
                for idx in range(num_log):
                    # 获取前10个预测结果
                    k = min(10, rel_dists.size(1) - 1)
                    orig_vals, orig_inds = torch.topk(rel_dists[idx, 1:], k=k)
                    inv_vals, inv_inds = torch.topk(rel_dists_inv[idx, 1:], k=k)
                    
                    # 转换为概率
                    orig_probs = F.softmax(rel_dists[idx, 1:], dim=0)
                    inv_probs = F.softmax(rel_dists_inv[idx, 1:], dim=0)
                    
                    print(f"\n[Motifs Inverse Debug #{self._inverse_debug_logged + idx}]")
                    print("  Fused Top-10:")
                    for i, (idx_rel, val) in enumerate(zip(orig_inds, orig_vals)):
                        rel_name = rel_classes[idx_rel.item() + 1]  # +1 因为排除了background
                        prob = orig_probs[idx_rel].item()
                        print(f"    {i+1}. [{idx_rel.item()+1}] {rel_name:20s} (logit={val.item():.3f}, prob={prob:.4f})")
                    
                    print("  Inverse Top-10:")
                    for i, (idx_rel, val) in enumerate(zip(inv_inds, inv_vals)):
                        rel_name = rel_classes[idx_rel.item() + 1]
                        prob = inv_probs[idx_rel].item()
                        print(f"    {i+1}. [{idx_rel.item()+1}] {rel_name:20s} (logit={val.item():.3f}, prob={prob:.4f})")
                self._inverse_debug_logged += num_log

        obj_dists = obj_dists.split(num_objs, dim=0)
        rel_dists = rel_dists.split(num_rels, dim=0)

        if self.attribute_on:
            att_dists = att_dists.split(num_objs, dim=0)
            return (obj_dists, att_dists), rel_dists, add_losses
        else:
            return obj_dists, rel_dists, add_losses


@registry.ROI_RELATION_PREDICTOR.register("VCTreePredictor_Inverse")
class VCTreePredictor_Inverse(nn.Module):
    """
    带逆关系测试支持的 VCTree Predictor
    
    在测试阶段，额外计算逆关系预测并与原始预测融合。
    """
    
    def __init__(self, config, in_channels):
        super(VCTreePredictor_Inverse, self).__init__()
        self.attribute_on = config.MODEL.ATTRIBUTE_ON
        self.num_obj_cls = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
        self.num_att_cls = config.MODEL.ROI_ATTRIBUTE_HEAD.NUM_ATTRIBUTES
        self.num_rel_cls = config.MODEL.ROI_RELATION_HEAD.NUM_CLASSES
        
        assert in_channels is not None

        # load class dict
        statistics = get_dataset_statistics(config)
        obj_classes, rel_classes, att_classes = (
            statistics['obj_classes'], 
            statistics['rel_classes'], 
            statistics['att_classes']
        )
        assert self.num_obj_cls == len(obj_classes)
        assert self.num_att_cls == len(att_classes)
        assert self.num_rel_cls == len(rel_classes)
        
        # init contextual lstm encoding
        self.context_layer = VCTreeLSTMContext(config, obj_classes, rel_classes, statistics, in_channels)

        # post decoding
        self.hidden_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_HIDDEN_DIM
        self.pooling_dim = config.MODEL.ROI_RELATION_HEAD.CONTEXT_POOLING_DIM
        self.post_emb = nn.Linear(self.hidden_dim, self.hidden_dim * 2)
        self.post_cat = nn.Linear(self.hidden_dim * 2, self.pooling_dim)
        self.ctx_compress = nn.Linear(self.pooling_dim, self.num_rel_cls)
        
        layer_init(self.ctx_compress, xavier=True)
        layer_init(self.post_emb, 10.0 * (1.0 / self.hidden_dim) ** 0.5, normal=True)
        layer_init(self.post_cat, xavier=True)
        
        if self.pooling_dim != config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM:
            self.union_single_not_match = True
            self.up_dim = nn.Linear(config.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM, self.pooling_dim)
            layer_init(self.up_dim, xavier=True)
        else:
            self.union_single_not_match = False

        self.freq_bias = FrequencyBias(config, statistics)

        # [新增] 逆关系测试支持
        hp_cfg = getattr(config.MODEL.ROI_RELATION_HEAD, "HP", None)
        self.inverse_alpha = getattr(hp_cfg, "INVERSE_ALPHA", 0.0) if hp_cfg else 0.0
        self.inverse_mapper = create_inverse_mapper(config, rel_classes)
        self._inverse_debug_logged = 0
        
        print(f"[VCTreePredictor_Inverse] inverse_alpha = {self.inverse_alpha}")

    def maybe_augment_relations(self, rel_pair_idxs, rel_labels, rel_binarys=None):
        """逗关系数据增强方法（不使用增强）"""
        inverse_flags = []
        for labels in rel_labels:
            if labels is None:
                inverse_flags.append(None)
            else:
                inverse_flags.append(labels.new_zeros(labels.size(0), dtype=torch.bool, device=labels.device))
        return rel_pair_idxs, rel_labels, rel_binarys, inverse_flags

    def forward(self, proposals, rel_pair_idxs, rel_labels, rel_binarys, rel_inverse_flags=None, roi_features=None, union_features=None, logger=None):
        """
        Returns:
            obj_dists (list[Tensor]): logits of object label distribution
            rel_dists (list[Tensor])
            rel_pair_idxs (list[Tensor]): (num_rel, 2) index of subject and object
            union_features (Tensor): (batch_num_rel, context_pooling_dim): visual union feature of each pair
        """
        add_losses = {}

        # encode context information
        obj_dists, obj_preds, edge_ctx, binary_preds = self.context_layer(roi_features, proposals, rel_pair_idxs, logger)

        # post decode
        edge_rep = F.relu(self.post_emb(edge_ctx))
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim)
        head_rep = edge_rep[:, 0].contiguous().view(-1, self.hidden_dim)
        tail_rep = edge_rep[:, 1].contiguous().view(-1, self.hidden_dim)

        num_rels = [r.shape[0] for r in rel_pair_idxs]
        num_objs = [len(b) for b in proposals]
        assert len(num_rels) == len(num_objs)

        head_reps = head_rep.split(num_objs, dim=0)
        tail_reps = tail_rep.split(num_objs, dim=0)
        obj_preds = obj_preds.split(num_objs, dim=0)
        
        prod_reps = []
        prod_reps_inv = []  # [新增] 逆关系视觉特征
        pair_preds = []
        
        for pair_idx, head_rep, tail_rep, obj_pred in zip(rel_pair_idxs, head_reps, tail_reps, obj_preds):
            # 原始：(subj, obj)
            prod_reps.append(torch.cat((head_rep[pair_idx[:, 0]], tail_rep[pair_idx[:, 1]]), dim=-1))
            # [新增] 逆关系：(obj, subj) - 交换主客体
            prod_reps_inv.append(torch.cat((head_rep[pair_idx[:, 1]], tail_rep[pair_idx[:, 0]]), dim=-1))
            pair_preds.append(torch.stack((obj_pred[pair_idx[:, 0]], obj_pred[pair_idx[:, 1]]), dim=1))
        
        prod_rep = cat(prod_reps, dim=0)
        prod_rep_inv = cat(prod_reps_inv, dim=0)  # [新增]
        pair_pred = cat(pair_preds, dim=0)
        pair_pred_inv = pair_pred[:, [1, 0]]  # [新增] 交换主客体预测

        prod_rep = self.post_cat(prod_rep)
        prod_rep_inv = self.post_cat(prod_rep_inv)  # [新增]

        if self.union_single_not_match:
            union_features_up = self.up_dim(union_features)
        else:
            union_features_up = union_features

        # 原始关系预测
        ctx_dists = self.ctx_compress(prod_rep * union_features_up)
        frq_dists = self.freq_bias.index_with_labels(pair_pred.long())
        rel_dists = ctx_dists + frq_dists

        # [新增] 测试阶段的逆关系融合
        if (not self.training) and self.inverse_alpha > 0:
            self.inverse_mapper.log_mapping_info()
            
            # 计算逆关系预测
            ctx_dists_inv = self.ctx_compress(prod_rep_inv * union_features_up)
            frq_dists_inv = self.freq_bias.index_with_labels(pair_pred_inv.long())
            rel_dists_inv_raw = ctx_dists_inv + frq_dists_inv
            
            # 将逆关系预测分数重排到对应的原关系位置
            rel_dists_inv = self.inverse_mapper.remap_scores(rel_dists_inv_raw)
            
            # 融合原始预测和逆关系预测（排除 background）
            rel_dists[:, 1:] = (1 - self.inverse_alpha) * rel_dists[:, 1:] + self.inverse_alpha * rel_dists_inv[:, 1:]
            
            # Debug 输出
            if self._inverse_debug_logged < 5 and rel_dists.size(0) > 0:
                num_log = min(5 - self._inverse_debug_logged, rel_dists.size(0))
                rel_classes = self.inverse_mapper.rel_classes
                for idx in range(num_log):
                    # 获取前10个预测结果
                    k = min(10, rel_dists.size(1) - 1)
                    orig_vals, orig_inds = torch.topk(rel_dists[idx, 1:], k=k)
                    inv_vals, inv_inds = torch.topk(rel_dists_inv[idx, 1:], k=k)
                    
                    # 转换为概率
                    orig_probs = F.softmax(rel_dists[idx, 1:], dim=0)
                    inv_probs = F.softmax(rel_dists_inv[idx, 1:], dim=0)
                    
                    print(f"\n[VCTree Inverse Debug #{self._inverse_debug_logged + idx}]")
                    print("  Fused Top-10:")
                    for i, (idx_rel, val) in enumerate(zip(orig_inds, orig_vals)):
                        rel_name = rel_classes[idx_rel.item() + 1]  # +1 因为排除了background
                        prob = orig_probs[idx_rel].item()
                        print(f"    {i+1}. [{idx_rel.item()+1}] {rel_name:20s} (logit={val.item():.3f}, prob={prob:.4f})")
                    
                    print("  Inverse Top-10:")
                    for i, (idx_rel, val) in enumerate(zip(inv_inds, inv_vals)):
                        rel_name = rel_classes[idx_rel.item() + 1]
                        prob = inv_probs[idx_rel].item()
                        print(f"    {i+1}. [{idx_rel.item()+1}] {rel_name:20s} (logit={val.item():.3f}, prob={prob:.4f})")
                self._inverse_debug_logged += num_log

        obj_dists = obj_dists.split(num_objs, dim=0)
        rel_dists = rel_dists.split(num_rels, dim=0)

        if self.training:
            binary_loss = []
            for bi_gt, bi_pred in zip(rel_binarys, binary_preds):
                bi_gt = (bi_gt > 0).float()
                binary_loss.append(F.binary_cross_entropy_with_logits(bi_pred, bi_gt))
            add_losses["binary_loss"] = sum(binary_loss) / len(binary_loss)

        if self.attribute_on:
            att_dists = att_dists.split(num_objs, dim=0)
            return (obj_dists, att_dists), rel_dists, add_losses
        else:
            return obj_dists, rel_dists, add_losses



