import copy
import numpy as np
import os
from os import path as osp
import torch
import random
import json, pickle
import tempfile
import cv2
from pyquaternion import Quaternion
import pyquaternion
from projects.mmdet3d_plugin.datasets import DATASETS
from projects.mmdet3d_plugin.utils import save_tensor
from mmcv.parallel import DataContainer as DC
from projects.mmdet3d_plugin.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes
from projects.mmdet3d_plugin.fileio.io import load, dump
from projects.mmdet3d_plugin.utils import track_iter_progress, mkdir_or_exist
from projects.mmdet3d_plugin.datasets.pipelines import to_tensor
from .custom_3d import Custom3DDataset
from .pipelines import Compose
from .nuscenes_styled_eval_utils import DetectionMetrics, EvalBoxes, DetectionBox,center_distance,accumulate,DetectionMetricDataList,calc_ap, calc_tp, quaternion_yaw
from prettytable import PrettyTable
from .pipelines import VectorizeMap
import laspy
import mmcv
from nuscenes.utils.data_classes import Box as NuScenesBox
from nuscenes.eval.detection.config import config_factory as det_configs
from nuscenes.eval.common.config import config_factory as track_configs
from .utils import (
    draw_lidar_bbox3d_on_img,
    draw_lidar_bbox3d_on_bev,
)
from .pipelines import ResizeCropFlipImage



@DATASETS.register_module()
class B2D_E2E_Dataset(Custom3DDataset):
    def __init__(self, queue_length=4, bev_size=(200, 200),overlap_test=False,with_velocity=True,sample_interval=5,name_mapping= None,
            eval_cfg = None, map_root =None,map_file=None,past_frames=4, future_frames=4,predict_frames=12,planning_frames=6,
            patch_size = [102.4, 102.4],point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] ,occ_receptive_field=3,
            occ_n_future=6,occ_filter_invalid_sample=False,occ_filter_by_valid_flag=False,eval_mod=None,keep_consistent_seq_aug=True,
            with_seq_flag=True,#add
            sequences_split_num=1,#add
            data_aug_conf=None, # add
            *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.queue_length = queue_length
        self.bev_size = (200, 200)
        self.overlap_test = overlap_test
        self.with_velocity = with_velocity
        self.NameMapping  = name_mapping
        self.eval_cfg  = eval_cfg
        self.sample_interval = sample_interval
        self.past_frames = past_frames
        self.future_frames = future_frames
        self.predict_frames = predict_frames
        self.planning_frames = planning_frames
        self.map_root = map_root
        self.map_file = map_file
        self.point_cloud_range = np.array(point_cloud_range)
        self.patch_size = patch_size
        self.occ_receptive_field = occ_receptive_field  # past + current
        self.occ_n_future = occ_n_future  # future only
        self.occ_filter_invalid_sample = occ_filter_invalid_sample
        self.occ_filter_by_valid_flag = occ_filter_by_valid_flag
        self.occ_only_total_frames = 7  # NOTE: hardcode, not influenced by planning   
        self.eval_mod = eval_mod     
        self.map_element_class = {'Broken':0, 'Solid':1, 'SolidSolid':2,'Center':3,'TrafficLight':4,'StopSign':5}
        with open(self.map_file,'rb') as f: 
            self.map_infos = pickle.load(f)
        
        self.keep_consistent_seq_aug = keep_consistent_seq_aug
        # add 
        # self.test_mode=False,
        # import pdb; pdb.set_trace()
        self.with_seq_flag=with_seq_flag,
        self.sequences_split_num=sequences_split_num,#2,
        self.sequences_split_num = self.sequences_split_num[0]
        
        # keep_consistent_seq_aug=True,
        if with_seq_flag:
            self._set_sequence_group_flag()
        self.work_dir = "/mnt/private-user-data/ed/Sparsedrivev12/work_dirs"
        self.det3d_eval_version = "detection_cvpr_2019" #,det3d_eval_version
        self.det3d_eval_configs = det_configs(self.det3d_eval_version)
        self.det3d_eval_configs.class_names = list(self.det3d_eval_configs.class_range.keys())
        self.track3d_eval_version = "tracking_nips_2019" #,track3d_eval_version
        self.track3d_eval_configs = track_configs(self.track3d_eval_version)
        self.track3d_eval_configs.class_names = list(self.track3d_eval_configs.class_range.keys())
        if self.modality is None:
            self.modality = dict(
                use_camera=False,
                use_lidar=True,
                use_radar=False,
                use_map=False,
                use_external=False,
            )
        self.vis_score_threshold = 0.25 #vis_score_threshold

        self.data_aug_conf = data_aug_conf
        # self.sequences_split_num = sequences_split_num
        # self.keep_consistent_seq_aug = keep_consistent_seq_aug


    def _set_sequence_group_flag(self):
        """
        Set each sequence to be a different group
        """
        
        if self.sequences_split_num == -1:
            self.flag = np.arange(len(self.data_infos))
            return
        
        res = []

        curr_sequence = 0
        # import pdb; pdb.set_trace()
        for idx in range(len(self.data_infos)):
            # if idx != 0 and len(self.data_infos[idx]["sweeps"]) == 0:
            
            if idx != 0 and self.data_infos[idx]["folder"] != self.data_infos[idx-1]["folder"]:
                # Not first frame and # of sweeps is 0 -> new sequence
                curr_sequence += 1
            res.append(curr_sequence)
        # import pdb; pdb.set_trace()
        self.flag = np.array(res, dtype=np.int64)

        if self.sequences_split_num != 1:
            if self.sequences_split_num == "all":
                self.flag = np.array(
                    range(len(self.data_infos)), dtype=np.int64
                )
            else:
                bin_counts = np.bincount(self.flag)
                new_flags = []
                curr_new_flag = 0
                for curr_flag in range(len(bin_counts)):
                    curr_sequence_length = np.array(
                        list(
                            range(
                                0,
                                bin_counts[curr_flag],
                                math.ceil(
                                    bin_counts[curr_flag]
                                    / self.sequences_split_num
                                ),
                            )
                        )
                        + [bin_counts[curr_flag]]
                    )

                    for sub_seq_idx in (
                        curr_sequence_length[1:] - curr_sequence_length[:-1]
                    ):
                        for _ in range(sub_seq_idx):
                            new_flags.append(curr_new_flag)
                        curr_new_flag += 1

                assert len(new_flags) == len(self.flag)
                assert (
                    len(np.bincount(new_flags))
                    == len(np.bincount(self.flag)) * self.sequences_split_num
                )
                self.flag = np.array(new_flags, dtype=np.int64)

    def invert_pose(self, pose):
        inv_pose = np.eye(4)
        inv_pose[:3, :3] = np.transpose(pose[:3, :3])
        inv_pose[:3, -1] = - inv_pose[:3, :3] @ pose[:3, -1]
        return inv_pose

    def get_augmentation(self):
        if self.data_aug_conf is None:
            return None
        H, W = self.data_aug_conf["H"], self.data_aug_conf["W"]
        fH, fW = self.data_aug_conf["final_dim"]
        if not self.test_mode:
            resize = np.random.uniform(*self.data_aug_conf["resize_lim"])
            resize_dims = (int(W * resize), int(H * resize))
            newW, newH = resize_dims
            crop_h = (
                int(
                    (1 - np.random.uniform(*self.data_aug_conf["bot_pct_lim"]))
                    * newH
                )
                - fH
            )
            crop_w = int(np.random.uniform(0, max(0, newW - fW)))
            crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
            flip = False
            if self.data_aug_conf["rand_flip"] and np.random.choice([0, 1]):
                flip = True
            rotate = np.random.uniform(*self.data_aug_conf["rot_lim"])
            rotate_3d = np.random.uniform(*self.data_aug_conf["rot3d_range"])
        else:
            resize = max(fH / H, fW / W)
            resize_dims = (int(W * resize), int(H * resize))
            newW, newH = resize_dims
            crop_h = (
                int((1 - np.mean(self.data_aug_conf["bot_pct_lim"])) * newH)
                - fH
            )
            crop_w = int(max(0, newW - fW) / 2)
            crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
            flip = False
            rotate = 0
            rotate_3d = 0
        aug_config = {
            "resize": resize,
            "resize_dims": resize_dims,
            "crop": crop,
            "flip": flip,
            "rotate": rotate,
            "rotate_3d": rotate_3d,
        }
        return aug_config
        # return None

    def prepare_train_data(self, index, aug_config):
        """
        Training data preparation.
        Args:
            index (int): Index for accessing the target data.
        Returns:
            dict: Training data dict of the corresponding index.
        """
        # import pdb; pdb.set_trace()
     
        queue = []
        index_list = list()
        # index_list = list(range(index-self.queue_length*self.sample_interval, index,self.sample_interval))
        # random.shuffle(index_list)
        # index_list = sorted(index_list[1:])
        index_list.append(index)
        for i in index_list:
            i = max(0, i)
            input_dict = self.get_data_info(i)
            # import pdb; pdb.set_trace()
            input_dict['projection_mat'] = np.array(input_dict['lidar2img'], dtype=np.float32)
            if input_dict is None:
                print("====input_dict== is none")
                return None
            # # ADD lidar points
            lidar_full_dir = os.path.join("/mnt/private-user-data/ed/Sparsedrivev12/data/bench2drive",input_dict['folder'], 'lidar')
            pt_name = str(input_dict['frame_idx']).zfill(5)+".laz"
            lidar_full_path =  os.path.join(lidar_full_dir,pt_name)
            inFile = laspy.read(lidar_full_path)
            point_cloud = inFile.points
            px = np.array(inFile.x)
            py = np.array(inFile.y)
            pz = np.array(inFile.z)
            p1 = np.ones(pz.shape)
            # pts = np.concatenate((px,py,pz,p1))
            input_dict['lidar_points'] = np.concatenate((px[:,None],py[:,None],pz[:,None],p1[:,None]),axis=1)
            # import pdb; pdb.set_trace()
            input_dict["aug_config"] = aug_config

            self.pre_pipeline(input_dict)
           
            example = self.pipeline(input_dict)

            if self.filter_empty_gt and \
                    (example is None or ~(example['gt_labels_3d']._data != -1).any()):
                print(" data label ",example['gt_labels_3d'])
                return None
            

            queue.append(example)
        # import pdb; pdb.set_trace()
        # return self.union2one(queue) #origin
        return queue[-1]
    
    def prepare_test_data(self, index):
        """
        Training data preparation.
        Args:
            index (int): Index for accessing the target data.
        Returns:
            dict: Training data dict of the corresponding index.
        """
        # import pdb; pdb.set_trace()
        queue = []
        index_list = list(range(index-self.queue_length*self.sample_interval, index,self.sample_interval))
        random.shuffle(index_list)
        index_list = sorted(index_list[1:])
        index_list.append(index)
        for i in index_list:
            i = max(0, i)
            input_dict = self.get_data_info(i)
            # import pdb; pdb.set_trace()
            input_dict['projection_mat'] = np.array(input_dict['lidar2cam'], dtype=np.float32)
            if input_dict is None:
                return None
            # # ADD lidar points
            lidar_full_dir = os.path.join("/mnt/private-user-data/ed/Sparsedrivev12/data/bench2drive",input_dict['folder'], 'lidar')
            pt_name = str(input_dict['frame_idx']).zfill(5)+".laz"
            lidar_full_path =  os.path.join(lidar_full_dir,pt_name)
            inFile = laspy.read(lidar_full_path)
            point_cloud = inFile.points
            px = np.array(inFile.x)
            py = np.array(inFile.y)
            pz = np.array(inFile.z)
            p1 = np.ones(pz.shape)
            # pts = np.concatenate((px,py,pz,p1))
            input_dict['lidar_points'] = np.concatenate((px[:,None],py[:,None],pz[:,None],p1[:,None]),axis=1)
            # import pdb; pdb.set_trace()

            self.pre_pipeline(input_dict)
            example = self.pipeline(input_dict)
            if self.filter_empty_gt and \
                    (example is None or ~(example['gt_labels_3d']._data != -1).any()):
                return None
            

            queue.append(example)
        # import pdb; pdb.set_trace()
        # return self.union2one(queue) #origin
        return queue[-1]
    
    def union2one(self, queue):
        imgs_list = [each['img'].data for each in queue]
        gt_labels_3d_list = [each['gt_labels_3d'].data for each in queue]
        gt_sdc_label_list = [each['gt_sdc_label'].data for each in queue]
        gt_inds_list = [to_tensor(each['gt_inds']) for each in queue]
        gt_bboxes_3d_list = [each['gt_bboxes_3d'].data for each in queue]
        gt_past_traj_list = [to_tensor(each['gt_past_traj']) for each in queue]
        gt_past_traj_mask_list = [ to_tensor(each['gt_past_traj_mask']) for each in queue]
        gt_sdc_bbox_list = [each['gt_sdc_bbox'].data for each in queue]
        l2g_r_mat_list = [to_tensor(each['l2g_r_mat']) for each in queue]
        l2g_t_list = [to_tensor(each['l2g_t']) for each in queue]
        timestamp_list = [to_tensor(each['timestamp']) for each in queue]
        gt_fut_traj = to_tensor(queue[-1]['gt_fut_traj'])
        gt_fut_traj_mask = to_tensor(queue[-1]['gt_fut_traj_mask'])
        if 'gt_future_boxes' in queue[-1]:
            gt_future_boxes_list = queue[-1]['gt_future_boxes']
        else:
            gt_future_boxes_list = None
        if 'gt_future_labels' in queue[-1]:    
            gt_future_labels_list = [to_tensor(each) for each in queue[-1]['gt_future_labels']]
        else:
            gt_future_labels_list = None

        metas_map = {}
        prev_scene_token = None
        prev_pos = None
        prev_angle = None
        for i, each in enumerate(queue):
            metas_map[i] = each['img_metas'].data
            if metas_map[i]['folder'] != prev_scene_token:
                metas_map[i]['prev_bev_exists'] = False
                prev_scene_token = metas_map[i]['folder']
                prev_pos = copy.deepcopy(metas_map[i]['can_bus'][:3])
                prev_angle = copy.deepcopy(metas_map[i]['can_bus'][-1])
                metas_map[i]['can_bus'][:3] = 0
                metas_map[i]['can_bus'][-1] = 0
            else:
                metas_map[i]['prev_bev_exists'] = True
                tmp_pos = copy.deepcopy(metas_map[i]['can_bus'][:3])
                tmp_angle = copy.deepcopy(metas_map[i]['can_bus'][-1])
                metas_map[i]['can_bus'][:3] -= prev_pos
                metas_map[i]['can_bus'][-1] -= prev_angle
                prev_pos = copy.deepcopy(tmp_pos)
                prev_angle = copy.deepcopy(tmp_angle)
        queue[-1]['img'] = DC(torch.stack(imgs_list), cpu_only=False, stack=True)
        queue[-1]['img_metas'] = DC(metas_map, cpu_only=True)
        queue = queue[-1]
        queue['gt_labels_3d'] = DC(gt_labels_3d_list)
        queue['gt_sdc_label'] = DC(gt_sdc_label_list)
        queue['gt_inds'] = DC(gt_inds_list)
        queue['gt_bboxes_3d'] = DC(gt_bboxes_3d_list, cpu_only=True)
        queue['gt_sdc_bbox'] = DC(gt_sdc_bbox_list, cpu_only=True)
        queue['l2g_r_mat'] = DC(l2g_r_mat_list)
        queue['l2g_t'] = DC(l2g_t_list)
        queue['timestamp'] = DC(timestamp_list)
        queue['gt_fut_traj'] = DC(gt_fut_traj)
        queue['gt_fut_traj_mask'] = DC(gt_fut_traj_mask)
        queue['gt_past_traj'] = DC(gt_past_traj_list)
        queue['gt_past_traj_mask'] = DC(gt_past_traj_mask_list)
        if gt_future_boxes_list is not None:
            queue['gt_future_boxes'] = DC(gt_future_boxes_list, cpu_only=True)
        if gt_future_labels_list is not None:
            queue['gt_future_labels'] = DC(gt_future_labels_list)

        return queue

    def get_data_info(self, index):
        """Get data info according to the given index.

        Args:
            index (int): Index of the sample data to get.

        Returns:
            dict: Data information that will be passed to the data \
                preprocessing pipelines. It includes the following keys:

                - sample_idx (str): Sample index.
                - pts_filename (str): Filename of point clouds.
                - sweeps (list[dict]): Infos of sweeps.
                - timestamp (float): Sample timestamp.
                - img_filename (str, optional): Image filename.
                - lidar2img (list[np.ndarray], optional): Transformations \
                    from lidar to different cameras.
                - ann_info (dict): Annotation info.
        """
        info = self.data_infos[index]

        for i in range(len(info['gt_names'])):
            if info['gt_names'][i] in self.NameMapping.keys():
                info['gt_names'][i] = self.NameMapping[info['gt_names'][i]]


        # gt_masks,gt_labels,gt_bboxes = self.get_map_info(index)
        gt_masks,gt_labels,gt_bboxes, gt_line_points = self.get_map_info(index)
        


        input_dict = dict(
            folder=info['folder'],
            scene_token=info['folder'],
            frame_idx=info['frame_idx'],
            ego_yaw=np.nan_to_num(info['ego_yaw'],nan=np.pi/2),
            ego_translation=info['ego_translation'],
            sensors=info['sensors'],
            world2lidar=info['sensors']['LIDAR_TOP']['world2lidar'],
            gt_ids=info['gt_ids'],
            gt_boxes=info['gt_boxes'],
            gt_names=info['gt_names'],
            ego_vel = info['ego_vel'],
            ego_accel = info['ego_accel'],
            ego_rotation_rate = info['ego_rotation_rate'],
            npc2world = info['npc2world'],
            gt_lane_labels=gt_labels,
            gt_lane_bboxes=gt_bboxes,
            gt_lane_masks=gt_masks,
            timestamp=info['frame_idx'], #/10,
            # new
            gt_map_pts = gt_line_points,
            gt_map_labels=gt_labels

        )
        # import pdb; pdb.set_trace()

        if self.modality['use_camera']:
            image_paths = []
            lidar2img_rts = []
            lidar2cam_rts = []
            cam_intrinsics = []
            lidar2ego = info['sensors']['LIDAR_TOP']['lidar2ego']
            for sensor_type, cam_info in info['sensors'].items():
                
                if not 'CAM' in sensor_type:
                    continue
                image_paths.append(osp.join(self.data_root,cam_info['data_path']))
                # obtain lidar to image transformation matrix
                cam2ego = cam_info['cam2ego']
                intrinsic = cam_info['intrinsic']
                # import pdb;pdb.set_trace()
                intrinsic_pad = np.eye(4)
                intrinsic_pad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
                # lidar2cam = self.invert_pose(cam2ego) @ lidar2ego
                lidar2cam = np.linalg.inv(cam2ego) @ lidar2ego
                # lidar2cam =  lidar2ego @ np.linalg.inv(cam2ego) 
                lidar2img = intrinsic_pad @ lidar2cam
                lidar2img_rts.append(lidar2img)
                cam_intrinsics.append(intrinsic_pad)
                lidar2cam_rts.append(lidar2cam)
            ego2world = np.eye(4)
            ego2world[0:3,0:3] = Quaternion(axis=[0, 0, 1], radians=input_dict['ego_yaw']).rotation_matrix
            ego2world[0:3,3] = input_dict['ego_translation']
            lidar2global = ego2world @ lidar2ego
            input_dict.update(
                dict(
                    img_filename=image_paths,
                    lidar2img=lidar2img_rts,
                    cam_intrinsic=cam_intrinsics,
                    lidar2cam=lidar2cam_rts,
                    l2g_r_mat=lidar2global[0:3,0:3],
                    l2g_t=lidar2global[0:3,3]

                ))
        # import pdb; pdb.set_trace()
        annos = self.get_ann_info(index)
        input_dict['ann_info'] = annos
        yaw = input_dict['ego_yaw']
        rotation = list(Quaternion(axis=[0, 0, 1], radians=yaw))
        if yaw < 0:
            yaw += 2*np.pi
        yaw_in_degree = yaw / np.pi * 180 
        
        can_bus = np.zeros(18)
        can_bus[:3] = input_dict['ego_translation']
        can_bus[3:7] = rotation
        can_bus[7:10] = input_dict['ego_vel']
        can_bus[10:13] = input_dict['ego_accel']
        can_bus[13:16] = input_dict['ego_rotation_rate']
        can_bus[16] = yaw
        can_bus[17] = yaw_in_degree
        input_dict['can_bus'] = can_bus
        all_frames = []
        for adj_idx in range(index-self.occ_receptive_field+1,index+self.occ_n_future+1):
            if adj_idx<0 or adj_idx>=len(self.data_infos):
                all_frames.append(-1)
            elif self.data_infos[adj_idx]['folder'] != self.data_infos[index]['folder']:
                all_frames.append(-1)
            else: 
                all_frames.append(adj_idx)
            
        future_frames = all_frames[self.occ_receptive_field-1:]
        input_dict['occ_has_invalid_frame'] = (-1 in all_frames[:self.occ_only_total_frames])
        input_dict['occ_img_is_valid'] = np.array(all_frames) >= 0
        occ_future_ann_infos = []
        for future_frame in future_frames:
            if future_frame >= 0:
                occ_future_ann_infos.append(
                    self.get_ann_boxes_only(future_frame),
                )
            else:
                occ_future_ann_infos.append(None)
        input_dict['occ_future_ann_infos'] = occ_future_ann_infos

        input_dict.update(self.occ_get_transforms(future_frames))
        sdc_planning, sdc_planning_mask = self.get_ego_future_xy(index,self.sample_interval,self.planning_frames)
        input_dict['sdc_planning'] = sdc_planning
        input_dict['sdc_planning_mask'] = sdc_planning_mask
        command = info['command_near']
        if command < 0:
            command = 4
        command -= 1
        input_dict['command'] = command

        return input_dict


    def get_map_info(self, index):

        gt_masks = []
        gt_labels = []
        gt_bboxes = []
        gt_line_points = []
        

        ann_info = self.data_infos[index]
        town_name = ann_info['town_name']
        map_info = self.map_infos[town_name]
        lane_points = map_info['lane_points']
        lane_sample_points = map_info['lane_sample_points']
        lane_types = map_info['lane_types']
        trigger_volumes_points = map_info['trigger_volumes_points']
        trigger_volumes_sample_points = map_info['trigger_volumes_sample_points']
        trigger_volumes_types = map_info['trigger_volumes_types']
        world2lidar = np.array(ann_info['sensors']['LIDAR_TOP']['world2lidar'])
        ego_xy = np.linalg.inv(world2lidar)[0:2,3]

        #1st search
        max_distance = 100
        chosed_idx = []
        # chose_map = []
        for idx in range(len(lane_sample_points)):
            single_sample_points = lane_sample_points[idx]
            distance = np.linalg.norm((single_sample_points[:,0:2]-ego_xy),axis=-1)
            if np.min(distance) < max_distance:
                chosed_idx.append(idx)

        # import pdb; pdb.set_trace()
        for idx in chosed_idx:
            if not lane_types[idx] in self.map_element_class.keys():
                continue
            points = lane_points[idx]
            points = np.concatenate([points,np.ones((points.shape[0],1))],axis=-1)
            points_in_ego = (world2lidar @ points.T).T
            #print(points_in_ego)
            mask = (points_in_ego[:,0]>self.point_cloud_range[0]) & (points_in_ego[:,0]<self.point_cloud_range[3]) & (points_in_ego[:,1]>self.point_cloud_range[1]) & (points_in_ego[:,1]<self.point_cloud_range[4])
            points_in_ego_range = points_in_ego[mask,0:2]
            if len(points_in_ego_range) > 1:
                gt_mask = np.zeros(self.bev_size,dtype=np.uint8)
                normalized_points = np.zeros_like(points_in_ego_range)
                normalized_points[:,0] = (points_in_ego_range[:,0] + self.patch_size[0]/2)*(self.bev_size[0]/self.patch_size[0])
                normalized_points[:,1] = (points_in_ego_range[:,1] + self.patch_size[1]/2)*(self.bev_size[1]/self.patch_size[1])
                cv2.polylines(gt_mask, [normalized_points.astype(np.int32)], False, color=1, thickness=2)
                gt_label =  self.map_element_class[lane_types[idx]]
                gt_masks.append(gt_mask)
                gt_labels.append(gt_label)
                ys, xs = np.where(gt_mask==1)
                gt_bboxes.append([min(xs), min(ys), max(xs), max(ys)]) 
                # add
                line_points = points_in_ego[mask, 0:3]
                def uniform_sample(lst, k):
                    # indices = np.random.choice(len(lst), k, replace=False)
                    if len(lst) < 21:
                        import pdb; pdb.set_trace()
                    indices = np.linspace(0, len(lst), len(lst))
                    diff = int(len(lst)/20)
                    indices_s = indices[::diff]
                    indices_s = indices_s[:k]
                    # import pdb; pdb.set_trace()
                    return [lst[int(i)] for i in indices_s.tolist()]
                if len(line_points) < 21:
                    # import pdb; pdb.set_trace()
                    miss = 21 - len(line_points)
                    for i in range(miss +1):
                        line_points = np.concatenate((line_points,np.array([line_points[-1,:]])),axis=0)

                line_points_select = np.array(uniform_sample(line_points, 20))
                gt_line_points.append(line_points_select)

        for idx in range(len(trigger_volumes_points)):
            if not trigger_volumes_types[idx] in self.map_element_class.keys():
                continue
            points = trigger_volumes_points[idx]
            points = np.concatenate([points,np.ones((points.shape[0],1))],axis=-1)
            points_in_ego = (world2lidar @ points.T).T
            mask = (points_in_ego[:,0]>self.point_cloud_range[0]) & (points_in_ego[:,0]<self.point_cloud_range[3]) & (points_in_ego[:,1]>self.point_cloud_range[1]) & (points_in_ego[:,1]<self.point_cloud_range[4])
            points_in_ego_range = points_in_ego[mask,0:2]
            if mask.all():
                gt_mask = np.zeros(self.bev_size,dtype=np.uint8)
                normalized_points = np.zeros_like(points_in_ego_range)
                normalized_points[:,0] = (points_in_ego_range[:,0] + self.patch_size[0]/2)*(self.bev_size[0]/self.patch_size[0])
                normalized_points[:,1] = (points_in_ego_range[:,1] + self.patch_size[1]/2)*(self.bev_size[1]/self.patch_size[1])
                cv2.fillConvexPoly(gt_mask, normalized_points.astype(np.int32), color=1)
                gt_label = self.map_element_class[trigger_volumes_types[idx]]
                gt_masks.append(gt_mask)
                gt_labels.append(gt_label)
                ys, xs = np.where(gt_mask==1)
                gt_bboxes.append([min(xs), min(ys), max(xs), max(ys)]) 
                # add 
                volum_points = points_in_ego[mask, 0:3]
                def uniform_sample(lst, k):
                    # indices = np.random.choice(len(lst), k, replace=False)
                    indices = np.linspace(1, len(lst), len(lst))
                    diff = int(len(lst)/20)
                    indices_s = indices[::diff]
                    indices_s = indices_s[:k]
                    # import pdb; pdb.set_trace()
                    if len(lst) < 20:
                        import pdb; pdb.set_trace()
                    return [lst[int(i)] for i in indices_s.tolist()]
                if len(volum_points) < 20:
                    # import pdb; pdb.set_trace()
                    miss = 20 - len(volum_points)
                    for i in range(miss +1):
                        volum_points = np.concatenate((volum_points,np.array([volum_points[-1,:]])),axis=0)

                volum_points_select = np.array(uniform_sample(volum_points, 20))
                gt_line_points.append(volum_points_select)


        if len(gt_masks) == 0:
            gt_masks.append(np.zeros(self.bev_size,dtype=np.uint8))
            gt_labels.append(-1)
            gt_bboxes.append([0,0,0,0])
            # add
            gt_line_points.append([[0,0]])

        gt_masks = np.stack(gt_masks)
        gt_labels = np.array(gt_labels)
        gt_bboxes = np.array(gt_bboxes)
        # add
        gt_line_points = np.array(gt_line_points)

        # return gt_masks,gt_labels,gt_bboxes
        return gt_masks,gt_labels,gt_bboxes, gt_line_points


    def get_ann_info(self, index):
        """Get annotation info according to the given index.

        Args:
            index (int): Index of the annotation data to get.

        Returns:
            dict: Annotation information consists of the following keys:

                - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): \
                    3D ground truth bboxes
                - gt_labels_3d (np.ndarray): Labels of ground truths.
                - gt_names (list[str]): Class names of ground truths.
        """
        info = self.data_infos[index]
        # filter out bbox containing no points
        # import pdb; pdb.set_trace()

        for i in range(len(info['gt_names'])):
            if info['gt_names'][i] in self.NameMapping.keys():
                info['gt_names'][i] = self.NameMapping[info['gt_names'][i]]
        mask = (info['num_points'] >= -1)
        gt_bboxes_3d = info['gt_boxes'][mask]
        gt_names_3d = info['gt_names'][mask]
        gt_inds = info['gt_ids']
        gt_labels_3d = []

        for cat in gt_names_3d:
            if cat in self.CLASSES:
                gt_labels_3d.append(self.CLASSES.index(cat))
            else:
                gt_labels_3d.append(-1)
        gt_labels_3d = np.array(gt_labels_3d)
        if not self.with_velocity:
            gt_bboxes_3d = gt_bboxes_3d[:,0:7]
        gt_bboxes_3d = LiDARInstance3DBoxes(
            gt_bboxes_3d,
            box_dim=gt_bboxes_3d.shape[-1],
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
        # convert
        gt_bboxes_3d = gt_bboxes_3d.tensor
        
        ego_future_track, ego_future_mask = self.get_ego_future_xy(index,self.sample_interval,self.predict_frames)
        past_track, past_mask = self.get_past_or_future_xy(index,self.sample_interval,self.past_frames,past_or_future='past',local_xy=True)
        predict_track, predict_mask = self.get_past_or_future_xy(index,self.sample_interval,self.predict_frames,past_or_future='future',local_xy=False)
        mask = (past_mask.sum((1,2))>0).astype(np.int)
        future_track = predict_track[:,0:self.future_frames,:]*mask[:,None,None]
        future_mask = predict_mask[:,0:self.future_frames,:]*mask[:,None,None]
        full_past_track = np.concatenate([past_track,future_track],axis=1)
        full_past_mask = np.concatenate([past_mask,future_mask],axis=1)
        gt_sdc_bbox, gt_sdc_label =self.generate_sdc_info(index)
        anns_results = dict(
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels_3d=gt_labels_3d,
            gt_names=gt_names_3d,
            gt_inds=gt_inds,
            gt_fut_traj=predict_track,
            gt_fut_traj_mask=predict_mask,
            gt_past_traj=full_past_track,
            gt_past_traj_mask=full_past_mask,
            gt_sdc_bbox=gt_sdc_bbox,
            gt_sdc_label=gt_sdc_label,
            gt_sdc_fut_traj=ego_future_track[:,:,0:2],
            gt_sdc_fut_traj_mask=ego_future_mask,
            # gt_ego_fut_trajs=ego_future_track[:,:,0:2],
            # gt_ego_fut_masks=ego_future_mask,
            )
        return anns_results

    def get_ann_boxes_only(self, index):

        info = self.data_infos[index]
        for i in range(len(info['gt_names'])):
            if info['gt_names'][i] in self.NameMapping.keys():
                info['gt_names'][i] = self.NameMapping[info['gt_names'][i]]
        gt_bboxes_3d = info['gt_boxes']
        gt_names_3d = info['gt_names']
        gt_inds = info['gt_ids']
        gt_labels_3d = []
        for cat in gt_names_3d:
            if cat in self.CLASSES:
                gt_labels_3d.append(self.CLASSES.index(cat))
            
            else:
                gt_labels_3d.append(-1)
        gt_labels_3d = np.array(gt_labels_3d)
        if not self.with_velocity:
            gt_bboxes_3d = gt_bboxes_3d[:,0:7]
        gt_bboxes_3d = LiDARInstance3DBoxes(
            gt_bboxes_3d,
            box_dim=gt_bboxes_3d.shape[-1],
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
        # convert
        gt_bboxes_3d = gt_bboxes_3d.tensor
        boxes_annos = dict(
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels_3d=gt_labels_3d,
            gt_inds=gt_inds,
            )
        return boxes_annos

    def __getitem__(self, idx):
        """Get item from infos according to the given index.
        Returns:
            dict: Data dictionary of the corresponding index.
        """
        # import pdb; pdb.set_trace()
        # print("---carla idx ==", idx)
        if isinstance(idx, dict):
            aug_config = idx["aug_config"]
            idx = idx["idx"]
        else:
            aug_config = self.get_augmentation()
            # aug_config = None

        if self.test_mode:
            # import pdb; pdb.set_trace()
            return self.prepare_test_data(idx)
        while True:
        # else:
            # import pdb; pdb.set_trace()
            data = self.prepare_train_data(idx, aug_config)
            if data is None:
                print("---idx---- is none", idx)
                idx = self._rand_another(idx)
                continue
            # for key in data.keys():
            #     if key.type is not 'float':
            #         print("============", key, len(data[key]))
            data.pop('gt_inds', None)
            data.pop('gt_fut_traj', None)
            data.pop('gt_fut_traj_mask', None)
            data.pop('sdc_planning', None)
            data.pop('sdc_planning_mask', None),

            data.pop('l2g_r_mat', None)
            data.pop('l2g_t', None)
            data.pop('gt_past_traj', None)
            data.pop('gt_past_traj_mask', None)
            data.pop('gt_sdc_bbox', None)
            data.pop('gt_sdc_label', None)

            data.pop('gt_lane_labels', None)
            data.pop('gt_lane_bboxes', None)
            data.pop('gt_lane_masks', None)
            data.pop('command', None)
            # print("============", data['gt_agent_fut_trajs'].data.shape)
            # for key in data.keys():
                # if data[key] is DataContainer:
                #     print(" is ===dc ", key)
                #     continue
                # data[key] = DC(to_tensor(data[key]))
            data['gt_agent_fut_trajs'] = DC(to_tensor(data['gt_agent_fut_trajs']))
            data['gt_agent_fut_masks'] = DC(to_tensor(data['gt_agent_fut_masks']))
            data['gt_ego_fut_cmd'] = DC(to_tensor(data['gt_ego_fut_cmd']))
            data['gt_map_pts'] = DC(to_tensor(data['gt_map_pts']))
            data['gt_map_labels'] = DC(to_tensor(data['gt_map_labels']))
            # data['gt_map_pts'] = DC(to_tensor(data['gt_map_pts']))
            return data
        
    def generate_sdc_info(self,idx):

        info = self.data_infos[idx]
        ego_size = info['ego_size']
        ego_vel = info['ego_vel']
        psudo_sdc_bbox = np.array([0.0, 0.0, 0.0, ego_size[0], ego_size[1], ego_size[2], -np.pi, ego_vel[1], ego_vel[0] ])
        if not self.with_velocity:
            psudo_sdc_bbox = psudo_sdc_bbox[0:7]
        gt_bboxes_3d = np.array([psudo_sdc_bbox]).astype(np.float32)
        gt_names_3d = ['car']
        gt_labels_3d = []
        for cat in gt_names_3d:
            if cat in self.CLASSES:
                gt_labels_3d.append(self.CLASSES.index(cat))
            else:
                gt_labels_3d.append(-1)
        gt_labels_3d = np.array(gt_labels_3d)

        # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
        # the same as KITTI (0.5, 0.5, 0)
        gt_bboxes_3d = LiDARInstance3DBoxes(
            gt_bboxes_3d,
            box_dim=gt_bboxes_3d.shape[-1],
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
        # convert
        gt_bboxes_3d = gt_bboxes_3d.tensor
  
        gt_labels_3d = DC(to_tensor(gt_labels_3d))
        gt_bboxes_3d = DC(gt_bboxes_3d, cpu_only=True)

        return gt_bboxes_3d, gt_labels_3d

    def get_past_or_future_xy(self,idx,sample_rate,frames,past_or_future,local_xy=False):

        assert past_or_future in ['past','future']
        
        if past_or_future == 'past':
            adj_idx_list = range(idx-sample_rate,idx-(frames+1)*sample_rate,-sample_rate)
        else:
            adj_idx_list = range(idx+sample_rate,idx+(frames+1)*sample_rate,sample_rate)

        cur_frame = self.data_infos[idx]
        box_ids = cur_frame['gt_ids']
        adj_track = np.zeros((len(box_ids),frames,2))
        adj_mask = np.zeros((len(box_ids),frames,2))
        world2lidar_ego_cur = cur_frame['sensors']['LIDAR_TOP']['world2lidar']
        # if past_or_future == 'future':
        #     import pdb; pdb.set_trace()
        for i in range(len(box_ids)):
            
            box_id = box_ids[i]
            cur_box2lidar = world2lidar_ego_cur @ cur_frame['npc2world'][i]
            cur_xy = cur_box2lidar[0:2,3]      
            for j in range(len(adj_idx_list)):
                adj_idx = adj_idx_list[j]
                if adj_idx <0 or adj_idx>=len(self.data_infos):
                    break
                adj_frame = self.data_infos[adj_idx]
                if adj_frame['folder'] != cur_frame ['folder']:
                    break
                if len(np.where(adj_frame['gt_ids']==box_id)[0])==0:
                    continue
                assert len(np.where(adj_frame['gt_ids']==box_id)[0]) == 1 , np.where(adj_frame['gt_ids']==box_id)[0]
                adj_idx = np.where(adj_frame['gt_ids']==box_id)[0][0]
                adj_box2lidar = world2lidar_ego_cur @ adj_frame['npc2world'][adj_idx]
                adj_xy = adj_box2lidar[0:2,3]  
                # adj_xy = adj_frame['gt_boxes'][adj_idx][:2]
                # import pdb; pdb.set_trace()
                # adj_xy = adj_box2lidar[0:2]  
                # print("===adj_box2lidar=====", idx, i, j, box_id, adj_idx, cur_xy, adj_xy)  
                if local_xy:
                    adj_xy -= cur_xy
                adj_track[i,j,:] = adj_xy
                
                # print("===adj_track=====", idx, i, j, cur_xy, adj_xy, adj_frame['npc2world'][adj_idx])
                adj_mask[i,j,:] = 1
        return adj_track, adj_mask

    def get_ego_future_xy(self,idx,sample_rate,frames):

        adj_idx_list = range(idx+sample_rate,idx+(frames+1)*sample_rate,sample_rate)
        cur_frame = self.data_infos[idx]
        adj_track = np.zeros((1,frames,3))
        adj_mask = np.zeros((1,frames,2))
        world2lidar_ego_cur = cur_frame['sensors']['LIDAR_TOP']['world2lidar']
        for j in range(len(adj_idx_list)):
            adj_idx = adj_idx_list[j]
            if adj_idx <0 or adj_idx>=len(self.data_infos):
                break
            adj_frame = self.data_infos[adj_idx]
            if adj_frame['folder'] != cur_frame ['folder']:
                break
            world2lidar_ego_adj = adj_frame['sensors']['LIDAR_TOP']['world2lidar']
            adj2cur_lidar = world2lidar_ego_cur @ np.linalg.inv(world2lidar_ego_adj)
            xy = adj2cur_lidar[0:2,3]
            yaw = np.arctan2(adj2cur_lidar[1,0],adj2cur_lidar[0,0])
            yaw = -yaw -np.pi
            while yaw > np.pi:
                yaw -= np.pi*2
            while yaw < -np.pi:
                yaw += np.pi*2
            adj_track[0,j,0:2] = xy
            adj_track[0,j,2] = yaw
            adj_mask[0,j,:] = 1

        return adj_track, adj_mask

    def occ_get_transforms(self, indices, data_type=torch.float32):

        l2e_r_mats = []
        l2e_t_vecs = []
        e2g_r_mats = []
        e2g_t_vecs = []

        for index in indices:
            if index == -1:
                l2e_r_mats.append(None)
                l2e_t_vecs.append(None)
                e2g_r_mats.append(None)
                e2g_t_vecs.append(None)
            else:
                info = self.data_infos[index]
                lidar2ego = info['sensors']['LIDAR_TOP']['lidar2ego']
                l2e_r = lidar2ego[0:3,0:3]
                l2e_t = lidar2ego[0:3,3]
                ego2global = np.linalg.inv(info['world2ego'])
                e2g_r = ego2global[0:3,0:3]
                e2g_t = ego2global[0:3,3]
                l2e_r_mats.append(torch.tensor(l2e_r).to(data_type))
                l2e_t_vecs.append(torch.tensor(l2e_t).to(data_type))
                e2g_r_mats.append(torch.tensor(e2g_r).to(data_type))
                e2g_t_vecs.append(torch.tensor(e2g_t).to(data_type))
        res = {
            'occ_l2e_r_mats': l2e_r_mats,
            'occ_l2e_t_vecs': l2e_t_vecs,
            'occ_e2g_r_mats': e2g_r_mats,
            'occ_e2g_t_vecs': e2g_t_vecs,
        }

        return res

# carla not working

    # def evaluate(self,
    #              results,
    #              eval_mode, # new from unscenes
    #              metric='bbox',
    #              logger=None,
    #              jsonfile_prefix=None,
    #              result_names=['pts_bbox'],
    #              show=False,
    #              out_dir=None,
    #              pipeline=None):
    #     """Evaluation in nuScenes protocol.

    #     Args:
    #         results (list[dict]): Testing results of the dataset.
    #         metric (str | list[str]): Metrics to be evaluated.
    #         logger (logging.Logger | str | None): Logger used for printing
    #             related information during evaluation. Default: None.
    #         jsonfile_prefix (str | None): The prefix of json files. It includes
    #             the file path and the prefix of filename, e.g., "a/b/prefix".
    #             If not specified, a temp file will be created. Default: None.
    #         show (bool): Whether to visualize.
    #             Default: False.
    #         out_dir (str): Path to save the visualization results.
    #             Default: None.
    #         pipeline (list[dict], optional): raw data loading for showing.
    #             Default: None.

    #     Returns:
    #         dict[str, float]: Results of each evaluation metric.
    #     """

    #     # NOTE:Curremtly we only support evaluation on detection and planning 

    #     #  if eval_mode['with_det']?
    #     import pdb; pdb.set_trace()

    #     result_files, tmp_dir = self.format_results(results['bbox_results'], jsonfile_prefix)    
    #     result_path = result_files
    #     with open(result_path) as f:
    #         result_data = json.load(f)
    #     pred_boxes = EvalBoxes.deserialize(result_data['results'], DetectionBox)
    #     meta = result_data['meta']

    #     gt_boxes = self.load_gt()

    #     metric_data_list = DetectionMetricDataList()
    #     for class_name in self.eval_cfg['class_names']:
    #         for dist_th in self.eval_cfg['dist_ths']:
    #             md = accumulate(gt_boxes, pred_boxes, class_name, center_distance, dist_th)
    #             metric_data_list.set(class_name, dist_th, md)
    #             metrics = DetectionMetrics(self.eval_cfg)

    #     for class_name in self.eval_cfg['class_names']:
    #         # Compute APs.
    #         for dist_th in self.eval_cfg['dist_ths']:
    #             metric_data = metric_data_list[(class_name, dist_th)]
    #             ap = calc_ap(metric_data, self.eval_cfg['min_recall'], self.eval_cfg['min_precision'])
    #             metrics.add_label_ap(class_name, dist_th, ap)

    #         # Compute TP metrics.
    #         for metric_name in self.eval_cfg['tp_metrics']:
    #             metric_data = metric_data_list[(class_name, self.eval_cfg['dist_th_tp'])]
    #             tp = calc_tp(metric_data, self.eval_cfg['min_recall'], metric_name)
    #             metrics.add_label_tp(class_name, metric_name, tp)

    #     metrics_summary = metrics.serialize()
    #     metrics_summary['meta'] = meta.copy()
    #     print('mAP: %.4f' % (metrics_summary['mean_ap']))
    #     err_name_mapping = {
    #         'trans_err': 'mATE',
    #         'scale_err': 'mASE',
    #         'orient_err': 'mAOE',
    #         'vel_err': 'mAVE',
    #     }
    #     for tp_name, tp_val in metrics_summary['tp_errors'].items():
    #         print('%s: %.4f' % (err_name_mapping[tp_name], tp_val))
    #     print('NDS: %.4f' % (metrics_summary['nd_score']))
    #     #print('Eval time: %.1fs' % metrics_summary['eval_time'])

    #     # Print per-class metrics.
    #     print()
    #     print('Per-class results:')
    #     print('Object Class\tAP\tATE\tASE\tAOE\tAVE')
    #     class_aps = metrics_summary['mean_dist_aps']
    #     class_tps = metrics_summary['label_tp_errors']
    #     for class_name in class_aps.keys():
    #         print('%s\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f'
    #               % (class_name, class_aps[class_name],
    #                  class_tps[class_name]['trans_err'],
    #                  class_tps[class_name]['scale_err'],
    #                  class_tps[class_name]['orient_err'],
    #                  class_tps[class_name]['vel_err']))        

    #     detail = dict()
    #     metric_prefix = 'bbox_NuScenes'
    #     for name in self.eval_cfg['class_names']:
    #         for k, v in metrics_summary['label_aps'][name].items():
    #             val = float('{:.4f}'.format(v))
    #             detail['{}/{}_AP_dist_{}'.format(metric_prefix, name, k)] = val
    #         for k, v in metrics_summary['label_tp_errors'][name].items():
    #             val = float('{:.4f}'.format(v))
    #             detail['{}/{}_{}'.format(metric_prefix, name, k)] = val
    #         for k, v in metrics_summary['tp_errors'].items():
    #             val = float('{:.4f}'.format(v))
    #             detail['{}/{}'.format(metric_prefix,self.eval_cfg['err_name_maping'][k])] = val
    #     detail['{}/NDS'.format(metric_prefix)] = metrics_summary['nd_score']
    #     detail['{}/mAP'.format(metric_prefix)] = metrics_summary['mean_ap']

    #     if 'planning_results_computed' in results.keys():
    #             planning_results_computed = results['planning_results_computed']
    #             planning_tab = PrettyTable()
    #             planning_tab.field_names = [
    #                 "metrics", "0.5s", "1.0s", "1.5s", "2.0s", "2.5s", "3.0s"]
    #             for key in planning_results_computed.keys():
    #                 value = planning_results_computed[key]
    #                 row_value = []
    #                 row_value.append(key)
    #                 for i in range(len(value)):
    #                     row_value.append('%.4f' % float(value[i]))
    #                 planning_tab.add_row(row_value)
    #             print(planning_tab)


    #     return detail

    # def load_gt(self):
    #     all_annotations = EvalBoxes()
    #     for i in range(len(self.data_infos)):
    #         sample_boxes = []
    #         sample_data = self.data_infos[i]

    #         gt_boxes = sample_data['gt_boxes']
            
    #         for j in range(gt_boxes.shape[0]):
    #             class_name = self.NameMapping[sample_data['gt_names'][j]]
    #             if not class_name in self.eval_cfg['class_range'].keys():
    #                 continue
    #             range_x, range_y = self.eval_cfg['class_range'][class_name]
    #             if abs(gt_boxes[j,0]) > range_x or abs(gt_boxes[j,1]) > range_y:
    #                 continue
    #             sample_boxes.append(DetectionBox(
    #                                             sample_token=sample_data['folder']+'_'+str(sample_data['frame_idx']),
    #                                             translation=gt_boxes[j,0:3],
    #                                             size=gt_boxes[j,3:6],
    #                                             rotation=list(Quaternion(axis=[0, 0, 1], radians=-gt_boxes[j,6]-np.pi/2)),
    #                                             velocity=gt_boxes[j,7:9],
    #                                             num_pts=int(sample_data['num_points'][j]),
    #                                             detection_name=self.NameMapping[sample_data['gt_names'][j]],
    #                                             detection_score=-1.0,  
    #                                             attribute_name=self.NameMapping[sample_data['gt_names'][j]]
    #                                             ))
    #         all_annotations.add_boxes(sample_data['folder']+'_'+str(sample_data['frame_idx']), sample_boxes)
    #     return all_annotations
    
    # def _format_bbox(self, results, jsonfile_prefix=None):
    #     """Convert the results to the standard format.

    #     Args:
    #         results (list[dict]): Testing results of the dataset.
    #         jsonfile_prefix (str): The prefix of the output jsonfile.
    #             You can specify the output directory/filename by
    #             modifying the jsonfile_prefix. Default: None.

    #     Returns:
    #         str: Path of the output json file.
    #     """


    #     nusc_annos = {}
    #     mapped_class_names = self.CLASSES

    #     print('Start to convert detection format...')
    #     for sample_id, det in enumerate(track_iter_progress(results)):
    #         #pdb.set_trace()
    #         annos = []
    #         box3d = det['boxes_3d']
    #         scores = det['scores_3d']
    #         labels = det['labels_3d']
    #         box_gravity_center = box3d.gravity_center
    #         box_dims = box3d.dims
    #         box_yaw = box3d.yaw.numpy()
    #         box_yaw = -box_yaw - np.pi / 2
    #         sample_token = self.data_infos[sample_id]['folder'] + '_' + str(self.data_infos[sample_id]['frame_idx'])



    #         for i in range(len(box3d)):
    #             #import pdb;pdb.set_trace()
    #             quat = list(Quaternion(axis=[0, 0, 1], radians=box_yaw[i]))
    #             velocity = [box3d.tensor[i, 7].item(),box3d.tensor[i, 8].item()]
    #             name = mapped_class_names[labels[i]]
    #             nusc_anno = dict(
    #                 sample_token=sample_token,
    #                 translation=box_gravity_center[i].tolist(),
    #                 size=box_dims[i].tolist(),
    #                 rotation=quat,
    #                 velocity=velocity,
    #                 detection_name=name,
    #                 detection_score=scores[i].item(),
    #                 attribute_name=name)
    #             annos.append(nusc_anno)
    #         nusc_annos[sample_token] = annos
    #     nusc_submissions = {
    #         'meta': self.modality,
    #         'results': nusc_annos,
    #     }

    #     mkdir_or_exist(jsonfile_prefix)
    #     res_path = osp.join(jsonfile_prefix, 'results_nusc.json')
    #     print('Results writes to', res_path)
    #     dump(nusc_submissions, res_path)
    #     return res_path  

    # def format_results(self, results, jsonfile_prefix=None):
    #     """Format the results to json (standard format for COCO evaluation).

    #     Args:
    #         results (list[dict]): Testing results of the dataset.
    #         jsonfile_prefix (str | None): The prefix of json files. It includes
    #             the file path and the prefix of filename, e.g., "a/b/prefix".
    #             If not specified, a temp file will be created. Default: None.

    #     Returns:
    #         tuple: Returns (result_files, tmp_dir), where `result_files` is a \
    #             dict containing the json filepaths, `tmp_dir` is the temporal \
    #             directory created for saving json files when \
    #             `jsonfile_prefix` is not specified.
    #     """
    #     assert isinstance(results, list), 'results must be a list'
    #     # assert len(results) == len(self), (
    #     #     'The length of results is not equal to the dataset len: {} != {}'.
    #     #     format(len(results), len(self)))

    #     if jsonfile_prefix is None:
    #         tmp_dir = tempfile.TemporaryDirectory()
    #         jsonfile_prefix = osp.join(tmp_dir.name, 'results')
    #     else:
    #         tmp_dir = None

    #     if not ('pts_bbox' in results[0] or 'img_bbox' in results[0]):
    #         result_files = self._format_bbox(results, jsonfile_prefix)
    #     else:
    #         # should take the inner dict out of 'pts_bbox' or 'img_bbox' dict
    #         result_files = dict()
    #         for name in results[0]:
    #             print(f'\nFormating bboxes of {name}')
    #             results_ = [out[name] for out in results]
    #             tmp_file_ = osp.join(jsonfile_prefix, name)
    #             result_files.update(
    #                 {name: self._format_bbox(results_, tmp_file_)})
    #     return result_files, tmp_dir
    

    ### ours

    def _format_bbox(self, results, jsonfile_prefix=None, tracking=False):
        nusc_annos = {}
        mapped_class_names = self.CLASSES

        print("Start to convert detection format...")
        for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
            annos = []
            boxes = output_to_nusc_box(
                det, threshold=self.tracking_threshold if tracking else None
            )
            # sample_token = self.data_infos[sample_id]["token"]
            sample_token = self.data_infos[sample_id]["folder"] + str(self.data_infos[sample_id]["frame_idx"])
            boxes = lidar_nusc_box_to_global(
                self.data_infos[sample_id],
                boxes,
                mapped_class_names,
                self.det3d_eval_configs,
                self.det3d_eval_version,
            )
            for i, box in enumerate(boxes):
                name = mapped_class_names[box.label]
                if tracking and name in [
                    "barrier",
                    "traffic_cone",
                    "construction_vehicle",
                ]:
                    continue
                if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
                    if name in [
                        "car",
                        "construction_vehicle",
                        "bus",
                        "truck",
                        "trailer",
                    ]:
                        attr = "vehicle.moving"
                    elif name in ["bicycle", "motorcycle"]:
                        attr = "cycle.with_rider"
                    else:
                        attr = NuScenes3DDataset.DefaultAttribute[name]
                else:
                    if name in ["pedestrian"]:
                        attr = "pedestrian.standing"
                    elif name in ["bus"]:
                        attr = "vehicle.stopped"
                    else:
                        attr = NuScenes3DDataset.DefaultAttribute[name]

                nusc_anno = dict(
                    sample_token=sample_token,
                    translation=box.center.tolist(),
                    size=box.wlh.tolist(),
                    rotation=box.orientation.elements.tolist(),
                    velocity=box.velocity[:2].tolist(),
                )
                if not tracking:
                    nusc_anno.update(
                        dict(
                            detection_name=name,
                            detection_score=box.score,
                            attribute_name=attr,
                        )
                    )
                else:
                    nusc_anno.update(
                        dict(
                            tracking_name=name,
                            tracking_score=box.score,
                            tracking_id=str(box.token),
                        )
                    )

                annos.append(nusc_anno)
            nusc_annos[sample_token] = annos
        nusc_submissions = {
            "meta": self.modality,
            "results": nusc_annos,
        }

        mmcv.mkdir_or_exist(jsonfile_prefix)
        res_path = osp.join(jsonfile_prefix, "results_nusc.json")
        print("Results writes to", res_path)
        mmcv.dump(nusc_submissions, res_path)
        return res_path

    def _evaluate_single(
        self, result_path, logger=None, result_name="img_bbox", tracking=False
    ):
        from nuscenes import NuScenes

        output_dir = osp.join(*osp.split(result_path)[:-1])
        nusc = NuScenes(
            version=self.version, dataroot=self.data_root, verbose=False
        )
        eval_set_map = {
            "v1.0-mini": "mini_val",
            "v1.0-trainval": "val",
        }
        if not tracking:
            from nuscenes.eval.detection.evaluate import NuScenesEval
            # import pdb; pdb.set_trace()
            nusc_eval = NuScenesEval(
                nusc,
                config=self.det3d_eval_configs,
                result_path=result_path,
                eval_set=eval_set_map[self.version],
                output_dir=output_dir,
                verbose=True,
            )
            nusc_eval.main(render_curves=False)

            # record metrics
            metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json"))
            detail = dict()
            metric_prefix = f"{result_name}_NuScenes"
            for name in self.CLASSES:
                for k, v in metrics["label_aps"][name].items():
                    val = float("{:.4f}".format(v))
                    detail[
                        "{}/{}_AP_dist_{}".format(metric_prefix, name, k)
                    ] = val
                for k, v in metrics["label_tp_errors"][name].items():
                    val = float("{:.4f}".format(v))
                    detail["{}/{}_{}".format(metric_prefix, name, k)] = val
                for k, v in metrics["tp_errors"].items():
                    val = float("{:.4f}".format(v))
                    detail[
                        "{}/{}".format(metric_prefix, self.ErrNameMapping[k])
                    ] = val

            detail["{}/NDS".format(metric_prefix)] = metrics["nd_score"]
            detail["{}/mAP".format(metric_prefix)] = metrics["mean_ap"]
        else:
            from nuscenes.eval.tracking.evaluate import TrackingEval

            nusc_eval = TrackingEval(
                config=self.track3d_eval_configs,
                result_path=result_path,
                eval_set=eval_set_map[self.version],
                output_dir=output_dir,
                verbose=True,
                nusc_version=self.version,
                nusc_dataroot=self.data_root,
            )
            metrics = nusc_eval.main()

            # record metrics
            metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json"))
            print(metrics)
            detail = dict()
            metric_prefix = f"{result_name}_NuScenes"
            keys = [
                "amota",
                "amotp",
                "recall",
                "motar",
                "gt",
                "mota",
                "motp",
                "mt",
                "ml",
                "faf",
                "tp",
                "fp",
                "fn",
                "ids",
                "frag",
                "tid",
                "lgd",
            ]
            for key in keys:
                detail["{}/{}".format(metric_prefix, key)] = metrics[key]

        return detail

    def format_results(self, results, jsonfile_prefix=None, tracking=False):
        assert isinstance(results, list), "results must be a list"

        if jsonfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            jsonfile_prefix = osp.join(tmp_dir.name, "results")
        else:
            tmp_dir = None

        if not ("pts_bbox" in results[0] or "img_bbox" in results[0]):
            result_files = self._format_bbox(
                results, jsonfile_prefix, tracking=tracking
            )
        else:
            result_files = dict()
            for name in results[0]:
                print(f"\nFormating bboxes of {name}")
                results_ = [out[name] for out in results]
                tmp_file_ = jsonfile_prefix
                result_files.update(
                    {
                        name: self._format_bbox(
                            results_, tmp_file_, tracking=tracking
                        )
                    }
                )
        return result_files, tmp_dir

    def format_map_results(self, results, prefix=None):
        submissions = {'results': {},}
        
        for j, pred in enumerate(results):
            '''
            For each case, the result should be formatted as Dict{'vectors': [], 'scores': [], 'labels': []}
            'vectors': List of vector, each vector is a array([[x1, y1], [x2, y2] ...]),
                contain all vectors predicted in this sample.
            'scores: List of score(float), 
                contain scores of all instances in this sample.
            'labels': List of label(int), 
                contain labels of all instances in this sample.
            '''
            if pred is None: # empty prediction
                continue
            pred = pred['img_bbox']

            single_case = {'vectors': [], 'scores': [], 'labels': []}
            token = self.data_infos[j]['token']
            for i in range(len(pred['scores'])):
                score = pred['scores'][i]
                label = pred['labels'][i]
                vector = pred['vectors'][i]

                # A line should have >=2 points
                if len(vector) < 2:
                    continue
                
                single_case['vectors'].append(vector)
                single_case['scores'].append(score)
                single_case['labels'].append(label)
            
            submissions['results'][token] = single_case
        
        out_path = osp.join(prefix, 'submission_vector.json')
        print(f'saving submissions results to {out_path}')
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        mmcv.dump(submissions, out_path)
        return out_path

    def format_motion_results(self, results, jsonfile_prefix=None, tracking=False, thresh=None):
        nusc_annos = {}
        mapped_class_names = self.CLASSES

        print("Start to convert detection format...")
        for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
            annos = []
            boxes = output_to_nusc_box(
                det['img_bbox'], threshold=None
            )
            sample_token = self.data_infos[sample_id]["token"]
            boxes = lidar_nusc_box_to_global(
                self.data_infos[sample_id],
                boxes,
                mapped_class_names,
                self.det3d_eval_configs,
                self.det3d_eval_version,
                filter_with_cls_range=False,
            )
            for i, box in enumerate(boxes):
                if thresh is not None and box.score < thresh:
                    continue
                name = mapped_class_names[box.label]
                if tracking and name in [
                    "barrier",
                    "traffic_cone",
                    "construction_vehicle",
                ]:
                    continue
                if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
                    if name in [
                        "car",
                        "construction_vehicle",
                        "bus",
                        "truck",
                        "trailer",
                    ]:
                        attr = "vehicle.moving"
                    elif name in ["bicycle", "motorcycle"]:
                        attr = "cycle.with_rider"
                    else:
                        attr = NuScenes3DDataset.DefaultAttribute[name]
                else:
                    if name in ["pedestrian"]:
                        attr = "pedestrian.standing"
                    elif name in ["bus"]:
                        attr = "vehicle.stopped"
                    else:
                        attr = NuScenes3DDataset.DefaultAttribute[name]

                nusc_anno = dict(
                    sample_token=sample_token,
                    translation=box.center.tolist(),
                    size=box.wlh.tolist(),
                    rotation=box.orientation.elements.tolist(),
                    velocity=box.velocity[:2].tolist(),
                )
                if not tracking:
                    nusc_anno.update(
                        dict(
                            detection_name=name,
                            detection_score=box.score,
                            attribute_name=attr,
                        )
                    )
                else:
                    nusc_anno.update(
                        dict(
                            tracking_name=name,
                            tracking_score=box.score,
                            tracking_id=str(box.token),
                        )
                    )
                nusc_anno.update(
                    dict(
                        trajs=det['img_bbox']['trajs_3d'][i].numpy(),
                    )
                )
                annos.append(nusc_anno)
            nusc_annos[sample_token] = annos
        nusc_submissions = {
            "meta": self.modality,
            "results": nusc_annos,
        }

        return nusc_submissions 

    def _evaluate_single_motion(self,
                         results,
                         result_path,
                         logger=None,
                         metric='bbox',
                         result_name='pts_bbox'):
        """Evaluation for a single model in nuScenes protocol.

        Args:
            result_path (str): Path of the result file.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            metric (str): Metric name used for evaluation. Default: 'bbox'.
            result_name (str): Result name in the metric prefix.
                Default: 'pts_bbox'.

        Returns:
            dict: Dictionary of evaluation details.
        """
        from nuscenes import NuScenes
        from .evaluation.motion.motion_eval_uniad import NuScenesEval as NuScenesEvalMotion

        output_dir = result_path
        nusc = NuScenes(
            version=self.version, dataroot=self.data_root, verbose=False)
        eval_set_map = {
            'v1.0-mini': 'mini_val',
            'v1.0-trainval': 'val',
        }
        nusc_eval = NuScenesEvalMotion(
            nusc,
            config=copy.deepcopy(self.det3d_eval_configs),
            result_path=results,
            eval_set=eval_set_map[self.version],
            output_dir=output_dir,
            verbose=False,
            seconds=6)
        metrics = nusc_eval.main(render_curves=False)
        
        MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err']
        class_names = ['car', 'pedestrian']

        table = prettytable.PrettyTable()
        table.field_names = ["class names"] + MOTION_METRICS
        for class_name in class_names:
            row_data = [class_name]
            for m in MOTION_METRICS:
                row_data.append('%.4f' % metrics[f'{class_name}_{m}'])
            table.add_row(row_data)
        print_log('\n'+str(table), logger=logger)
        return metrics

    def evaluate(
        self,
        results,
        eval_mode,
        metric=None,
        logger=None,
        jsonfile_prefix=None,
        result_names=["img_bbox"],
        show=False,
        out_dir=None,
        pipeline=None,
    ):
        # res_path = "results.pkl" if "trainval" in self.version else "results_mini.pkl"
        res_path = "results.pkl" #if "trainval" in self.version else "results_mini.pkl"
        res_path = osp.join(self.work_dir, res_path)
        print('All Results write to', res_path)
        mmcv.dump(results, res_path)

        results_dict = dict()
        import pdb; pdb.set_trace()
        if eval_mode['with_det']:
            self.tracking = eval_mode["with_tracking"]
            self.tracking_threshold = eval_mode["tracking_threshold"]
            for metric in ["detection", "tracking"]:
                tracking = metric == "tracking"
                if tracking and not self.tracking:
                    continue
                result_files, tmp_dir = self.format_results(
                    results, jsonfile_prefix=self.work_dir, tracking=tracking
                )

                if isinstance(result_files, dict):
                    for name in result_names:
                        ret_dict = self._evaluate_single(
                            result_files[name], tracking=tracking
                        )
                    results_dict.update(ret_dict)
                elif isinstance(result_files, str):
                    ret_dict = self._evaluate_single(
                        result_files, tracking=tracking
                    )
                    results_dict.update(ret_dict)

                if tmp_dir is not None:
                    tmp_dir.cleanup()

        if eval_mode['with_map']:
            from .evaluation.map.vector_eval import VectorEvaluate
            self.map_evaluator = VectorEvaluate(self.eval_config)
            result_path = self.format_map_results(results, prefix=self.work_dir)
            map_results_dict = self.map_evaluator.evaluate(result_path, logger=logger)
            results_dict.update(map_results_dict)

        if eval_mode['with_motion']:
            thresh = eval_mode["motion_threshhold"]
            result_files = self.format_motion_results(results, jsonfile_prefix=self.work_dir, thresh=thresh)
            motion_results_dict = self._evaluate_single_motion(result_files, self.work_dir, logger=logger)
            results_dict.update(motion_results_dict)
        
        if eval_mode['with_planning']:
            from .evaluation.planning.planning_eval import planning_eval
            planning_results_dict = planning_eval(results, self.eval_config, logger=logger)
            results_dict.update(planning_results_dict)

        if show or out_dir:
            self.show(results, save_dir=out_dir, show=show, pipeline=pipeline)
        
        # print main metrics for recording
        metric_str = '\n'
        if "img_bbox_NuScenes/NDS" in results_dict:
            metric_str += f'mAP: {results_dict.get("img_bbox_NuScenes/mAP"):.4f}\n'
            metric_str += f'mATE: {results_dict.get("img_bbox_NuScenes/mATE"):.4f}\n'
            metric_str += f'mASE: {results_dict.get("img_bbox_NuScenes/mASE"):.4f}\n'
            metric_str += f'mAOE: {results_dict.get("img_bbox_NuScenes/mAOE"):.4f}\n' 
            metric_str += f'mAVE: {results_dict.get("img_bbox_NuScenes/mAVE"):.4f}\n' 
            metric_str += f'mAAE: {results_dict.get("img_bbox_NuScenes/mAAE"):.4f}\n' 
            metric_str += f'NDS: {results_dict.get("img_bbox_NuScenes/NDS"):.4f}\n\n'
        
        if "img_bbox_NuScenes/amota" in results_dict:
            metric_str += f'AMOTA: {results_dict["img_bbox_NuScenes/amota"]:.4f}\n' 
            metric_str += f'AMOTP: {results_dict["img_bbox_NuScenes/amotp"]:.4f}\n' 
            metric_str += f'RECALL: {results_dict["img_bbox_NuScenes/recall"]:.4f}\n' 
            metric_str += f'MOTAR: {results_dict["img_bbox_NuScenes/motar"]:.4f}\n' 
            metric_str += f'MOTA: {results_dict["img_bbox_NuScenes/mota"]:.4f}\n' 
            metric_str += f'MOTP: {results_dict["img_bbox_NuScenes/motp"]:.4f}\n' 
            metric_str += f'IDS: {results_dict["img_bbox_NuScenes/ids"]}\n\n' 
        
        if "mAP_normal" in results_dict:
            metric_str += f'ped_crossing= {results_dict["ped_crossing"]:.4f}\n' 
            metric_str += f'divider= {results_dict["divider"]:.4f}\n' 
            metric_str += f'boundary= {results_dict["boundary"]:.4f}\n' 
            metric_str += f'mAP_normal= {results_dict["mAP_normal"]:.4f}\n\n' 

        if "car_EPA" in results_dict:
            metric_str += f'Car / Ped\n' 
            metric_str += f'epa= {results_dict["car_EPA"]:.4f} / {results_dict["pedestrian_EPA"]:.4f}\n'
            metric_str += f'ade= {results_dict["car_min_ade_err"]:.4f} / {results_dict["pedestrian_min_ade_err"]:.4f}\n'
            metric_str += f'fde= {results_dict["car_min_fde_err"]:.4f} / {results_dict["pedestrian_min_fde_err"]:.4f}\n'
            metric_str += f'mr= {results_dict["car_miss_rate_err"]:.4f} / {results_dict["pedestrian_miss_rate_err"]:.4f}\n\n' 

        if "L2" in results_dict:
            metric_str += f'obj_box_col: {(results_dict["obj_box_col"]*100):.3f}%\n'
            metric_str += f'L2: {results_dict["L2"]:.4f}\n\n'
        
        print_log(metric_str, logger=logger)
        return results_dict


def output_to_nusc_box(detection, threshold=None):
    box3d = detection["boxes_3d"]
    scores = detection["scores_3d"].numpy()
    labels = detection["labels_3d"].numpy()
    if "instance_ids" in detection:
        ids = detection["instance_ids"]  # .numpy()
    if threshold is not None:
        if "cls_scores" in detection:
            mask = detection["cls_scores"].numpy() >= threshold
        else:
            mask = scores >= threshold
        box3d = box3d[mask]
        scores = scores[mask]
        labels = labels[mask]
        ids = ids[mask]

    if hasattr(box3d, "gravity_center"):
        box_gravity_center = box3d.gravity_center.numpy()
        box_dims = box3d.dims.numpy()
        nus_box_dims = box_dims[:, [1, 0, 2]]
        box_yaw = box3d.yaw.numpy()
    else:
        box3d = box3d.numpy()
        box_gravity_center = box3d[..., :3].copy()
        box_dims = box3d[..., 3:6].copy()
        nus_box_dims = box_dims[..., [1, 0, 2]]
        box_yaw = box3d[..., 6].copy()

    # TODO: check whether this is necessary
    # with dir_offset & dir_limit in the head
    # box_yaw = -box_yaw - np.pi / 2

    box_list = []
    for i in range(len(box3d)):
        quat = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw[i])
        if hasattr(box3d, "gravity_center"):
            velocity = (*box3d.tensor[i, 7:9], 0.0)
        else:
            velocity = (*box3d[i, 7:9], 0.0)
        box = NuScenesBox(
            box_gravity_center[i],
            nus_box_dims[i],
            quat,
            label=labels[i],
            score=scores[i],
            velocity=velocity,
        )
        if "instance_ids" in detection:
            box.token = ids[i]
        box_list.append(box)
    return box_list


def lidar_nusc_box_to_global(
    info,
    boxes,
    classes,
    eval_configs,
    eval_version="detection_cvpr_2019",
    filter_with_cls_range=True,
):
    box_list = []
    for i, box in enumerate(boxes):
        # Move box to ego vehicle coord system
        box.rotate(pyquaternion.Quaternion(info["lidar2ego_rotation"]))
        box.translate(np.array(info["lidar2ego_translation"]))
        # filter det in ego.
        if filter_with_cls_range:
            cls_range_map = eval_configs.class_range
            radius = np.linalg.norm(box.center[:2], 2)
            det_range = cls_range_map[classes[box.label]]
            if radius > det_range:
                continue
        # Move box to global coord system
        box.rotate(pyquaternion.Quaternion(info["ego2global_rotation"]))
        box.translate(np.array(info["ego2global_translation"]))
        box_list.append(box)
    return box_list


def get_T_global(info):
    lidar2ego = np.eye(4)
    lidar2ego[:3, :3] = pyquaternion.Quaternion(
        info["lidar2ego_rotation"]
    ).rotation_matrix
    lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"])
    ego2global = np.eye(4)
    ego2global[:3, :3] = pyquaternion.Quaternion(
        info["ego2global_rotation"]
    ).rotation_matrix
    ego2global[:3, 3] = np.array(info["ego2global_translation"])
    return ego2global @ lidar2ego


