from ..builder import DETECTORS
from .two_stage import TwoStageDetector
from ..builder import build_backbone, build_head, build_neck
import torch.nn as nn
import torch
from mmdet.core import bbox2result, bbox2roi, bbox_xyxy_to_cxcywh, bboxbbox2result, bboxpoint2results
from mmdet.core.bbox.iou_calculators import bbox_overlaps
import numpy as np
from mmcv.ops import point_sample
from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
import cv2
from mmcv.ops.nms import batched_nms                
from mmcv.ops import point_sample
import torch.nn.functional as F


@DETECTORS.register_module()
class DPM(TwoStageDetector):
    def __init__(self,
                 backbone,
                 rpn_head,
                 roi_head,
                 lss_head,
                 keypoint_head,
                 train_cfg,
                 test_cfg,
                 roi_skip_fpn=False,
                 test_wo_detector=False,
                 test_on_fcos=False,
                 test_point_sam=False,
                 test_part_sam=False,
                 train_sam=False,
                 neck=None,
                 init_cfg=None,
                 *args, **kwargs):
        super(DPM, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg,
            *args, **kwargs)
        
        if lss_head is not None:
            lss_train_cfg = train_cfg.lss if train_cfg is not None else None
            lss_head.update(train_cfg=lss_train_cfg)
            lss_head.update(test_cfg=test_cfg.lss)
            self.lss_head = build_head(lss_head)
        if keypoint_head is not None:
#             keypoint_head.update(train_cfg=train_cfg.keypoint_head)
            self.keypoint_head = build_head(keypoint_head)
            
        self.roi_skip_fpn = roi_skip_fpn
        self.test_wo_detector = test_wo_detector
        self.test_on_fcos = test_on_fcos
        self.test_point_sam = test_point_sam
        self.test_part_sam = test_part_sam
        self.train_sam = train_sam
        
        if self.test_point_sam or self.test_part_sam or self.train_sam:
            from segment_anything import sam_model_registry, SamPredictor
#             sam_checkpoint = "/home/GuoZonghao/sam_vit_h_4b8939.pth"
#             model_type = "vit_h"

            sam_checkpoint = "/home/GuoZonghao/sam_vit_b_01ec64.pth"
            model_type = "vit_b"

            device = "cuda"
            sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
            sam.to(device='cuda')
            self.predictor = SamPredictor(sam)
            
    @property
    def with_lss_head(self):
        return hasattr(self, 'lss_head') and self.lss_head is not None
    @property
    def with_decoder_head(self):
        return hasattr(self, 'decoder_head') and self.decoder_head is not None
            
    def get_roi_feat(self, x, vit_feat):
        B, _, H, W = x[2].shape
        x = [
            vit_feat.transpose(1, 2).reshape(B, -1, H, W).contiguous()
        ]
        return x
    
    def extract_feat(self, img, gt_points=None, gt_labels=None):
        # 说明只有一个用于imted的fpn
        x = self.backbone(img)
        x = list(x)
        x[0] = self.neck(x[0])
        return x
            
    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_points=None,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None,
                      **kwargs):
        
        mean = img_metas[0]['img_norm_cfg']['mean']
        std = img_metas[0]['img_norm_cfg']['std']
        
        image = img[0].clone()
        for i, c in enumerate(img[0]):
            image[i] = c * std[i] + mean[i]
        image = image.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        img_h, img_w, _ = image.shape
#         import pdb
#         pdb.set_trace()
        
        batch_input_shape = tuple(img[0].size()[-2:])
        for i, _ in enumerate(img_metas):
            img_metas[i]['batch_input_shape'] = batch_input_shape
        # aug strategy causes that there may be no gts.
        # 
        gt_masks_ = []
        for gt_mask in gt_masks:
            gt_mask = gt_mask.masks
            n, ori_h, ori_w = gt_mask.shape
            padding_gt_mask = np.zeros((n, *batch_input_shape), dtype=np.bool)
            padding_gt_mask[:, :ori_h, :ori_w] = gt_mask
            padding_gt_mask = torch.as_tensor(padding_gt_mask, device=gt_bboxes[0].device).bool()
            gt_masks_.append(padding_gt_mask)
        gt_masks = gt_masks_
        # 
        empty = False
        for g in gt_bboxes: # 只要有一个batch中有一个img 无gt的，就直接都不要了
            if len(g) == 0:
                empty = True
        if empty:
            gt_bboxes = [torch.as_tensor([[20, 20, 40, 40]], 
                                         dtype=gt_bboxes[0].dtype,
                                         device=gt_bboxes[0].device) 
                         for _ in range(len(img_metas))]
            gt_labels = [torch.as_tensor([0], 
                                         dtype=gt_labels[0].dtype,
                                         device=gt_labels[0].device) 
                         for _ in range(len(img_metas))]
            gt_points = [torch.as_tensor([[30, 30]], 
                                         dtype=gt_bboxes[0].dtype,
                                         device=gt_bboxes[0].device) 
                         for _ in range(len(img_metas))]
            gt_masks = [torch.zeros((1, *batch_input_shape), device=gt_bboxes[0].device).bool()
                         for _ in range(len(img_metas))]
            
        pseudo_gt_bboxes, pseudo_gt_labels = None, None
        x = self.extract_feat(img)
        if len(x) == 6:
            x, vit_feat, point_tokens, attns, scale_features, vit_feat_be_norm = x
            B, _, H, W = x[2].shape
            # center points as gt_points
            gt_points = [torch.cat([
                bboxes[:, 0::2].mean(-1).unsqueeze(-1), 
                bboxes[:, 1::2].mean(-1).unsqueeze(-1)
            ], dim=-1) for bboxes in gt_bboxes]
            # point settings
            imgs_whwh = []
            for meta in img_metas:
                h, w, _ = meta['img_shape']
                imgs_whwh.append(x[0].new_tensor([[w, h]]))
            imgs_whwh = torch.cat(imgs_whwh, dim=0)
            imgs_whwh = imgs_whwh[:, None, :]
            
            losses = dict()            
            # point training / pseudo gt generation
            points_results = self.lss_head.forward_train_point(x,
                                                            vit_feat,
                                                            point_tokens,
                                                            attns,
                                                            img_metas,
                                                            gt_bboxes,
                                                            gt_labels,
                                                            gt_points,
                                                            imgs_whwh=imgs_whwh,
                                                            gt_masks=gt_masks)
            losses.update(points_results['loss_point'])
            # 第一阶段wsddn
            mil_results = self.lss_head.forward_train_mil(x,
                                                        vit_feat,
                                                        point_tokens,
                                                        None,
                                                        None,
                                                        points_results['multiple_cams'],
                                                        points_results['refined_multiple_masks'],
                                                        points_results['pseudo_proposals'],
                                                        img_metas,
                                                        points_results['gt_bboxes'],
                                                        points_results['gt_labels'],
                                                        points_results['gt_points'],
                                                        imgs_whwh=imgs_whwh) 
            losses.update(mil_results['loss_mil'])
            iou_metric = mil_results['iou_metric']
            iou_metric = dict(iou_metric=iou_metric)
            losses.update(iou_metric)
            
            pseudo_gt_bboxes = mil_results['pseudo_gt_bboxes']
            pseudo_gt_labels = mil_results['pseudo_gt_labels']
            
#             if self.train_sam:
#                 self.predictor.set_image(image)
                
#                 pos_point_pred = points_results['pos_point_pred'] #[0]
#                 pos_part_offset_pred = points_results['pos_part_offset_pred'] #[0]
#                 pos_object_points = pos_point_pred.unsqueeze(1) * imgs_whwh + pos_part_offset_pred
                
#                 cls_score, point_pred, \
#                 part_offset_pred = self.lss_head.point_head(point_tokens)
    
#                 root_scores, _, _ = self.keypoint_head(vit_feat) # bs, N, 20 # bs, 
#                 root_scores = root_scores.softmax(-1) # b, n, c
#                 print(root_scores.shape, B, W, H)
#                 root_scores = root_scores.reshape(B, H, W, -1).permute(0, 3, 1, 2)
                
#                 gt_labels_temp = torch.cat(points_results['gt_labels'])
#                 num_gt = gt_labels_temp.shape[0]
#                 root_scores_ = F.interpolate(root_scores, (img_h, img_w), mode='bilinear')
                
                
#                 gt_points_temp = torch.cat(points_results['gt_points'])
#                 pos_object_points = torch.cat([gt_points_temp.unsqueeze(1), pos_object_points], dim=1)
#                 pos_part_scores = point_sample(root_scores_.repeat(num_gt, 1, 1, 1), pos_object_points / imgs_whwh)
#                 pos_part_scores = pos_part_scores[torch.arange(num_gt).cuda(), gt_labels_temp, ...]
                
#                 # 只用box进行prompt
#                 pseudo_gt_bboxes_temp = torch.cat(pseudo_gt_bboxes)
#                 transformed_boxes = self.predictor.transform.apply_boxes_torch(pseudo_gt_bboxes_temp, (img_h, img_w))
#                 masks, scores, logits = self.predictor.predict_torch(
#                     point_coords=None,
#                     point_labels=None,
#                     boxes=transformed_boxes,
#                     multimask_output=False,
#                 )
#                 fg_masks = [masks.unsqueeze(1)]

#                 # 得到结果再增加point prompt
#                 pos_thr = 0.9
#                 all_masks = []
#                 all_points_ = []
#                 for i, (part_points, logit, part_score) in enumerate(zip(pos_object_points, logits, pos_part_scores)):
#                     true_part = part_score > pos_thr
    
#                     part_points = part_points[true_part]
#                     all_points_.append(part_points)
#                     part_lables = true_part[true_part]
    
#                     part_points = part_points.half().detach().cpu().numpy()
#                     part_lables = part_lables.long().detach().cpu().numpy()
# #                     import pdb
# #                     pdb.set_trace()
#                     masks, _, _ = self.predictor.predict(
#                         point_coords=part_points,
#                         point_labels=part_lables,
#                         mask_input=logit,
#                         multimask_output=False,
#                     )
#                     all_masks.append(masks)
#                 fg_masks = torch.from_numpy(np.vstack(all_masks)).cuda()
#                 img_ = image.copy()
#                 for b in pseudo_gt_bboxes_temp:
#                     xmin, ymin, xmax, ymax = b.int().detach().cpu().numpy()
#                     cv2.rectangle(img_, (xmin, ymin), (xmax, ymax), (255, 0, 0))
#                 for ps in all_points_:
#                     for p in ps:
#                         xc, yc = p.int().detach().cpu().numpy()
#                         cv2.circle(img_,(xc,yc),10,(0,0,255))
                    
#                 cv2.imwrite('1.png', img_)
#                 for i, im in enumerate(all_masks):
#                     cv2.imwrite('{}.png'.format(str(i) + '_'), im[0].astype(np.uint8) * 255)
                
#                 import pdb
#                 pdb.set_trace()
#                 fg_masks = [fg_masks]
                
                

                
#                 img_h, img_w = imgs_whwh.reshape(-1)
#                 fg_masks = []
#                 for bboxes in pseudo_gt_bboxes:
#                     transformed_boxes = self.predictor.transform.apply_boxes_torch(bboxes, (img_h, img_w))
#                     masks, _, _ = self.predictor.predict_torch(
#                         point_coords=None,
#                         point_labels=None,
#                         boxes=transformed_boxes,
#                         multimask_output=False,
#                     ) 
#                     fg_masks.append(masks.float()) # num_gt, 1, img_h, img_w
                    
            attnshift_results, losses_hinge = self.keypoint_head.forward_train(x,
                                                               vit_feat,
                                                               mil_results['matched_cams'], # num_gt, 1, img_h, img_w
#                                                                fg_masks,
                                                               img_metas,
                                                               pseudo_gt_bboxes,
                                                               pseudo_gt_labels,
                                                               points_results['gt_points'],
                                                               mil_results['semantic_scores'],
                                                               vit_feat_be_norm,
                                                               imgs_whwh=imgs_whwh,
                                                               gt_masks=points_results['gt_masks'],
                                                               sam=self.predictor if self.train_sam else None,
                                                            )
            losses.update(losses_hinge)
            
            losses_part_offset = self.lss_head.point_head.part_offset_loss(
                points_results['pos_part_offset_pred'],
                points_results['pos_point_pred'],
                attnshift_results['true_part_targets'], # 这个因为是top3，所以还需要另外搞一个来进行part约束
                attnshift_results['true_part_visibles'],
                imgs_whwh=imgs_whwh,
#                 pseudo_gt_bboxes=pseudo_gt_bboxes
            )
            losses.update(losses_part_offset)
#             losses.update(losses_part_offset)
            
            # losses_point = self.lss_head._calculate_loss(img_metas,
            #                                              points_results['gt_points'], 
            #                                              points_results['gt_labels'], 
            #                                              # dedetr_results['fg_points'], 
            #                                              attnshift_results['all_semantic_points'],
            #                                              attnshift_results['all_visible_weights'],
            #                                              points_results,
            #                                              imgs_whwh=imgs_whwh)
            # losses.update(losses_point)
            
            # rpn setting 
            proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)
            rpn_losses, proposal_list = self.rpn_head.forward_train(
                x,
                img_metas,
                gt_bboxes if pseudo_gt_bboxes is None else pseudo_gt_bboxes,
                # points_results['pseudo_proposals'],
                # gt_labels=points_results['gt_labels'],
                gt_labels=None,
                gt_bboxes_ignore=gt_bboxes_ignore,
                proposal_cfg=proposal_cfg)
            
            losses.update(rpn_losses)
            
            
            self.predictor.set_image(image)
            true_part_targets = attnshift_results['true_part_targets'][0]
            pseudo_gt_bboxes_ = mil_results['pseudo_gt_bboxes'][0]
            true_part_visibles = attnshift_results['true_part_visibles'][0]
            
            fg_masks = []
            for part_points, bbox, vis in zip(true_part_targets, pseudo_gt_bboxes_, true_part_visibles):
                part_points = part_points[vis].detach().cpu().numpy()
                part_labels = np.ones(len(part_points)).astype(np.int)
                bbox = bbox.detach().cpu().numpy()

                masks, _, _ = self.predictor.predict(
                    point_coords=part_points,
                    point_labels=part_labels,
                    box=bbox,
                    multimask_output=False,
                )
                fg_masks.append(masks)

#             gt_points_temp = torch.cat(points_results['gt_points']) #num_gt, 4
#             fg_masks = []
#             for center_point in gt_points_temp:
#                 part_points = center_point.unsqueeze(0).detach().cpu().numpy()
#                 part_labels = np.ones(len(part_points)).astype(np.int)

#                 masks, scores, _ = self.predictor.predict(
#                     point_coords=part_points,
#                     point_labels=part_labels,
#                     box=None,
#                     multimask_output=True,
#                 )
#                 masks = masks[np.argmax(scores)].reshape(1, img_h, img_w)
#                 fg_masks.append(masks)

                
            img_size = [img_h, img_w]
            # bboxes = pseudo_gt_bboxes
            fg_bbox_masks = box2mask(pseudo_gt_bboxes_, img_size, default_val=0)

            final_fg_masks = []
            final_bg_masks = []
            final_fg_points = []
            final_bg_points = []
            final_fg_labels = []
            final_bg_labels = []
            
            for bbox_mask, fg_mask, bbox in zip(fg_bbox_masks, fg_masks, pseudo_gt_bboxes_):

                final_fg_mask = bbox_mask.bool().detach().cpu().numpy() & fg_mask #.bool()
                final_fg_masks.append(final_fg_mask)
                coords = torch.from_numpy(final_fg_mask)[0].nonzero().flip(-1)
                fg_points = coords[torch.randperm(len(coords))[:100]]
                fg_labels = torch.ones(100).bool()
                
                if len(fg_points) < 100:
                    num_pad = 100 - len(fg_points)
                    fg_points = torch.cat([fg_points, 
                                           -1000 * torch.ones(num_pad * 2).type_as(fg_points).reshape(-1 ,2)])
                final_fg_points.append(fg_points)
                final_fg_labels.append(fg_labels)
                
                final_bg_mask = bbox_mask.bool().detach().cpu().numpy() & ~fg_mask #.bool()
                final_bg_masks.append(final_bg_mask)
                coords = torch.from_numpy(final_bg_mask)[0].nonzero().flip(-1)
                bg_points = coords[torch.randperm(len(coords))[:100]]
                bg_labels = torch.zeros(100).bool()
                
                if len(bg_points) < 100:
                    num_pad = 100 - len(bg_points)
                    bg_points = torch.cat([bg_points, 
                                           -1000 * torch.ones(num_pad * 2).type_as(bg_points).reshape(-1 ,2)])
                    
                final_bg_points.append(bg_points)
                final_bg_labels.append(bg_labels)

            # final_fg_masks = [torch.stack(final_fg_masks)]
            # final_bg_masks = [torch.stack(final_bg_masks)]
            final_fg_points = [torch.stack(final_fg_points).cuda().float()]
            final_bg_points = [torch.stack(final_bg_points).cuda().float()]
            final_fg_labels = [torch.stack(final_fg_labels).cuda()]
            final_bg_labels = [torch.stack(final_bg_labels).cuda()]
            
            # rcnn setting
            if self.roi_skip_fpn: # imted
                roi_losses = self.roi_head.forward_train(self.get_roi_feat(x, vit_feat), img_metas, proposal_list,
                                                         gt_bboxes if pseudo_gt_bboxes is None else pseudo_gt_bboxes, 
                                                         gt_labels if pseudo_gt_labels is None else pseudo_gt_labels,
#                                                          gt_bboxes if pseudo_gt_bboxes is None else dedetr_results['gt_bboxes'], 
#                                                          gt_labels if pseudo_gt_labels is None else dedetr_results['gt_labels'],
                                                         final_bg_points,
                                                         final_bg_labels,
#                                                          attnshift_results['pseudo_points'], 
#                                                          attnshift_results['pseudo_bin_labels'],
#                                                          attnshift_results['all_semantic_points'],
#                                                          attnshift_results['all_visible_weights'],
                                                         # attnshift_results['all_dpm_points'],
                                                         # attnshift_results['all_dpm_visible'],
#                                                          attnshift_results['all_mask_sup_points'],
#                                                          attnshift_results['all_mask_sup_visibles'],
                                                         final_fg_points,
                                                         final_fg_labels,
#                                                          attnshift_results['new_all_mask_sup_points'],
#                                                          attnshift_results['new_all_mask_sup_visibles'],
                                                         gt_bboxes_ignore, 
#                                                          gt_masks=attnshift_results['gt_masks'],
                                                         img=img, **kwargs)        
        
            else: # faster rcnn
                assert False, 'no implement'
#                 roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
#                                                          pseudo_gt_bboxes, pseudo_gt_labels,
#                                                          gt_bboxes_ignore, gt_masks,
#                                                          **kwargs)
            losses.update(roi_losses)
            # debug
            # if losses['dpm_bce_loss'] == 0 or losses['dpm_aux_bce_loss'] == 0:
            #     print(losses['dpm_bce_loss'], losses['dpm_aux_bce_loss'])
                
            if empty:
                for k, v in losses.items():
                    if k == 'loss_rpn_cls' or k == 'loss_rpn_bbox':
                        losses[k] = [l * 0 for l in losses[k]]
                    else:
                        losses[k] = v * 0
            return losses
        else:
            assert False, 'no implemention'
    
    def simple_test(self, img, img_metas, 
                    gt_bboxes=None, 
                    gt_labels=None,
                    gt_points=None, 
                    gt_masks=None,
                    proposals=None, 
                    rescale=False):
        
        """Test without augmentation."""
        assert self.with_bbox, 'Bbox head must be implemented.'
        x = self.extract_feat(img)
        
#         if len(x) == 5:
        if len(x) == 6:
            x, vit_feat, point_tokens, attns, scale_features, vit_feat_be_norm = x
            # if False:
            #     pass
            
            if self.test_point_sam or self.test_part_sam:
                
                mean=[123.675, 116.28, 103.53]
                std=[58.395, 57.12, 57.375]
                image = img[0].clone()
                for i, c in enumerate(img[0]):
                    image[i] = c * std[i] + mean[i]
                image = image.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
                image_ = image.copy()
        
                self.predictor.set_image(image_)
                
                B, _, H, W = x[2].shape
                num_query = point_tokens.size(1)
                cls_score, point_pred, \
                part_offset_pred = self.lss_head.point_head(point_tokens)

                imgs_whwh = []
                for meta in img_metas:
                    h, w, _ = meta['img_shape']
                    ori_h, ori_w, _ = meta['ori_shape']
                    imgs_whwh.append(x[0].new_tensor([[w, h]]))
                imgs_whwh = torch.cat(imgs_whwh, dim=0)
                imgs_whwh = imgs_whwh[:, None, :] # 1, 1, 2

                # 获得所有inference的结果
                point_det_score, point_det_labels = cls_score.sigmoid().max(-1)
                point_det = point_pred * imgs_whwh # 1, 100, 2
                point_part_det = point_det.unsqueeze(-2) + part_offset_pred.reshape(B, num_query, 5, -1)
                object_points = torch.cat([point_part_det,
                                           point_det.unsqueeze(-2),
                                           point_det.unsqueeze(-2),
                                           point_det.unsqueeze(-2),
                                           point_det.unsqueeze(-2),
                                           point_det.unsqueeze(-2)], dim=-2) # b, 100, 5, 2
                object_points_norm = object_points / imgs_whwh.unsqueeze(-2)
                # 获得所有inference的结果

                root_scores, _, _ = self.keypoint_head(vit_feat) # bs, N, 20 # bs, 
                root_scores = root_scores.softmax(-1) # b, n, c
                root_scores = root_scores.reshape(B, H, W, -1).permute(0, 3, 1, 2)

                point_scores = point_sample(root_scores, 
                                            object_points_norm.reshape(B, -1, 2)).reshape(B, -1, num_query, 5 + 5)
                # 取正例分数
                object_scores = point_scores[..., -1] #.mean(-1)
    
    
                obj_cls, obj_label = object_scores.max(1)
                obj_points = object_points[:, :, :6]
                point_scores_ = point_scores.permute(0, 2, 3, 1)
                
                remained_obj_label = obj_label[obj_label != 20]
                remained_obj_cls = obj_cls[obj_label != 20]
                remained_point_score = point_scores_[obj_label != 20][:, :6]
                remained_obj_points = obj_points[obj_label != 20]
                
                if len(remained_obj_label) == 0:
                    
                    cls_segms = [[] for _ in range(self.roi_head.bbox_head.num_classes)]  # BG is not included in num_classes
                    boxes = [[np.zeros((0, 5), dtype=np.float32) for i in range(self.roi_head.bbox_head.num_classes)]]
                    
#                     return list(zip(boxes, [cls_segms]))
                    return boxes
#                     return [boxes, [cls_segms]]
                
                sam_points = remained_obj_points.clone()
                sam_labels = (remained_point_score.max(-1)[1] == remained_obj_label.unsqueeze(-1)).long()
                outer = (sam_points[..., 0] > w) | (sam_points[..., 0] < 0) | \
                        (sam_points[..., 1] > h) | (sam_points[..., 1] < 0)
                sam_labels[outer] = 0

#                 all_masks, all_scores, all_logits = [], [], []
                final_masks, final_bboxes = [], []
#                 all_input_points, all_input_labels = [], []
                
                obj_labels = []
                obj_clses = []
                
                for i_gt, (input_label, input_point, obj_label, obj_cls) in \
                    enumerate(zip(sam_labels, sam_points, remained_obj_label, remained_obj_cls)):

                    input_label = input_label.detach().cpu().numpy()
                    input_point = input_point.detach().cpu().numpy()

                    masks, scores, logits = self.predictor.predict(
                            point_coords=input_point,
                            point_labels=input_label,
                            multimask_output=True)
                    
                    final_mask = masks[np.argmax(scores)].astype(np.uint8)
                    final_mask = cv2.resize(final_mask, (ori_w, ori_h), interpolation=cv2.INTER_CUBIC) > 0
                    ys, xs = np.where(final_mask==True)
                    if len(ys) == 0 or len(xs) == 0:
                        continue
                    xmin, xmax = xs.min(), xs.max()
                    ymin, ymax = ys.min(), ys.max()
                    final_masks.append(final_mask)
                    final_bboxes.append(np.array([xmin, ymin, xmax, ymax], dtype=np.float32))
                    obj_labels.append(obj_label.reshape(1))
                    obj_clses.append(obj_cls.reshape(1))
#                 print(obj_labels)
#                     all_input_points.append(input_point)
#                     all_input_labels.append(input_label)
#                     all_masks.append(masks)
#                     all_scores.append(scores)
                
#                 if len(obj_labels) == 0:
                    
#                     cls_segms = [[] for _ in range(self.roi_head.bbox_head.num_classes)]  # BG is not included in num_classes
#                     boxes = [[np.zeros((0, 5), dtype=np.float32) for i in range(self.roi_head.bbox_head.num_classes)]]
                    
#                     return [boxes, [cls_segms]]
                
                # final_masks, final_bboxes, remained_obj_label, remained_obj_cls
                # rescale
                remained_obj_label_ = torch.cat(obj_labels).detach().cpu().numpy()
                remained_obj_cls_ = torch.cat(obj_clses).detach().cpu().numpy()
                final_bboxes_ = np.stack(final_bboxes)
                bboxes = torch.from_numpy(final_bboxes_).cuda()
                scores = torch.from_numpy(remained_obj_cls_).cuda()
                labels = torch.from_numpy(remained_obj_label_).cuda()
                nms_cfg = dict(type='nms', iou_threshold=0.5)
                dets, keep = batched_nms(bboxes, 
                                         scores,
                                         labels, 
                                         nms_cfg)

#                 final_bboxes_ = np.stack(final_bboxes)
#                 final_dets = np.hstack([final_bboxes_, remained_obj_cls_.reshape(-1, 1)])
                
                labels = labels[keep]
                keep = keep.detach().cpu().numpy()

                final_masks_ = np.array(final_masks)[keep]
                
                cls_segms = [[] for _ in range(self.roi_head.bbox_head.num_classes)]  # BG is not included in num_classes
                for i in range(len(keep)):
                    cls_segms[labels[i]].append(final_masks_[i])
                    
                bbox_results = [
                    bbox2result(dets[i], labels[i],
                                self.roi_head.bbox_head.num_classes)
                ]
#                 return list(zip(bbox_results, [cls_segms]))
                return bbox_results
    
    
#                 import pdb
#                 pdb.set_trace()
                
            elif self.test_part_sam:
                pass
            elif not self.test_wo_detector:
                proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
                if self.roi_skip_fpn: # imted
                    return self.roi_head.simple_test(
                        self.get_roi_feat(x, vit_feat), proposal_list, img_metas, rescale=rescale)
                else:
                    return self.roi_head.simple_test(
                        x, proposal_list, img_metas, rescale=rescale)
            else:
                # 去掉多尺度测试
                gt_bboxes = gt_bboxes[0]
                gt_labels = gt_labels[0]

                # print(len(img_metas), img_metas, len(img), len(img[0]))
                batch_input_shape = tuple(img[0].size()[-2:])
                for i, _ in enumerate(img_metas):
                    img_metas[i]['batch_input_shape'] = batch_input_shape
            
                gt_masks_ = []
                for gt_mask in gt_masks[0]:
                    gt_mask = gt_mask.masks
                    n, ori_h, ori_w = gt_mask.shape
                    padding_gt_mask = np.zeros((n, *batch_input_shape), dtype=np.bool)
                    padding_gt_mask[:, :ori_h, :ori_w] = gt_mask
                    padding_gt_mask = torch.as_tensor(padding_gt_mask, device=gt_bboxes[0].device).bool()
                    gt_masks_.append(padding_gt_mask)
                gt_masks = gt_masks_
                
                        
                # center points as gt_points
                gt_points = [torch.cat([
                    bboxes[:, 0::2].mean(-1).unsqueeze(-1), 
                    bboxes[:, 1::2].mean(-1).unsqueeze(-1)
                ], dim=-1) for bboxes in gt_bboxes]
                
                # point settings
                imgs_whwh = []
                for meta in img_metas:
                    h, w, _ = meta['img_shape']
                    imgs_whwh.append(x[0].new_tensor([[w, h]]))
                imgs_whwh = torch.cat(imgs_whwh, dim=0)
                imgs_whwh = imgs_whwh[:, None, :]
                
                # point training / pseudo gt generation
                points_results = self.lss_head.forward_train_point(x,
                                                                vit_feat,
                                                                point_tokens,
                                                                attns,
                                                                img_metas,
                                                                gt_bboxes,
                                                                gt_labels,
                                                                gt_points,
                                                                imgs_whwh=imgs_whwh,
                                                                gt_masks=gt_masks)
                # 第一阶段wsddn
                mil_results = self.lss_head.forward_train_mil(x,
                                                            vit_feat,
                                                            point_tokens,
                                                            None,
                                                            None,
                                                            points_results['multiple_cams'],
                                                            points_results['refined_multiple_masks'],
                                                            points_results['pseudo_proposals'],
                                                            img_metas,
                                                            points_results['gt_bboxes'],
                                                            points_results['gt_labels'],
                                                            points_results['gt_points'],
                                                            imgs_whwh=imgs_whwh) 
                pseudo_gt_bboxes = mil_results['pseudo_gt_bboxes']
                pseudo_gt_labels = mil_results['pseudo_gt_labels']
                
                attnshift_results, losses_hinge = self.keypoint_head.forward_train(x,
                                                                   vit_feat,
                                                                   mil_results['matched_cams'],
                                                                   img_metas,
                                                                   pseudo_gt_bboxes,
                                                                   pseudo_gt_labels,
                                                                   points_results['gt_points'],
                                                                   mil_results['semantic_scores'],
                                                                   vit_feat_be_norm,
                                                                   imgs_whwh=imgs_whwh,
                                                                   gt_masks=points_results['gt_masks'],
                                                                )
                dpm_points = attnshift_results['all_dpm_points']
                gt_points_vis = attnshift_results['all_dpm_visible']
                
                fg_masks = attnshift_results['fg_masks']

                seg_results = []
                # for i_img in range(len(fg_masks)):
                for i_img in range(len(points_results['gt_masks'])):
                    cls_segms = [[] for _ in 
                                 range(self.lss_head.point_head.num_classes)]  # BG is not included in num_classes
                    for attn_mask, label in zip(points_results['gt_masks'][i_img], pseudo_gt_labels[i_img]):
                        cls_segms[label].append(attn_mask.detach().cpu().numpy())
                    seg_results.append(cls_segms)

            
                return seg_results
            
def box2mask(bboxes, img_size, default_val=0.0):
    N = bboxes.shape[0]
    mask = torch.zeros(N, img_size[0], img_size[1], device=bboxes.device, dtype=bboxes.dtype) + default_val
    for n in range(N):
        box = bboxes[n]
        mask[n, int(box[1]):int(box[3]+1), int(box[0]):int(box[2]+1)] = 1.0
    return mask
