# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import numpy.random as npr

from maskrcnn_benchmark.layers import smooth_l1_loss, Label_Smoothing_Regression
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.modeling.utils import cat


class RelationLossComputation(object):
    """
    Computes the loss for relation triplet.
    Also supports FPN
    """

    def __init__(
            self,
            cfg,
            attri_on,
            num_attri_cat,
            max_num_attri,
            attribute_sampling,
            attribute_bgfg_ratio,
            use_label_smoothing,
            predicate_proportion,
    ):
        """
        Arguments:
            bbox_proposal_matcher (Matcher)
            rel_fg_bg_sampler (RelationPositiveNegativeSampler)
        """
        self.attri_on = attri_on
        self.num_attri_cat = num_attri_cat
        self.max_num_attri = max_num_attri
        self.attribute_sampling = attribute_sampling
        self.attribute_bgfg_ratio = attribute_bgfg_ratio
        self.use_label_smoothing = use_label_smoothing
        
        # RWT (Re-Weighting) 相关配置
        self.use_rwt = cfg.MODEL.ROI_RELATION_HEAD.USE_RWT
        self.rwt_alpha = cfg.MODEL.ROI_RELATION_HEAD.RWT_ALPHA  # 建议范围 [0.0, 1.0]，越大越关注尾部
        self.rwt_eps = cfg.MODEL.ROI_RELATION_HEAD.RWT_EPS
        self.rwt_normalize = cfg.MODEL.ROI_RELATION_HEAD.RWT_NORMALIZE
        
        # 计算类别重加权权重
        if self.use_rwt and predicate_proportion:
            # predicate_proportion: 各个谓词类别的比例 (长度为 num_predicates - 1，不含背景类)
            # 标准 RWT: w_i = 1 / freq_i
            # 平滑版本: w_i = freq_i^(-alpha)，alpha ∈ [0, 1]
            #   - alpha=0: 无重加权，所有类权重相同
            #   - alpha=1: 完全反比例，尾部类权重最大
            #   - alpha=0.5: 平方根平滑，推荐用于SGG
            
            prop = torch.FloatTensor([0.5] + predicate_proportion)
            
            if self.rwt_alpha == 0.0:
                # alpha=0: 不重加权
                self.pred_weight = torch.ones_like(prop)
            else:
                # 使用 freq^(-alpha) = (1/freq)^alpha 的形式
                # 注意：这里 alpha 应该 <= 1.0，避免极端权重
                # 当 alpha=1.0 时等价于标准 RWT
                inv_prop = 1.0 / (prop + self.rwt_eps)
                if self.rwt_alpha != 1.0:
                    # 平滑版本：对倒数取 alpha 次幂（降低极端值）
                    self.pred_weight = inv_prop.pow(self.rwt_alpha)
                else:
                    # 标准版本：直接使用倒数
                    self.pred_weight = inv_prop
                
                # 裁剪权重，防止极端值（最大权重不超过最小权重的50倍）
                min_weight = self.pred_weight.min()
                max_weight = self.pred_weight.max()
                if max_weight > min_weight * 50:
                    self.pred_weight = torch.clamp(self.pred_weight, min=min_weight, max=min_weight * 50)
                    print(f"[RWT] Clipped extreme weights: max/min ratio = {max_weight/min_weight:.1f} -> 50.0")
            
            # 归一化：使权重均值为 1，便于与其他损失项平衡
            if self.rwt_normalize:
                self.pred_weight = self.pred_weight / self.pred_weight.mean()
            
            # 打印权重统计信息
            print(f"[RWT] Enabled with alpha={self.rwt_alpha:.2f}")
            print(f"[RWT] Weight stats: min={self.pred_weight.min():.3f}, max={self.pred_weight.max():.3f}, "
                  f"mean={self.pred_weight.mean():.3f}, std={self.pred_weight.std():.3f}")
            
            self.pred_weight = self.pred_weight.cuda()
        else:
            self.pred_weight = None

        if self.use_label_smoothing:
            self.criterion_loss = Label_Smoothing_Regression(e=0.01)
        else:
            # 如果使用 RWT，则在 __call__ 中手动计算加权损失，此处使用 reduction='none'
            if self.use_rwt and self.pred_weight is not None:
                self.criterion_loss = nn.CrossEntropyLoss(reduction='none')
            else:
                self.criterion_loss = nn.CrossEntropyLoss()

    def __call__(self, proposals, rel_labels, relation_logits, refine_logits, add_losses=None):
        """
        Computes the loss for relation triplet.
        This requires that the subsample method has been called beforehand.

        Arguments:
            relation_logits (list[Tensor])
            refine_obj_logits (list[Tensor])
            add_losses (dict, optional): 包含额外信息的字典,可能包含 'rel_dists_fip' 和 'rel_dists_iip'

        Returns:
            predicate_loss (Tensor)
            finetune_obj_loss (Tensor)
        """
        if self.attri_on:
            if isinstance(refine_logits[0], (list, tuple)):
                refine_obj_logits, refine_att_logits = refine_logits
            else:
                # just use attribute feature, do not actually predict attribute
                self.attri_on = False
                refine_obj_logits = refine_logits
        else:
            refine_obj_logits = refine_logits

        relation_logits = cat(relation_logits, dim=0)
        refine_obj_logits = cat(refine_obj_logits, dim=0)

        fg_labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)
        rel_labels = cat(rel_labels, dim=0)

        # === 标准的BP损失计算 (支持 RWT 重加权) ===
        if self.use_rwt and self.pred_weight is not None:
            # 使用 RWT: 根据类别比例进行重加权
            # criterion_loss 返回每个样本的损失 (reduction='none')
            loss_per_sample = self.criterion_loss(relation_logits, rel_labels.long())
            # 获取每个样本对应类别的权重
            sample_weights = self.pred_weight[rel_labels.long()]
            # 加权求均值 -> 标量
            loss_relation = (loss_per_sample * sample_weights).mean()
        else:
            loss_relation = self.criterion_loss(relation_logits, rel_labels.long())
            # 若 criterion 返回向量(例如自定义或未来配置), 做安全降维
            if hasattr(loss_relation, "dim") and loss_relation.dim() > 0:
                loss_relation = loss_relation.mean()

        # === 如果有分层logits,计算额外的损失 ===
        if add_losses is not None and 'rel_dists_fip' in add_losses:
            rel_dists_fip = add_losses.get('rel_dists_fip')
            rel_dists_iip = add_losses.get('rel_dists_iip')
            tail_indices = add_losses.get('tail_indices', [])

            N = len(rel_labels)
            omega = 1.0  # 超参数

            # FIP损失 (仅前景类)
            if rel_dists_fip is not None:
                logits_fip = cat(rel_dists_fip, dim=0)
                fg_mask = rel_labels != 0
                if fg_mask.sum() > 0:
                    loss_fip = self.criterion_loss(logits_fip[fg_mask], rel_labels[fg_mask].long())
                    if hasattr(loss_fip, "dim") and loss_fip.dim() > 0:
                        loss_fip = loss_fip.mean()
                    N_fg = fg_mask.sum().float()
                    loss_relation += omega * loss_fip * N / (N_fg + 1e-6)

            # IIP损失 (仅尾部类)
            if rel_dists_iip is not None and len(tail_indices) > 0:
                logits_iip = cat(rel_dists_iip, dim=0)
                tail_mask = torch.zeros_like(rel_labels, dtype=torch.bool)
                for idx in tail_indices:
                    tail_mask |= (rel_labels == idx)

                if tail_mask.sum() > 0:
                    loss_iip = self.criterion_loss(logits_iip[tail_mask], rel_labels[tail_mask].long())
                    if hasattr(loss_iip, "dim") and loss_iip.dim() > 0:
                        loss_iip = loss_iip.mean()
                    N_tail = tail_mask.sum().float()
                    loss_relation += omega * loss_iip * N / (N_tail + 1e-6)

        # 对象精修损失：若使用 RWT 则 criterion 为 reduction='none'，需显式取均值
        loss_refine_obj = self.criterion_loss(refine_obj_logits, fg_labels.long())
        if hasattr(loss_refine_obj, "dim") and loss_refine_obj.dim() > 0:
            loss_refine_obj = loss_refine_obj.mean()

        # The following code is used to calcaulate sampled attribute loss
        if self.attri_on:
            refine_att_logits = cat(refine_att_logits, dim=0)
            fg_attributes = cat([proposal.get_field("attributes") for proposal in proposals], dim=0)

            attribute_targets, fg_attri_idx = self.generate_attributes_target(fg_attributes)
            if float(fg_attri_idx.sum()) > 0:
                # have at least one bbox got fg attributes
                refine_att_logits = refine_att_logits[fg_attri_idx > 0]
                attribute_targets = attribute_targets[fg_attri_idx > 0]
            else:
                refine_att_logits = refine_att_logits[0].view(1, -1)
                attribute_targets = attribute_targets[0].view(1, -1)

            loss_refine_att = self.attribute_loss(refine_att_logits, attribute_targets,
                                                  fg_bg_sample=self.attribute_sampling,
                                                  bg_fg_ratio=self.attribute_bgfg_ratio)
            return loss_relation, (loss_refine_obj, loss_refine_att)
        else:
            return loss_relation, loss_refine_obj

    def generate_attributes_target(self, attributes):
        """
        from list of attribute indexs to [1,0,1,0,0,1] form
        """
        assert self.max_num_attri == attributes.shape[1]
        device = attributes.device
        num_obj = attributes.shape[0]

        fg_attri_idx = (attributes.sum(-1) > 0).long()
        attribute_targets = torch.zeros((num_obj, self.num_attri_cat), device=device).float()

        for idx in torch.nonzero(fg_attri_idx).squeeze(1).tolist():
            for k in range(self.max_num_attri):
                att_id = int(attributes[idx, k])
                if att_id == 0:
                    break
                else:
                    attribute_targets[idx, att_id] = 1
        return attribute_targets, fg_attri_idx

    def attribute_loss(self, logits, labels, fg_bg_sample=True, bg_fg_ratio=3):
        if fg_bg_sample:
            loss_matrix = F.binary_cross_entropy_with_logits(logits, labels, reduction='none').view(-1)
            fg_loss = loss_matrix[labels.view(-1) > 0]
            bg_loss = loss_matrix[labels.view(-1) <= 0]

            num_fg = fg_loss.shape[0]
            # if there is no fg, add at least one bg
            num_bg = max(int(num_fg * bg_fg_ratio), 1)
            perm = torch.randperm(bg_loss.shape[0], device=bg_loss.device)[:num_bg]
            bg_loss = bg_loss[perm]

            return torch.cat([fg_loss, bg_loss], dim=0).mean()
        else:
            attri_loss = F.binary_cross_entropy_with_logits(logits, labels)
            attri_loss = attri_loss * self.num_attri_cat / 20.0
            return attri_loss


class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):
        target = target.view(-1)

        logpt = F.log_softmax(input)
        logpt = logpt.index_select(-1, target).diag()
        logpt = logpt.view(-1)
        pt = logpt.exp()

        logpt = logpt * self.alpha * (target > 0).float() + logpt * (1 - self.alpha) * (target <= 0).float()

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()


def make_roi_relation_loss_evaluator(cfg):
    loss_evaluator = RelationLossComputation(
        cfg,
        cfg.MODEL.ATTRIBUTE_ON,
        cfg.MODEL.ROI_ATTRIBUTE_HEAD.NUM_ATTRIBUTES,
        cfg.MODEL.ROI_ATTRIBUTE_HEAD.MAX_ATTRIBUTES,
        cfg.MODEL.ROI_ATTRIBUTE_HEAD.ATTRIBUTE_BGFG_SAMPLE,
        cfg.MODEL.ROI_ATTRIBUTE_HEAD.ATTRIBUTE_BGFG_RATIO,
        cfg.MODEL.ROI_RELATION_HEAD.LABEL_SMOOTHING_LOSS,
        cfg.MODEL.ROI_RELATION_HEAD.REL_PROP,
    )

    return loss_evaluator