import copy
import gc
import os
import pickle
from pathlib import Path
import SharedArray as SA
import torch
from torch.utils.data import Dataset
import glob

import numpy as np
from ..dataset import DatasetTemplate
from pcdet.utils import common_utils


from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils, polar_utils
from ..dataset import DatasetTemplate
from ...utils.vis_utils import write_ply_color

from scipy.sparse import csr_matrix
from torch_cluster import radius_graph
from sklearn.neighbors import NearestNeighbors as NN
from scipy.sparse.csgraph import connected_components


class ScanNetDataset(DatasetTemplate):
    def __init__(self, dataset_cfg, training=True, root_path=None, logger=None):
        """
        split: [train, val, test]
        """
        super().__init__(
            dataset_cfg=dataset_cfg, training=training, root_path=root_path, logger=logger,
        )
        self.dataset_name = self.dataset_cfg.DATASET
        self.data_path = self.root_path
        data_root = self.data_path
        #self.data_path = os.path.join(self.root_path, 'scannet_train_detection_data')
        
        self.ignore_index = [0]
        self.num_sweeps = self.dataset_cfg.get("NUM_SWEEPS", 1)
        self.remove_back = self.dataset_cfg.get("REMOVE_BACK", 1)
        
        self.split =self.dataset_cfg.DATA_SPLIT[self.mode]
        split = self.dataset_cfg.DATA_SPLIT[self.mode] 
        self.max_voxel = 120000
        self.num_classes = 21

        if split == "train" or split == 'val':
            self.data_list = glob.glob(os.path.join(data_root, split, "*.pth"))
        elif split == 'trainval':
            self.data_list = glob.glob(os.path.join(data_root, "train", "*.pth")) + glob.glob(os.path.join(data_root, "val", "*.pth"))
        else:
            raise ValueError("no such split: {}".format(split))

        print("Totally {} samples in {} set.".format(len(self.data_list), split))

    def __len__(self):
        # return len(self.data_idx) * self.loop
        return len(self.data_list)
    
    def __getitem__(self, index, mix3d=False, ins_cur=None):
        
        data_idx = index % len(self.data_list)
        data_path = self.data_list[data_idx]
        scan_name = data_path.split('/')[-1][:12]

        data = torch.load(data_path)

        coord, feat = data[0], data[1]
        if self.split != 'test':
            label = data[2]
            label += 1 
            label[label==-99] = 0 
        
        if 1:

            # generate background instance label
            # 8: sign, 9: traffic light, 10: pole, 11: cone, 12: bicyle, 13: motorcycle, 14: building, 15: vegetation, 16: treetruck
            instance_label_back = -1*np.ones(label.shape[0]).astype(np.int32)
            seg_label = label
            points = coord
            cur_num = instance_label_back.max() + 1
            back_cls_list = [3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]

            for ii in back_cls_list:
                seg_mask = seg_label==ii
                seg_ids = (seg_mask> 0).nonzero()[0]

                if len(seg_ids)<10:
                    continue
                masked_points = points[seg_ids, :]
                num_points= masked_points.shape[0]

                tree = NN(n_neighbors=10).fit(masked_points)
                dists, indices = tree.kneighbors(masked_points)
                e0 = np.arange(num_points).repeat(10)
                e1 = indices.reshape(-1)
                mask = dists.reshape(-1) < 1.0
                #mask = dists.reshape(-1) / (prange + 1e-6) < dist_th
                e0, e1 = e0[mask], e1[mask]
                graph = csr_matrix((np.ones_like(e0), (e0, e1)), shape=(num_points, num_points))
                n_components, labels = connected_components(graph, directed=False)
                #print('num_points', ii, np.unique(labels), x.shape, n_components)
                for lb in np.unique(labels):
                    ids = seg_ids[(labels==lb).nonzero()[0]]
                    if len(ids)<10:
                        continue
                    instance_label_back[ids] = cur_num
                    cur_num += 1

        instance_label = instance_label_back.copy()

        point_wise_dict = dict(
            point_xyz=coord,
            point_feat=feat,
            segmentation_label=label,
            instance_label = instance_label, 
        )

        object_wise_dict = dict(
            gt_box_attr=np.zeros((0, 7)),
            gt_box_cls_label=np.zeros((0)).astype(str),
        )

        scene_wise_dict = dict(
            frame_id = scan_name,
        )

        data_dict = dict(
            point_wise = point_wise_dict,
            object_wise = object_wise_dict,
            scene_wise = scene_wise_dict,
        )
        data_dict = self.prepare_data(data_dict)

        if (self.mix3d_cfg is not None) and (not mix3d) and self.training:
            prob = self.mix3d_cfg.get("PROB", 1.0)
            if np.random.rand() < prob:
                num_samples = self.__len__()
                rand_idx = np.random.randint(0, num_samples)
                data_dict2 = self.__getitem__(rand_idx, mix3d=True)
                data_dict['point_wise'] = common_utils.concat_dicts([data_dict['point_wise'], data_dict2['point_wise']])
                data_dict['object_wise'] = common_utils.concat_dicts([data_dict['object_wise'], data_dict2['object_wise']])
        #print(data_dict['point_wise']['point_feat'].shape)
        return data_dict
    
    def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None):
        """
        Args:
            batch_dict:
                frame_id:
            pred_dicts: list of pred_dicts
                pred_boxes: (N, 7), Tensor
                pred_scores: (N), Tensor
                pred_labels: (N), Tensor
            class_names:
            output_path:
        Returns:
        """

        def generate_single_sample_dict(cur_dict, output_path=None):
            frame_id = str(cur_dict['scene_wise']['frame_id'].reshape(-1)[-1])
            sequence_id, sample_idx = frame_id.split('_')

            pred_dict = dict(
                object_wise=dict(),
                point_wise=dict(),
                scene_wise=dict(
                    frame_id=frame_id,
                    sequence_id=sequence_id,
                    sample_idx=sample_idx,
                ),
                voxel_wise=dict(),
            )

            if 'pred_box_attr' in cur_dict['object_wise']:
                object_wise_dict = cur_dict['object_wise']

                pred_scores = object_wise_dict['pred_box_scores'].cpu().numpy()
                pred_boxes = object_wise_dict['pred_box_attr'].cpu().numpy()
                pred_labels = object_wise_dict['pred_box_cls_label'].cpu().numpy()

                pred_dict['object_wise'].update(
                    dict(
                        box_scores=pred_scores,
                        box_attr=pred_boxes,
                        box_name=np.array(class_names)[pred_labels - 1],
                    ))

            if 'pred_segmentation_label' in cur_dict['point_wise']:

                point_wise_dict = cur_dict['point_wise']
                # propagate segmentation from predicted points to gt points
                # path = f'../data/waymo/waymo_processed_data_v0_5_0/{sequence_id}/{sample_idx:04d}_seg.npy'
                data_path = os.path.join(self.data_path, self.split, frame_id+'_inst_nostuff.pth')
                data = torch.load(data_path)

                point_xyz, feat = data[0], data[1]
                label = data[2]
                instance_label = data[3]
                label += 1
                label[label==-99] = 0
                instance_label += 1
                instance_label[instance_label==99] = 0
                segmentation_label = label


                tree = NN(n_neighbors=1).fit(point_wise_dict['point_xyz'].detach().cpu().numpy())
                dists, indices = tree.kneighbors(point_xyz)
                pred_segmentation_label = point_wise_dict['pred_segmentation_label'].detach().cpu().numpy()[indices[:, 0]]
                

                if 0:
                    from ...utils.vis_utils import write_ply_color
                    out_path = 'tmp_refine_vis/strat2'
                    if not os.path.exists(out_path):
                        os.makedirs(out_path)

                    name = cur_dict['scene_wise']['frame_id']


                    write_ply_color(point_xyz, segmentation_label, os.path.join(out_path, '%s_gt_seg.ply'%(name)))
                    #write_ply_color(points_xyz, gt_ins_label, os.path.join(out_path, '%s_gt_ins.ply'%(name)))
                    write_ply_color(point_xyz, pred_segmentation_label, os.path.join(out_path, '%s_pred_seg.ply'%(name)))
                    print('saved')


                pred_dict['point_wise'].update(
                    dict(
                        pred_segmentation_label=pred_segmentation_label,
                    ))

                for ignore_index in self.ignore_index:
                    if self.split!='test':
                        pred_segmentation_label[segmentation_label == ignore_index] = ignore_index

                point_wise_dict['pred_segmentation_label'] = torch.from_numpy(pred_segmentation_label)

                # compute statistics
                ups = torch.zeros(self.num_classes, dtype=torch.long)
                downs = torch.zeros(self.num_classes, dtype=torch.long)
                for i in range(self.num_classes):
                    ups[i] = ((segmentation_label == i) & (pred_segmentation_label == i)).sum()
                    downs[i] = ((segmentation_label == i) | (pred_segmentation_label == i)).sum()
                

                if output_path is not None:
                    pred_labels = point_wise_dict['pred_segmentation_label'].detach().to(torch.uint8).cpu()
                    os.makedirs(output_path / sequence_id, exist_ok=True)
                    path = str(output_path / sequence_id / f"{sample_idx:03d}_pred.npy")
                    np.save(path, pred_labels)

                pred_dict['scene_wise'].update(
                    ups=ups.detach().cpu(),
                    downs=downs.detach().cpu(),
                )

            return pred_dict

        annos = []

        if (output_path is not None) and (not os.path.exists(output_path)):
            os.makedirs(output_path, exist_ok=True)
        for index, box_dict in enumerate(pred_dicts):
            single_pred_dict = generate_single_sample_dict(box_dict, output_path=output_path)
            annos.append(single_pred_dict)

        return annos


    def evaluation(self, pred_dicts, box_class_names, **kwargs):
      
        def kitti_eval(eval_det_annos, eval_gt_annos):
            from ..kitti.kitti_object_eval_python import eval as kitti_eval
            from ..kitti import kitti_utils

            map_name_to_kitti = {
                'Vehicle': 'Car',
                'Pedestrian': 'Pedestrian',
                'Cyclist': 'Cyclist',
                'Sign': 'Sign',
                'Car': 'Car'
            }
            kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti)
            kitti_utils.transform_annotations_to_kitti_format(
                eval_gt_annos, map_name_to_kitti=map_name_to_kitti,
                info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            kitti_class_names = [map_name_to_kitti[x] for x in class_names]
            ap_result_str, ap_dict = kitti_eval.get_official_eval_result(
                gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names
            )
            return ap_result_str, ap_dict

        def waymo_eval(eval_det_annos, eval_gt_annos):
            from .waymo_eval import OpenPCDetWaymoDetectionMetricsEstimator
            eval = OpenPCDetWaymoDetectionMetricsEstimator()

            ap_dict = eval.waymo_evaluation(
                eval_det_annos, eval_gt_annos, class_name=box_class_names,
                distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
            )
            ap_result_str = '\n'
            for key in ap_dict:
                ap_dict[key] = ap_dict[key][0]
                ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])

            return ap_result_str, ap_dict
        
        res_strs = []
        res_dicts = [] 

        if 'box' in self.evaluation_list:
            eval_gt_annos = []
            for i in range(self.__len__()):
                annos = self.infos[i]['annos']
                box_label = annos['name']
                box_difficulty = annos['difficulty']
                box_attr = annos['gt_boxes_lidar']
                box_npoints = annos['num_points_in_gt']
                #box_attr, box_label, box_difficulty, box_npoints = self.get_box3d(i)
                eval_gt_annos.append(
                    dict(
                        difficulty=box_difficulty,
                        num_points_in_gt=box_npoints,
                        name=box_label,
                        gt_boxes_lidar=box_attr
                    )
                )
            eval_det_annos = copy.deepcopy(pred_dicts)
            #eval_det_annos = translate_names(eval_det_annos)
            #eval_gt_annos = translate_names(eval_gt_annos)
            ap_result_str, ap_dict = waymo_eval(eval_det_annos, eval_gt_annos)
            res_strs.append(ap_result_str)
            res_dicts.append(ap_dict)
        
        if 'seg' in self.evaluation_list:
            total_ups, total_downs = None, None
            
            total_ups_refine, total_downs_refine = None, None
            
            for pred_dict in pred_dicts:
                ups, downs = pred_dict['scene_wise']['ups'], pred_dict['scene_wise']['downs']

                if total_ups is None:
                    total_ups = ups.clone()
                    total_downs = downs.clone()
                else:
                    total_ups += ups
                    total_downs += downs

            seg_result_str = '\n'
            iou_dict = {}
            ious = []
            for cls in range(total_ups.shape[0]):
                iou = total_ups[cls]/np.clip(total_downs[cls], 1, None)
                seg_result_str += f'IoU for class {cls} {iou:.4f} \n'
                iou_dict[f'IoU_{cls}'] = iou
                ious.append(iou)
            
            ious = np.array(ious).reshape(-1)[1:]
            iou_dict['mIoU'] = ious.mean()
            iou_dict['IoU_FG'] = total_ups[1:8].sum() / np.clip(total_downs[1:8].sum(), 1, None)
            iou_dict['IoU_BG'] = total_ups[8:].sum() / np.clip(total_downs[8:].sum(), 1, None)
            seg_result_str += f'mIoU={ious.mean():.4f} \n'
            seg_result_str += f"IoU_FG={iou_dict['IoU_FG']:.4f} \n"
            seg_result_str += f"IoU_BG={iou_dict['IoU_BG']:.4f} \n"
            
            res_strs.append(seg_result_str)
            res_dicts.append(iou_dict)
        
        return res_strs, res_dicts 

