import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import DETECTORS
from mmdet3d.core import bbox3d2result
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
from mmcv.runner import force_fp32, auto_fp16
from mmdet3d.ops import Voxelization, DynamicScatter
from mmdet3d.models import builder
from mmdet3d.models.builder import build_neck
from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core.bbox.transforms import bbox_xyxy_to_cxcywh, bbox_cxcywh_to_xyxy
from mmdet.core import multi_apply


def denormalize_3d_pts(pts, pc_range):
    new_pts = pts.clone()
    new_pts[..., 0:1] = (pts[..., 0:1]*(pc_range[3] -
                                        pc_range[0]) + pc_range[0])
    new_pts[..., 1:2] = (pts[..., 1:2]*(pc_range[4] -
                                        pc_range[1]) + pc_range[1])
    new_pts[..., 2:3] = (pts[..., 2:3]*(pc_range[5] -
                                        pc_range[2]) + pc_range[2])
    return new_pts


def normalize_3d_pts(pts, pc_range):
    patch_h = pc_range[4]-pc_range[1]
    patch_w = pc_range[3]-pc_range[0]
    patch_z = pc_range[5]-pc_range[2]
    new_pts = pts.clone()
    new_pts[..., 0:1] = pts[..., 0:1] - pc_range[0]
    new_pts[..., 1:2] = pts[..., 1:2] - pc_range[1]
    new_pts[..., 2:3] = pts[..., 2:3] - pc_range[2]
    factor = pts.new_tensor([patch_w, patch_h, patch_z])
    normalized_pts = new_pts / factor
    return normalized_pts


def normalize_2d_bbox(bboxes, pc_range):

    patch_h = pc_range[4]-pc_range[1]
    patch_w = pc_range[3]-pc_range[0]
    cxcywh_bboxes = bbox_xyxy_to_cxcywh(bboxes)
    cxcywh_bboxes[..., 0:1] = cxcywh_bboxes[..., 0:1] - pc_range[0]
    cxcywh_bboxes[..., 1:2] = cxcywh_bboxes[..., 1:2] - pc_range[1]
    factor = bboxes.new_tensor([patch_w, patch_h, patch_w, patch_h])

    normalized_bboxes = cxcywh_bboxes / factor
    return normalized_bboxes


def normalize_2d_pts(pts, pc_range):
    patch_h = pc_range[4]-pc_range[1]
    patch_w = pc_range[3]-pc_range[0]
    new_pts = pts.clone()
    new_pts[..., 0:1] = pts[..., 0:1] - pc_range[0]
    new_pts[..., 1:2] = pts[..., 1:2] - pc_range[1]
    factor = pts.new_tensor([patch_w, patch_h])
    normalized_pts = new_pts / factor
    return normalized_pts


def denormalize_2d_bbox(bboxes, pc_range):

    bboxes = bbox_cxcywh_to_xyxy(bboxes)
    bboxes[..., 0::2] = (bboxes[..., 0::2]*(pc_range[3] -
                                            pc_range[0]) + pc_range[0])
    bboxes[..., 1::2] = (bboxes[..., 1::2]*(pc_range[4] -
                                            pc_range[1]) + pc_range[1])

    return bboxes


def denormalize_2d_pts(pts, pc_range):
    new_pts = pts.clone()
    new_pts[..., 0:1] = (pts[..., 0:1]*(pc_range[3] -
                                        pc_range[0]) + pc_range[0])
    new_pts[..., 1:2] = (pts[..., 1:2]*(pc_range[4] -
                                        pc_range[1]) + pc_range[1])
    return new_pts


@DETECTORS.register_module()
class M3TR(MVXTwoStageDetector):
    """M3TR.
    Args:
        video_test_mode (bool): Decide whether to use temporal information during inference.
    """

    def __init__(self,
                 use_grid_mask=False,
                 pts_voxel_layer=None,
                 pts_voxel_encoder=None,
                 pts_middle_encoder=None,
                 pts_fusion_layer=None,
                 img_backbone=None,
                 pts_backbone=None,
                 img_neck=None,
                 pts_neck=None,
                 pts_bbox_head=None,
                 img_roi_head=None,
                 img_rpn_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 video_test_mode=False,
                 modality='vision',
                 lidar_encoder=None,
                 osm_map_encoder=None
                 ):

        super(M3TR,
              self).__init__(pts_voxel_layer, pts_voxel_encoder,
                             pts_middle_encoder, pts_fusion_layer,
                             img_backbone, pts_backbone, img_neck, pts_neck,
                             pts_bbox_head, img_roi_head, img_rpn_head,
                             train_cfg, test_cfg, pretrained)
        self.grid_mask = GridMask(
            True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
        self.use_grid_mask = use_grid_mask
        self.fp16_enabled = False

        if osm_map_encoder is not None:
            self.osm_map_encoder = build_neck(osm_map_encoder)
        else:
            self.osm_map_encoder = None

        # temporal
        self.video_test_mode = video_test_mode
        self.prev_frame_info = {
            'prev_bev': None,
            'scene_token': None,
            'prev_pos': 0,
            'prev_angle': 0,
        }
        self.modality = modality
        if self.modality == 'fusion' and lidar_encoder is not None:
            if lidar_encoder["voxelize"].get("max_num_points", -1) > 0:
                voxelize_module = Voxelization(**lidar_encoder["voxelize"])
            else:
                voxelize_module = DynamicScatter(**lidar_encoder["voxelize"])
            self.lidar_modal_extractor = nn.ModuleDict(
                {
                    "voxelize": voxelize_module,
                    "backbone": builder.build_middle_encoder(lidar_encoder["backbone"]),
                }
            )
            self.voxelize_reduce = lidar_encoder.get("voxelize_reduce", True)

        from prettytable import PrettyTable

        def count_parameters(model):
            table = PrettyTable(["Modules", "Parameters"])
            total_params = 0
            for name, parameter in model.named_parameters():
                if not parameter.requires_grad:
                    continue
                # if "osm_map_encoder.nlp_model" in name:
                #     continue
                params = parameter.numel()
                table.add_row([name, params])
                total_params += params
            print(table)
            print(f"Total Trainable Params: {total_params}")
            return total_params

        count_parameters(self)

    def extract_img_feat(self, img, img_metas, len_queue=None):
        """Extract features of images."""
        B = img.size(0)
        if img is not None:

            # input_shape = img.shape[-2:]
            # # update real input shape of each single img
            # for img_meta in img_metas:
            #     img_meta.update(input_shape=input_shape)

            if img.dim() == 5 and img.size(0) == 1:
                img.squeeze_()
            elif img.dim() == 5 and img.size(0) > 1:
                B, N, C, H, W = img.size()
                img = img.reshape(B * N, C, H, W)
            if self.use_grid_mask:
                img = self.grid_mask(img)

            img_feats = self.img_backbone(img)
            if isinstance(img_feats, dict):
                img_feats = list(img_feats.values())
        else:
            return None
        if self.with_img_neck:
            img_feats = self.img_neck(img_feats)

        img_feats_reshaped = []
        for img_feat in img_feats:
            BN, C, H, W = img_feat.size()
            if len_queue is not None:
                img_feats_reshaped.append(img_feat.view(
                    int(B/len_queue), len_queue, int(BN / B), C, H, W))
            else:
                img_feats_reshaped.append(
                    img_feat.view(B, int(BN / B), C, H, W))
        return img_feats_reshaped

    @auto_fp16(apply_to=('img'), out_fp32=True)
    def extract_feat(self, img, img_metas=None, len_queue=None):
        """Extract features from images and points."""

        img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)

        return img_feats

    def extract_map_feat(self, osm_map_data):
        sd_map_feats, sd_map_bev_feats = self.osm_map_encoder(osm_map_data)
        return sd_map_feats, sd_map_bev_feats

    def forward_pts_train(self,
                          pts_feats,
                          lidar_feat,
                          gt_bboxes_3d,
                          gt_labels_3d,
                          img_metas,
                          gt_bboxes_ignore=None,
                          prev_bev=None,
                          gt_depth=None,
                          gt_seg_mask=None,
                          gt_pv_seg_mask=None,
                          gt_masked=None,
                          osm_map_feats=None,
                          osm_map_bev_feats=None):
        """Forward function'
        Args:
            pts_feats (list[torch.Tensor]): Features of point cloud branch
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes for each sample.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels for
                boxes of each sampole
            img_metas (list[dict]): Meta information of samples.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                boxes to be ignored. Defaults to None.
            prev_bev (torch.Tensor, optional): BEV features of previous frame.
        Returns:
            dict: Losses of each branch.
        """

        outs = self.pts_bbox_head(
            pts_feats, lidar_feat, img_metas, prev_bev, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d, gt_masked=gt_masked, osm_map_feats=osm_map_feats, osm_map_bev_feats=osm_map_bev_feats)

        depth = outs.pop('depth')
        losses = dict()
        # calculate depth loss
        if gt_depth is not None:
            loss_depth = self.pts_bbox_head.transformer.encoder.get_depth_loss(
                gt_depth, depth)
            if digit_version(TORCH_VERSION) >= digit_version('1.8'):
                loss_depth = torch.nan_to_num(loss_depth)
            losses.update(loss_depth=loss_depth)

        # DEBUGGING STUFF FOR ASSIGNER TO TURN PRED TO GT

        # pc_range = [-30.0, -15.0, -5.0, 30.0, 15.0, 3.0]
        # device = gt_labels_3d[0].device
        # gt_bboxes_list = [
        #     normalize_2d_bbox(gt_bboxes.bbox.to(device), pc_range) for gt_bboxes in gt_bboxes_3d]
        # gt_pts_list = [normalize_2d_pts(gt_bboxes.fixed_num_sampled_points.to(device)[..., 0:2], pc_range) for gt_bboxes in gt_bboxes_3d]
        # gt_masked = [torch.zeros_like(gt_m, dtype=bool) for gt_m in gt_masked]
        # for _b, gt_pts in enumerate(gt_pts_list):
        #     gt_bboxes = gt_bboxes_list[_b]
        #     for _l in range(0, outs['all_pts_preds'].shape[0]):
        #         outs['all_pts_preds'][_l, _b, 0:gt_pts.shape[0]] = gt_pts
        #         outs['all_bbox_preds'][_l, _b, 0:gt_bboxes.shape[0]] = gt_bboxes

        loss_inputs = [gt_bboxes_3d, gt_labels_3d,
                       gt_seg_mask, gt_pv_seg_mask, outs, gt_masked]
        losses_pts = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
        losses.update(losses_pts)
        # import ipdb;ipdb.set_trace()
        k_one2many = self.pts_bbox_head.k_one2many
        multi_gt_bboxes_3d = copy.deepcopy(gt_bboxes_3d)
        multi_gt_labels_3d = copy.deepcopy(gt_labels_3d)
        multi_gt_masked = copy.deepcopy(gt_masked)
        for i, (each_gt_bboxes_3d, each_gt_labels_3d) in enumerate(zip(multi_gt_bboxes_3d, multi_gt_labels_3d)):
            each_gt_bboxes_3d.instance_list = each_gt_bboxes_3d.instance_list * k_one2many
            each_gt_bboxes_3d.instance_labels = each_gt_bboxes_3d.instance_labels * k_one2many
            each_gt_bboxes_3d.instance_label_names = each_gt_bboxes_3d.instance_label_names * k_one2many
            multi_gt_labels_3d[i] = each_gt_labels_3d.repeat(k_one2many)

            if self.pts_bbox_head.mask_assign_gt_one2many:
                multi_gt_masked[i] = multi_gt_masked[i].repeat(k_one2many)
            else:
                one2many_mask = torch.ones_like(multi_gt_masked[i], dtype=bool)
                cat_list = [multi_gt_masked[i]] + \
                    [one2many_mask for i in range(0, k_one2many-1)]
                multi_gt_masked[i] = torch.cat(cat_list)
        # import ipdb;ipdb.set_trace()
        one2many_outs = outs['one2many_outs']
        loss_one2many_inputs = [multi_gt_bboxes_3d, multi_gt_labels_3d,
                                gt_seg_mask, gt_pv_seg_mask, one2many_outs, multi_gt_masked]
        loss_dict_one2many = self.pts_bbox_head.loss(
            *loss_one2many_inputs, img_metas=img_metas)

        lambda_one2many = self.pts_bbox_head.lambda_one2many
        for key, value in loss_dict_one2many.items():
            if key + "_one2many" in losses.keys():
                losses[key + "_one2many"] += value * lambda_one2many
            else:
                losses[key + "_one2many"] = value * lambda_one2many
        # import ipdb;ipdb.set_trace()
        return losses

    def forward_dummy(self, img):
        dummy_metas = None
        return self.forward_test(img=img, img_metas=[[dummy_metas]])

    def forward(self, return_loss=True, **kwargs):
        """Calls either forward_train or forward_test depending on whether
        return_loss=True.
        Note this setting will change the expected inputs. When
        `return_loss=True`, img and img_metas are single-nested (i.e.
        torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
        img_metas should be double nested (i.e.  list[torch.Tensor],
        list[list[dict]]), with the outer list indicating test time
        augmentations.
        """
        osm_map_data = dict(
            osm_map_nodes_tags_input_ids = kwargs['osm_map_nodes_tags_input_ids'],
            osm_map_ways_tags_input_ids = kwargs['osm_map_ways_tags_input_ids'],
            osm_map_relations_tags_input_ids = kwargs['osm_map_relations_tags_input_ids'],
            osm_map_relations_node_member_tags_input_ids = kwargs['osm_map_relations_node_member_tags_input_ids'],
            osm_map_relations_way_member_tags_input_ids = kwargs['osm_map_relations_way_member_tags_input_ids'],
            osm_map_relations_relation_member_tags_input_ids = kwargs['osm_map_relations_relation_member_tags_input_ids'],
            osm_map_nodes_tags_token_type_ids = kwargs['osm_map_nodes_tags_token_type_ids'],
            osm_map_ways_tags_token_type_ids = kwargs['osm_map_ways_tags_token_type_ids'],
            osm_map_relations_tags_token_type_ids = kwargs['osm_map_relations_tags_token_type_ids'],
            osm_map_relations_node_member_tags_token_type_ids = kwargs['osm_map_relations_node_member_tags_token_type_ids'],
            osm_map_relations_way_member_tags_token_type_ids = kwargs['osm_map_relations_way_member_tags_token_type_ids'],
            osm_map_relations_relation_member_tags_token_type_ids = kwargs['osm_map_relations_relation_member_tags_token_type_ids'],
            osm_map_nodes_tags_attention_mask = kwargs['osm_map_nodes_tags_attention_mask'],
            osm_map_ways_tags_attention_mask = kwargs['osm_map_ways_tags_attention_mask'],
            osm_map_relations_tags_attention_mask = kwargs['osm_map_relations_tags_attention_mask'],
            osm_map_relations_node_member_tags_attention_mask = kwargs['osm_map_relations_node_member_tags_attention_mask'],
            osm_map_relations_way_member_tags_attention_mask = kwargs['osm_map_relations_way_member_tags_attention_mask'],
            osm_map_relations_relation_member_tags_attention_mask = kwargs['osm_map_relations_relation_member_tags_attention_mask'],
            osm_map_relations_node_member_indices = kwargs['osm_map_relations_node_member_indices'],
            osm_map_relations_way_member_indices = kwargs['osm_map_relations_way_member_indices'],
            osm_map_relations_relation_member_indices = kwargs['osm_map_relations_relation_member_indices'],
            osm_map_nodes_pts=kwargs['osm_map_nodes_pts'],
            osm_map_ways_pts=kwargs['osm_map_ways_pts'],
            osm_map_ways_smerf_classes=kwargs['osm_map_ways_smerf_classes']
        )
        if return_loss:
            return self.forward_train(**kwargs, osm_map_data=osm_map_data)
        else:
            return self.forward_test(**kwargs, osm_map_data=osm_map_data)

    def obtain_history_bev(self, imgs_queue, img_metas_list, lidar_feat):
        """Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
        """
        self.eval()

        with torch.no_grad():
            prev_bev = None
            bs, len_queue, num_cams, C, H, W = imgs_queue.shape
            imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
            img_feats_list = self.extract_feat(
                img=imgs_queue, len_queue=len_queue)
            for i in range(len_queue):
                img_metas = [each[i] for each in img_metas_list]
                # img_feats = self.extract_feat(img=img, img_metas=img_metas)
                img_feats = [each_scale[:, i] for each_scale in img_feats_list]
                prev_bev = self.pts_bbox_head(
                    img_feats, lidar_feat, img_metas, prev_bev, only_bev=True)
            self.train()
            return prev_bev

    @torch.no_grad()
    @force_fp32()
    def voxelize(self, points):
        feats, coords, sizes = [], [], []
        for k, res in enumerate(points):
            ret = self.lidar_modal_extractor["voxelize"](res)
            if len(ret) == 3:
                # hard voxelize
                f, c, n = ret
            else:
                assert len(ret) == 2
                f, c = ret
                n = None
            feats.append(f)
            coords.append(F.pad(c, (1, 0), mode="constant", value=k))
            if n is not None:
                sizes.append(n)

        feats = torch.cat(feats, dim=0)
        coords = torch.cat(coords, dim=0)
        if len(sizes) > 0:
            sizes = torch.cat(sizes, dim=0)
            if self.voxelize_reduce:
                feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
                    -1, 1
                )
                feats = feats.contiguous()

        return feats, coords, sizes

    @auto_fp16(apply_to=('points'), out_fp32=True)
    def extract_lidar_feat(self, points):
        feats, coords, sizes = self.voxelize(points)
        # voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords)
        batch_size = coords[-1, 0] + 1
        lidar_feat = self.lidar_modal_extractor["backbone"](
            feats, coords, batch_size, sizes=sizes)

        return lidar_feat

    # @auto_fp16(apply_to=('img', 'points'))
    @force_fp32(apply_to=('img', 'points', 'prev_bev'))
    def forward_train(self,
                      points=None,
                      img_metas=None,
                      gt_bboxes_3d=None,
                      gt_labels_3d=None,
                      gt_labels=None,
                      gt_bboxes=None,
                      img=None,
                      proposals=None,
                      gt_bboxes_ignore=None,
                      img_depth=None,
                      img_mask=None,
                      gt_depth=None,
                      gt_seg_mask=None,
                      gt_pv_seg_mask=None,
                      gt_map_data_idx=None,
                      gt_masked=None,
                      osm_map_data=None,
                      **kwargs
                      ):
        """Forward training function.
        Args:
            points (list[torch.Tensor], optional): Points of each sample.
                Defaults to None.
            img_metas (list[dict], optional): Meta information of each sample.
                Defaults to None.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
                Ground truth 3D boxes. Defaults to None.
            gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
                of 3D boxes. Defaults to None.
            gt_labels (list[torch.Tensor], optional): Ground truth labels
                of 2D boxes in images. Defaults to None.
            gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
                images. Defaults to None.
            img (torch.Tensor optional): Images of each sample with shape
                (N, C, H, W). Defaults to None.
            proposals ([list[torch.Tensor], optional): Predicted proposals
                used for training Fast RCNN. Defaults to None.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                2D boxes in images to be ignored. Defaults to None.
        Returns:
            dict: Losses of different branches.
        """
        lidar_feat = None
        if self.modality == 'fusion':
            lidar_feat = self.extract_lidar_feat(points)

        len_queue = img.size(1)
        prev_img = img[:, :-1, ...]
        img = img[:, -1, ...]

        prev_img_metas = copy.deepcopy(img_metas)
        # prev_bev = self.obtain_history_bev(prev_img, prev_img_metas, lidar_feat)
        # # import pdb;pdb.set_trace()
        # if len_queue>1:
        #     print("QUEUE LARGER THAN 1")
        #     print(prev_img.shape)
        #     print(len(prev_img_metas))
        #     print(len(prev_img_metas[0]))
        #     print(len(prev_img_metas[1]))
        #     print(prev_bev.shape)
        # else:
        #     print("QUEUE SMALLER THAN 1")

        prev_bev = self.obtain_history_bev(
            prev_img, prev_img_metas, lidar_feat) if len_queue > 1 else None

        img_metas = [each[len_queue-1] for each in img_metas]
        if not img_metas[0]['prev_bev_exists']:
            prev_bev = None
        img_feats = self.extract_feat(img=img, img_metas=img_metas)

        osm_map_feats = None
        osm_map_bev_feats = None
        if self.osm_map_encoder is not None:
            osm_map_feats, osm_map_bev_feats = self.extract_map_feat(osm_map_data=osm_map_data)

        losses = dict()
        losses_pts = self.forward_pts_train(img_feats, lidar_feat, gt_bboxes_3d,
                                            gt_labels_3d, img_metas,
                                            gt_bboxes_ignore, prev_bev, gt_depth, gt_seg_mask, gt_pv_seg_mask, gt_masked, osm_map_feats, osm_map_bev_feats)

        losses.update(losses_pts)
        return losses

    def forward_test(self, img_metas, img=None, points=None, osm_map_data=None, **kwargs):
        for var, name in [(img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))
        img = [img] if img is None else img
        points = [points] if points is None else points
        if img_metas[0][0]['scene_token'] != self.prev_frame_info['scene_token']:
            # the first sample of each scene is truncated
            self.prev_frame_info['prev_bev'] = None
        # update idx
        self.prev_frame_info['scene_token'] = img_metas[0][0]['scene_token']

        # do not use temporal information
        if not self.video_test_mode:
            self.prev_frame_info['prev_bev'] = None

        # Get the delta of ego position and angle between two timestamps.
        tmp_pos = copy.deepcopy(img_metas[0][0]['can_bus'][:3])
        tmp_angle = copy.deepcopy(img_metas[0][0]['can_bus'][-1])
        if self.prev_frame_info['prev_bev'] is not None:
            img_metas[0][0]['can_bus'][:3] -= self.prev_frame_info['prev_pos']
            img_metas[0][0]['can_bus'][-1] -= self.prev_frame_info['prev_angle']
        else:
            img_metas[0][0]['can_bus'][-1] = 0
            img_metas[0][0]['can_bus'][:3] = 0

        # hack to adapt to other input
        kwargs['gt_bboxes_3d'] = [kwargs['gt_bboxes_3d'][0]]
        kwargs['gt_labels_3d'] = [kwargs['gt_labels_3d'][0]]
        kwargs['gt_masked'] = [kwargs['gt_masked'][0]]

        new_prev_bev, bbox_results = self.simple_test(
            img_metas[0], img[0], points[0], prev_bev=self.prev_frame_info['prev_bev'], osm_map_data=osm_map_data, **kwargs)
        # During inference, we save the BEV features and ego motion of each timestamp.
        self.prev_frame_info['prev_pos'] = tmp_pos
        self.prev_frame_info['prev_angle'] = tmp_angle
        self.prev_frame_info['prev_bev'] = new_prev_bev
        return bbox_results

    def pred2result(self, bboxes, scores, labels, pts, attrs=None):
        """Convert detection results to a list of numpy arrays.

        Args:
            bboxes (torch.Tensor): Bounding boxes with shape of (n, 5).
            labels (torch.Tensor): Labels with shape of (n, ).
            scores (torch.Tensor): Scores with shape of (n, ).
            attrs (torch.Tensor, optional): Attributes with shape of (n, ). \
                Defaults to None.

        Returns:
            dict[str, torch.Tensor]: Bounding box results in cpu mode.

                - boxes_3d (torch.Tensor): 3D boxes.
                - scores (torch.Tensor): Prediction scores.
                - labels_3d (torch.Tensor): Box labels.
                - attrs_3d (torch.Tensor, optional): Box attributes.
        """
        result_dict = dict(
            boxes_3d=bboxes.to('cpu'),
            scores_3d=scores.cpu(),
            labels_3d=labels.cpu(),
            pts_3d=pts.to('cpu'))

        if attrs is not None:
            result_dict['attrs_3d'] = attrs.cpu()

        return result_dict

    def test_assign_gt(self, outs, gt_bboxes_3d=None, gt_labels_3d=None, gt_masked=None):
        if len(gt_bboxes_3d[0].instance_list) > 0:
            gt_vecs = copy.deepcopy(gt_bboxes_3d)
            # import pdb;pdb.set_trace()
            all_cls_scores = outs['all_cls_scores']
            all_bbox_preds = outs['all_bbox_preds']
            all_pts_preds = outs['all_pts_preds']
            device = gt_labels_3d[0].device
            gt_bboxes = gt_vecs[0].bbox.to(device)
            gt_pts = gt_vecs[0].fixed_num_sampled_points.to(device)
            gt_labels = gt_labels_3d[0]
            gt_shifts_pts = gt_vecs[0].shift_fixed_num_sampled_points_v2.to(
                device)
            gt_mask = gt_masked[0]

            assign_result, order_indices = self.pts_bbox_head.assigner.assign(all_bbox_preds[-1][-1], all_cls_scores[-1][-1], all_pts_preds[-1][-1],
                                                                              gt_bboxes, gt_labels, gt_shifts_pts, gt_mask)
            # fix wierd mmdet3d sampler +1 indices
            assign_result.gt_inds = assign_result.gt_inds - 1

            denormalized_pts_preds = denormalize_2d_pts(all_pts_preds[-1][-1], self.pts_bbox_head.pc_range) if not self.pts_bbox_head.z_cfg['gt_z_flag'] \
                else denormalize_3d_pts(all_pts_preds[-1][-1], self.pts_bbox_head.pc_range)

            denormalized_bbox_preds = denormalize_2d_bbox(
                all_bbox_preds[-1][-1], self.pts_bbox_head.pc_range)

            matched_pred_inds = torch.argwhere(
                assign_result.gt_inds != -1).squeeze()
            assigned_pred_inds = torch.ones(
                gt_pts.shape[0], dtype=torch.long, device=device) * -1
            assigned_pred_inds[assign_result.gt_inds[matched_pred_inds]
                               ] = matched_pred_inds

            result_dict = {'assigned_gt_inds': assign_result.gt_inds,
                           'assigned_pred_inds': assigned_pred_inds,
                           'bbox_preds': denormalized_bbox_preds,
                           'pts_preds': denormalized_pts_preds,
                           'cls_scores': all_cls_scores[-1, -1],
                           'gt_labels': gt_labels,
                           'gt_pts': gt_pts,
                           'gt_mask': gt_mask}

        else:
            all_cls_scores = outs['all_cls_scores']
            all_bbox_preds = outs['all_bbox_preds']
            all_pts_preds = outs['all_pts_preds']

            denormalized_pts_preds = denormalize_2d_pts(all_pts_preds[-1][-1], self.pts_bbox_head.pc_range) if not self.pts_bbox_head.z_cfg['gt_z_flag'] \
                else denormalize_3d_pts(all_pts_preds[-1][-1], self.pts_bbox_head.pc_range)

            denormalized_bbox_preds = denormalize_2d_bbox(
                all_bbox_preds[-1][-1], self.pts_bbox_head.pc_range)

            result_dict = {'assigned_gt_inds': torch.tensor([]),
                           'assigned_pred_inds': torch.tensor([]),
                           'bbox_preds': denormalized_bbox_preds,
                           'pts_preds': denormalized_pts_preds,
                           'cls_scores': all_cls_scores[-1, -1],
                           'gt_labels': torch.tensor([]),
                           'gt_pts': torch.tensor([]),
                           'gt_mask': torch.tensor([])}

        return result_dict

    def simple_test_pts(self, x, lidar_feat, img_metas, prev_bev=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None, gt_masked=None, osm_map_feats=None, osm_map_bev_feats=None):
        """Test function"""
        outs = self.pts_bbox_head(x, lidar_feat, img_metas, prev_bev=prev_bev,
                                  gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d, gt_masked=gt_masked, osm_map_feats=osm_map_feats, osm_map_bev_feats=osm_map_bev_feats)

        assigned_preds = self.test_assign_gt(
            outs, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d, gt_masked=gt_masked)

        bbox_list = self.pts_bbox_head.get_bboxes(
            outs, img_metas, rescale=rescale)

        bbox_results = [
            self.pred2result(bboxes, scores, labels, pts)
            for bboxes, scores, labels, pts in bbox_list
        ]
        # import pdb;pdb.set_trace()
        return outs['bev_embed'], bbox_results, assigned_preds

    def simple_test(self, img_metas, img=None, points=None, prev_bev=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None, gt_masked=None, osm_map_data=None, **kwargs):
        """Test function without augmentaiton."""
        lidar_feat = None
        if self.modality == 'fusion':
            lidar_feat = self.extract_lidar_feat(points)
        img_feats = self.extract_feat(img=img, img_metas=img_metas)

        # import pdb;pdb.set_trace()
        # print([el.shape for el in img_feats])

        osm_map_feats = None
        osm_map_bev_feats = None
        if self.osm_map_encoder is not None:
            osm_map_feats, osm_map_bev_feats = self.extract_map_feat(osm_map_data=osm_map_data)

        bbox_list = [dict() for i in range(len(img_metas))]
        new_prev_bev, bbox_pts, assigned_preds = self.simple_test_pts(
            img_feats, lidar_feat, img_metas, prev_bev, rescale=rescale, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d, gt_masked=gt_masked, osm_map_feats=osm_map_feats, osm_map_bev_feats=osm_map_bev_feats)
        for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
            result_dict['pts_bbox'] = pts_bbox

        for key, val in assigned_preds.items():
            bbox_list[0]['pts_bbox'][key] = val.cpu()

        return new_prev_bev, bbox_list
