import pickle
import time
import os 

import numpy as np
import torch
import tqdm
from sklearn.metrics import confusion_matrix

from pcdet.models import load_data_to_gpu
from pcdet.utils import common_utils
from pcdet.utils.vis_utils import write_ply_color
from .vis_output import write_output

KITTI_TEST_DICT = {'11': 921, '12': 1061, '13': 3281, '14': 631, '15': 1901, '16': 1731, '17': 491, '18': 1801, '19': 4981,
             '20': 831, '21': 2721}

def statistics_info(cfg, ret_dict, metric, disp_dict):
    if "POST_PROCESSING" in cfg.MODEL:
        for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
            metric['recall_roi_%s' % str(cur_thresh)] += ret_dict.get('roi_%s' % str(cur_thresh), 0)
            metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict.get('rcnn_%s' % str(cur_thresh), 0)
        metric['gt_num'] += ret_dict.get('gt', 0)
        min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0]
        disp_dict['recall_%s' % str(min_thresh)] = \
            '(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num'])

def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, save_to_file=False, result_dir=None, dump_test=False, \
        label_map_inv=None, rank=0):

    result_dir.mkdir(parents=True, exist_ok=True)

    final_output_dir = result_dir / 'final_result' / 'data'
    if save_to_file:
        final_output_dir.mkdir(parents=True, exist_ok=True)

    dataset_name = dataloader.dataset.dataset_name
    if dump_test and dataset_name=='SemanticKITTIDataset':
        dump_dir = os.path.join(result_dir, 'dump_test', 'sequences')
        for test_id in KITTI_TEST_DICT.keys(): 
            if not os.path.exists(os.path.join(dump_dir, test_id, 'predictions')):
                os.makedirs(os.path.join(dump_dir, test_id, 'predictions'))

        vis_dir = os.path.join(result_dir, 'dump_vis')
        if not os.path.exists(os.path.join(result_dir, 'dump_vis')):
            os.makedirs(os.path.join(result_dir, 'dump_vis'))
        print(dump_dir, )
    
    elif dump_test:
        sample_sequence_list = dataloader.dataset.sample_sequence_list
        dump_dir = os.path.join(result_dir, 'dump_test', )

        if rank==0:
            vis_dir = os.path.join(result_dir, 'dump_vis')
            if not os.path.exists(os.path.join(result_dir, 'dump_vis')):
                os.makedirs(os.path.join(result_dir, 'dump_vis'))

            for name in sample_sequence_list:
                if not os.path.exists(os.path.join(dump_dir, name[:-9])):
                    os.makedirs(os.path.join(dump_dir, name[:-9]))
        print(dump_dir)    
    metric = {
        'gt_num': 0,
    }
    if "POST_PROCESSING" in cfg.MODEL:
        for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
            metric['recall_roi_%s' % str(cur_thresh)] = 0
            metric['recall_rcnn_%s' % str(cur_thresh)] = 0

    dataset = dataloader.dataset
    det_annos = []
    
    dataset_name = dataset.dataset_name 

    conf_matrix = np.zeros((23,23,))
    conf_matrix_norm = np.zeros((23,23))

    logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id)
    if dist_test:
        num_gpus = torch.cuda.device_count()
        local_rank = cfg.LOCAL_RANK % num_gpus
        model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False
        )
    model.eval()

    if cfg.LOCAL_RANK == 0:
        progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True)
    start_time = time.time()
    
    #rare_classes = [2,3,4,5,6,9,11, 12, 13]
    #rare_classes = [2, 3, 4, 5, 8, ] # kitti 
    rare_classes = [4, 6, 7, 8,9,11,12, 13, 17]
    for i, batch_dict in enumerate(dataloader):

        load_data_to_gpu(batch_dict)
        with torch.no_grad():
            pred_dicts, ret_dict = model(batch_dict)
        disp_dict = {}
        
        if 0:
            write_output(pred_dicts, batch_dict, result_dir, rare_classes)

        conf_matrix_i = confusion_matrix(pred_dicts[0]['point_wise']['gt_segmentation_label'].cpu().numpy(), \
                pred_dicts[0]['point_wise']['pred_segmentation_label'].cpu().numpy(), \
                labels = range(23))
        conf_matrix += conf_matrix_i / 1000.0   

        if ret_dict is not None:
            statistics_info(cfg, ret_dict, metric, disp_dict)
        annos = dataset.generate_prediction_dicts(
            batch_dict, pred_dicts, cfg.DATA_CONFIG.get("BOX_CLASSES", None),
            output_path=final_output_dir if save_to_file else None
        )
        det_annos += annos
        if dump_test and dataset_name=='SemanticKITTIDataset':
            
            for pred_dict in annos:
                    
                name = pred_dict['scene_wise']['frame_id']
                seq_id, name_id = name.split('_')
                
                point_wise_dict = pred_dict['point_wise']
                if 0:
                #if i%10==0:
                    write_ply_color(point_wise_dict['point_xyz'].cpu().numpy(), point_wise_dict['gt_segmentation_label'].cpu().numpy(), \
                            os.path.join(vis_dir, '%s_gt_seg.ply'%(name)))
                    write_ply_color(point_wise_dict['point_xyz'].cpu().numpy(), point_wise_dict['pred_segmentation_label'].detach().cpu().numpy(), \
                            os.path.join(vis_dir, '%s_pred_seg.ply'%(name)))
                
                pred_label = point_wise_dict['pred_segmentation_label']
                
                if label_map_inv is not None:
                    mapped_pred_label = label_map_inv[pred_label]
                    mapped_pred_label = mapped_pred_label.astype(np.uint32)
                    mapped_pred_label.tofile(os.path.join(dump_dir, seq_id, 'predictions', name_id+'.label'))
        
        elif dump_test:
            for pred_dict in annos:
                frame_id = pred_dict['scene_wise']['frame_id']
                seq_id, name_id = frame_id[:-4], frame_id[-3:]
                context_name = pred_dict['scene_wise']['context_name']
                timestamp = pred_dict['scene_wise']['timestamp']

                point_wise_dict = pred_dict['point_wise']
                if 0:
                #if i%10==0:
                    write_ply_color(point_wise_dict['point_xyz'], point_wise_dict['pred_segmentation_label'], \
                            os.path.join(vis_dir, '%s_pred_seg.ply'%(frame_id)))

                pred_label = point_wise_dict['pred_segmentation_label']
                pred_label = pred_label.astype(np.uint32)
                out_dump_dir = os.path.join(dump_dir, seq_id)
                #out_dump_dir = os.path.join(dump_dir, context_name)
                if not os.path.exists(out_dump_dir):
                    os.makedirs(out_dump_dir)
                pred_label.tofile(os.path.join(out_dump_dir, timestamp+'.label'))
                #pred_label.tofile(os.path.join(out_dump_dir, name_id+'.label'))

        if cfg.LOCAL_RANK == 0:
            progress_bar.set_postfix(disp_dict)
            progress_bar.update()

    if cfg.LOCAL_RANK == 0:
        progress_bar.close()

    if dist_test:
        rank, world_size = common_utils.get_dist_info()
        det_annos = common_utils.merge_results_dist(det_annos, len(dataset), tmpdir=result_dir / 'tmpdir')
        metric = common_utils.merge_results_dist([metric], world_size, tmpdir=result_dir / 'tmpdir')

    logger.info('*************** Performance of EPOCH %s *****************' % epoch_id)
    sec_per_example = (time.time() - start_time) / len(dataloader.dataset)
    logger.info('Generate label finished(sec_per_example: %.4f second).' % sec_per_example)

    if cfg.LOCAL_RANK != 0:
        return {}

    ret_dict = {}
    if dist_test:
        for key, val in metric[0].items():
            for k in range(1, world_size):
                metric[0][key] += metric[k][key]
        metric = metric[0]

    gt_num_cnt = metric['gt_num']
    if "POST_PROCESSING" in cfg.MODEL:
        for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
            cur_roi_recall = metric['recall_roi_%s' % str(cur_thresh)] / max(gt_num_cnt, 1)
            cur_rcnn_recall = metric['recall_rcnn_%s' % str(cur_thresh)] / max(gt_num_cnt, 1)
            logger.info('recall_roi_%s: %f' % (cur_thresh, cur_roi_recall))
            logger.info('recall_rcnn_%s: %f' % (cur_thresh, cur_rcnn_recall))
            ret_dict['recall/roi_%s' % str(cur_thresh)] = cur_roi_recall
            ret_dict['recall/rcnn_%s' % str(cur_thresh)] = cur_rcnn_recall

    if len(det_annos) > 0 and ('name' in det_annos[0]):
        total_pred_objects = 0
        for anno in det_annos:
            total_pred_objects += anno['name'].__len__()
        logger.info('Average predicted number of objects(%d samples): %.3f'
                    % (len(det_annos), total_pred_objects / max(1, len(det_annos))))
    
    np.save(os.path.join(result_dir, 'confusion_matrix.npy'), conf_matrix)
    
    with open(result_dir / 'result.pkl', 'wb') as f:
        pickle.dump(det_annos, f)
    result_str, result_dict = dataset.evaluation(
        det_annos, cfg.DATA_CONFIG.get("BOX_CLASSES", None),
        output_path=final_output_dir
    )

    for r_str in result_str:
        logger.info(r_str)
    
    for r_dict in result_dict:
        ret_dict.update(r_dict)

    logger.info('Result is save to %s' % result_dir)
    logger.info('****************Evaluation done.*****************')
    return ret_dict


if __name__ == '__main__':
    pass
