# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn

from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from ..attribute_head.roi_attribute_feature_extractors import make_roi_attribute_feature_extractor
from ..box_head.roi_box_feature_extractors import make_roi_box_feature_extractor
from .roi_relation_feature_extractors import make_roi_relation_feature_extractor
from .roi_relation_predictors import make_roi_relation_predictor
from .inference import make_roi_relation_post_processor
from .loss import make_roi_relation_loss_evaluator
from .sampling import make_roi_relation_samp_processor

class ROIRelationHead(nn.Module):
    """
    Generic Relation Head class.
    """

    def __init__(self, cfg, in_channels):
        super(ROIRelationHead, self).__init__()
        self.cfg = cfg.clone()
        # same structure with box head, but different parameters
        # these param will be trained in a slow learning rate, while the parameters of box head will be fixed
        # Note: there is another such extractor in uniton_feature_extractor
        self.union_feature_extractor = make_roi_relation_feature_extractor(cfg, in_channels)
        if cfg.MODEL.ATTRIBUTE_ON:
            self.box_feature_extractor = make_roi_box_feature_extractor(cfg, in_channels, half_out=True)
            self.att_feature_extractor = make_roi_attribute_feature_extractor(cfg, in_channels, half_out=True)
            feat_dim = self.box_feature_extractor.out_channels * 2
        else:
            self.box_feature_extractor = make_roi_box_feature_extractor(cfg, in_channels)
            feat_dim = self.box_feature_extractor.out_channels
        self.predictor = make_roi_relation_predictor(cfg, feat_dim)
        self.post_processor = make_roi_relation_post_processor(cfg)
        self.loss_evaluator = make_roi_relation_loss_evaluator(cfg)
        self.samp_processor = make_roi_relation_samp_processor(cfg)

        # parameters
        self.use_union_box = self.cfg.MODEL.ROI_RELATION_HEAD.PREDICT_USE_VISION
        self.rel_pair_box = self._build_rel_pair_generator(cfg)

    def _build_rel_pair_generator(self, cfg):
        def _gen_pairs(proposals, use_gt_boxes=False):
            rel_pair_idxs = []
            for boxes in proposals:
                if boxes is None or len(boxes) == 0:
                    rel_pair_idxs.append(boxes.new_zeros((0, 2), dtype=torch.long))
                    continue
                num_obj = len(boxes)
                idx = torch.arange(num_obj, device=boxes.bbox.device)
                subj, obj = torch.meshgrid(idx, idx, indexing="ij")
                mask = subj != obj
                rel_pair_idxs.append(torch.stack([subj[mask], obj[mask]], dim=-1))
            return rel_pair_idxs
        return _gen_pairs

    def forward(self, features, proposals, targets=None, logger=None):
        """
        Arguments:
            features (list[Tensor]): feature-maps from possibly several levels
            proposals (list[BoxList]): proposal boxes. Note: it has been post-processed (regression, nms) in sgdet mode
            targets (list[BoxList], optional): the ground-truth targets.

        Returns:
            x (Tensor): the result of the feature extractor
            proposals (list[BoxList]): during training, the subsampled proposals
                are returned. During testing, the predicted boxlists are returned
            losses (dict[Tensor]): During training, returns the losses for the
                head. During testing, returns an empty dict.
        """
        if self.training:
            proposals, rel_labels, rel_pair_idxs, rel_binarys = self.samp_processor(
                proposals, targets
            )
        else:
            rel_labels = None
            rel_binarys = None
            rel_pair_idxs = self.prepare_test_pairs(proposals)

        rel_inverse_flags = None
        if (
            self.training
            and hasattr(self.predictor, "maybe_augment_relations")
        ):
            rel_pair_idxs, rel_labels, rel_binarys, rel_inverse_flags = self.predictor.maybe_augment_relations(
                rel_pair_idxs, rel_labels, rel_binarys
            )

        roi_features = self.box_feature_extractor(features, proposals)
        union_features = self.union_feature_extractor(features, proposals, rel_pair_idxs)
        predictor_result = self.predictor(
            proposals,
            rel_pair_idxs,
            rel_labels,
            rel_binarys,
            rel_inverse_flags,
            roi_features,
            union_features,
            logger,
        )
        # 兼容3返回值或4返回值的predictor
        if len(predictor_result) == 4:
            refine_logits, relation_logits, add_losses, add_data = predictor_result
        else:
            refine_logits, relation_logits, add_losses = predictor_result
            add_data = {}

        # for test
        if not self.training:
            result = self.post_processor((relation_logits, refine_logits), rel_pair_idxs, proposals)
            return roi_features, result, {}

        # 调用 loss_evaluator，支持 RWT 模式（需要 add_data）
        loss_result = self.loss_evaluator(proposals, rel_labels, relation_logits, refine_logits, add_data)
        
        # 兼容 RWT 模式（返回3个值：loss_relation, loss_refine, loss_ce）
        # 和非 RWT 模式（返回2个值：loss_relation, loss_refine）
        if len(loss_result) == 3:
            loss_relation, loss_refine, loss_ce = loss_result
        else:
            loss_relation, loss_refine = loss_result
            loss_ce = None

        if self.cfg.MODEL.ATTRIBUTE_ON and isinstance(loss_refine, (list, tuple)):
            output_losses = dict(loss_rel=loss_relation, loss_refine_obj=loss_refine[0], loss_refine_att=loss_refine[1])
        else:
            output_losses = dict(loss_rel=loss_relation, loss_refine_obj=loss_refine)

        # 添加 RWT 的 loss_ce（如果存在）
        if loss_ce is not None:
            output_losses['loss_ce'] = loss_ce

        output_losses.update(add_losses)

        return roi_features, proposals, output_losses

    def prepare_test_pairs(self, proposals):
        if self.training:
            raise RuntimeError("prepare_test_pairs 仅在测试阶段调用")
        return self.rel_pair_box(proposals, use_gt_boxes=False)


def build_roi_relation_head(cfg, in_channels):
    """
    Constructs a new relation head.
    By default, uses ROIRelationHead, but if it turns out not to be enough, just register a new class
    and make it a parameter in the config
    """
    return ROIRelationHead(cfg, in_channels)
