# ------------------------------------------------------------------------
# Copyright (c) 2023 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR3D (https://github.com/WangYueFt/detr3d)
# Copyright (c) 2021 Wang, Yue
# ------------------------------------------------------------------------
# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------

import numpy as np
import mmcv
from mmdet.datasets import DATASETS
from mmdet3d.datasets import NuScenesDataset
from .custom_nuscenes_dataset import CustomNuScenesDataset

from nuscenes import NuScenes
from nuscenes.eval.detection.evaluate import NuScenesEval
from nuscenes.eval.common.data_classes import EvalBoxes
from nuscenes.eval.detection.data_classes import DetectionConfig, DetectionMetrics, DetectionBox, \
    DetectionMetricDataList
from nuscenes.eval.detection.render import summary_plot, class_pr_curve, class_tp_curve, dist_pr_curve, visualize_sample
from nuscenes.eval.detection.algo import accumulate, calc_ap, calc_tp
from nuscenes.eval.detection.constants import TP_METRICS

import random
import os
import os.path as osp
import json
import time
from typing import Dict, Any, Tuple


class NuScenesEvalSingleScene(NuScenesEval):
    def evaluate(self) -> Tuple[Dict[str, DetectionMetrics], Dict[str, DetectionMetricDataList]]:
        metrics = dict()
        metric_data_list = dict()

        for token in mmcv.track_iter_progress(self.gt_boxes.sample_tokens):
            pred_box = self.pred_boxes[token]
            pred_boxes = EvalBoxes()
            pred_boxes.add_boxes(token, pred_box)

            gt_box = self.gt_boxes[token]
            gt_boxes = EvalBoxes()
            gt_boxes.add_boxes(token, gt_box)

            metrics_sample, metric_data_list_sample = self.evaluate_single(gt_boxes, pred_boxes)
            metrics[token] = metrics_sample
            metric_data_list[token] = metric_data_list_sample
        return metrics, metric_data_list

    def evaluate_single(self, gt_boxes: EvalBoxes, pred_boxes: EvalBoxes) -> Tuple[DetectionMetrics, DetectionMetricDataList]:
        """
        Performs the actual evaluation.
        :return: A tuple of high-level and the raw metric data.
        """
        start_time = time.time()

        # -----------------------------------
        # Step 1: Accumulate metric data for all classes and distance thresholds.
        # -----------------------------------
        if self.verbose:
            print('Accumulating metric data...')
        metric_data_list = DetectionMetricDataList()
        for class_name in self.cfg.class_names:
            for dist_th in self.cfg.dist_ths:
                md = accumulate(gt_boxes, pred_boxes, class_name, self.cfg.dist_fcn_callable, dist_th)
                metric_data_list.set(class_name, dist_th, md)

        # -----------------------------------
        # Step 2: Calculate metrics from the data.
        # -----------------------------------
        if self.verbose:
            print('Calculating metrics...')
        metrics = DetectionMetrics(self.cfg)
        for class_name in self.cfg.class_names:
            # Compute APs.
            for dist_th in self.cfg.dist_ths:
                metric_data = metric_data_list[(class_name, dist_th)]
                ap = calc_ap(metric_data, self.cfg.min_recall, self.cfg.min_precision)
                metrics.add_label_ap(class_name, dist_th, ap)

            # Compute TP metrics.
            for metric_name in TP_METRICS:
                metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)]
                if class_name in ['traffic_cone'] and metric_name in ['attr_err', 'vel_err', 'orient_err']:
                    tp = np.nan
                elif class_name in ['barrier'] and metric_name in ['attr_err', 'vel_err']:
                    tp = np.nan
                else:
                    tp = calc_tp(metric_data, self.cfg.min_recall, metric_name)
                metrics.add_label_tp(class_name, metric_name, tp)

        # Compute evaluation time.
        metrics.add_runtime(time.time() - start_time)

        return metrics, metric_data_list

    def main(self,
             plot_examples: int = 0,
             render_curves: bool = True) -> Dict[str, Any]:
        """
        Main function that loads the evaluation code, visualizes samples, runs the evaluation and renders stat plots.
        :param plot_examples: How many example visualizations to write to disk.
        :param render_curves: Whether to render PR and TP curves to disk.
        :return: A dict that stores the high-level metrics and meta data.
        """
        if plot_examples > 0:
            raise NotImplementedError()

        # Run evaluation.
        metrics, metric_data_list = self.evaluate()

        # Render PR and TP curves.
        if render_curves:
            raise NotImplementedError()

        # Dump the metric data, meta and metrics to disk.
        if self.verbose:
            print('Saving metrics to: %s' % self.output_dir)
        metrics_summary = dict(metrics=dict())
        for token, metrics_sample in metrics.items():
            metrics_summary_sample = metrics_sample.serialize()
            metrics_summary['metrics'][token] = metrics_summary_sample
        metrics_summary['meta'] = self.meta.copy()
        with open(os.path.join(self.output_dir, 'metrics_summary.json'), 'w') as f:
            json.dump(metrics_summary, f, indent=2)

        metrics_details = dict()
        for token, metric_data_list_sample in metric_data_list.items():
            metrics_details[token] = metric_data_list_sample.serialize()
        with open(os.path.join(self.output_dir, 'metrics_details.json'), 'w') as f:
            json.dump(metrics_details, f, indent=2)

        # # Print high-level metrics.
        # print('mAP: %.4f' % (metrics_summary['mean_ap']))
        # err_name_mapping = {
        #     'trans_err': 'mATE',
        #     'scale_err': 'mASE',
        #     'orient_err': 'mAOE',
        #     'vel_err': 'mAVE',
        #     'attr_err': 'mAAE'
        # }
        # for tp_name, tp_val in metrics_summary['tp_errors'].items():
        #     print('%s: %.4f' % (err_name_mapping[tp_name], tp_val))
        # print('NDS: %.4f' % (metrics_summary['nd_score']))
        # print('Eval time: %.1fs' % metrics_summary['eval_time'])

        # # Print per-class metrics.
        # print()
        # print('Per-class results:')
        # print('%-20s\t%-6s\t%-6s\t%-6s\t%-6s\t%-6s\t%-6s' % ('Object Class', 'AP', 'ATE', 'ASE', 'AOE', 'AVE', 'AAE'))
        # class_aps = metrics_summary['mean_dist_aps']
        # class_tps = metrics_summary['label_tp_errors']
        # for class_name in class_aps.keys():
        #     print('%-20s\t%-6.3f\t%-6.3f\t%-6.3f\t%-6.3f\t%-6.3f\t%-6.3f'
        #         % (class_name, class_aps[class_name],
        #             class_tps[class_name]['trans_err'],
        #             class_tps[class_name]['scale_err'],
        #             class_tps[class_name]['orient_err'],
        #             class_tps[class_name]['vel_err'],
        #             class_tps[class_name]['attr_err']))

        return metrics_summary


@DATASETS.register_module()
class SceneEvalNuScenesDataset(CustomNuScenesDataset):
    def __init__(self, *args, eval_set=None, **kwargs):
        super(SceneEvalNuScenesDataset, self).__init__(*args, **kwargs)
        self.eval_set = eval_set

    def evaluate(self,
                 results,
                 metric='bbox',
                 logger=None,
                 jsonfile_prefix=None,
                 result_names=['pts_bbox'],
                 show=False,
                 out_dir=None,
                 pipeline=None):
        """Evaluation in nuScenes protocol.

        Args:
            results (list[dict]): Testing results of the dataset.
            metric (str | list[str], optional): Metrics to be evaluated.
                Default: 'bbox'.
            logger (logging.Logger | str, optional): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str, optional): The prefix of json files including
                the file path and the prefix of filename, e.g., "a/b/prefix".
                If not specified, a temp file will be created. Default: None.
            show (bool, optional): Whether to visualize.
                Default: False.
            out_dir (str, optional): Path to save the visualization results.
                Default: None.
            pipeline (list[dict], optional): raw data loading for showing.
                Default: None.

        Returns:
            dict[str, float]: Results of each evaluation metric.
        """
        result_files, tmp_dir = self.format_results(results, jsonfile_prefix)

        if isinstance(result_files, dict):
            results_dict = dict()
            for name in result_names:
                print('Evaluating bboxes of {}'.format(name))
                ret_dict = self._evaluate_single(result_files[name])
            results_dict.update(ret_dict)
        elif isinstance(result_files, str):
            results_dict = self._evaluate_single(result_files)

        if tmp_dir is not None:
            tmp_dir.cleanup()

        if show or out_dir:
            self.show(results, out_dir, show=show, pipeline=pipeline)
        return results_dict

    def _evaluate_single(self,
                         result_path,
                         logger=None,
                         metric='bbox',
                         result_name='pts_bbox'):
        """Evaluation for a single model in nuScenes protocol.

        Args:
            result_path (str): Path of the result file.
            logger (logging.Logger | str, optional): Logger used for printing
                related information during evaluation. Default: None.
            metric (str, optional): Metric name used for evaluation.
                Default: 'bbox'.
            result_name (str, optional): Result name in the metric prefix.
                Default: 'pts_bbox'.

        Returns:
            dict: Dictionary of evaluation details.
        """
        output_dir = osp.join(*osp.split(result_path)[:-1])
        nusc = NuScenes(
            version=self.version, dataroot=self.data_root, verbose=False)
        if self.eval_set is None:
            eval_set_map = {
                'v1.0-mini': 'mini_val',
                'v1.0-trainval': 'val',
            }
            eval_set = eval_set_map[self.version]
        else:
            eval_set = self.eval_set
        nusc_eval = NuScenesEvalSingleScene(
            nusc,
            config=self.eval_detection_configs,
            result_path=result_path,
            eval_set=eval_set,
            output_dir=output_dir,
            verbose=False)
        nusc_eval.main(render_curves=False)

        # record metrics
        metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json'))
        # detail = dict()
        # metric_prefix = f'{result_name}_NuScenes'
        # for name in self.CLASSES:
        #     for k, v in metrics['label_aps'][name].items():
        #         val = float('{:.4f}'.format(v))
        #         detail['{}/{}_AP_dist_{}'.format(metric_prefix, name, k)] = val
        #     for k, v in metrics['label_tp_errors'][name].items():
        #         val = float('{:.4f}'.format(v))
        #         detail['{}/{}_{}'.format(metric_prefix, name, k)] = val
        #     for k, v in metrics['tp_errors'].items():
        #         val = float('{:.4f}'.format(v))
        #         detail['{}/{}'.format(metric_prefix,
        #                               self.ErrNameMapping[k])] = val

        # detail['{}/NDS'.format(metric_prefix)] = metrics['nd_score']
        # detail['{}/mAP'.format(metric_prefix)] = metrics['mean_ap']
        return metrics