# OpenPCDet PyTorch Dataloader and Evaluation Tools for Waymo Open Dataset
# Reference https://github.com/open-mmlab/OpenPCDet
# Written by Shaoshuai Shi, Chaoxu Guo
# All Rights Reserved 2019-2020.

import os
import pickle
import copy
import numpy as np
import torch
import multiprocessing
import SharedArray
import torch.distributed as dist
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict

from sklearn.neighbors import NearestNeighbors as NN
import torchmetrics
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils, polar_utils
from ..dataset import DatasetTemplate
from pcdet.utils.box_utils import boxes_to_corners_3d
from ...utils.vis_utils import write_ply_color

from scipy.sparse.csgraph import connected_components
from scipy.sparse import csr_matrix
from torch_cluster import radius_graph
from sklearn.neighbors import NearestNeighbors as NN

class WaymoDataset(DatasetTemplate):
    def __init__(self, dataset_cfg, training=True, root_path=None, logger=None):
        super().__init__(
            dataset_cfg=dataset_cfg, training=training, root_path=root_path, logger=logger
        )
        self.dataset_name =  self.dataset_cfg.DATASET
        self.processed_data_tag = self.dataset_cfg.PROCESSED_DATA_TAG

        self.data_path = self.root_path / self.dataset_cfg.PROCESSED_DATA_TAG
        self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
        split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
        self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]
        self.num_sweeps = self.dataset_cfg.get('NUM_SWEEPS', 1) 
        self.sweep_interval = self.dataset_cfg.get('SWEEP_INTERVAL', 1)
        
        self.set_frames = self.dataset_cfg.get('SET_FRAMES', False)
        self.sampled_frames = self.dataset_cfg.get('SAMPLED_FRAMES', [0])
        if self.set_frames:
            self.num_sweeps = len(self.sampled_frames)  

        self.limit_points_per_frame = self.dataset_cfg.get('LIMIT_POINTS_PER_FRAME', False)
        self.max_num_per_frame = self.dataset_cfg.get('MAX_NUM_POINTS_PER_FRAME', 1000000)

        self.max_num_points *= self.num_sweeps
        self._merge_all_iters_to_one_epoch = dataset_cfg.get("MERGE_ALL_ITERS_TO_ONE_EPOCH", False)
        self.more_cls5 = self.segmentation_cfg.get('MORE_CLS5', False)
        self.use_spherical_resampling = self.dataset_cfg.get("SPHERICAL_RESAMPLING", False)
        self.ignore_index = dataset_cfg.get("IGNORE_INDEX", [0])
        self.extend_offsets = dataset_cfg.get("EXTEND_OFFSETS", [])
        self.back_ins = dataset_cfg.get("BACK_INS", False)

        self.drop_points_by_seg_len = dataset_cfg.get("DROP_POINTS_BY_SEG_LEN", False)
        filter_foreground = dataset_cfg.get("FILTER_FOREGROUND", None)
        self.filter_foreground = (filter_foreground is not None)
        if filter_foreground is not None:
            self.pred_label_path = filter_foreground["PRED_LABEL_PATH"]
            self.filter_class = filter_foreground["FILTER_CLASS"]
        self.filtered_class = dataset_cfg.get("FILTERED_CLASS", None)

        self.infos = []
        self.include_waymo_data(self.mode)
        if self.num_sweeps > 1:
            logger.info(f"Sequence Dataset: {self.num_sweeps} sweeps")

        if len(self.extend_offsets) > 0:
            new_infos = []
            for info in self.infos:
                g_sample_idx = info['point_cloud']['sample_idx']
                sequence_id = info['point_cloud']['lidar_sequence']
                for offset in self.extend_offsets:
                    if (sequence_id, g_sample_idx-offset) in self.info_pool:
                        new_infos.append(self.info_pool[(sequence_id, g_sample_idx-offset)])
                    if (sequence_id, g_sample_idx-offset) in self.info_pool:
                        new_infos.append(self.info_pool[(sequence_id, g_sample_idx-offset)])
                new_infos.append(info)
            self.infos = new_infos

        if 'REPEAT' in dataset_cfg:
            repeat = dataset_cfg.get("REPEAT", 1)
            new_infos = []
            for i in range(repeat):
                new_infos += self.infos
            self.infos = new_infos
            logger.info(f"Repeating data by {repeat} times")
            
        self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
        if self.use_shared_memory:
            self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
            self.load_data_to_shared_memory()


    def set_split(self, split):
        super().__init__(
            dataset_cfg=self.dataset_cfg, training=self.training,
            root_path=self.root_path, logger=self.logger
        )
        self.split = split
        if self.logger is not None:
            self.logger.info(f"{self.__class__} Dataset switched to {self.mode} mode, split={self.split}")
        split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
        self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]
        self.infos = []
        self.include_waymo_data(self.mode)

    def spherical_resampling(self, point_wise_dict, config={}):
        max_h = 64
        max_w = 2650
        point_xyz = point_wise_dict['point_xyz']
        point_feat = point_wise_dict['point_feat']
        point_rimage_h = point_wise_dict.pop('point_rimage_h')
        offset = 0
        new_point_wise_dict = dict(
            point_xyz = [],
            point_feat = [],
        )
        for h in range(max_h):
            mask_h = np.where(point_rimage_h == h)[0]
            num_points = mask_h.shape[0]
            if num_points == 0:
                continue
            point_xyz_h = point_xyz[mask_h]
            point_feat_h = point_feat[mask_h]
            new_point_wise_dict['point_xyz'].append(point_xyz_h)
            new_point_wise_dict['point_feat'].append(point_feat_h)
            if num_points < 10:
                continue
            r, polar, azimuth = polar_utils.cartesian2spherical_np(point_xyz_h)
            prange = np.linalg.norm(point_xyz_h, ord=2, axis=-1)
            tree = NN(n_neighbors=10).fit(point_xyz_h)
            dists, e1 = tree.kneighbors(point_xyz_h)
            e0 = np.arange(num_points)[:, np.newaxis]
            azimuth_diff = azimuth[e0] - azimuth[e1]
            azimuth_diff[azimuth_diff < 1e-6] = 1e10
            nn_index = azimuth_diff.argmin(axis=-1)
            e0 = e0[:, 0]
            dists = dists[(e0, nn_index)]
            e1 = e1[(e0, nn_index)]
            
            mask = dists < 0.3
            e0, e1, dists = e0[mask], e1[mask], dists[mask]

            num_samples_per_edge = np.ceil((dists+1e-6) / 0.1) + 1
            max_sample_per_edge = int(num_samples_per_edge.max())
        
            for sample_idx in range(max_sample_per_edge):
                edge_mask = sample_idx <= num_samples_per_edge - 1
                ratio = (sample_idx / (num_samples_per_edge-1))
                edge_mask = edge_mask & (ratio > 1e-6) & (ratio < 1 - 1e-6)
                if edge_mask.any():
                    ratio = ratio[edge_mask, np.newaxis]
                    new_xyz = point_xyz_h[e0[edge_mask]] * ratio + point_xyz_h[e1[edge_mask]] * (1.0-ratio)
                    new_feat = point_feat_h[e0[edge_mask]] * ratio + point_feat_h[e1[edge_mask]] * (1.0-ratio)
                    new_point_wise_dict['point_xyz'].append(new_xyz)
                    new_point_wise_dict['point_feat'].append(new_feat)

        for key in new_point_wise_dict.keys():
            new_point_wise_dict[key] = np.concatenate(new_point_wise_dict[key], axis=0)
            if new_point_wise_dict[key].dtype == np.float64:               new_point_wise_dict[key] = new_point_wise_dict[key].astype(np.float32)
        
        tree = NN(n_neighbors=1).fit(point_xyz)
        dists, indices = tree.kneighbors(new_point_wise_dict['point_xyz'])
        indices = indices[:, 0]

        for key in point_wise_dict.keys():
            if key not in new_point_wise_dict:
                new_point_wise_dict[key] = point_wise_dict[key][indices]

        return new_point_wise_dict

    def include_waymo_data(self, mode):
        self.logger.info('Loading Waymo dataset')
        waymo_infos = []

        num_skipped_infos = 0
        for k in range(0, len(self.sample_sequence_list), self.dataset_cfg.SAMPLED_SEQUENCE_INTERVAL[mode]):
            sequence_name = os.path.splitext(self.sample_sequence_list[k])[0]
            info_path = self.data_path / sequence_name / ('%s.pkl' % sequence_name)
            info_path = self.check_sequence_name_with_all_version(info_path)
            if not info_path.exists():
                num_skipped_infos += 1
                continue
            with open(info_path, 'rb') as f:
                infos = pickle.load(f)
                waymo_infos.extend(infos)

        sampled_idx_range = self.dataset_cfg.get("SAMPLED_IDX_RANGE", None)
        if sampled_idx_range is not None:
            new_infos = []
            min_idx, max_idx = sampled_idx_range[mode]
            for info in waymo_infos:
                sample_idx = info['point_cloud']['sample_idx'] 
                if min_idx <= sample_idx < max_idx:
                    new_infos.append(info)
            self.logger.info('Sample Index Range Filtering for Waymo dataset: %d -> %d' % (len(waymo_infos), len(new_infos)))
            waymo_infos = new_infos

        self.infos.extend(waymo_infos[:])
        self.logger.info('Total skipped info %s' % num_skipped_infos)
        self.logger.info('Total samples for Waymo dataset: %d' % (len(waymo_infos)))
        
        self.info_pool = {}
        for index, info in enumerate(self.infos):
            pc_info = info['point_cloud']
            sequence_id = pc_info['lidar_sequence']
            sample_idx = pc_info['sample_idx']
            self.info_pool[(sequence_id, sample_idx)] = info
        
        #with open('/home/ubuntu/projects/PCPerception/data/waymo/seq_nums.pkl', 'rb') as fseq:
        #    seq_num_dict = pickle.load(fseq)
        
        if  self.use_only_samples_with_seg_labels and self.split=='test':
            testing_set_frame_file = '/home/ubuntu/projects/PCPerception/data/waymo/ImageSets/3d_semseg_test_set_frames.txt'
            test_names = [x.rstrip() for x in (open(testing_set_frame_file, 'r').readlines())]
            new_infos = [info for info in self.infos if ("%s,%s"%(info['metadata']['context_name'], info['metadata']['timestamp_micros']) in test_names)]
            self.logger.info(f'Getting samples with valid test timestamps {len(self.infos)} -> {len(new_infos)}')
            self.infos = new_infos 

        if self.use_only_samples_with_seg_labels and self.split!='test':
            new_infos = [info for info in self.infos if info['annos'].get('seg_label_path', None) is not None]
            new_infos = [info for info in new_infos if '_propseg.npy' not in info['annos'].get('seg_label_path', None)]
            
            if not self.set_frames: 
                new_infos = [info for info in new_infos if (info['point_cloud']['sample_idx'] >= (self.num_sweeps - 1)*self.sweep_interval)]
            else:
                new_infos = [info for info in new_infos if (info['point_cloud']['sample_idx'] + np.array(self.sampled_frames).min() >=0 and \
                        info['point_cloud']['sample_idx'] + np.array(self.sampled_frames).max() < seq_num_dict[info['frame_id'][:-4]])]
            

            self.logger.info(f'Dropping samples without segmentation labels {len(self.infos)} -> {len(new_infos)}')
            
            #self.refinement_data_dir = 'dump_first_stage/tmp_ohem_weight_30'
            #new_infos = [info for info in new_infos if os.path.exists(os.path.join(self.refinement_data_dir, \
            #    info['frame_id'][:-4], info['frame_id'][-3:]+'.pkl'))]

            #self.logger.info(f'Dropping samples without refinement dump path -> {len(new_infos)}')

            new_infos = [info for info in new_infos if info['annos']]
            
            if self.filtered_class:
                new_infos = [info for info in new_infos if self.filtered_class in info['annos']['name']]
                print('Filter frames for class %s'%(self.filtered_class))

            self.infos = new_infos

        if self.more_cls5 and self.training:
            with open('../data/waymo/cls5.txt', 'r') as fin:
                frame_ids = [line.strip() for line in fin.readlines()]
            new_infos = [info for info in self.infos if info['frame_id'] in frame_ids]

        if self.dataset_cfg.SAMPLED_INTERVAL[mode] > 1:
            sampled_waymo_infos = []
            for k in range(0, len(self.infos), self.dataset_cfg.SAMPLED_INTERVAL[mode]):
                sampled_waymo_infos.append(self.infos[k])
            self.infos = sampled_waymo_infos
            self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))
        
        if self.more_cls5 and self.training:
            self.logger.info(f'repeating {len(new_infos)} scenes for cls 5')
            self.infos += new_infos

    def load_data_to_shared_memory(self):
        self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')

        cur_rank, num_gpus = common_utils.get_dist_info()
        all_infos = self.infos[:self.shared_memory_file_limit] \
            if self.shared_memory_file_limit < len(self.infos) else self.infos
        cur_infos = all_infos[cur_rank::num_gpus]
        for info in cur_infos:
            pc_info = info['point_cloud']
            sequence_name = pc_info['lidar_sequence']
            sample_idx = pc_info['sample_idx']

            sa_key = f'{sequence_name}___{sample_idx}'
            if not os.path.exists(f"/dev/shm/{sa_key}"):
                points = self.get_lidar(sequence_name, sample_idx)
                common_utils.sa_create(f"shm://{sa_key}", points)
            
            sa_key = f'{sequence_name}___seglabel___{sample_idx}'
            if not os.path.exists(f"/dev/shm/{sa_key}"):
                seg_labels = self.get_seg_label(sequence_name, sample_idx)
                common_utils.sa_create(f"shm://{sa_key}", seg_labels)

        dist.barrier()
        self.logger.info('Training data has been saved to shared memory')

    def clean_shared_memory(self):
        self.logger.info(f'Clean training data from shared memory (file limit={self.shared_memory_file_limit})')

        cur_rank, num_gpus = common_utils.get_dist_info()
        all_infos = self.infos[:self.shared_memory_file_limit] \
            if self.shared_memory_file_limit < len(self.infos) else self.infos
        cur_infos = all_infos[cur_rank::num_gpus]
        for info in cur_infos:
            pc_info = info['point_cloud']
            sequence_name = pc_info['lidar_sequence']
            sample_idx = pc_info['sample_idx']

            sa_key = f'{sequence_name}___{sample_idx}'
            if os.path.exists(f"/dev/shm/{sa_key}"):
                SharedArray.delete(f"shm://{sa_key}")
            
            sa_key = f'{sequence_name}___seglabel___{sample_idx}'
            if os.path.exists(f"/dev/shm/{sa_key}"):
                SharedArray.delete(f"shm://{sa_key}")

        if num_gpus > 1:
            dist.barrier()
        self.logger.info('Training data has been deleted from shared memory')

    @staticmethod
    def check_sequence_name_with_all_version(sequence_file):
        if not sequence_file.exists():
            found_sequence_file = sequence_file
            for pre_text in ['training', 'validation', 'testing']:
                if not sequence_file.exists():
                    temp_sequence_file = Path(str(sequence_file).replace('segment', pre_text + '_segment'))
                    if temp_sequence_file.exists():
                        found_sequence_file = temp_sequence_file
                        break
            if not found_sequence_file.exists():
                found_sequence_file = Path(str(sequence_file).replace('_with_camera_labels', ''))
            if found_sequence_file.exists():
                sequence_file = found_sequence_file
        return sequence_file

    def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1):
        from functools import partial
        from . import waymo_utils
        print('---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
              % (sampled_interval, len(self.sample_sequence_list)))

        process_single_sequence = partial(
            waymo_utils.process_single_sequence,
            save_path=save_path, sampled_interval=sampled_interval, has_label=has_label
        )
        sample_sequence_file_list = [
            self.check_sequence_name_with_all_version(raw_data_path / sequence_file)
            for sequence_file in self.sample_sequence_list
        ]

        with multiprocessing.Pool(num_workers) as p:
            sequence_infos = list(tqdm(p.imap(process_single_sequence, sample_sequence_file_list),
                                       total=len(sample_sequence_file_list)))

        all_sequences_infos = [item for infos in sequence_infos for item in infos]
        return all_sequences_infos

    def get_lidar(self, sequence_name, sample_idx):
        lidar_file = self.data_path / sequence_name / ('%04d.npy' % sample_idx)
        point_features = np.load(lidar_file)  # (N, 8): [x, y, z, intensity, elongation, range, rimage_w, rimage_h]

        points_all = point_features[:, [0,1,2,3,4,5,6,7]] # [x, y, z, intensity, elongation, range, rimage_w, rimage_h]
        points_all[:, 3] = np.tanh(points_all[:, 3])
        points_all[:, 5] /= 75.
        points_all[:, 7] *= 64
        points_all[:, 6] *= 2650
        return points_all

    def get_seg_label(self, sequence_name, sample_idx, ignore=False):
        seg_file = str(self.data_path / sequence_name / ('%04d_seg.npy' % sample_idx))
        if not os.path.exists(seg_file):
            seg_file = seg_file.replace('_seg.npy', '_propseg.npy')
        
        seg_labels = np.load(seg_file)  # (N, 2): [instance_label, segmentation_label]
        if ('_propseg.npy' in seg_file) and self.ignore_prop_labels:
            seg_labels = np.zeros_like(seg_labels)
        
        if ignore: 
            seg_labels = np.zeros_like(seg_labels)

        return seg_labels

    def __len__(self):
        if self._merge_all_iters_to_one_epoch:
            return len(self.infos) * self.total_epochs

        return len(self.infos)

    def load_data(self, info, ignore=False):
        pc_info = info['point_cloud']
        sequence_name = pc_info['lidar_sequence']
        sample_idx = pc_info['sample_idx']
        
        if self.use_shared_memory: # and ((self.shared_memory_file_limit < 0) or (index < self.shared_memory_file_limit)):
            assert False
            sa_key = f'{sequence_name}___{sample_idx}'
            points = SharedArray.attach(f"shm://{sa_key}").copy()
            if self.use_only_samples_with_seg_labels:
                sa_key = f'{sequence_name}___seglabel___{sample_idx}'
                seg_labels = SharedArray.attach(f"shm://{sa_key}").copy()
        else:
            points = self.get_lidar(sequence_name, sample_idx)
            if self.load_seg:
                if self.split=='test':
                    seg_labels = np.zeros((info['num_points_of_each_lidar'][0], 2), np.int32)
                else:
                    seg_labels = self.get_seg_label(sequence_name, sample_idx, ignore=ignore)

        points = points.astype(np.float32)

        point_wise_dict = dict(
            point_xyz=points[:, :3],
            point_feat=points[:, 3:-2],
            point_rimage_h=points[:,-1].astype(np.int64),
        )
        if self.load_seg:
            point_wise_dict['segmentation_label'] = seg_labels[:, 1]
            point_wise_dict['instance_label'] = seg_labels[:, 0]
        
        ins_back_file = str(self.data_path / sequence_name / ('%04d_ins_back_13.npy' % sample_idx)) 
        #if 0:
        if os.path.exists(ins_back_file):
            try:
                instance_label_back = np.load(ins_back_file)
                point_wise_dict['instance_label_back'] = instance_label_back 
            except:
                pass 

        if self.drop_points_by_lidar_index is not None:
            num_points_of_each_lidar = info['num_points_of_each_lidar']
            offset = 0
            lidar_point_index_list = []
            for i, num_points in enumerate(num_points_of_each_lidar):
                if i not in self.drop_points_by_lidar_index:
                    lidar_point_index = np.arange(offset, offset+num_points)
                    lidar_point_index_list.append(lidar_point_index)
            lidar_point_indices = np.concatenate(lidar_point_index_list, axis=0)
            point_wise_dict = common_utils.filter_dict(point_wise_dict, lidar_point_indices)

        if self.drop_points_by_seg_len:
            seg_len = np.arange(point_wise_dict['segmentation_label'].shape[0])
            point_wise_dict = common_utils.filter_dict(point_wise_dict, seg_len)
        

        scene_wise_dict = dict(
            frame_id=info['frame_id'],
            pose=info['pose'].reshape(4, 4),
            timestamp=str(info['metadata']['timestamp_micros']),
            context_name=info['metadata']['context_name']
        )
        
        if 'top_lidar_pose' in info['metadata']:
            top_lidar_pose = info['metadata']['top_lidar_pose'][4].reshape(4, 4)
            top_lidar_origin = top_lidar_pose[:3, 3]
            scene_wise_dict['top_lidar_origin'] = top_lidar_origin

        if 'annos' in info:
            annos = info['annos']
            annos = common_utils.drop_info_with_name(annos, name='unknown')
            annos = common_utils.drop_info_with_name(annos, name='Sign')

            if self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False):
                gt_boxes_lidar = box_utils.boxes3d_kitti_fakelidar_to_lidar(annos['gt_boxes_lidar'])
            else:
                gt_boxes_lidar = annos['gt_boxes_lidar']

            if self.training and self.dataset_cfg.get('FILTER_EMPTY_BOXES_FOR_TRAIN', False):
                mask = (annos['num_points_in_gt'] > 0)  # filter empty boxes
                annos['name'] = annos['name'][mask]
                gt_boxes_lidar = gt_boxes_lidar[mask]
                annos['num_points_in_gt'] = annos['num_points_in_gt'][mask]
                annos['obj_ids'] = annos['obj_ids'][mask]
                if 'transform' in annos:
                    annos['transform'] = annos['transform'][mask]
                
            point_masks = roiaware_pool3d_utils.points_in_boxes_cpu(point_wise_dict['point_xyz'][:, :3],
                                                                    gt_boxes_lidar).sum(0)
            point_wise_dict['is_foreground'] = point_masks > 0
            object_wise_dict = dict(
                gt_box_cls_label=annos['name'].astype(str),
                gt_box_attr=gt_boxes_lidar,
                augmented=np.zeros(annos['name'].shape[0], dtype=bool),
                obj_ids=annos['obj_ids'],
                num_points_in_gt=annos['num_points_in_gt'],
            )
        else:
            object_wise_dict = {}

        if self.use_spherical_resampling:
            point_wise_dict = self.spherical_resampling(point_wise_dict)
        else:
            point_wise_dict.pop('point_rimage_h')
        
        if self.filter_foreground:
            pred_label = np.load(f'{self.pred_label_path}/{sequence_name}/{sample_idx:03d}_pred.npy').astype(np.int64)
            gt_label = point_wise_dict['segmentation_label'].astype(np.int64)
            mask = np.ones(pred_label.shape[0], dtype=bool)
            for cls in self.filter_class:
                mask[pred_label == cls] = 0
            gt_mask = np.ones(gt_label.shape[0], dtype=bool)
            for cls in self.filter_class:
                gt_mask[gt_label == cls] = 0
            point_wise_dict = common_utils.filter_dict(point_wise_dict, mask)
        
        
        input_dict=dict(
            point_wise=point_wise_dict,
            scene_wise=scene_wise_dict,
            object_wise=object_wise_dict,
        )

        return input_dict

    def __getitem__(self, index, sweeping=False, mix3d=False, ins_cur=None):
        if self._merge_all_iters_to_one_epoch:
            index = index % len(self.infos)

        info = copy.deepcopy(self.infos[index])
        input_dict = self.load_data(info)
        pc_info = info['point_cloud']
        sequence_name = pc_info['lidar_sequence']
        sample_idx = pc_info['sample_idx']
        timestamp = info['metadata']['timestamp_micros']
        context_name = info['metadata']['context_name']
        
        if self.back_ins  and not 'instance_label_back' in input_dict['point_wise'].keys():
            ins_back_file = str(self.data_path / sequence_name / ('%04d_ins_back_13.npy' % sample_idx)) 
        
            # generate background instance label 
            # 8: sign, 9: traffic light, 10: pole, 11: cone, 12: bicyle, 13: motorcycle, 14: building, 15: vegetation, 16: treetruck 
            instance_label_back = np.zeros(input_dict['point_wise']['instance_label'].shape, np.int32)
            seg_label = input_dict['point_wise']['segmentation_label']
	
            points = input_dict['point_wise']['point_xyz'] 
            cur_num = instance_label_back.max() + 1 
            #back_cls_list = [8, 9, 10, 11, 12, 13]
            back_cls_list = [1,2,3,4,5,6,7,8,9,10,11,12,13]
            for ii in back_cls_list:
                seg_mask = seg_label==ii
                seg_ids = (seg_mask> 0).nonzero()[0]

                if len(seg_ids)<10:
                    continue
                masked_points = points[seg_ids, :]
                num_points= masked_points.shape[0]
                
                tree = NN(n_neighbors=10).fit(masked_points)
                dists, indices = tree.kneighbors(masked_points)
                e0 = np.arange(num_points).repeat(10)
                e1 = indices.reshape(-1)
                mask = dists.reshape(-1) < 1.0
                #mask = dists.reshape(-1) / (prange + 1e-6) < dist_th
                e0, e1 = e0[mask], e1[mask]
                
                graph = csr_matrix((np.ones_like(e0), (e0, e1)), shape=(num_points, num_points))
                n_components, labels = connected_components(graph, directed=False)
                #print('num_points', ii, np.unique(labels), x.shape, n_components)
                for lb in np.unique(labels):
                    ids = seg_ids[(labels==lb).nonzero()[0]]
                    if len(ids)<10:
                        continue 
                    instance_label_back[ids] = cur_num 
                    cur_num += 1 
            input_dict['point_wise']['instance_label_back'] =  instance_label_back 

        if not 'instance_label_back' in input_dict['point_wise'].keys():
            input_dict['point_wise']['instance_label_back'] = input_dict['point_wise']['instance_label'].copy()
        
        cur_sample_idx = info['point_cloud']['sample_idx']
        lidar_sequence = info['point_cloud']['lidar_sequence']
        data_dicts = [input_dict]
        num_obj = input_dict['object_wise']['gt_box_attr'].shape[0]

        input_dict = dict(
            point_wise=dict(common_utils.concat_dicts([dd['point_wise'] for dd in data_dicts])),
            object_wise=dict(common_utils.stack_dicts([dd['object_wise'] for dd in data_dicts],
                                                 pad_to_size=num_obj)),
            scene_wise=dict(common_utils.stack_dicts([dd['scene_wise'] for dd in data_dicts])),
        )
        
        for key, val in input_dict['object_wise'].items():
            input_dict['object_wise'][key] = val.reshape(self.num_sweeps*num_obj, *val.shape[2:])

        data_dict = self.prepare_data(data_dict=input_dict)
        
        data_dict['scene_wise']['num_sweeps'] = self.num_sweeps
        return data_dict

    def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None):
        """
        Args:
            batch_dict:
                frame_id:
            pred_dicts: list of pred_dicts
                pred_boxes: (N, 7), Tensor
                pred_scores: (N), Tensor
                pred_labels: (N), Tensor
            class_names:
            output_path:
        Returns:
        """

        def generate_single_sample_dict(cur_dict, output_path=None):
            frame_id = str(cur_dict['scene_wise']['frame_id'].reshape(-1)[-1])
            sequence_id, sample_idx = frame_id[:-4], int(frame_id[-3:])
            timestamp = str(cur_dict['scene_wise']['timestamp'].reshape(-1)[-1])
            context_name = str(cur_dict['scene_wise']['context_name'].reshape(-1)[-1])
            pred_dict = dict(
                object_wise=dict(),
                point_wise=dict(),
                scene_wise=dict(
                    frame_id=frame_id,
                    sequence_id=sequence_id,
                    sample_idx=sample_idx,
                    timestamp=timestamp,
                    context_name=context_name,
                ),
                voxel_wise=dict(),
            )

            if 'pred_box_attr' in cur_dict['object_wise']:
                object_wise_dict = cur_dict['object_wise']

                pred_scores = object_wise_dict['pred_box_scores'].cpu().numpy()
                pred_boxes = object_wise_dict['pred_box_attr'].cpu().numpy()
                pred_labels = object_wise_dict['pred_box_cls_label'].cpu().numpy()

                pred_dict['object_wise'].update(
                    dict(
                        box_scores=pred_scores,
                        box_attr=pred_boxes,
                        box_name=np.array(class_names)[pred_labels - 1],
                    ))

            if 'pred_segmentation_label' in cur_dict['point_wise']:

                point_wise_dict = cur_dict['point_wise']

                # propagate segmentation from predicted points to gt points
                point_xyz = np.load(f'../data/waymo/{self.processed_data_tag}/{sequence_id}/{sample_idx:04d}.npy')

                path = f'../data/waymo/waymo_processed_data_v0_5_0/{sequence_id}/{sample_idx:04d}_seg.npy'
                if not os.path.exists(path):
                    path = f'../data/waymo/waymo_processed_data_v0_5_0/{sequence_id}/{sample_idx:04d}_propseg.npy'
                
                if not self.split=='test':
                    segmentation_label = np.load(path)[:, 1]
                    point_xyz = point_xyz[:segmentation_label.shape[0], :3]
                else:
                    segmentation_label = np.zeros(point_xyz.shape[0])
                
                tree = NN(n_neighbors=1).fit(point_wise_dict['point_xyz'].detach().cpu().numpy())
                dists, indices = tree.kneighbors(point_xyz[:, :3])
                pred_segmentation_label = point_wise_dict['pred_segmentation_label'].detach().cpu().numpy()[indices[:, 0]]

                pred_dict['point_wise'].update(
                    dict(
                        pred_segmentation_label=pred_segmentation_label,
                        point_xyz = point_xyz,
                    ))
                
                if 'instance_label_back' in point_wise_dict.keys():
                    instance_label_back = point_wise_dict['instance_label_back'].cpu().numpy()[indices[:, 0]]
                    pred_dict['point_wise'].update(
                        dict(
                            instance_label_back=instance_label_back,
                    ))
                
                if 'gt_binary_label' in point_wise_dict.keys():
                    gt_binary_label = point_wise_dict['gt_binary_label'][indices[:, 0]]
                    pred_binary_logits = point_wise_dict['pred_binary_logits'][indices[:, 0]]
                    
                    n_hist = 100
                    binary_hist = torch.zeros(n_hist)
                    binary_label_hist = torch.zeros(n_hist)

                    gt_binary_label = point_wise_dict['gt_binary_label'][indices[:, 0]].cpu()
                    pred_binary_logits = point_wise_dict['pred_binary_logits'][indices[:, 0]].cpu()
                    interval = float(1/n_hist)

                    for ii in range(n_hist):
                        if ii<n_hist-1:

                            mask_ii = (pred_binary_logits>=ii*interval) & (pred_binary_logits<(ii+1)*interval)
                        else:
                            mask_ii = (pred_binary_logits>=ii*interval)
                        if mask_ii.sum()==0:
                            continue
                        binary_hist[ii] +=  mask_ii.sum()
                        binary_label_hist[ii] += gt_binary_label[mask_ii].sum()


                    if 0:
                    # if gt_binary_label.max()>0:

                        ap_eval = torchmetrics.AveragePrecision(task="binary")
                        ap = ap_eval(pred_binary_logits, gt_binary_label).detach().cpu()
                        auroc_eval = torchmetrics.AUROC(task="binary")    
                        auroc = auroc_eval(pred_binary_logits, gt_binary_label).detach().cpu()
                        print(ap,'$$', auroc) 
                    else:
                        ap = -1 
                        auroc = -1 

                if self.filter_foreground:
                    pred_label = np.load(f'{self.pred_label_path}/{sequence_id}/{sample_idx:03d}_pred.npy').astype(np.int64)
                    mask = np.zeros(pred_label.shape[0], dtype=bool)
                    for cls in self.filter_class:
                        mask[pred_label == cls] = 1
                    if self.split!='test': 
                        pred_segmentation_label[mask] = pred_label[mask]

                for ignore_index in self.ignore_index:
                    if self.split!='test':
                        pred_segmentation_label[segmentation_label == ignore_index] = ignore_index

                point_wise_dict['pred_segmentation_label'] = torch.from_numpy(pred_segmentation_label)

                # compute statistics
                ups = torch.zeros(23, dtype=torch.long)
                downs = torch.zeros(23, dtype=torch.long)
                for i in range(23):
                    ups[i] = ((segmentation_label == i) & (pred_segmentation_label == i)).sum()
                    downs[i] = ((segmentation_label == i) | (pred_segmentation_label == i)).sum()
                
                # build class 5 statistics 
                pred_distribution = torch.zeros(23, dtype=torch.long) # pred=5, gt labels 
                gt_distribution = torch.zeros(23, dtype=torch.long) # gt=5, pred labels 
                
                point_number_stats = torch.zeros(23, dtype=torch.long)
                ins_number_stats = torch.zeros(23, dtype=torch.long) 
                if 'ins_cls_label' in cur_dict['scene_wise'].keys():
                    for i in range(23):
                        point_number_stats[i] = (segmentation_label==i).sum()
                        ins_number_stats[i] = (cur_dict['scene_wise']['ins_cls_label']==(i-1)).int().sum()

        
                gt_labels_5 = segmentation_label[pred_segmentation_label == 5] 
                pred_labels_5 = pred_segmentation_label[segmentation_label==5]

                for i in range(23):
                    pred_distribution[i] = (pred_labels_5==i).sum() 
                    gt_distribution[i] = (gt_labels_5==i).sum()

                if output_path is not None:
                    pred_labels = point_wise_dict['pred_segmentation_label'].detach().to(torch.uint8).cpu()
                    os.makedirs(output_path / sequence_id, exist_ok=True)
                    path = str(output_path / sequence_id / f"{sample_idx:03d}_pred.npy")
                    np.save(path, pred_labels)

                pred_dict['scene_wise'].update(
                    ups=ups.detach().cpu(),
                    downs=downs.detach().cpu(),
                    pred_distribution=pred_distribution.detach().cpu(),
                    gt_distribution=gt_distribution.detach().cpu(),
                    point_number_stats=point_number_stats.detach().cpu(),
                    ins_number_stats=ins_number_stats.detach().cpu(),
                )
                if 'gt_binary_label' in point_wise_dict.keys():
                    pred_dict['scene_wise'].update(
                            binary_hist=binary_hist,
                            binary_label_hist=binary_label_hist,

                            #ap=ap,
                            #auroc = auroc,
                        )

                if 'pred_segmentation_label_refine' in cur_dict['point_wise']:
                    tree = NN(n_neighbors=1).fit(point_wise_dict['point_xyz_refine'].detach().cpu().numpy())
                    dists, indices = tree.kneighbors(point_xyz)

                    pred_segmentation_label_refine = point_wise_dict['pred_segmentation_label_refine'].detach().cpu().numpy()[indices[:, 0]]

                    for ignore_index in self.ignore_index:
                        pred_segmentation_label_refine[segmentation_label == ignore_index] = ignore_index

                    point_wise_dict['pred_segmentation_label_refine'] = torch.from_numpy(pred_segmentation_label_refine)

                    # compute statistics
                    ups = torch.zeros(23, dtype=torch.long)
                    downs = torch.zeros(23, dtype=torch.long)
                    for i in range(23):
                        ups[i] = ((segmentation_label == i) & (pred_segmentation_label_refine == i)).sum()
                        downs[i] = ((segmentation_label == i) | (pred_segmentation_label_refine == i)).sum()
                    
                    # build class 5 statistics
                    pred_distribution = torch.zeros(23, dtype=torch.long) # pred=5, gt labels
                    gt_distribution = torch.zeros(23, dtype=torch.long) # gt=5, pred labels

                    gt_labels_5 = segmentation_label[pred_segmentation_label_refine==5]
                    pred_labels_5 = pred_segmentation_label_refine[segmentation_label==5]

                    for i in range(23):
                        pred_distribution[i] = (pred_labels_5==i).sum()
                        gt_distribution[i] = (gt_labels_5==i).sum()

                    pred_dict['scene_wise'].update(
                        ups_refine=ups.detach().cpu(),
                        downs_refine=downs.detach().cpu(),
                        pred_distribution_refine = pred_distribution.detach().cpu(),
                        gt_distribution_refine=gt_distribution.detach().cpu(),
                    )


            if 'voxel_wise' in cur_dict.keys() and 'pred_bev_segmentation_label' in cur_dict['voxel_wise']:
                voxel_wise_dict = cur_dict['voxel_wise']
                pred_segmentation_label = voxel_wise_dict['pred_bev_segmentation_label']
                segmentation_label = voxel_wise_dict['gt_bev_segmentation_label']

                # compute statistics
                ups = torch.zeros(4, dtype=torch.long)
                downs = torch.zeros(4, dtype=torch.long)
                for i in range(4):
                    ups[i] = ((segmentation_label == i) & (pred_segmentation_label == i)).sum()
                    downs[i] = ((segmentation_label == i) | (pred_segmentation_label == i)).sum()

                pred_dict['voxel_wise'].update(
                    ups=ups.detach().cpu(),
                    downs=downs.detach().cpu(),
                )
            
            if 'gt_mask' in cur_dict['object_wise']:
                object_wise_dict = cur_dict['object_wise']
                pred_mask_label = object_wise_dict['pred_mask']
                mask_label = object_wise_dict['gt_mask']

                # compute statistics
                ups = torch.zeros(2, dtype=torch.long)
                downs = torch.zeros(2, dtype=torch.long)
                for i in range(2):
                    ups[i] = ((mask_label == i) & (pred_mask_label == i)).sum()
                    downs[i] = ((mask_label == i) | (pred_mask_label == i)).sum()

                pred_dict['scene_wise'].update(
                    ups_mask=ups.detach().cpu(),
                    downs_mask=downs.detach().cpu(),
                )

                pred_cls_label = object_wise_dict['pred_box_cls']
                cls_label = object_wise_dict['gt_box_cls']

                # compute statistics
                ups = torch.zeros(8, dtype=torch.long)
                downs = torch.zeros(8, dtype=torch.long)
                for i in range(8):
                    ups[i] = ((cls_label == i) & (pred_cls_label == i)).sum()
                    downs[i] = ((cls_label == i) | (pred_cls_label == i)).sum()

                pred_dict['scene_wise'].update(
                    ups_box_cls=ups.detach().cpu(),
                    downs_box_cls=downs.detach().cpu(),
                )
                if 'pred_box_points_seg_label' in object_wise_dict.keys(): 
                    pred_cls_label = object_wise_dict['pred_box_points_seg_label']
                    cls_label = object_wise_dict['gt_box_points_seg_label']

                    # compute statistics
                    ups = torch.zeros(8, dtype=torch.long)
                    downs = torch.zeros(8, dtype=torch.long)
                    for i in range(8):
                        ups[i] = ((cls_label == i) & (pred_cls_label == i)).sum()
                        downs[i] = ((cls_label == i) | (pred_cls_label == i)).sum()

                    pred_dict['scene_wise'].update(
                        ups_box_points_cls=ups.detach().cpu(),
                        downs_box_points_cls=downs.detach().cpu(),
                    )

                if 'gt_box_cls_num' in object_wise_dict.keys():
                    pred_dict['scene_wise'].update(
                        gt_box_cls_num=object_wise_dict['gt_box_cls_num'].detach().cpu(),
                        gt_box_points_num=object_wise_dict['gt_box_points_num'].detach().cpu(),
                        detected_box_cls_num=object_wise_dict['detected_box_cls_num'].detach().cpu(),
                        detected_box_points_num=object_wise_dict['detected_box_points_num'].detach().cpu(),
                    )   
            return pred_dict

        annos = []

        if (output_path is not None) and (not os.path.exists(output_path)):
            os.makedirs(output_path, exist_ok=True)
        for index, box_dict in enumerate(pred_dicts):
            single_pred_dict = generate_single_sample_dict(box_dict, output_path=output_path)
            annos.append(single_pred_dict)

        return annos


    def evaluation(self, pred_dicts, box_class_names, **kwargs):
        if 'annos' not in self.infos[0].keys():
            return 'No ground-truth boxes for evaluation', {}

        def kitti_eval(eval_det_annos, eval_gt_annos):
            from ..kitti.kitti_object_eval_python import eval as kitti_eval
            from ..kitti import kitti_utils

            map_name_to_kitti = {
                'Vehicle': 'Car',
                'Pedestrian': 'Pedestrian',
                'Cyclist': 'Cyclist',
                'Sign': 'Sign',
                'Car': 'Car'
            }
            kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti)
            kitti_utils.transform_annotations_to_kitti_format(
                eval_gt_annos, map_name_to_kitti=map_name_to_kitti,
                info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            kitti_class_names = [map_name_to_kitti[x] for x in class_names]
            ap_result_str, ap_dict = kitti_eval.get_official_eval_result(
                gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names
            )
            return ap_result_str, ap_dict

        def waymo_eval(eval_det_annos, eval_gt_annos):
            from .waymo_eval import OpenPCDetWaymoDetectionMetricsEstimator
            eval = OpenPCDetWaymoDetectionMetricsEstimator()

            ap_dict = eval.waymo_evaluation(
                eval_det_annos, eval_gt_annos, class_name=box_class_names,
                distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            ap_result_str = '\n'
            for key in ap_dict:
                ap_dict[key] = ap_dict[key][0]
                ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])

            return ap_result_str, ap_dict
        
        res_strs = []
        res_dicts = [] 

        if 'box' in self.evaluation_list:
            eval_gt_annos = []
            for i in range(self.__len__()):
                annos = self.infos[i]['annos']
                box_label = annos['name']
                box_difficulty = annos['difficulty']
                box_attr = annos['gt_boxes_lidar']
                box_npoints = annos['num_points_in_gt']
                #box_attr, box_label, box_difficulty, box_npoints = self.get_box3d(i)
                eval_gt_annos.append(
                    dict(
                        difficulty=box_difficulty,
                        num_points_in_gt=box_npoints,
                        name=box_label,
                        gt_boxes_lidar=box_attr
                    )
                )
            eval_det_annos = copy.deepcopy(pred_dicts)
            #eval_det_annos = translate_names(eval_det_annos)
            #eval_gt_annos = translate_names(eval_gt_annos)
            ap_result_str, ap_dict = waymo_eval(eval_det_annos, eval_gt_annos)
            res_strs.append(ap_result_str)
            res_dicts.append(ap_dict)
        
        if 'bev' in self.evaluation_list:
            
            total_ups, total_downs = None, None
            for pred_dict in pred_dicts:
                ups, downs = pred_dict['voxel_wise']['ups'], pred_dict['voxel_wise']['downs']
                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for bev class {cls} {iou:.4f} \n'
                iou_dict[f'IoU_{cls}'] = iou
                ious.append(iou)
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU_bev'] = ious.mean()
            seg_result_str += f'mIoU_bev={ious.mean():.4f} \n'
            
            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)

        if 'refine' in self.evaluation_list:
        #if 'refine' in self.evaluation_list and 'ups_mask' in pred_dict['scene_wise'].keys(): 
            total_ups, total_downs = None, None
            for pred_dict in pred_dicts:
                ups, downs = pred_dict['scene_wise']['ups_mask'], pred_dict['scene_wise']['downs_mask']
                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for box points mask class {cls} {iou:.4f} \n'
                iou_dict[f'Box_points_mask_IoU_{cls}'] = iou
                ious.append(iou)
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU_box_points_mask'] = ious.mean()
            seg_result_str += f'mIoU_box_points_mask={ious.mean():.4f} \n'
            
            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)

            total_ups, total_downs = None, None
            for pred_dict in pred_dicts:
                ups, downs = pred_dict['scene_wise']['ups_box_cls'], pred_dict['scene_wise']['downs_box_cls']
                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for box class {cls} {iou:.4f} \n'
                iou_dict[f'Box_class_IoU_{cls}'] = iou
                ious.append(iou)
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU_box_cls'] = ious.mean()
            seg_result_str += f'mIoU_box_cls={ious.mean():.4f} \n'
            
            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)
           
            if 'ups_box_points_cls' in pred_dicts[0]['scene_wise'].keys():

                total_ups, total_downs = None, None
                for pred_dict in pred_dicts:
                    ups, downs = pred_dict['scene_wise']['ups_box_points_cls'], pred_dict['scene_wise']['downs_box_points_cls']
                    if total_ups is None:
                        total_ups = ups.clone()
                        total_downs = downs.clone()
                    else:
                        total_ups += ups
                        total_downs += downs
                seg_result_str = '\n'
                iou_dict = {}
                ious = []
                for cls in range(total_ups.shape[0]):
                    iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                    seg_result_str += f'IoU for box points class {cls} {iou:.4f} \n'
                    iou_dict[f'Box_class_points_IoU_{cls}'] = iou
                ious.append(iou)
                ious = np.array(ious).reshape(-1)[1:]
                iou_dict['mIoU_box_points_cls'] = ious.mean()
                seg_result_str += f'mIoU_box_points_cls={ious.mean():.4f} \n'

                res_strs.append(seg_result_str)
                res_dicts.append(iou_dict)
            
            if 'gt_box_cls_num' in pred_dicts[0]['scene_wise'].keys():
                total_gt_cls_num = None 
                total_gt_points_num = None 
                total_detected_cls_num = None 
                total_detected_points_num = None 
                seg_result_str = '\n'
                for pred_dict in pred_dicts:
                    if total_gt_cls_num is None:
                        total_gt_cls_num = pred_dict['scene_wise']['gt_box_cls_num'].clone()
                        total_gt_points_num = pred_dict['scene_wise']['gt_box_points_num'].clone()
                        total_detected_cls_num = pred_dict['scene_wise']['detected_box_cls_num'].clone()
                        total_detected_points_num = pred_dict['scene_wise']['detected_box_points_num'].clone()
                    else:
                        total_gt_cls_num +=  pred_dict['scene_wise']['gt_box_cls_num'] 
                        total_gt_points_num += pred_dict['scene_wise']['gt_box_points_num']
                        total_detected_cls_num += pred_dict['scene_wise']['detected_box_cls_num']
                        total_detected_points_num += pred_dict['scene_wise']['detected_box_points_num']
                for cls in range(total_gt_cls_num.shape[0]):
                    seg_result_str += f'Statistics class {cls}: gt_box: {total_gt_cls_num[cls]}, detected_box: {total_detected_cls_num[cls]}, gt_points: {total_gt_points_num[cls]}, detected_points: {total_detected_points_num[cls]}\n'

                res_strs.append(seg_result_str)

        if 'seg' in self.evaluation_list:
            total_ups, total_downs = None, None
            total_v_ups_moving, total_v_downs_moving = None, None
            total_v_ups_middle, total_v_downs_middle = None, None
            total_v_ups_static, total_v_downs_static = None, None
            total_ups_refine, total_downs_refine = None, None
            
            total_pred_distribution = None 
            total_gt_distribution = None 
            
            total_point_number_stats = None 
            total_ins_number_stats = None
            
            total_ap = None 
            total_auroc = None 
            total_num_ap = 0  
            
            total_binary_hist = None
            total_binary_label = None

            for pred_dict in pred_dicts:
                ups, downs = pred_dict['scene_wise']['ups'], pred_dict['scene_wise']['downs']

                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
                
                if total_pred_distribution is None: 
                    total_pred_distribution = pred_dict['scene_wise']['pred_distribution'].clone()
                    total_gt_distribution = pred_dict['scene_wise']['gt_distribution'].clone()
                else:
                    total_pred_distribution += pred_dict['scene_wise']['pred_distribution'] 
                    total_gt_distribution += pred_dict['scene_wise']['gt_distribution'] 

                if total_point_number_stats is None:
                    total_point_number_stats = pred_dict['scene_wise']['point_number_stats'].clone() 
                    total_ins_number_stats = pred_dict['scene_wise']['ins_number_stats'].clone()
                else:
                    total_point_number_stats += pred_dict['scene_wise']['point_number_stats']
                    total_ins_number_stats += pred_dict['scene_wise']['ins_number_stats'] 
                
                if 'ap' in pred_dict['scene_wise'].keys():
                    if  pred_dict['scene_wise']['ap']!=-1:
                        total_num_ap += 1 
                        if total_ap is None:
                            total_ap = pred_dict['scene_wise']['ap'].clone()
                            total_auroc = pred_dict['scene_wise']['auroc'].clone()
                        else:
                            total_ap += pred_dict['scene_wise']['ap']
                            total_auroc += pred_dict['scene_wise']['auroc'] 
                if 'binary_hist' in pred_dict['scene_wise'].keys():
                    if total_binary_hist is None:
                        total_binary_hist = pred_dict['scene_wise']['binary_hist'].clone()
                        total_binary_label =  pred_dict['scene_wise']['binary_label_hist'].clone()
                    else:
                        total_binary_hist +=  pred_dict['scene_wise']['binary_hist']
                        total_binary_label +=  pred_dict['scene_wise']['binary_label_hist']


                if 'ups_refine' in pred_dict['scene_wise'].keys():
                    ups_rf, downs_rf = pred_dict['scene_wise']['ups_refine'], pred_dict['scene_wise']['downs_refine']
                    if total_ups_refine is None:
                        total_ups_refine = ups_rf.clone()
                        total_downs_refine = downs_rf.clone()
                    else:
                        total_ups_refine += ups_rf
                        total_downs_refine += downs_rf

            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for class {cls} {iou:.4f} \n'
                iou_dict[f'IoU_{cls}'] = iou
                ious.append(iou)
            
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU'] = ious.mean()
            iou_dict['IoU_FG'] = total_ups[1:8].sum() / np.clip(total_downs[1:8].sum(), 1, None)
            iou_dict['IoU_BG'] = total_ups[8:].sum() / np.clip(total_downs[8:].sum(), 1, None)
            seg_result_str += f'mIoU={ious.mean():.4f} \n'
            seg_result_str += f"IoU_FG={iou_dict['IoU_FG']:.4f} \n"
            seg_result_str += f"IoU_BG={iou_dict['IoU_BG']:.4f} \n"
            
            seg_result_str += f'GT seg label is 5, pred label distribution: \n'
            for cls in range(total_ups.shape[0]):
                seg_result_str += 'c=%d: %d, '%(cls, total_gt_distribution[cls])
            
            seg_result_str += '\n'
            seg_result_str += f'Pred seg label is 5, GT label distribution: \n'
            for cls in range(total_ups.shape[0]):
                seg_result_str += 'c=%d: %d, '%(cls, total_pred_distribution[cls])
            
            seg_result_str += '\n'
            seg_result_str += 'point number for each class: \n'
            for cls in range(total_ups.shape[0]):
                seg_result_str += 'c=%d: %d, '%(cls, total_point_number_stats[cls])
            
            seg_result_str += '\n'
            seg_result_str += 'ins number for each class: \n'
            for cls in range(total_ups.shape[0]):
                seg_result_str += 'c=%d: %d, '%(cls, total_ins_number_stats[cls])

            if total_ap is not None: 
                seg_result_str += '\n'
                seg_result_str += 'AP: %f\n'%(total_ap/total_num_ap)
                seg_result_str += 'AUROC: %f\n'%(total_auroc/total_num_ap)
            
            if total_binary_hist is not None:
                preds = torch.zeros(total_binary_hist.sum().int().item())
                gts = torch.zeros(total_binary_hist.sum().int().item())
                cnt = 0
                total_binary_hist = total_binary_hist.int()
                total_binary_label = total_binary_label.int()
                n_hist = 100
                interval = float(1/n_hist)
                for ii in range(n_hist):
                    if total_binary_hist[ii]==0:
                        continue
                    preds[cnt:cnt+total_binary_hist[ii]] = (ii+0.5)*interval
                    gts[cnt:cnt+total_binary_label[ii]] = 1.0
                    cnt += total_binary_hist[ii]
                gts = gts.long()
                ap_eval = torchmetrics.AveragePrecision(task="binary")
                ap = ap_eval(preds, gts).detach().cpu()
                auroc_eval = torchmetrics.AUROC(task="binary")
                auroc = auroc_eval(preds, gts).detach().cpu()
                seg_result_str += '\n'
                seg_result_str += 'AP: %f\n'%(ap)
                seg_result_str += 'AUROC: %f\n'%(auroc)


            if total_ups_refine is not None:
                ious = []
                for cls in range(total_ups_refine.shape[0]):
                    iou = total_ups_refine[cls]/np.clip(total_downs_refine[cls], 1, None)
                    seg_result_str += f'IoU for class refined {cls} {iou:.4f} \n'
                    iou_dict[f'IoU_{cls}_refined'] = iou
                    ious.append(iou)
                    
                ious = np.array(ious).reshape(-1)[1:]
                iou_dict['mIoU_refined'] = ious.mean()
                iou_dict['IoU_FG_refined'] = total_ups_refine[1:8].sum() / np.clip(total_downs_refine[1:8].sum(), 1, None)
                iou_dict['IoU_BG_refined'] = total_ups_refine[8:].sum() / np.clip(total_downs_refine[8:].sum(), 1, None)
                seg_result_str += f'mIoU_refined={ious.mean():.4f} \n'
                seg_result_str += f"IoU_FG_refined={iou_dict['IoU_FG_refined']:.4f} \n"
                seg_result_str += f"IoU_BG_refined={iou_dict['IoU_BG_refined']:.4f} \n"

            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)
        
        return res_strs, res_dicts 

    def create_groundtruth_database(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10,
                                    processed_data_tag=None):
        database_save_path = save_path / ('%s_gt_database_%s_sampled_%d' % (processed_data_tag, split, sampled_interval))
        db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d.pkl' % (processed_data_tag, split, sampled_interval))
        db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_global.npy' % (processed_data_tag, split, sampled_interval))
        database_save_path.mkdir(parents=True, exist_ok=True)
        all_db_infos = {}
        with open(info_path, 'rb') as f:
            infos = pickle.load(f)

        point_offset_cnt = 0
        stacked_gt_points = []
        for k in range(0, len(infos), sampled_interval):
            print('gt_database sample: %d/%d' % (k + 1, len(infos)))
            info = infos[k]

            pc_info = info['point_cloud']
            sequence_name = pc_info['lidar_sequence']
            sample_idx = pc_info['sample_idx']
            points = self.get_lidar(sequence_name, sample_idx)

            annos = info['annos']
            names = annos['name']
            difficulty = annos['difficulty']
            gt_boxes = annos['gt_boxes_lidar']

            if k % 4 != 0 and len(names) > 0:
                mask = (names == 'Vehicle')
                names = names[~mask]
                difficulty = difficulty[~mask]
                gt_boxes = gt_boxes[~mask]

            if k % 2 != 0 and len(names) > 0:
                mask = (names == 'Pedestrian')
                names = names[~mask]
                difficulty = difficulty[~mask]
                gt_boxes = gt_boxes[~mask]

            num_obj = gt_boxes.shape[0]
            if num_obj == 0:
                continue

            box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
                torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
                torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda()
            ).long().squeeze(dim=0).cpu().numpy()

            for i in range(num_obj):
                filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
                filepath = database_save_path / filename
                gt_points = points[box_idxs_of_pts == i]
                gt_points[:, :3] -= gt_boxes[i, :3]

                if (used_classes is None) or names[i] in used_classes:
                    with open(filepath, 'w') as f:
                        gt_points.tofile(f)

                    db_path = str(filepath.relative_to(self.root_path))  # gt_database/xxxxx.bin
                    db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
                               'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
                               'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]}

                    # it will be used if you choose to use shared memory for gt sampling
                    stacked_gt_points.append(gt_points)
                    db_info['global_data_offset'] = [point_offset_cnt, point_offset_cnt + gt_points.shape[0]]
                    point_offset_cnt += gt_points.shape[0]

                    if names[i] in all_db_infos:
                        all_db_infos[names[i]].append(db_info)
                    else:
                        all_db_infos[names[i]] = [db_info]
        for k, v in all_db_infos.items():
            print('Database %s: %d' % (k, len(v)))

        with open(db_info_save_path, 'wb') as f:
            pickle.dump(all_db_infos, f)

        # it will be used if you choose to use shared memory for gt sampling
        stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
        np.save(db_data_save_path, stacked_gt_points)
    
    def propagate_segmentation_labels(self, waymo_infos, save_path, num_workers=multiprocessing.cpu_count()):
        from functools import partial
        from . import waymo_utils
        print('---------------Propagating Segmentation Labels------------------------')

        propagate_single_sequence = partial(
            waymo_utils.propagate_segmentation_labels,
            waymo_infos=waymo_infos,
            save_path=save_path
        )

        sequence_ids = list(set([info['point_cloud']['lidar_sequence'] for info in waymo_infos]))

        #propagate_single_sequence(sequence_id = sequence_ids[1])
        with multiprocessing.Pool(num_workers) as p:
            sequence_infos = list(tqdm(p.imap(propagate_single_sequence,
                                              sequence_ids),
                                       total=len(sequence_ids)))

        all_sequences_infos = [item for infos in sequence_infos for item in infos]
        return all_sequences_infos

def create_waymo_infos(dataset_cfg, data_path, save_path,
                       raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
                       workers=min(16, multiprocessing.cpu_count())):
    dataset = WaymoDataset(
        dataset_cfg=dataset_cfg, root_path=data_path,
        training=False, logger=common_utils.create_logger()
    )
    train_split = dataset_cfg.DATA_SPLIT['train']
    val_split = dataset_cfg.DATA_SPLIT['test']
    #train_split, val_split = 'train', 'val'

    train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))
    val_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, val_split))

    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    print('---------------Start to generate data infos---------------')

    dataset.set_split(train_split)
    waymo_infos_train = dataset.get_infos(
        raw_data_path=data_path / raw_data_tag,
        save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
        sampled_interval=1
    )
    with open(train_filename, 'wb') as f:
        pickle.dump(waymo_infos_train, f)
    print('----------------Waymo info train file is saved to %s----------------' % train_filename)

    dataset.set_split(val_split)
    waymo_infos_val = dataset.get_infos(
        raw_data_path=data_path / raw_data_tag,
        save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
        sampled_interval=1
    )
    with open(val_filename, 'wb') as f:
        pickle.dump(waymo_infos_val, f)
    print('----------------Waymo info val file is saved to %s----------------' % val_filename)

    print('---------------Start create groundtruth database for data augmentation---------------')
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    dataset.set_split(train_split)
    dataset.create_groundtruth_database(
        info_path=train_filename, save_path=save_path, split=train_split, sampled_interval=1,
        used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
    )
    print('---------------Data preparation Done---------------')


def propagate_segmentation_labels(dataset_cfg, data_path, save_path,
                                  raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
                                  workers=min(16, multiprocessing.cpu_count())):
    dataset = WaymoDataset(
        dataset_cfg=dataset_cfg, root_path=data_path,
        training=False, logger=common_utils.create_logger()
    )
    train_split = dataset_cfg.DATA_SPLIT['train']
    val_split = dataset_cfg.DATA_SPLIT['test']
    #train_split, val_split = 'train', 'val'

    train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))
    val_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, val_split))

    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    print('---------------Start to generate data infos---------------')

    dataset.set_split(train_split)
    waymo_infos_train = dataset.get_infos(
        raw_data_path=data_path / raw_data_tag,
        save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
        sampled_interval=1
    )
    waymo_infos_train = dataset.propagate_segmentation_labels(
        waymo_infos_train,
        save_path=save_path / processed_data_tag,
        num_workers=workers,
    )
    with open(train_filename, 'wb') as f:
        pickle.dump(waymo_infos_train, f)
    print('----------------Waymo info train file is saved to %s----------------' % train_filename)

    dataset.set_split(val_split)
    waymo_infos_val = dataset.get_infos(
        raw_data_path=data_path / raw_data_tag,
        save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
        sampled_interval=1
    )
    waymo_infos_val = dataset.propagate_segmentation_labels(
        waymo_infos_val,
        save_path=save_path / processed_data_tag,
        num_workers=workers,
    )
    with open(val_filename, 'wb') as f:
        pickle.dump(waymo_infos_val, f)
    print('----------------Waymo info val file is saved to %s----------------' % val_filename)

    print('---------------Start create groundtruth database for data augmentation---------------')
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    dataset.set_split(train_split)
    dataset.create_groundtruth_database(
        info_path=train_filename, save_path=save_path, split=train_split, sampled_interval=1,
        used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
    )
    print('---------------Data preparation Done---------------')

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--cfg_file', type=str, default=None, help='specify the config of dataset')
    parser.add_argument('--func', type=str, default='create_waymo_infos', help='')
    parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
    args = parser.parse_args()

    if args.func == 'create_waymo_infos':
        import yaml
        from easydict import EasyDict
        try:
            yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
        except:
            yaml_config = yaml.safe_load(open(args.cfg_file))
        dataset_cfg = EasyDict(yaml_config)
        ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()
        dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
        create_waymo_infos(
            dataset_cfg=dataset_cfg,
            data_path=ROOT_DIR / 'data' / 'waymo',
            save_path=ROOT_DIR / 'data' / 'waymo',
            raw_data_tag='raw_data',
            processed_data_tag=args.processed_data_tag
        )
    
    if args.func == 'propagate_segmentation_labels':
        import yaml
        from easydict import EasyDict
        try:
            yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
        except:
            yaml_config = yaml.safe_load(open(args.cfg_file))
        dataset_cfg = EasyDict(yaml_config)
        ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()
        dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
        propagate_segmentation_labels(
            dataset_cfg=dataset_cfg,
            data_path=ROOT_DIR / 'data' / 'waymo',
            save_path=ROOT_DIR / 'data' / 'waymo',
            raw_data_tag='raw_data',
            processed_data_tag=args.processed_data_tag
        )
