# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings
import subprocess
import shlex
import signal
import tempfile

import mmcv
import torch
from torch.utils.data import Dataset, DataLoader
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)

import mmdet
from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.datasets.pipelines import Compose
from mmdet3d.models import build_model
from mmdet.apis import multi_gpu_test, set_random_seed
from mmdet.datasets import replace_ImageToTensor

if mmdet.__version__ > '2.23.0':
    # If mmdet version > 2.23.0, setup_multi_processes would be imported and
    # used from mmdet instead of mmdet3d.
    from mmdet.utils import setup_multi_processes
else:
    from mmdet3d.utils import setup_multi_processes

try:
    # If mmdet version > 2.23.0, compat_cfg would be imported and
    # used from mmdet instead of mmdet3d.
    from mmdet.utils import compat_cfg
except ImportError:
    from mmdet3d.utils import compat_cfg

import time
import json
import numpy as np
from copy import deepcopy
from tqdm import tqdm

import sys
sys.path.extend(['.', '..'])
from projects.mmdet3d_plugin.utils.timer import CUDATimer
from parse_tegrastats import calculate_average_power_from_log


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--camera-min-duration', type=float, required=True)
    parser.add_argument('--camera-boot-time', type=float, required=True)
    parser.add_argument('--camera-power', type=float, required=True)
    parser.add_argument('--lidar-min-duration', type=float, required=True)
    parser.add_argument('--lidar-boot-time', type=float, required=True)
    parser.add_argument('--lidar-power', type=float, required=True)
    parser.add_argument('--sensor-default-status', type=str, choices=['activated', 'off'])
    parser.add_argument('--reset-min-duration', action='store_true')
    parser.add_argument('--wait-for-booting', action='store_true')
    parser.add_argument('--save-power-data', type=str, default=None, help='json file path for power data')
    parser.add_argument('--save-results', type=str, default=None, help='Output json/pkl file path')
    parser.add_argument('--save-metas', type=str, default=None, help='Output json file path')
    parser.add_argument('--save-outputs', type=str, default=None, help='Model output pkl path')
    parser.add_argument('--samples', type=int, default=None, help='Number of samples to test')
    parser.add_argument('--no-eval', action='store_true', help='Disable evaluation, just test inference time')
    parser.add_argument('--tegrastats', action='store_true', help='Enable a subprocess to run tegrastats (root needed)')
    parser.add_argument('--tegrastats-warmup', type=int, default=20, help='Warmup iterations before runing tegrastats')
    parser.add_argument('--tegrastats-interval', type=int, default=100, help='Logging interval for tegrastats in milliseconds')
    parser.add_argument('--tegrastats-logfile', type=str, default=None, help='Path to the logging file for tegrastats')
    parser.add_argument(
        '--format-only',
        action='store_true',
        help='Format the output results without perform evaluation. It is'
        'useful when you want to format the result to a specific format and '
        'submit it to the test server')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--eval-options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function')
    args = parser.parse_args()

    if args.save_results is not None:
        _, ext = os.path.splitext(args.save_results)
        if ext not in ['.json', '.pkl']:
            parser.error(f'--save-results must be a path to a json file or a pkl file, got {args.save_results}')
    if args.save_metas is not None:
        if not args.save_metas.endswith('.json'):
            parser.error(f'--save-metas must be a path to a json file, got {args.save_metas}')
    if args.save_power_data is not None:
        if not args.save_power_data.endswith('.json'):
            parser.error(f'--save-power-data must be a path to a json file, got {args.save_power_data}')

    if args.tegrastats and args.tegrastats_logfile is not None:
        if os.path.exists(args.tegrastats_logfile):
            parser.error(f'logging file of tegrastats exists: {args.tegrastats_logfile}, please use another path')

    return args

def collate_fn(batch_list):
    assert len(batch_list) == 1
    return batch_list[0]

tegrastats_process = None
tegrastats_logfile = None

def signal_handler(sig, frame):
    global tegrastats_process

    if tegrastats_process is None:
        sys.exit(1)

    kill_tegrastats()
    sys.exit(1)

def kill_tegrastats():
    global tegrastats_process, tegrastats_logfile
    if tegrastats_process.poll() is None:
        try:
            pgid = os.getpgid(tegrastats_process.pid)
            os.killpg(pgid, signal.SIGINT)
            tegrastats_process.wait(timeout=10)
            print(f'tegrastats is killed by SIGINT, log file: {tegrastats_logfile}')
        except ProcessLookupError:
            print(f'Cannot find the process group of tegrastats. It may exit prematurely and log may be incomplete.')
        except subprocess.TimeoutExpired:
            print(f'Killing tegrastats timeouted, using SIGKILL.')
            if tegrastats_process.poll() is None: # 再次检查
                os.killpg(pgid, signal.SIGKILL)
                tegrastats_process.wait()
                print(f'tegrastats is killed by SIGKILL, log file: {tegrastats_logfile}')
        except Exception as e:
            print(f'Error killing tegrastats: {e}')
    else:
        print('tegrastats exited prematurely, log may be incomplete.')

def main():
    global tegrastats_process, tegrastats_logfile
    args = parse_args()
    tegrastats_logfile = args.tegrastats_logfile

    # check root and register signal handler for killing tegrastats
    if args.tegrastats:
        effective_uid = os.geteuid()
        if effective_uid != 0:
            print('Error: Enabling tegrastats requires root privileges.')
            exit(1)

        signal.signal(signal.SIGINT, signal_handler)
        signal.signal(signal.SIGTERM, signal_handler)

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])

    # import modules from plguin/xx, registry will be updated
    if hasattr(cfg, 'plugin'):
        if cfg.plugin:
            import importlib
            if hasattr(cfg, 'plugin_dir'):
                plugin_dir = cfg.plugin_dir
                _module_dir = os.path.dirname(plugin_dir)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]

                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)
            else:
                # import dir is the dirpath for the config file
                _module_dir = os.path.dirname(args.config)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]
                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)

    cfg = compat_cfg(cfg)

    # set multi-process settings
    setup_multi_processes(cfg)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    cfg.model.pretrained = None

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build datasets for all models
    assert cfg.model.type in ['SceneDetectorV2', 'SceneDetectorV3']
    if cfg.model.type == 'SceneDetectorV2':
        cfg.model.type = 'SceneDetectorV3'
    cfg.model.camera_min_duration = args.camera_min_duration
    cfg.model.camera_boot_time = args.camera_boot_time
    cfg.model.lidar_min_duration = args.lidar_min_duration
    cfg.model.lidar_boot_time = args.lidar_boot_time
    cfg.model.sensor_default_status = args.sensor_default_status
    cfg.model.reset_min_duration = args.reset_min_duration
    cfg.model.wait_for_booting = args.wait_for_booting
    if cfg.model.camera_lidar_detector.get('dataset') is not None:
        cfg.model.camera_lidar_detector.dataset.test_mode = True
        cfg.model.camera_lidar_detector.dataset = build_dataset(cfg.model.camera_lidar_detector.dataset)
    if cfg.model.camera_only_detector.get('dataset') is not None:
        cfg.model.camera_only_detector.dataset.test_mode = True
        cfg.model.camera_only_detector.dataset = build_dataset(cfg.model.camera_only_detector.dataset)
    if cfg.model.scene_classifier.get('dataset') is not None:
        cfg.model.scene_classifier.dataset = build_dataset(cfg.model.scene_classifier.dataset)
    camera_lidar_dataloader = DataLoader(cfg.model.camera_lidar_detector.dataset, batch_size=1, num_workers=0, shuffle=False, collate_fn=collate_fn)
    camera_only_dataloader = DataLoader(cfg.model.camera_only_detector.dataset, batch_size=1, num_workers=0, shuffle=False, collate_fn=collate_fn)
    scene_classifier_dataloader = DataLoader(cfg.model.scene_classifier.dataset, batch_size=1, num_workers=0, shuffle=False, collate_fn=collate_fn)
    assert len(camera_lidar_dataloader) == len(camera_only_dataloader) and len(camera_lidar_dataloader) == len(scene_classifier_dataloader)
    num_samples = len(camera_lidar_dataloader) if args.samples is None else args.samples

    # build the model and load checkpoint
    cfg.model.train_cfg = None
    model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        model = model.to(torch.float16)
        wrap_fp16_model(model)
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    if args.fuse_conv_bn:
        model = fuse_conv_bn(model)

    model.eval()
    model = model.cuda()

    outputs = []
    metas = []
    time_total = 0.
    temp_tegrastats_file = None
    loop_time_start = time.perf_counter()
    model.start_sensor_recording()
    for idx, (camera_lidar_data, camera_only_data, scene_classifier_data) in enumerate(tqdm(zip(camera_lidar_dataloader, camera_only_dataloader, scene_classifier_dataloader), total=num_samples)):
        if idx == num_samples:
            break

        if args.tegrastats and idx == args.tegrastats_warmup:
            if tegrastats_logfile is None:
                temp_tegrastats_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='tegrastats_logging_')
                tegrastats_logfile = temp_tegrastats_file.name
                temp_tegrastats_file.close()

            tegrastats_command = f"sudo tegrastats --interval {args.tegrastats_interval} --logfile {tegrastats_logfile}"
            tegrastats_command = shlex.split(tegrastats_command)
            tegrastats_process = subprocess.Popen(tegrastats_command, preexec_fn=os.setsid)
            print(f'runing tegrastats with command: {" ".join(tegrastats_command)}')

        with torch.no_grad():
            with CUDATimer('') as t:
                output = model(camera_lidar_data, camera_only_data, scene_classifier_data)
            infer_time = t.get_time()
            time_total += infer_time
            metas_sample = output[0].pop('metas', None)
            metas_sample['infer_time'] = infer_time
            metas.append(metas_sample)
            outputs.extend(output)
    sensor_running_time_dict = model.get_total_sensor_running_time_dict()
    loop_time = time.perf_counter() - loop_time_start

    if args.tegrastats:
        kill_tegrastats()

    power_data = dict()
    power_data['mmdet_config_file'] = args.config
    power_data['model_checkpoint_path'] = args.checkpoint
    power_data['lidar_power_single'] = args.lidar_power
    power_data['camera_power_single'] = args.camera_power
    power_data['lidar_min_duration'] = args.lidar_min_duration
    power_data['camera_min_duration'] = args.camera_min_duration
    power_data['total_loop_time'] = loop_time
    power_data['sensor_running_time'] = sensor_running_time_dict
    power_data['sensor_average_power'] = dict()
    for sensor_name, running_time in sensor_running_time_dict.items():
        if 'LIDAR' in sensor_name:
            sensor_power_single = args.lidar_power
        elif 'CAM' in sensor_name:
            sensor_power_single = args.camera_power
        else:
            raise ValueError(f'Unknown sensor type: {sensor_name}')

        power_data['sensor_average_power'][sensor_name] = sensor_power_single * running_time / loop_time
    print(f'averaged sensor power: {power_data["sensor_average_power"]}')

    time_avg = time_total / num_samples
    infer_time_all = [meta['infer_time'] for meta in metas]
    infer_time_all = np.array(infer_time_all, dtype=np.float64)
    time_std = infer_time_all.std().item()
    power_data['forward_time'] = dict()
    power_data['forward_time']['total'] = time_total
    power_data['forward_time']['avg'] = time_avg
    power_data['forward_time']['std'] = time_std
    print(f'total end-to-end time: {time_total:.3f}s')
    print(f'average end-to-end time: {time_avg:.3f}s')
    print(f'standard deviation of end-to-end time: {time_std:.5f}')

    if args.tegrastats:
        try:
            tegrastats_results = calculate_average_power_from_log(tegrastats_logfile)
            power_data['tegrastats'] = tegrastats_results

            if temp_tegrastats_file is not None:
                print(f'removing temp logging file of tegrastats: {tegrastats_logfile}')
                os.remove(tegrastats_logfile)

        except Exception as e:
            print(f'Error parsing logging file of tegrastats ({args.tegrastats_logfile}): {e}')

    if args.save_power_data is not None:
        with open(args.save_power_data, 'w') as f:
            json.dump(power_data, f, indent=2)
        print(f'power data saved to {args.save_power_data}')

    if metas[0].get('time_data_loading') is not None:
        time_data_loading = 0.
        for meta in metas:
            time_data_loading += meta['time_data_loading']
        time_data_loading_avg = time_data_loading / num_samples
        print(f'average data loading time: {time_data_loading_avg:.3f}')

    if metas[0].get('time_data_processing') is not None:
        time_data_processing = 0.
        for meta in metas:
            time_data_processing += meta['time_data_processing']
        time_data_processing_avg = time_data_processing / num_samples
        print(f'average data processing time: {time_data_processing_avg:.3f}')

    if metas[0].get('infer_time_dict') is not None:
        scene_classifier_infer_time = []
        cmt_infer_time = []
        streampetr_infer_time = []
        for meta in metas:
            infer_time_dict = meta['infer_time_dict']
            for k, v in infer_time_dict.items():
                elapsed_time = v['elapsed']
                if k.startswith('scene_classifier'):
                    scene_classifier_infer_time.append(elapsed_time)
                elif k.startswith('cmt'):
                    cmt_infer_time.append(elapsed_time)
                elif k.startswith('streampetr'):
                    streampetr_infer_time.append(elapsed_time)
                else:
                    raise RuntimeError(f'Unknown key: {k}')

        scene_classifier_infer_time_avg = sum(scene_classifier_infer_time) / len(scene_classifier_infer_time)
        cmt_infer_time_avg = sum(cmt_infer_time) / len(cmt_infer_time)
        streampetr_infer_time_avg = sum(streampetr_infer_time) / len(streampetr_infer_time)

        print(f'average scene classifier inference time: {scene_classifier_infer_time_avg:.3f}')
        print(f'average CMT inference time: {cmt_infer_time_avg:.3f}')
        print(f'average StreamPETR inference time: {streampetr_infer_time_avg:.3f}')

    if metas[0].get('use_lidar') is not None:
        use_lidar_cnt = 0
        for meta in metas:
            use_lidar_cnt += meta['use_lidar']
        use_lidar_cnt_avg = use_lidar_cnt / len(metas)
        print(f'average used LiDARs: {use_lidar_cnt_avg:.5f} ({use_lidar_cnt}/{len(metas)})')

    if metas[0].get('used_cameras') is not None:
        use_camera_cnt = 0
        for meta in metas:
            use_camera_cnt += len(meta['used_cameras'])
        use_camera_cnt_avg = use_camera_cnt / len(metas)
        print(f'average used cameras: {use_camera_cnt_avg:.5f} ({use_camera_cnt}/{len(metas)})')

    if args.save_outputs is not None:
        mmcv.dump(outputs, args.save_outputs)
        print(f'model outputs saved to {args.save_outputs}')

    if args.save_metas is not None:
        with open(args.save_metas, 'w') as f:
            json.dump(metas, f, indent=2)
        print(f'metas saved to {args.save_metas}')

    if not args.no_eval:
        dataset = build_dataset(cfg.data.test)
        eval_results = dataset.evaluate(outputs, metric='bbox')
        print(eval_results)

        kwargs = {} if args.eval_options is None else args.eval_options
        if args.format_only:
            dataset.format_results(outputs, **kwargs)

        if args.save_results is not None:
            if args.save_results.endswith('.json'):
                with open(args.save_results, 'w') as f:
                    json.dump(eval_results, f, indent=2)
            elif args.save_results.endswith('.pkl'):
                mmcv.dump(eval_results, args.save_results)
            else:
                raise NotImplementedError()
            print(f'eval results saved to {args.save_results}')


if __name__ == '__main__':
    main()
