# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..builder import BBOX_ASSIGNERS
from ..match_costs import build_match_cost
from ..transforms import bbox_cxcywh_to_xyxy
from .assign_result import AssignResult
from .base_assigner import BaseAssigner

try:
    from scipy.optimize import linear_sum_assignment
except ImportError:
    linear_sum_assignment = None


@BBOX_ASSIGNERS.register_module()
class DapAssigner(BaseAssigner):
    """Computes one-to-one matching between predictions and ground truth.

    This class computes an assignment between the targets and the predictions
    based on the costs. The costs are weighted sum of three components:
    classification cost, regression L1 cost and regression iou cost. The
    targets don't include the no_object, so generally there are more
    predictions than targets. After the one-to-one matching, the un-matched
    are treated as backgrounds. Thus each query prediction will be assigned
    with `0` or a positive integer indicating the ground truth index:

    - 0: negative sample, no assigned gt
    - positive integer: positive sample, index (1-based) of assigned gt

    Args:
        cls_weight (int | float, optional): The scale factor for classification
            cost. Default 1.0.
        bbox_weight (int | float, optional): The scale factor for regression
            L1 cost. Default 1.0.
        iou_weight (int | float, optional): The scale factor for regression
            iou cost. Default 1.0.
        iou_calculator (dict | optional): The config for the iou calculation.
            Default type `BboxOverlaps2D`.
        iou_mode (str | optional): "iou" (intersection over union), "iof"
                (intersection over foreground), or "giou" (generalized
                intersection over union). Default "giou".
    """

    def __init__(self,
                 cls_cost=dict(type='ClassificationCost', weight=1.),
                 reg_cost=dict(type='BBoxL1Cost', weight=1.0),
                 iou_cost=dict(type='IoUCost', iou_mode='iou', weight=1.0),
                 giou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)):
        self.cls_cost = build_match_cost(cls_cost)
        self.reg_cost = build_match_cost(reg_cost)
        self.iou_cost = build_match_cost(iou_cost)
        self.giou_cost = build_match_cost(giou_cost)
        self.delta_matrix = None
        self.mode_A = True
        self.use_fp_cost = True
        self.dim_emb =128
        self.i_S = (torch.arange(0,self.dim_emb)+0.5)/self.dim_emb
        self.normal_dist = torch.distributions.Normal(0, 0.05)
        self.iter_step = 0

    @torch.no_grad()
    def cal_tp_cum_matrix(self, mem_dis_tp):
        tp_sum_all = mem_dis_tp.sum(dim=1,keepdim=True)
        tp_cumsum = torch.cumsum(mem_dis_tp, dim=1)
        tp_cumsum = tp_sum_all-tp_cumsum
        tp_cumsum = torch.clamp(tp_cumsum, min=0.0)
        # tp_cumsum[-1] = 1.0
        return tp_cumsum
    @torch.no_grad()
    def update(self, assigned_gt_inds, assigned_labels, ious):
        gt_scores =  self.cls_pred[assigned_gt_inds>0,assigned_labels[assigned_gt_inds>0]]
        pos_scores , _ = torch.topk(self.cls_pred.flatten(0), 300, axis=0)
        gt_score_v = self.score_to_svec(gt_scores) #[N,256]
        tp_50_mem = (gt_score_v * (ious>=0.5).unsqueeze(1)).sum(dim=0)
        tp_55_mem = (gt_score_v * (ious>=0.55).unsqueeze(1)).sum(dim=0)
        tp_60_mem = (gt_score_v * (ious>=0.60).unsqueeze(1)).sum(dim=0)
        tp_65_mem = (gt_score_v * (ious>=0.65).unsqueeze(1)).sum(dim=0)
        tp_70_mem = (gt_score_v * (ious>=0.70).unsqueeze(1)).sum(dim=0)
        tp_75_mem = (gt_score_v * (ious>=0.75).unsqueeze(1)).sum(dim=0)
        tp_80_mem = (gt_score_v * (ious>=0.80).unsqueeze(1)).sum(dim=0)
        tp_85_mem = (gt_score_v * (ious>=0.85).unsqueeze(1)).sum(dim=0)
        tp_90_mem = (gt_score_v * (ious>=0.90).unsqueeze(1)).sum(dim=0)
        tp_95_mem = (gt_score_v * (ious>=0.95).unsqueeze(1)).sum(dim=0)
        # [10, 256]
        tp_matrix=torch.stack([tp_50_mem,tp_55_mem,tp_60_mem,tp_65_mem,tp_70_mem,tp_75_mem,tp_80_mem,tp_85_mem,tp_90_mem,tp_95_mem])
        tp_cumsum_matrix = self.cal_tp_cum_matrix(tp_matrix.double()+1e-14) # [10, 256]
        
        #pos
        pos_score_v = self.score_to_svec(pos_scores.view(-1))
        pos_sum_v = pos_score_v.sum(dim=0)
        pos_sum_all = pos_sum_v.sum(dim=0,keepdim=True)
        pos_cumsum = torch.cumsum(pos_sum_v, dim=0)
        pos_cumsum = pos_sum_all-pos_cumsum
        pos_cumsum = torch.clamp(pos_cumsum, min=0.0)

        # P_matrix
        P_matrix = tp_cumsum_matrix/(pos_cumsum+1e-8).unsqueeze(0)
        P_matrix[:, 1:]= torch.where(P_matrix[:, 1:]>P_matrix[:, :-1], P_matrix[:, 1:], P_matrix[:, :-1])
        P_matrix = torch.cumsum(P_matrix, dim=0)
        
        if self.delta_matrix is None:
            self.delta_matrix = P_matrix.double()
            
        else:
            self.iter_step += 1
            warmup_step = 1000
            m = 1e-3
            if self.iter_step<warmup_step:
                m += 0.5*(1.0-self.iter_step/warmup_step)
            self.delta_matrix = (1.0-m)*self.delta_matrix + m*P_matrix.data.double()
            self.delta_matrix = torch.clamp(self.delta_matrix, min=1e-14)
    @torch.no_grad()  
    def interp(self, x, left_x, right_x, left_y, right_y):
        bias = (x-left_x)/(right_x-left_x)
        bias = bias.unsqueeze(1)
        # x tensor: [], left&right is float 
        return bias*right_y.unsqueeze(0) + (1.0-bias)*left_y.unsqueeze(0)
    @torch.no_grad()
    def score_to_svec(self, score):
        N = score.shape[0]
        res_s = self.i_S.to(score.device).unsqueeze(0).repeat([N,1])-score.unsqueeze(1).repeat([1,self.dim_emb])
        result = self.normal_dist.log_prob(res_s).exp().double()
        result = result/self.dim_emb
        result[:, 0] += self.normal_dist.cdf(-1.0*score)
        result[:, -1] += (1.0 - self.normal_dist.cdf(1.0-score))
        return result

    def update_batch(self, P_matrix):
        if self.delta_matrix is None:
            self.delta_matrix = P_matrix.double()
            
        else:
            self.iter_step += 1
            warmup_step = 333
            m = 0.98
            if self.iter_step<warmup_step:
                m -= 0.5*(1.0-self.iter_step/warmup_step)
            self.delta_matrix = m*self.delta_matrix + (1.0-m)*P_matrix.data.double()
            self.delta_matrix = torch.clamp(self.delta_matrix, min=1e-14)
        

    @torch.no_grad()
    def assign(self,
               bbox_pred,
               cls_pred,
               gt_bboxes,
               gt_labels,
               img_meta,
               gt_bboxes_ignore=None,
               eps=1e-7):
        """Computes one-to-one matching based on the weighted costs.

        This method assign each query prediction to a ground truth or
        background. The `assigned_gt_inds` with -1 means don't care,
        0 means negative sample, and positive number is the index (1-based)
        of assigned gt.
        The assignment is done in the following steps, the order matters.

        1. assign every prediction to -1
        2. compute the weighted costs
        3. do Hungarian matching on CPU based on the costs
        4. assign all to 0 (background) first, then for each matched pair
           between predictions and gts, treat this prediction as foreground
           and assign the corresponding gt index (plus 1) to it.

        Args:
            bbox_pred (Tensor): Predicted boxes with normalized coordinates
                (cx, cy, w, h), which are all in range [0, 1]. Shape
                [num_query, 4].
            cls_pred (Tensor): Predicted classification logits, shape
                [num_query, num_class].
            gt_bboxes (Tensor): Ground truth boxes with unnormalized
                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
            img_meta (dict): Meta information for current image.
            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
                labelled as `ignored`. Default None.
            eps (int | float, optional): A value added to the denominator for
                numerical stability. Default 1e-7.

        Returns:
            :obj:`AssignResult`: The assigned result.
        """
        assert gt_bboxes_ignore is None, \
            'Only case when gt_bboxes_ignore is None is supported.'
        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)

        # 1. assign -1 by default
        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
                                              -1,
                                              dtype=torch.long)
        assigned_labels = bbox_pred.new_full((num_bboxes, ),
                                             -1,
                                             dtype=torch.long)
        if num_gts == 0 or num_bboxes == 0:
            # No ground truth or boxes, return empty assignment
            if num_gts == 0:
                # No ground truth, assign all to background
                assigned_gt_inds[:] = 0
            return AssignResult(
                num_gts, assigned_gt_inds, None, labels=assigned_labels)
        img_h, img_w, _ = img_meta['img_shape']
        factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
                                       img_h]).unsqueeze(0)
        # 2. compute the weighted costs
        # classification and bboxcost.
        cls_cost = self.cls_cost(cls_pred, gt_labels)
        # regression L1 cost
        normalize_gt_bboxes = gt_bboxes / factor
        reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
        # regression iou cost, defaultly giou is used in official DETR.
        bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
        iou_cost = self.iou_cost(bboxes, gt_bboxes)
        giou_cost =self.giou_cost(bboxes, gt_bboxes)
        # weighted sum of above three costs
        cost = cls_cost + reg_cost + giou_cost

        if self.use_fp_cost:
            self.cls_pred = cls_pred.sigmoid()
            pred_prob = self.cls_pred[:, gt_labels]

            if self.delta_matrix is None or self.mode_A:
                ious = -1.0 * iou_cost #[N, T]
                hit_iou = (ious>=0.5) #[N, T]
                score_iou = torch.zeros_like(ious, device=ious.device)
                score_iou[hit_iou] = (2.0*ious[hit_iou]-0.9)
                score_iou[~hit_iou] = -0.2*giou_cost[~hit_iou]
                cost = -1*pred_prob*score_iou
            else:

                ious = -1.0 * iou_cost #[N, T]
                scores_vector = self.score_to_svec(pred_prob.view(-1))
                scores_vector = scores_vector.reshape([pred_prob.size(0),pred_prob.size(1),-1])
                iou_threshold_list=[0.5,0.55,0.60,0.65,0.70,0.75,0.80,0.85,0.90,0.95]
                result = torch.zeros_like(ious, dtype=torch.float64, device=ious.device)

                for i in range(0, len(iou_threshold_list)-1):
                    mask = (ious>=iou_threshold_list[i])&(ious<iou_threshold_list[i+1])
                    if (mask.sum()==0):
                        continue
                    weight = self.interp(ious[mask], iou_threshold_list[i], iou_threshold_list[i+1], self.delta_matrix[i], self.delta_matrix[i+1])
                    result[mask] = (scores_vector[mask]*weight).sum(dim=1)
                    

                # iou<0.50 or iou>0.95
                mask = (ious<iou_threshold_list[0])
                if (mask.sum()!=0):
                    weight = self.interp(ious[mask], 0, iou_threshold_list[0], torch.zeros_like(self.delta_matrix[0], device=self.delta_matrix.device), self.delta_matrix[0])
                    result[mask] = (scores_vector[mask]*weight).sum(dim=1)


                mask = (ious>iou_threshold_list[-1])
                if (mask.sum()!=0):
                    weight = self.interp(ious[mask], iou_threshold_list[-2], iou_threshold_list[-1], self.delta_matrix[-2], self.delta_matrix[-1])
                    result[mask] = (scores_vector[mask]*weight).sum(dim=1)

                cost = -1.0*result

        # 3. do Hungarian matching on CPU using linear_sum_assignment
        cost = cost.detach().cpu()
        if linear_sum_assignment is None:
            raise ImportError('Please run "pip install scipy" '
                              'to install scipy first.')
        matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
        matched_row_inds = torch.from_numpy(matched_row_inds).to(
            bbox_pred.device)
        matched_col_inds = torch.from_numpy(matched_col_inds).to(
            bbox_pred.device)

        # 4. assign backgrounds and foregrounds
        # assign all indices to backgrounds first
        assigned_gt_inds[:] = 0
        # assign foregrounds based on matching results
        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
        # self.update(assigned_gt_inds,assigned_labels, ious=-1*iou_cost[matched_row_inds,matched_col_inds])

        return AssignResult(
            num_gts, assigned_gt_inds, None, labels=assigned_labels)

