import copy
import gc
import os
import pickle
from pathlib import Path
import SharedArray as SA
import torch
from torch.utils.data import Dataset
import torchmetrics

import numpy as np
from ..dataset import DatasetTemplate
from pcdet.utils import common_utils


from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils, polar_utils
from ..dataset import DatasetTemplate
from pcdet.utils.box_utils import boxes_to_corners_3d
from ...utils.vis_utils import write_ply_color

from scipy.sparse.csgraph import connected_components
from scipy.sparse import csr_matrix
from torch_cluster import radius_graph
from sklearn.neighbors import NearestNeighbors as NN

TRAIN_DICT = {'00': 4541, '01': 1101, '02': 4661, '03': 801, '04': 271, '05': 2761, '06': 1101, '07': 1101, '09': 1591,
              '10': 1201}
VAL_DICT = {'08': 4071}
TRAINVAL_DICT = {**TRAIN_DICT, **VAL_DICT}
TEST_DICT = {'11': 921, '12': 1061, '13': 3281, '14': 631, '15': 1901, '16': 1731, '17': 491, '18': 1801, '19': 4981,
             '20': 831, '21': 2721}
TOY_DICT = {'04': 271}
NUM_CLASS = 20

label_name_mapping = {
    0: 'unlabeled',
    1: 'outlier',
    10: 'car',
    11: 'bicycle',
    13: 'bus',
    15: 'motorcycle',
    16: 'on-rails',
    18: 'truck',
    20: 'other-vehicle',
    30: 'person',
    31: 'bicyclist',
    32: 'motorcyclist',
    40: 'road',
    44: 'parking',
    48: 'sidewalk',
    49: 'other-ground',
    50: 'building',
    51: 'fence',
    52: 'other-structure',
    60: 'lane-marking',
    70: 'vegetation',
    71: 'trunk',
    72: 'terrain',
    80: 'pole',
    81: 'traffic-sign',
    99: 'other-object',
    252: 'moving-car',
    253: 'moving-bicyclist',
    254: 'moving-person',
    255: 'moving-motorcyclist',
    256: 'moving-on-rails',
    257: 'moving-bus',
    258: 'moving-truck',
    259: 'moving-other-vehicle'
}

learning_map = {
  0 : 0,     # "unlabeled"
  1 : 0,     # "outlier" mapped to "unlabeled" --------------------------mapped
  10: 1,     # "car"
  11: 2,     # "bicycle"
  13: 5,     # "bus" mapped to "other-vehicle" --------------------------mapped
  15: 3,     # "motorcycle"
  16: 5,     # "on-rails" mapped to "other-vehicle" ---------------------mapped
  18: 4,     # "truck"
  20: 5,     # "other-vehicle"
  30: 6,     # "person"
  31: 7,     # "bicyclist"
  32: 8,     # "motorcyclist"
  40: 9,     # "road"
  44: 10,    # "parking"
  48: 11,    # "sidewalk"
  49: 12,    # "other-ground"
  50: 13,    # "building"
  51: 14,    # "fence"
  52: 0,     # "other-structure" mapped to "unlabeled" ------------------mapped
  60: 9,     # "lane-marking" to "road" ---------------------------------mapped
  70: 15,    # "vegetation"
  71: 16,    # "trunk"
  72: 17,    # "terrain"
  80: 18,    # "pole"
  81: 19,    # "traffic-sign"
  99: 0,     # "other-object" to "unlabeled" ----------------------------mapped
  252: 1,    # "moving-car" to "car" ------------------------------------mapped
  253: 7,    # "moving-bicyclist" to "bicyclist" ------------------------mapped
  254: 6,    # "moving-person" to "person" ------------------------------mapped
  255: 8,    # "moving-motorcyclist" to "motorcyclist" ------------------mapped
  256: 5,    # "moving-on-rails" mapped to "other-vehicle" --------------mapped
  257: 5,    # "moving-bus" mapped to "other-vehicle" -------------------mapped
  258: 4,    # "moving-truck" to "truck" --------------------------------mapped
  259: 5,    # "moving-other"-vehicle to "other-vehicle" ----------------mapped
}

learning_map_inv ={ # inverse of previous map
  0: 0,      # "unlabeled", and others ignored
  1: 10,     # "car"
  2: 11,     # "bicycle"
  3: 15,     # "motorcycle"
  4: 18,     # "truck"
  5: 20,     # "other-vehicle"
  6: 30,     # "person"
  7: 31,     # "bicyclist"
  8: 32,     # "motorcyclist"
  9: 40,     # "road"
  10: 44,    # "parking"
  11: 48,    # "sidewalk"
  12: 49,    # "other-ground"
  13: 50,    # "building"
  14: 51,    # "fence"
  15: 70,    # "vegetation"
  16: 71,    # "trunk"
  17: 72,    # "terrain"
  18: 80,    # "pole"
  19: 81,    # "traffic-sign"
}
kept_labels = [
    'unlabeled', 'road', 'sidewalk', 'parking', 'other-ground', 'building', 'car', 'truck',
    'bicycle', 'motorcycle', 'other-vehicle', 'vegetation', 'trunk', 'terrain',
    'person', 'bicyclist', 'motorcyclist', 'fence', 'pole', 'traffic-sign'
]

# 1: car, 2: bicycle, 3: motorcycle, 4: truck, 5: other-vehicle, 6: person, 7: bicyclist, 8:motorcyclist,
# 9:  road, 10: parking, 11: sidewalk, 12: other-ground, 13: building, 14: fence, 15: vegetation, 
# 16: 

class SemanticKITTIDataset(DatasetTemplate):
    def __init__(self, dataset_cfg, training=True, root_path=None, logger=None):
        """
        split: [train, val, test]
        """
        super().__init__(
            dataset_cfg=dataset_cfg, training=training, root_path=root_path, logger=logger,
        )
        self.dataset_name = self.dataset_cfg.DATASET
        self.data_path = self.root_path
        self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
        self.seq_dict = eval('%s_DICT' % self.split.upper())
        self.data_idx = []
        self.num_sweeps = self.dataset_cfg.get("NUM_SWEEPS", 1)
        self.ignore_index = [0]
        self.cyl_label_mapping = self.dataset_cfg.get("CYL_LABEL_MAPPING", False)
        self.hide_other_class = self.dataset_cfg.get("HIDE_OTHER_CLASS", False)
        self.hide_class = self.dataset_cfg.get("HIDE_CLASS", 5)
        self.ds_sample = self.dataset_cfg.get("DS_SAMPLE", False)

        if self.split == 'train':
            self.seqs = [
                '00', '01', '02', '03', '04', '05', '06', '07', '09', '10'
            ]
            
        elif self.split == 'val':
            self.seqs = ['08']
        elif self.split == 'test':
            self.seqs = [
                '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'
            ]
        
        self.files = []
        for seq in self.seqs:
            seq_files = sorted(
                os.listdir(os.path.join(self.data_path, seq, 'velodyne')))
            seq_files = [
                os.path.join(self.data_path, seq, 'velodyne', x) for x in seq_files
            ]
            self.files.extend(seq_files)

        if self.dataset_cfg.SAMPLED_INTERVAL[self.mode]  > 1:
            self.files = self.files[::self.dataset_cfg.SAMPLED_INTERVAL[self.mode]]
        
        if self.cyl_label_mapping:
            self.label_map = np.zeros(260)
            self.label_map_inv = np.zeros(20)
            for label_id in learning_map.keys():
                self.label_map[label_id] = learning_map[label_id]
            for map_id in learning_map_inv.keys():
                self.label_map_inv[map_id] = learning_map_inv[map_id]

            self.reverse_label_name_mapping = learning_map_inv
        else:
            reverse_label_name_mapping = {}
            self.label_map = np.zeros(260)
            self.label_map_inv = np.zeros(20)
            cnt = 0
            for label_id in label_name_mapping:
                if label_id > 250:
                    if label_name_mapping[label_id].replace('moving-','') in kept_labels:
                        self.label_map[label_id] = reverse_label_name_mapping[
                            label_name_mapping[label_id].replace('moving-', '')]
                    else:
                        self.label_map[label_id] = 0
                else:
                    if label_name_mapping[label_id] in kept_labels:
                        self.label_map[label_id] = cnt
                        reverse_label_name_mapping[
                            label_name_mapping[label_id]] = cnt
                        self.label_map_inv[cnt] = label_id

                        cnt += 1
                    else:
                        self.label_map[label_id] = 0
            self.reverse_label_name_mapping = reverse_label_name_mapping

        self.num_classes = 20
        self.angle = 0.0

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index, mix3d=False, ins_cur=None):
        with open(self.files[index], 'rb') as b:
            block_ = np.fromfile(b, dtype=np.float32).reshape(-1, 4)
        coord = block_[:, 0:3]
        feat = block_[:, 3:].reshape(-1, 1)

        label_file = self.files[index].replace('velodyne', 'labels').replace(
            '.bin', '.label')
        if os.path.exists(label_file):
            with open(label_file, 'rb') as a:
                all_labels = np.fromfile(a, dtype=np.int32).reshape(-1)
        else:
            all_labels = np.zeros(coord.shape[0]).astype(np.int32)
        label = self.label_map[all_labels & 0xFFFF].astype(np.int64)

        ins_back_file = label_file.replace('.label', '_ins.npy')
        if not os.path.exists(ins_back_file):

            # generate background instance label 
            # 8: sign, 9: traffic light, 10: pole, 11: cone, 12: bicyle, 13: motorcycle, 14: building, 15: vegetation, 16: treetruck 
            instance_label_back = -1*np.ones(label.shape[0]).astype(np.int32) 
            seg_label = label
            points = coord
            cur_num = instance_label_back.max() + 1 
            back_cls_list = [1, 2, 3, 4, 5, 6, 7, 8,
                            14, 16, 18, 19]

            for ii in back_cls_list:
                seg_mask = seg_label==ii
                seg_ids = (seg_mask> 0).nonzero()[0]

                if len(seg_ids)<10:
                    continue
                masked_points = points[seg_ids, :]
                num_points= masked_points.shape[0]
                
                tree = NN(n_neighbors=10).fit(masked_points)
                dists, indices = tree.kneighbors(masked_points)
                e0 = np.arange(num_points).repeat(10)
                e1 = indices.reshape(-1)
                mask = dists.reshape(-1) < 1.0
                #mask = dists.reshape(-1) / (prange + 1e-6) < dist_th
                e0, e1 = e0[mask], e1[mask]
                
                graph = csr_matrix((np.ones_like(e0), (e0, e1)), shape=(num_points, num_points))
                n_components, labels = connected_components(graph, directed=False)
                #print('num_points', ii, np.unique(labels), x.shape, n_components)
                for lb in np.unique(labels):
                    ids = seg_ids[(labels==lb).nonzero()[0]]
                    if len(ids)<10:
                        continue 
                    instance_label_back[ids] = cur_num 
                    cur_num += 1 
            
            #np.save(ins_back_file, instance_label_back)
            #print('saved') 
        else:
            instance_label_back = np.load(ins_back_file)
    
        if ins_cur is not None:
            instance_label_back[instance_label_back>0] += ins_cur 

        instance_label = instance_label_back.copy() 
        
        seq_path = self.files[index].split('/')
        name = "%s_%s"%(seq_path[-3], seq_path[-1][:-4],)
        
        or_label = label.copy()

        if self.ds_sample and self.hide_other_class and self.split=='train':
            minimum_pts_thre = 300
            
            for instance_idx in range(1, instance_label.max()):
                rnd = np.random.rand()
                if rnd > 0.5 or (instance_label == instance_idx).sum() < minimum_pts_thre or label[instance_label==instance_idx][0]==self.hide_class:
                    continue

                obj_ins = coord[instance_label==instance_idx]
                obj_ins_center = np.mean(obj_ins, axis=0)
                obj_ins = obj_ins - obj_ins_center
                scale_ds_large = np.random.rand()*1.5+1.5
                scale_ds_small = np.random.rand()*0.25+0.25
                rnd = np.random.rand()
                scale_ds = scale_ds_large if rnd > 0.5 else scale_ds_small
                obj_ins = obj_ins * scale_ds + obj_ins_center
                coord[instance_label == instance_idx] = obj_ins
                label[instance_label == instance_idx] = self.hide_class 

        if self.hide_other_class and self.split=='train':
            other_mask = or_label==self.hide_class
            label[other_mask] = 0

        point_wise_dict = dict(
            point_xyz=coord,
            point_feat=feat,
            segmentation_label=label,
            instance_label = instance_label, 
        )

        object_wise_dict = dict(
            gt_box_attr=np.zeros((0, 7)),
            gt_box_cls_label=np.zeros((0)).astype(str),
        )

        scene_wise_dict = dict(
            frame_id = name,
        )

        data_dict = dict(
            point_wise = point_wise_dict,
            object_wise = object_wise_dict,
            scene_wise = scene_wise_dict,
        )

        data_dict = self.prepare_data(data_dict)

        if (self.mix3d_cfg is not None) and (not mix3d) and self.training:
            prob = self.mix3d_cfg.get("PROB", 1.0)
            if np.random.rand() < prob:
                num_samples = self.__len__()
                rand_idx = np.random.randint(0, num_samples)
                data_dict2 = self.__getitem__(rand_idx, mix3d=True)
                data_dict['point_wise'] = common_utils.concat_dicts([data_dict['point_wise'], data_dict2['point_wise']])
                data_dict['object_wise'] = common_utils.concat_dicts([data_dict['object_wise'], data_dict2['object_wise']])

        return data_dict
    
    def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None):
        """
        Args:
            batch_dict:
                frame_id:
            pred_dicts: list of pred_dicts
                pred_boxes: (N, 7), Tensor
                pred_scores: (N), Tensor
                pred_labels: (N), Tensor
            class_names:
            output_path:
        Returns:
        """

        def generate_single_sample_dict(cur_dict, output_path=None):
            frame_id = str(cur_dict['scene_wise']['frame_id'].reshape(-1)[-1])
            sequence_id, sample_idx = frame_id.split('_')

            pred_dict = dict(
                object_wise=dict(),
                point_wise=dict(),
                scene_wise=dict(
                    frame_id=frame_id,
                    sequence_id=sequence_id,
                    sample_idx=sample_idx,
                ),
                voxel_wise=dict(),
            )

            if 'pred_box_attr' in cur_dict['object_wise']:
                object_wise_dict = cur_dict['object_wise']

                pred_scores = object_wise_dict['pred_box_scores'].cpu().numpy()
                pred_boxes = object_wise_dict['pred_box_attr'].cpu().numpy()
                pred_labels = object_wise_dict['pred_box_cls_label'].cpu().numpy()

                pred_dict['object_wise'].update(
                    dict(
                        box_scores=pred_scores,
                        box_attr=pred_boxes,
                        box_name=np.array(class_names)[pred_labels - 1],
                    ))

            if 'pred_segmentation_label' in cur_dict['point_wise']:

                point_wise_dict = cur_dict['point_wise']
                # propagate segmentation from predicted points to gt points
                # path = f'../data/waymo/waymo_processed_data_v0_5_0/{sequence_id}/{sample_idx:04d}_seg.npy'
                label_file = f'../data/SemanticKITTI/{sequence_id}/labels/{sample_idx}.label'
                xyz_path = f'../data/SemanticKITTI/{sequence_id}/velodyne/{sample_idx}.bin'
                with open(xyz_path, 'rb') as b:
                    point_xyz = np.fromfile(b, dtype=np.float32).reshape(-1, 4)[:, :3]

                if os.path.exists(label_file):
                    with open(label_file, 'rb') as a:
                        all_labels = np.fromfile(a, dtype=np.int32).reshape(-1)
                else:
                    all_labels = np.zeros(point_xyz.shape[0]).astype(np.int32)
                segmentation_label = self.label_map[all_labels & 0xFFFF].astype(np.int64)
                
                tree = NN(n_neighbors=1).fit(point_wise_dict['point_xyz'].detach().cpu().numpy())
                dists, indices = tree.kneighbors(point_xyz)
                pred_segmentation_label = point_wise_dict['pred_segmentation_label'].detach().cpu().numpy()[indices[:, 0]]
                
                pred_dict['point_wise'].update(
                    dict(
                        pred_segmentation_label=pred_segmentation_label,
                    ))
                
                if 'gt_binary_label' in point_wise_dict.keys():
                    n_hist = 100
                    binary_hist = torch.zeros(n_hist)
                    binary_label_hist = torch.zeros(n_hist) 

                    gt_binary_label = point_wise_dict['gt_binary_label'][indices[:, 0]].cpu()
                    pred_binary_logits = point_wise_dict['pred_binary_logits'][indices[:, 0]].cpu()
                    interval = float(1/n_hist)

                    for ii in range(n_hist):
                        if ii<n_hist-1:

                            mask_ii = (pred_binary_logits>=ii*interval) & (pred_binary_logits<(ii+1)*interval)
                        else:
                            mask_ii = (pred_binary_logits>=ii*interval)
                        if mask_ii.sum()==0:
                            continue 
                        binary_hist[ii] +=  mask_ii.sum() 
                        binary_label_hist[ii] += gt_binary_label[mask_ii].sum()
                    ap = -1
                    auroc = -1

                for ignore_index in self.ignore_index:
                    if self.split!='test':
                        pred_segmentation_label[segmentation_label == ignore_index] = ignore_index

                point_wise_dict['pred_segmentation_label'] = torch.from_numpy(pred_segmentation_label)

                # compute statistics
                ups = torch.zeros(20, dtype=torch.long)
                downs = torch.zeros(20, dtype=torch.long)
                for i in range(20):
                    ups[i] = ((segmentation_label == i) & (pred_segmentation_label == i)).sum()
                    downs[i] = ((segmentation_label == i) | (pred_segmentation_label == i)).sum()
                

                if output_path is not None:
                    pred_labels = point_wise_dict['pred_segmentation_label'].detach().to(torch.uint8).cpu()
                    os.makedirs(output_path / sequence_id, exist_ok=True)
                    path = str(output_path / sequence_id / f"{sample_idx:03d}_pred.npy")
                    np.save(path, pred_labels)

                pred_dict['scene_wise'].update(
                    ups=ups.detach().cpu(),
                    downs=downs.detach().cpu(),
                )
                if 'gt_binary_label' in point_wise_dict.keys():
                    pred_dict['scene_wise'].update(
                            binary_hist=binary_hist,
                            binary_label_hist=binary_label_hist,
                            #ap=ap,
                            #auroc = auroc,
                        )


                if 'pred_segmentation_label_refine' in cur_dict['point_wise']:
                    tree = NN(n_neighbors=1).fit(point_wise_dict['point_xyz_refine'].detach().cpu().numpy())
                    dists, indices = tree.kneighbors(point_xyz)

                    pred_segmentation_label_refine = point_wise_dict['pred_segmentation_label_refine'].detach().cpu().numpy()[indices[:, 0]]

                    for ignore_index in self.ignore_index:
                        pred_segmentation_label_refine[segmentation_label == ignore_index] = ignore_index

                    point_wise_dict['pred_segmentation_label_refine'] = torch.from_numpy(pred_segmentation_label_refine)

                    # compute statistics
                    ups = torch.zeros(23, dtype=torch.long)
                    downs = torch.zeros(23, dtype=torch.long)
                    for i in range(23):
                        ups[i] = ((segmentation_label == i) & (pred_segmentation_label_refine == i)).sum()
                        downs[i] = ((segmentation_label == i) | (pred_segmentation_label_refine == i)).sum()
                    
                    # build class 5 statistics
                    pred_distribution = torch.zeros(23, dtype=torch.long) # pred=5, gt labels
                    gt_distribution = torch.zeros(23, dtype=torch.long) # gt=5, pred labels

                    gt_labels_5 = segmentation_label[pred_segmentation_label_refine==5]
                    pred_labels_5 = pred_segmentation_label_refine[segmentation_label==5]

                    for i in range(23):
                        pred_distribution[i] = (pred_labels_5==i).sum()
                        gt_distribution[i] = (gt_labels_5==i).sum()

                    pred_dict['scene_wise'].update(
                        ups_refine=ups.detach().cpu(),
                        downs_refine=downs.detach().cpu(),
                        pred_distribution_refine = pred_distribution.detach().cpu(),
                        gt_distribution_refine=gt_distribution.detach().cpu(),
                    )


            if 'voxel_wise' in cur_dict.keys() and 'pred_bev_segmentation_label' in cur_dict['voxel_wise']:
                voxel_wise_dict = cur_dict['voxel_wise']
                pred_segmentation_label = voxel_wise_dict['pred_bev_segmentation_label']
                segmentation_label = voxel_wise_dict['gt_bev_segmentation_label']

                # compute statistics
                ups = torch.zeros(4, dtype=torch.long)
                downs = torch.zeros(4, dtype=torch.long)
                for i in range(4):
                    ups[i] = ((segmentation_label == i) & (pred_segmentation_label == i)).sum()
                    downs[i] = ((segmentation_label == i) | (pred_segmentation_label == i)).sum()

                pred_dict['voxel_wise'].update(
                    ups=ups.detach().cpu(),
                    downs=downs.detach().cpu(),
                )
            
            if 'gt_mask' in cur_dict['object_wise']:
                object_wise_dict = cur_dict['object_wise']
                pred_mask_label = object_wise_dict['pred_mask']
                mask_label = object_wise_dict['gt_mask']

                # compute statistics
                ups = torch.zeros(2, dtype=torch.long)
                downs = torch.zeros(2, dtype=torch.long)
                for i in range(2):
                    ups[i] = ((mask_label == i) & (pred_mask_label == i)).sum()
                    downs[i] = ((mask_label == i) | (pred_mask_label == i)).sum()

                pred_dict['scene_wise'].update(
                    ups_mask=ups.detach().cpu(),
                    downs_mask=downs.detach().cpu(),
                )

                pred_cls_label = object_wise_dict['pred_box_cls']
                cls_label = object_wise_dict['gt_box_cls']

                # compute statistics
                ups = torch.zeros(8, dtype=torch.long)
                downs = torch.zeros(8, dtype=torch.long)
                for i in range(8):
                    ups[i] = ((cls_label == i) & (pred_cls_label == i)).sum()
                    downs[i] = ((cls_label == i) | (pred_cls_label == i)).sum()

                pred_dict['scene_wise'].update(
                    ups_box_cls=ups.detach().cpu(),
                    downs_box_cls=downs.detach().cpu(),
                )
                if 'pred_box_points_seg_label' in object_wise_dict.keys(): 
                    pred_cls_label = object_wise_dict['pred_box_points_seg_label']
                    cls_label = object_wise_dict['gt_box_points_seg_label']

                    # compute statistics
                    ups = torch.zeros(8, dtype=torch.long)
                    downs = torch.zeros(8, dtype=torch.long)
                    for i in range(8):
                        ups[i] = ((cls_label == i) & (pred_cls_label == i)).sum()
                        downs[i] = ((cls_label == i) | (pred_cls_label == i)).sum()

                    pred_dict['scene_wise'].update(
                        ups_box_points_cls=ups.detach().cpu(),
                        downs_box_points_cls=downs.detach().cpu(),
                    )

                if 'gt_box_cls_num' in object_wise_dict.keys():
                    pred_dict['scene_wise'].update(
                        gt_box_cls_num=object_wise_dict['gt_box_cls_num'].detach().cpu(),
                        gt_box_points_num=object_wise_dict['gt_box_points_num'].detach().cpu(),
                        detected_box_cls_num=object_wise_dict['detected_box_cls_num'].detach().cpu(),
                        detected_box_points_num=object_wise_dict['detected_box_points_num'].detach().cpu(),

                    )   


            return pred_dict

        annos = []

        if (output_path is not None) and (not os.path.exists(output_path)):
            os.makedirs(output_path, exist_ok=True)
        for index, box_dict in enumerate(pred_dicts):
            single_pred_dict = generate_single_sample_dict(box_dict, output_path=output_path)
            annos.append(single_pred_dict)

        return annos


    def evaluation(self, pred_dicts, box_class_names, **kwargs):
      
        def kitti_eval(eval_det_annos, eval_gt_annos):
            from ..kitti.kitti_object_eval_python import eval as kitti_eval
            from ..kitti import kitti_utils

            map_name_to_kitti = {
                'Vehicle': 'Car',
                'Pedestrian': 'Pedestrian',
                'Cyclist': 'Cyclist',
                'Sign': 'Sign',
                'Car': 'Car'
            }
            kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti)
            kitti_utils.transform_annotations_to_kitti_format(
                eval_gt_annos, map_name_to_kitti=map_name_to_kitti,
                info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            kitti_class_names = [map_name_to_kitti[x] for x in class_names]
            ap_result_str, ap_dict = kitti_eval.get_official_eval_result(
                gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names
            )
            return ap_result_str, ap_dict

        def waymo_eval(eval_det_annos, eval_gt_annos):
            from .waymo_eval import OpenPCDetWaymoDetectionMetricsEstimator
            eval = OpenPCDetWaymoDetectionMetricsEstimator()

            ap_dict = eval.waymo_evaluation(
                eval_det_annos, eval_gt_annos, class_name=box_class_names,
                distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            ap_result_str = '\n'
            for key in ap_dict:
                ap_dict[key] = ap_dict[key][0]
                ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])

            return ap_result_str, ap_dict
        
        res_strs = []
        res_dicts = [] 

        if 'box' in self.evaluation_list:
            eval_gt_annos = []
            for i in range(self.__len__()):
                annos = self.infos[i]['annos']
                box_label = annos['name']
                box_difficulty = annos['difficulty']
                box_attr = annos['gt_boxes_lidar']
                box_npoints = annos['num_points_in_gt']
                #box_attr, box_label, box_difficulty, box_npoints = self.get_box3d(i)
                eval_gt_annos.append(
                    dict(
                        difficulty=box_difficulty,
                        num_points_in_gt=box_npoints,
                        name=box_label,
                        gt_boxes_lidar=box_attr
                    )
                )
            eval_det_annos = copy.deepcopy(pred_dicts)
            #eval_det_annos = translate_names(eval_det_annos)
            #eval_gt_annos = translate_names(eval_gt_annos)
            ap_result_str, ap_dict = waymo_eval(eval_det_annos, eval_gt_annos)
            res_strs.append(ap_result_str)
            res_dicts.append(ap_dict)
        
        if 'bev' in self.evaluation_list:
            
            total_ups, total_downs = None, None
            for pred_dict in pred_dicts:
                ups, downs = pred_dict['voxel_wise']['ups'], pred_dict['voxel_wise']['downs']
                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for bev class {cls} {iou:.4f} \n'
                iou_dict[f'IoU_{cls}'] = iou
                ious.append(iou)
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU_bev'] = ious.mean()
            seg_result_str += f'mIoU_bev={ious.mean():.4f} \n'
            
            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)

        if 'refine' in self.evaluation_list:
        #if 'refine' in self.evaluation_list and 'ups_mask' in pred_dict['scene_wise'].keys(): 
            total_ups, total_downs = None, None
            for pred_dict in pred_dicts:
                ups, downs = pred_dict['scene_wise']['ups_mask'], pred_dict['scene_wise']['downs_mask']
                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for box points mask class {cls} {iou:.4f} \n'
                iou_dict[f'Box_points_mask_IoU_{cls}'] = iou
                ious.append(iou)
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU_box_points_mask'] = ious.mean()
            seg_result_str += f'mIoU_box_points_mask={ious.mean():.4f} \n'
            
            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)

            total_ups, total_downs = None, None
            for pred_dict in pred_dicts:
                ups, downs = pred_dict['scene_wise']['ups_box_cls'], pred_dict['scene_wise']['downs_box_cls']
                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for box class {cls} {iou:.4f} \n'
                iou_dict[f'Box_class_IoU_{cls}'] = iou
                ious.append(iou)
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU_box_cls'] = ious.mean()
            seg_result_str += f'mIoU_box_cls={ious.mean():.4f} \n'
            
            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)
           
            if 'ups_box_points_cls' in pred_dicts[0]['scene_wise'].keys():

                total_ups, total_downs = None, None
                for pred_dict in pred_dicts:
                    ups, downs = pred_dict['scene_wise']['ups_box_points_cls'], pred_dict['scene_wise']['downs_box_points_cls']
                    if total_ups is None:
                        total_ups = ups.clone()
                        total_downs = downs.clone()
                    else:
                        total_ups += ups
                        total_downs += downs
                seg_result_str = '\n'
                iou_dict = {}
                ious = []
                for cls in range(total_ups.shape[0]):
                    iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                    seg_result_str += f'IoU for box points class {cls} {iou:.4f} \n'
                    iou_dict[f'Box_class_points_IoU_{cls}'] = iou
                ious.append(iou)
                ious = np.array(ious).reshape(-1)[1:]
                iou_dict['mIoU_box_points_cls'] = ious.mean()
                seg_result_str += f'mIoU_box_points_cls={ious.mean():.4f} \n'

                res_strs.append(seg_result_str)
                res_dicts.append(iou_dict)
            
            if 'gt_box_cls_num' in pred_dicts[0]['scene_wise'].keys():
                total_gt_cls_num = None 
                total_gt_points_num = None 
                total_detected_cls_num = None 
                total_detected_points_num = None 
                seg_result_str = '\n'
                for pred_dict in pred_dicts:
                    if total_gt_cls_num is None:
                        total_gt_cls_num = pred_dict['scene_wise']['gt_box_cls_num'].clone()
                        total_gt_points_num = pred_dict['scene_wise']['gt_box_points_num'].clone()
                        total_detected_cls_num = pred_dict['scene_wise']['detected_box_cls_num'].clone()
                        total_detected_points_num = pred_dict['scene_wise']['detected_box_points_num'].clone()
                    else:
                        total_gt_cls_num +=  pred_dict['scene_wise']['gt_box_cls_num'] 
                        total_gt_points_num += pred_dict['scene_wise']['gt_box_points_num']
                        total_detected_cls_num += pred_dict['scene_wise']['detected_box_cls_num']
                        total_detected_points_num += pred_dict['scene_wise']['detected_box_points_num']
                for cls in range(total_gt_cls_num.shape[0]):
                    seg_result_str += f'Statistics class {cls}: gt_box: {total_gt_cls_num[cls]}, detected_box: {total_detected_cls_num[cls]}, gt_points: {total_gt_points_num[cls]}, detected_points: {total_detected_points_num[cls]}\n'

                res_strs.append(seg_result_str)

        if 'seg' in self.evaluation_list:
            total_ups, total_downs = None, None
            total_ups_refine, total_downs_refine = None, None
            
            total_ap = None
            total_auroc = None
            total_num_ap = 0

            total_binary_hist = None 
            total_binary_label = None 

            for pred_dict in pred_dicts:
                ups, downs = pred_dict['scene_wise']['ups'], pred_dict['scene_wise']['downs']

                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs
                
                if 'ap' in pred_dict['scene_wise'].keys():
                    if  pred_dict['scene_wise']['ap']!=-1:
                        total_num_ap += 1
                        if total_ap is None:
                            total_ap = pred_dict['scene_wise']['ap'].clone()
                            total_auroc = pred_dict['scene_wise']['auroc'].clone()
                        else:
                            total_ap += pred_dict['scene_wise']['ap']
                            total_auroc += pred_dict['scene_wise']['auroc']
                
                if 'binary_hist' in pred_dict['scene_wise'].keys():
                    if total_binary_hist is None:
                        total_binary_hist = pred_dict['scene_wise']['binary_hist'].clone()
                        total_binary_label =  pred_dict['scene_wise']['binary_label_hist'].clone()
                    else:
                        total_binary_hist +=  pred_dict['scene_wise']['binary_hist'] 
                        total_binary_label +=  pred_dict['scene_wise']['binary_label_hist']

                if 'ups_refine' in pred_dict['scene_wise'].keys():
                    ups_rf, downs_rf = pred_dict['scene_wise']['ups_refine'], pred_dict['scene_wise']['downs_refine']
                    if total_ups_refine is None:
                        total_ups_refine = ups_rf.clone()
                        total_downs_refine = downs_rf.clone()
                    else:
                        total_ups_refine += ups_rf
                        total_downs_refine += downs_rf

            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for class {cls} {iou:.4f} \n'
                iou_dict[f'IoU_{cls}'] = iou
                ious.append(iou)
            
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU'] = ious.mean()
            iou_dict['IoU_FG'] = total_ups[1:8].sum() / np.clip(total_downs[1:8].sum(), 1, None)
            iou_dict['IoU_BG'] = total_ups[8:].sum() / np.clip(total_downs[8:].sum(), 1, None)
            seg_result_str += f'mIoU={ious.mean():.4f} \n'
            seg_result_str += f"IoU_FG={iou_dict['IoU_FG']:.4f} \n"
            seg_result_str += f"IoU_BG={iou_dict['IoU_BG']:.4f} \n"
            
            
            if total_ap is not None:
                seg_result_str += '\n'
                seg_result_str += 'AP: %f\n'%(total_ap/total_num_ap)
                seg_result_str += 'AUROC: %f\n'%(total_auroc/total_num_ap)
            
            if total_binary_hist is not None:
                preds = torch.zeros(total_binary_hist.sum().int().item())
                gts = torch.zeros(total_binary_hist.sum().int().item())
                cnt = 0
                total_binary_hist = total_binary_hist.int() 
                total_binary_label = total_binary_label.int()
                n_hist = 100
                interval = float(1/n_hist)
                for ii in range(n_hist):
                    if total_binary_hist[ii]==0:
                        continue 
                    preds[cnt:cnt+total_binary_hist[ii]] = (ii+0.5)*interval 
                    gts[cnt:cnt+total_binary_label[ii]] = 1.0 
                    cnt += total_binary_hist[ii]

                gts = gts.long()
                ap_eval = torchmetrics.AveragePrecision(task="binary")
                ap = ap_eval(preds, gts).detach().cpu()
                auroc_eval = torchmetrics.AUROC(task="binary")
                auroc = auroc_eval(preds, gts).detach().cpu()
                seg_result_str += '\n'
                seg_result_str += 'AP: %f\n'%(ap)
                seg_result_str += 'AUROC: %f\n'%(auroc)

            if total_ups_refine is not None:
                ious = []
                for cls in range(total_ups_refine.shape[0]):
                    iou = total_ups_refine[cls]/np.clip(total_downs_refine[cls], 1, None)
                    seg_result_str += f'IoU for class refined {cls} {iou:.4f} \n'
                    iou_dict[f'IoU_{cls}_refined'] = iou
                    ious.append(iou)
                    
                ious = np.array(ious).reshape(-1)[1:]
                iou_dict['mIoU_refined'] = ious.mean()
                iou_dict['IoU_FG_refined'] = total_ups_refine[1:8].sum() / np.clip(total_downs_refine[1:8].sum(), 1, None)
                iou_dict['IoU_BG_refined'] = total_ups_refine[8:].sum() / np.clip(total_downs_refine[8:].sum(), 1, None)
                seg_result_str += f'mIoU_refined={ious.mean():.4f} \n'
                seg_result_str += f"IoU_FG_refined={iou_dict['IoU_FG_refined']:.4f} \n"
                seg_result_str += f"IoU_BG_refined={iou_dict['IoU_BG_refined']:.4f} \n"

            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)
        
        return res_strs, res_dicts 


