# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)

from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.models import build_detector
from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device,
                         replace_cfg_vals, rfnext_init_model,
                         setup_multi_processes, update_data_root)
from mqb_general_process import make_qmodel_for_mmd, prepocess
from mqbench.utils.state import *
import global_placeholder
from mqb_general_process import *
from copy import deepcopy


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('quant_config', default=None, help='quant config file path')
    parser.add_argument('--aqd-mode', type=int, default=0, help='when bigger than 0 , it means switch on aqd, and equals the neck output level num')
    parser.add_argument('--range-reg', 
        action='store_true', help='introduce range regularization')
    parser.add_argument('--amplify-outputs', 
        action='store_true', help='introduce amplify outputs mode')
    parser.add_argument('--quantize', 
        action='store_true', help='quant flag')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument('--out', help='output result file in pickle format')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='id of gpu to use '
        '(only applicable to non-distributed testing)')
    parser.add_argument(
        '--format-only',
        action='store_true',
        help='Format the output results without perform evaluation. It is'
        'useful when you want to format the result to a specific format and '
        'submit it to the test server')
    parser.add_argument(
        '--eval',
        type=str,
        nargs='+',
        help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
        ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
    parser.add_argument('--show', action='store_true', help='show results')
    parser.add_argument(
        '--show-dir', help='directory where painted images will be saved')
    parser.add_argument(
        '--show-score-thr',
        type=float,
        default=0.3,
        help='score threshold (default: 0.3)')
    parser.add_argument(
        '--gpu-collect',
        action='store_true',
        help='whether to use gpu to collect results.')
    parser.add_argument(
        '--tmpdir',
        help='tmp directory used for collecting results from multiple '
        'workers, available when gpu-collect is not specified')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function (deprecate), '
        'change to --eval-options instead.')
    parser.add_argument(
        '--eval-options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    if args.options and args.eval_options:
        raise ValueError(
            '--options and --eval-options cannot be both '
            'specified, --options is deprecated in favor of --eval-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --eval-options')
        args.eval_options = args.options
    return args


def main():
    args = parse_args()
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    global_placeholder.modify_logger(timestamp, 'work_dirs/debug_garbage')
    from global_placeholder import logger
    
    assert args.out or args.eval or args.format_only or args.show \
        or args.show_dir, \
        ('Please specify at least one operation (save/eval/format/show the '
         'results / save the results) with the argument "--out", "--eval"'
         ', "--format-only", "--show" or "--show-dir"')

    if args.eval and args.format_only:
        raise ValueError('--eval and --format_only cannot be both specified')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

    cfg = Config.fromfile(args.config)

    # replace the ${key} with the value of cfg.key
    cfg = replace_cfg_vals(cfg)

    # update data root according to MMDET_DATASETS
    update_data_root(cfg)

    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    cfg = compat_cfg(cfg)

    # set multi-process settings
    setup_multi_processes(cfg)
    set_random_seed(args.seed)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if 'pretrained' in cfg.model:
        cfg.model.pretrained = None
    elif 'init_cfg' in cfg.model.backbone:
        cfg.model.backbone.init_cfg = None

    if cfg.model.get('neck'):
        if isinstance(cfg.model.neck, list):
            for neck_cfg in cfg.model.neck:
                if neck_cfg.get('rfp_backbone'):
                    if neck_cfg.rfp_backbone.get('pretrained'):
                        neck_cfg.rfp_backbone.pretrained = None
        elif cfg.model.neck.get('rfp_backbone'):
            if cfg.model.neck.rfp_backbone.get('pretrained'):
                cfg.model.neck.rfp_backbone.pretrained = None

    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed testing. Use the first GPU '
                      'in `gpu_ids` now.')
    else:
        cfg.gpu_ids = [args.gpu_id]
    cfg.device = get_device()
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    test_dataloader_default_args = dict(
        samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)

    # in case the test dataset is concatenated
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.test.pipeline = replace_ImageToTensor(
                cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    test_loader_cfg = {
        **test_dataloader_default_args,
        **cfg.data.get('test_dataloader', {})
    }

    rank, _ = get_dist_info()
    # allows not to create
    if args.work_dir is not None and rank == 0:
        mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
        json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')

    # build the dataloader
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(dataset, **test_loader_cfg)

    if args.amplify_outputs:
        global_placeholder.modify_amplify_outputs_flag(True)  
    global_placeholder.modify_model_type(cfg.model.type)
    
    # build the model and load checkpoint
    cfg.model.train_cfg = None
    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
    # global setting
    if args.aqd_mode != 0:
        global_placeholder.modify_AQD_mode(args.aqd_mode)
    
    # 如果是QAT模型，那么就需要提前定义好量化模型结构
    if args.quantize:
        quant_config = prepocess(args.quant_config)
        # copy_config_file(args.quant_config, cfg.work_dir)
        logger.info(quant_config)
        global_placeholder.modify_quant_bit(quant_config.extra_prepare_dict.extra_qconfig_dict.w_qscheme.bit)
        global_placeholder.modify_quant_algorithm(quant_config.quantize.quant_algorithm)
        global_placeholder.modify_buff_flag(quant_config.training.my_buff_flag)
        global_placeholder.modify_fold_bn_flag(quant_config.training.fold_bn_flag)

        model.train()
        model = make_qmodel_for_mmd(model, quant_config, cfg.trace_config)
            
    # else:
    #     if 'HQOD' in cfg.work_dir:
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！\n")
    #         global_placeholder.modify_buff_flag(1) # 为mypro
    #     elif 'HarDet' in cfg.work_dir:
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！\n")
    #         global_placeholder.modify_buff_flag(2) # 为hardet
            
    # 检查 num levels 一致性
    if global_placeholder.aqd_mode != 0 and model.neck.num_outs != global_placeholder.aqd_mode:
        # 说明 num levels给的不对
        raise  ValueError(f'num levels给的不对! aqd_mode={global_placeholder.aqd_mode} 而 neck.num_outs={model.neck.num_outs}')
    
    # init rfnext if 'RFSearchHook' is defined in cfg
    rfnext_init_model(model, cfg=cfg)
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is None and cfg.get('device', None) == 'npu':
        fp16_cfg = dict(loss_scale='dynamic')
    if fp16_cfg is not None:
        wrap_fp16_model(model)
        
    model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

    # if args.quantize:
    #     # cali loader的设置
    #     cali_loader_cfg = deepcopy(test_loader_cfg)
    #     cali_loader_cfg['samples_per_gpu'] = 1
    #     cali_loader_cfg['workers_per_gpu'] = 1
    #     cali_data_loader = [build_dataloader(ds, **cali_loader_cfg, pin_memory=True) for ds in dataset][0]
    
    #     print('\nCalibrating\n')
    #     enable_calibration(model.module.backbone)
    #     enable_calibration(model.module.neck)
    #     model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
    #     if model_general_architecture == 'FasterRCNN':
    #         enable_calibration(model.module.rpn_head)
    #         enable_calibration(model.module.roi_head.bbox_head)
    #     else:
    #         enable_calibration(model.module.bbox_head)
        
    #     # calibrate(model, data_loader
    #     #           , 1
    #     #         #   , quant_config.quantize.cali_batchnum
    #     #           )  # NOTE 训练集的data确实全带container
        
    #     # 清除cali loader，释放内存
    #     del cali_data_loader
    # if cfg.model.type == 'RetinaNet':
        
    #     revise_keys = [
    #         (r'conv(\d)\.bn', r'bn\g<1>'),  # backbone.layer1.0.conv1.bn.weight -> backbone.layer1.0.bn1.weight
    #         (r'downsample\.0\.bn','downsample.1'),  # backbone.layer2.0.downsample.0.bn.weight -> backbone.layer2.0.downsample.1.weight
    #     ]
    #     # revise_keys = [
    #     #     # (r'layer(.+)conv(\d)\.bn', r'layer\g<1>bn\g<2>'),  # backbone.layer1.0.conv1.bn.weight -> backbone.layer1.0.bn1.weight
    #     #     # (r'^module\.', '')  # palceholder
    #     # ]
    # else:
    #     raise NotImplementedError(f'模型为{cfg.model.type}，请实现revise_keys！！')
    # checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu', revise_keys=revise_keys)
    checkpoint = load_ckpt_with_revise_keys(model, cfg.model.type, args.checkpoint, map_location='cpu', is_logging=False)
    # checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    head_list = []
    head_name_list = []
    head_data_list = []
    total_name_list = []
    total_data_list = []
    # NOTE 一般单纯测结果就注释下面这段
    # for key in checkpoint['state_dict'].keys():
    #     # if 'scale' in key:
    #     #     total_name_list.append(key)
    #     #     total_data_list.append(checkpoint['state_dict'][key].data.item())
    #     #     if 'bbox_head' in key:
    #     #         head_list.append([key, checkpoint['state_dict'][key].data.item()])
    #     #         head_name_list.append(key)
    #     #         head_data_list.append(checkpoint['state_dict'][key].data.item())
        
    #     if 'bbox_head' in key and 'weight' in key:
    #         data = checkpoint['state_dict'][key].data
    #         head_list.append([key, {'min:':data.min(), 'max:':data.max(), 'std:':data.std(), }])
    if args.fuse_conv_bn:
        model = fuse_conv_bn(model)
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint.get('meta', {}):
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES
        

    
    if args.quantize:
        enable_quantization(model)

        # enable_quantization(model.module.backbone)
        # enable_quantization(model.module.neck)
        # model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
        # if model_general_architecture == 'FasterRCNN':
        #     enable_quantization(model.module.rpn_head)
        #     enable_quantization(model.module.roi_head.bbox_head)
        # else:
        #     enable_quantization(model.module.bbox_head)

    model.eval()
    if not distributed:
        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
                                  args.show_score_thr)
    else:
        raise NotImplementedError

    rank, _ = get_dist_info()
    if rank == 0:
        if args.out:
            print(f'\nwriting results to {args.out}')
            mmcv.dump(outputs, args.out)
        kwargs = {} if args.eval_options is None else args.eval_options
        if args.format_only:
            dataset.format_results(outputs, **kwargs)
        if args.eval:
            eval_kwargs = cfg.get('evaluation', {}).copy()
            # hard-code way to remove EvalHook args
            for key in [
                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
                    'rule', 'dynamic_intervals'
            ]:
                eval_kwargs.pop(key, None)
            eval_kwargs.update(dict(metric=args.eval, **kwargs))
            metric = dataset.evaluate(outputs, **eval_kwargs)
            print(metric)
            metric_dict = dict(config=args.config, metric=metric)
            if args.work_dir is not None and rank == 0:
                mmcv.dump(metric_dict, json_file)


def set_random_seed(seed, deterministic=True):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    import random

    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


if __name__ == '__main__':
    main()
