import torch
from mmdet.core.bbox.builder import BBOX_ASSIGNERS
from mmdet.core.bbox.assigners import AssignResult
from mmdet.core.bbox.assigners import BaseAssigner
from mmdet.core.bbox.match_costs import build_match_cost
import torch.nn.functional as F
from mmdet.core.bbox.transforms import bbox_xyxy_to_cxcywh, bbox_cxcywh_to_xyxy
try:
    from scipy.optimize import linear_sum_assignment
except ImportError:
    linear_sum_assignment = None

def denormalize_3d_pts(pts, pc_range):
    new_pts = pts.clone()
    new_pts[...,0:1] = (pts[..., 0:1]*(pc_range[3] -
                            pc_range[0]) + pc_range[0])
    new_pts[...,1:2] = (pts[...,1:2]*(pc_range[4] -
                            pc_range[1]) + pc_range[1])
    new_pts[...,2:3] = (pts[...,2:3]*(pc_range[5] -
                            pc_range[2]) + pc_range[2])
    return new_pts

def normalize_3d_pts(pts, pc_range):
    patch_h = pc_range[4]-pc_range[1]
    patch_w = pc_range[3]-pc_range[0]
    patch_z = pc_range[5]-pc_range[2]
    new_pts = pts.clone()
    new_pts[...,0:1] = pts[..., 0:1] - pc_range[0]
    new_pts[...,1:2] = pts[...,1:2] - pc_range[1]
    new_pts[...,2:3] = pts[...,2:3] - pc_range[2]
    factor = pts.new_tensor([patch_w, patch_h,patch_z])
    normalized_pts = new_pts / factor
    return normalized_pts

def normalize_2d_bbox(bboxes, pc_range):

    patch_h = pc_range[4]-pc_range[1]
    patch_w = pc_range[3]-pc_range[0]
    cxcywh_bboxes = bbox_xyxy_to_cxcywh(bboxes)
    cxcywh_bboxes[...,0:1] = cxcywh_bboxes[..., 0:1] - pc_range[0]
    cxcywh_bboxes[...,1:2] = cxcywh_bboxes[...,1:2] - pc_range[1]
    factor = bboxes.new_tensor([patch_w, patch_h,patch_w,patch_h])

    normalized_bboxes = cxcywh_bboxes / factor
    return normalized_bboxes

def normalize_2d_pts(pts, pc_range):
    patch_h = pc_range[4]-pc_range[1]
    patch_w = pc_range[3]-pc_range[0]
    new_pts = pts.clone()
    new_pts[...,0:1] = pts[..., 0:1] - pc_range[0]
    new_pts[...,1:2] = pts[...,1:2] - pc_range[1]
    factor = pts.new_tensor([patch_w, patch_h])
    normalized_pts = new_pts / factor
    return normalized_pts

def denormalize_2d_bbox(bboxes, pc_range):

    bboxes = bbox_cxcywh_to_xyxy(bboxes)
    bboxes[..., 0::2] = (bboxes[..., 0::2]*(pc_range[3] -
                            pc_range[0]) + pc_range[0])
    bboxes[..., 1::2] = (bboxes[..., 1::2]*(pc_range[4] -
                            pc_range[1]) + pc_range[1])

    return bboxes
def denormalize_2d_pts(pts, pc_range):
    new_pts = pts.clone()
    new_pts[...,0:1] = (pts[..., 0:1]*(pc_range[3] -
                            pc_range[0]) + pc_range[0])
    new_pts[...,1:2] = (pts[...,1:2]*(pc_range[4] -
                            pc_range[1]) + pc_range[1])
    return new_pts

@BBOX_ASSIGNERS.register_module()
class MapTRAssigner(BaseAssigner):
    """Computes one-to-one matching between predictions and ground truth.
    This class computes an assignment between the targets and the predictions
    based on the costs. The costs are weighted sum of three components:
    classification cost, regression L1 cost and regression iou cost. The
    targets don't include the no_object, so generally there are more
    predictions than targets. After the one-to-one matching, the un-matched
    are treated as backgrounds. Thus each query prediction will be assigned
    with `0` or a positive integer indicating the ground truth index:
    - 0: negative sample, no assigned gt
    - positive integer: positive sample, index (1-based) of assigned gt
    Args:
        cls_weight (int | float, optional): The scale factor for classification
            cost. Default 1.0.
        bbox_weight (int | float, optional): The scale factor for regression
            L1 cost. Default 1.0.
        iou_weight (int | float, optional): The scale factor for regression
            iou cost. Default 1.0.
        iou_calculator (dict | optional): The config for the iou calculation.
            Default type `BboxOverlaps2D`.
        iou_mode (str | optional): "iou" (intersection over union), "iof"
                (intersection over foreground), or "giou" (generalized
                intersection over union). Default "giou".
    """

    def __init__(self,
                 z_cfg = dict(
                    pred_z_flag=False,
                    gt_z_flag=False,
                 ),
                 cls_cost=dict(type='ClassificationCost', weight=1.),
                 reg_cost=dict(type='BBoxL1Cost', weight=1.0),
                 iou_cost=dict(type='IoUCost', weight=0.0),
                 pts_cost=dict(type='ChamferDistance',loss_src_weight=1.0,loss_dst_weight=1.0),
                 pc_range=None):
        self.z_cfg = z_cfg
        self.cls_cost = build_match_cost(cls_cost)
        self.reg_cost = build_match_cost(reg_cost)
        self.iou_cost = build_match_cost(iou_cost)
        self.pts_cost = build_match_cost(pts_cost)
        self.pc_range = pc_range

    def assign(self,
               bbox_pred,
               cls_pred,
               pts_pred,
               gt_bboxes, 
               gt_labels,
               gt_pts,
               gt_masked,
               gt_bboxes_ignore=None,
               eps=1e-7):
        """Computes one-to-one matching based on the weighted costs.
        This method assign each query prediction to a ground truth or
        background. The `assigned_gt_inds` with -1 means don't care,
        0 means negative sample, and positive number is the index (1-based)
        of assigned gt.
        The assignment is done in the following steps, the order matters.
        1. assign every prediction to -1
        2. compute the weighted costs
        3. do Hungarian matching on CPU based on the costs
        4. assign all to 0 (background) first, then for each matched pair
           between predictions and gts, treat this prediction as foreground
           and assign the corresponding gt index (plus 1) to it.
        Args:
            bbox_pred (Tensor): Predicted boxes with normalized coordinates
                (cx, cy, w, h), which are all in range [0, 1]. Shape
                [num_query, 4].
            cls_pred (Tensor): Predicted classification logits, shape
                [num_query, num_class].
            gt_bboxes (Tensor): Ground truth boxes with unnormalized
                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
            gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
                labelled as `ignored`. Default None.
            eps (int | float, optional): A value added to the denominator for
                numerical stability. Default 1e-7.
        Returns:
            :obj:`AssignResult`: The assigned result.
        """
        assert gt_bboxes_ignore is None, \
            'Only case when gt_bboxes_ignore is None is supported.'
        assert bbox_pred.shape[-1] == 4, \
            'Only support bbox pred shape is 4 dims'
        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)

        _, num_orders, num_pts_per_gtline, num_coords = gt_pts.shape

        bbox_pred_unmasked = bbox_pred
        cls_pred_unmasked = cls_pred
        pts_pred_unmasked = pts_pred
        gt_bboxes_unmasked = gt_bboxes
        gt_labels_unmasked = gt_labels
        gt_pts_unmasked = gt_pts

        # 1. assign -1 by default
        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
                                              -1,
                                              dtype=torch.long)
        assigned_labels = bbox_pred.new_full((num_bboxes, ),
                                             -1,
                                             dtype=torch.long)
        if num_gts == 0 or num_bboxes == 0:
            # No ground truth or boxes, return empty assignment
            if num_gts == 0:
                # No ground truth, assign all to background
                assigned_gt_inds[:] = 0
            return AssignResult(
                num_gts, assigned_gt_inds, None, labels=assigned_labels), None

        # assign existing map element queries by default
        if gt_masked.numel() and torch.count_nonzero(torch.logical_not(gt_masked)) > 0:
            query_mask = torch.logical_not(gt_masked)
            num_not_masked = torch.count_nonzero(query_mask)
            assign_num = min(num_not_masked, num_bboxes)

            if len(torch.argwhere(query_mask).squeeze().shape) > 0:
                assigned_gt_inds[0:assign_num] = torch.argwhere(query_mask).squeeze()[0:assign_num] + 1 # I think the +1 is because of the mmdet3d sampler?
                assigned_labels[0:assign_num] = gt_labels[torch.argwhere(query_mask).squeeze()[0:assign_num]]
            else:
                assigned_gt_inds[0:assign_num] = torch.argwhere(query_mask).squeeze() + 1 # I think the +1 is because of the mmdet3d sampler?
                assigned_labels[0:assign_num] = gt_labels[torch.argwhere(query_mask).squeeze()]
            
            bbox_pred = bbox_pred[assign_num:]
            cls_pred = cls_pred[assign_num:]
            pts_pred = pts_pred[assign_num:]
            gt_bboxes = gt_bboxes[gt_masked]
            gt_labels = gt_labels[gt_masked]
            gt_pts = gt_pts[gt_masked]
        else:
            query_mask = gt_bboxes.new_full((num_gts, ),
                                             False,
                                             dtype=bool)
            gt_masked = torch.logical_not(query_mask)
            assign_num = 0

        # 2. compute the weighted costs
        # classification and bboxcost.
        cls_cost = self.cls_cost(cls_pred, gt_labels)
        # regression L1 cost
        
        normalized_gt_bboxes = normalize_2d_bbox(gt_bboxes, self.pc_range)
        # normalized_gt_bboxes = gt_bboxes
        # import pdb;pdb.set_trace()
        reg_cost = self.reg_cost(bbox_pred[:, :4], normalized_gt_bboxes[:, :4])

        normalized_gt_pts = normalize_2d_pts(gt_pts, self.pc_range) if not self.z_cfg['gt_z_flag'] \
                        else normalize_3d_pts(gt_pts, self.pc_range)
        num_pts_per_predline = pts_pred.size(1)
        if num_pts_per_predline != num_pts_per_gtline:
            pts_pred_interpolated = F.interpolate(pts_pred.permute(0,2,1),size=(num_pts_per_gtline),
                                            mode='linear', align_corners=True)
            pts_pred_interpolated = pts_pred_interpolated.permute(0,2,1).contiguous()
        else:
            pts_pred_interpolated = pts_pred
        # num_q, num_pts, 2 <-> num_gt, num_pts, 2
        
        if pts_pred_interpolated.numel() and normalized_gt_pts.numel():
            pts_cost_ordered = self.pts_cost(pts_pred_interpolated, normalized_gt_pts)
            pts_cost_ordered = pts_cost_ordered.view(num_bboxes-assign_num, num_gts-assign_num, num_orders)
            pts_cost, order_index_masked = torch.min(pts_cost_ordered, 2)
        else:
            pts_cost = torch.zeros((num_bboxes-assign_num, torch.count_nonzero(gt_masked)), dtype=torch.long, device=bbox_pred.device)
            order_index_masked = torch.zeros((num_bboxes-assign_num, torch.count_nonzero(gt_masked)), dtype=torch.long, device=bbox_pred.device)

        normalized_gt_pts_unmasked = normalize_2d_pts(gt_pts_unmasked, self.pc_range) if not self.z_cfg['gt_z_flag'] \
                else normalize_3d_pts(gt_pts_unmasked, self.pc_range)
        if num_pts_per_predline != num_pts_per_gtline:
            pts_pred_interpolated_unmasked = F.interpolate(pts_pred_unmasked.permute(0,2,1),size=(num_pts_per_gtline),
                                            mode='linear', align_corners=True)
            pts_pred_interpolated_unmasked = pts_pred_interpolated_unmasked.permute(0,2,1).contiguous()
        else:
            pts_pred_interpolated_unmasked = pts_pred_unmasked

        if pts_pred_interpolated_unmasked[:assign_num].numel() and normalized_gt_pts_unmasked[query_mask].numel():
            pts_cost_ordered_unmasked = self.pts_cost(pts_pred_interpolated_unmasked[:assign_num], normalized_gt_pts_unmasked[query_mask])
            # print(pts_cost_ordered_unmasked.shape)
            pts_cost_ordered_unmasked = pts_cost_ordered_unmasked.view(assign_num, num_not_masked, num_orders)
            _, order_index_unmasked = torch.min(pts_cost_ordered_unmasked, 2)
        else:
            order_index_unmasked = torch.zeros((assign_num, assign_num), dtype=torch.long, device=bbox_pred.device)

        # print(order_index_unmasked.shape)
        # print(order_index_masked.shape)
        # print(order_index_unmasked)
        # print(order_index_masked)

        order_index = torch.zeros((num_bboxes, num_gts), dtype=torch.long, device=bbox_pred.device)
        order_index[:assign_num, query_mask] = order_index_unmasked
        order_index[assign_num:, gt_masked] = order_index_masked
        
        bboxes = denormalize_2d_bbox(bbox_pred, self.pc_range)
        iou_cost = self.iou_cost(bboxes, gt_bboxes)
        # weighted sum of above three costs
        cost = cls_cost + reg_cost + iou_cost + pts_cost
        
        # 3. do Hungarian matching on CPU using linear_sum_assignment
        cost = cost.detach().cpu()
        if linear_sum_assignment is None:
            raise ImportError('Please run "pip install scipy" '
                              'to install scipy first.')
        matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
        matched_row_inds = torch.from_numpy(matched_row_inds).to(
            bbox_pred.device)
        matched_col_inds = torch.from_numpy(matched_col_inds).to(
            bbox_pred.device)

        # 4. assign backgrounds and foregrounds
        # assign all indices to backgrounds first
        assigned_gt_inds[assign_num:] = 0
        # assign foregrounds based on matching results
        # print(torch.argwhere(gt_masked)[matched_col_inds].squeeze())
        # print(gt_labels.shape)
        # print(order_index.shape)

        # print("DEBUG DUMP: ")
        # print(assigned_gt_inds.shape)
        # print(assigned_labels.shape)
        # print(assigned_gt_inds)
        # print(assigned_labels)
        # print(bbox_pred.shape)
        # print(cls_pred.shape)
        # print(pts_pred.shape)
        # print(gt_bboxes.shape)
        # print(gt_labels.shape)
        # print(gt_pts.shape)

        assigned_gt_inds[matched_row_inds + assign_num] = torch.argwhere(gt_masked)[matched_col_inds].squeeze() + 1 # I think the +1 is because of the mmdet3d sampler?
        assigned_labels[matched_row_inds + assign_num] = gt_labels[matched_col_inds]

        # print("ASSIGNMENTS NEW:")
        # print(assigned_gt_inds)
        # print(assigned_labels)
        # torch.set_printoptions(threshold=10_000)
        # print(order_index)

        return AssignResult(
            num_gts, assigned_gt_inds, None, labels=assigned_labels), order_index

#     def assign(self,
#                bbox_pred,
#                cls_pred,
#                pts_pred,
#                gt_bboxes, 
#                gt_labels,
#                gt_pts,
#                gt_masked,
#                gt_bboxes_ignore=None,
#                eps=1e-7):
#         """Computes one-to-one matching based on the weighted costs.
#         This method assign each query prediction to a ground truth or
#         background. The `assigned_gt_inds` with -1 means don't care,
#         0 means negative sample, and positive number is the index (1-based)
#         of assigned gt.
#         The assignment is done in the following steps, the order matters.
#         1. assign every prediction to -1
#         2. compute the weighted costs
#         3. do Hungarian matching on CPU based on the costs
#         4. assign all to 0 (background) first, then for each matched pair
#            between predictions and gts, treat this prediction as foreground
#            and assign the corresponding gt index (plus 1) to it.
#         Args:
#             bbox_pred (Tensor): Predicted boxes with normalized coordinates
#                 (cx, cy, w, h), which are all in range [0, 1]. Shape
#                 [num_query, 4].
#             cls_pred (Tensor): Predicted classification logits, shape
#                 [num_query, num_class].
#             gt_bboxes (Tensor): Ground truth boxes with unnormalized
#                 coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
#             gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
#             gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
#                 labelled as `ignored`. Default None.
#             eps (int | float, optional): A value added to the denominator for
#                 numerical stability. Default 1e-7.
#         Returns:
#             :obj:`AssignResult`: The assigned result.
#         """
#         assert gt_bboxes_ignore is None, \
#             'Only case when gt_bboxes_ignore is None is supported.'
#         assert bbox_pred.shape[-1] == 4, \
#             'Only support bbox pred shape is 4 dims'
#         num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
# 
#         # 1. assign -1 by default
#         assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
#                                               -1,
#                                               dtype=torch.long)
#         assigned_labels = bbox_pred.new_full((num_bboxes, ),
#                                              -1,
#                                              dtype=torch.long)
#         if num_gts == 0 or num_bboxes == 0:
#             # No ground truth or boxes, return empty assignment
#             if num_gts == 0:
#                 # No ground truth, assign all to background
#                 assigned_gt_inds[:] = 0
#             return AssignResult(
#                 num_gts, assigned_gt_inds, None, labels=assigned_labels), None
# 
#         # 2. compute the weighted costs
#         # classification and bboxcost.
#         cls_cost = self.cls_cost(cls_pred, gt_labels)
#         # regression L1 cost
#     
#         normalized_gt_bboxes = normalize_2d_bbox(gt_bboxes, self.pc_range)
#         # normalized_gt_bboxes = gt_bboxes
#         # import pdb;pdb.set_trace()
#         reg_cost = self.reg_cost(bbox_pred[:, :4], normalized_gt_bboxes[:, :4])
# 
#         _, num_orders, num_pts_per_gtline, num_coords = gt_pts.shape
#         normalized_gt_pts = normalize_2d_pts(gt_pts, self.pc_range) if not self.z_cfg['gt_z_flag'] \
#                         else normalize_3d_pts(gt_pts, self.pc_range)
#         num_pts_per_predline = pts_pred.size(1)
#         if num_pts_per_predline != num_pts_per_gtline:
#             pts_pred_interpolated = F.interpolate(pts_pred.permute(0,2,1),size=(num_pts_per_gtline),
#                                             mode='linear', align_corners=True)
#             pts_pred_interpolated = pts_pred_interpolated.permute(0,2,1).contiguous()
#         else:
#             pts_pred_interpolated = pts_pred
#         # num_q, num_pts, 2 <-> num_gt, num_pts, 2
#         pts_cost_ordered = self.pts_cost(pts_pred_interpolated, normalized_gt_pts)
#         pts_cost_ordered = pts_cost_ordered.view(num_bboxes, num_gts, num_orders)
#         pts_cost, order_index = torch.min(pts_cost_ordered, 2)
# 
#         bboxes = denormalize_2d_bbox(bbox_pred, self.pc_range)
#         iou_cost = self.iou_cost(bboxes, gt_bboxes)
#         # weighted sum of above three costs
#         cost = cls_cost + reg_cost + iou_cost + pts_cost
# 
#         # 3. do Hungarian matching on CPU using linear_sum_assignment
#         cost = cost.detach().cpu()
#         if linear_sum_assignment is None:
#             raise ImportError('Please run "pip install scipy" '
#                               'to install scipy first.')
#         matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
#         matched_row_inds = torch.from_numpy(matched_row_inds).to(
#             bbox_pred.device)
#         matched_col_inds = torch.from_numpy(matched_col_inds).to(
#             bbox_pred.device)
# 
#         # 4. assign backgrounds and foregrounds
#         # assign all indices to backgrounds first
#         assigned_gt_inds[:] = 0
#         # assign foregrounds based on matching results
#         assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
#         assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
# 
#         # print("ASSIGNMENTS:")
#         # print(assigned_gt_inds)
#         # print(assigned_labels)
#         # torch.set_printoptions(threshold=10_000)self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
#         # print(order_index)
# 
#         return AssignResult(
#             num_gts, assigned_gt_inds, None, labels=assigned_labels), order_index
