# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)

from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.models import build_detector
from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device,
                         replace_cfg_vals, rfnext_init_model,
                         setup_multi_processes, update_data_root)
from mqb_general_process import make_qmodel_for_mmd, prepocess
from mqbench.utils.state import *
import global_placeholder
from mqb_general_process import *
from copy import deepcopy


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument(
            'prediction_path', help='prediction path where test .pkl result')
    
    args = parser.parse_args()

    return args


def main():
    args = parse_args()

    
    cfg = Config.fromfile(args.config)

    # replace the ${key} with the value of cfg.key
    cfg = replace_cfg_vals(cfg)

    # update data root according to MMDET_DATASETS
    update_data_root(cfg)



    # set multi-process settings
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    cfg.device = get_device()
    # init distributed env first, since logger depends on the dist info.

    test_dataloader_default_args = dict(
        samples_per_gpu=1, workers_per_gpu=2, shuffle=False)

    # # in case the test dataset is concatenated
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        # if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
        #     # Replace 'ImageToTensor' to 'DefaultFormatBundle'
        #     cfg.data.test.pipeline = replace_ImageToTensor(
        #         cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):  # NOTE 这非常关键，决定是5000张图片还是4952张图片!
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    test_loader_cfg = {
        **test_dataloader_default_args,
        **cfg.data.get('test_dataloader', {})
    }

    rank, _ = get_dist_info()
    # allows not to create
    # if args.work_dir is not None and rank == 0:
    #     mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
    #     timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    #     json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')

    # build the dataloader
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(dataset, **test_loader_cfg)


    outputs = mmcv.load(args.prediction_path)
    jsonfile_prefix = os.path.dirname(args.prediction_path)
    jsonfile_path = os.path.join(jsonfile_prefix,'result')
    print(jsonfile_path)
    dataset.format_results(outputs, jsonfile_path)
    
    
    # rank, _ = get_dist_info()
    # if rank == 0:
    #     if args.out:
    #         print(f'\nwriting results to {args.out}')
    #         mmcv.dump(outputs, args.out)
    #     kwargs = {} if args.eval_options is None else args.eval_options
    #     if args.format_only:
    #         dataset.format_results(outputs, **kwargs)
    #     if args.eval:
    #         eval_kwargs = cfg.get('evaluation', {}).copy()
    #         # hard-code way to remove EvalHook args
    #         for key in [
    #                 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
    #                 'rule', 'dynamic_intervals'
    #         ]:
    #             eval_kwargs.pop(key, None)
    #         eval_kwargs.update(dict(metric=args.eval, **kwargs))
    #         metric = dataset.evaluate(outputs, **eval_kwargs)
    #         print(metric)
    #         metric_dict = dict(config=args.config, metric=metric)
    #         if args.work_dir is not None and rank == 0:
    #             mmcv.dump(metric_dict, json_file)


def set_random_seed(seed, deterministic=True):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    import random

    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
