# Copyright (c) OpenMMLab. All rights reserved.

import argparse
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import os.path as osp
import warnings
from copy import deepcopy

from mmengine import ConfigDict
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size

from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmdet.evaluation import DumpDetResults
from mmdet.registry import RUNNERS
from mmdet.utils import setup_cache_size_limit_of_dynamo
from spikingjelly.clock_driven.neuron import MultiStepLIFNode
import torch

# TODO: support fuse_conv_bn and format_only
def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument(
        '--out',
        type=str,
        help='dump predictions to a pickle file for offline evaluation')
    parser.add_argument(
        '--show', action='store_true', help='show prediction results')
    parser.add_argument(
        '--show-dir',
        help='directory where painted images will be saved. '
        'If specified, it will be automatically saved '
        'to the work_dir/timestamp/show_dir')
    parser.add_argument(
        '--wait-time', type=float, default=2, help='the interval of show (s)')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--tta', action='store_true')
    # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
    # will pass the `--local-rank` parameter to `tools/train.py` instead
    # of `--local_rank`.
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args


def main():
    args = parse_args()

    # Reduce the number of repeated compilations and improve
    # testing speed.
    setup_cache_size_limit_of_dynamo()

    # load config
    cfg = Config.fromfile(args.config)
    cfg.launcher = args.launcher
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])

    cfg.load_from = args.checkpoint

    if args.show or args.show_dir:
        cfg = trigger_visualization_hook(cfg, args)

    if args.tta:

        if 'tta_model' not in cfg:
            warnings.warn('Cannot find ``tta_model`` in config, '
                          'we will set it as default.')
            cfg.tta_model = dict(
                type='DetTTAModel',
                tta_cfg=dict(
                    nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))
        if 'tta_pipeline' not in cfg:
            warnings.warn('Cannot find ``tta_pipeline`` in config, '
                          'we will set it as default.')
            test_data_cfg = cfg.test_dataloader.dataset
            while 'dataset' in test_data_cfg:
                test_data_cfg = test_data_cfg['dataset']
            cfg.tta_pipeline = deepcopy(test_data_cfg.pipeline)
            flip_tta = dict(
                type='TestTimeAug',
                transforms=[
                    [
                        dict(type='RandomFlip', prob=1.),
                        dict(type='RandomFlip', prob=0.)
                    ],
                    [
                        dict(
                            type='PackDetInputs',
                            meta_keys=('img_id', 'img_path', 'ori_shape',
                                       'img_shape', 'scale_factor', 'flip',
                                       'flip_direction'))
                    ],
                ])
            cfg.tta_pipeline[-1] = flip_tta
        cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
        cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline

    # build the runner from config
    if 'runner_type' not in cfg:
        # build the default runner
        runner = Runner.from_cfg(cfg)
    else:
        # build customized runner from the registry
        # if 'runner_type' is set in the cfg
        runner = RUNNERS.build(cfg)

    # add `DumpResults` dummy metric
    if args.out is not None:
        assert args.out.endswith(('.pkl', '.pickle')), \
            'The dump file must be a pkl file.'
        runner.test_evaluator.metrics.append(
            DumpDetResults(out_file_path=args.out))

    # start testing
    # import tools.global_var as global_var
    # global_var.init()
    # global_var.set_value('firing', 0)
    # global_var.set_value('firing2', 20)
    # global_var.get_value('firing2')
    # import qaq
    # qaq.run_qaq()

    firing_dict = {}
    fr_dict = {}
    nz_dict = {}


    iter = 5000 #计算迭代轮次 #后面的数据是sample

    def forward_hook_fn(module, input, output): #计算每一层的发放率
        x_shape = torch.tensor(list(output.shape))
        z = torch.nonzero(output)
        all_neural = torch.prod(x_shape)
        if module.name+"_"+str(output.shape[-1]) not in fr_dict.keys():
            fr_dict[module.name+"_"+str(output.shape[-1])] = output.detach().mean()/ iter
            nz_dict[module.name+"_"+str(output.shape[-1])] = z.shape[0] / all_neural/ iter
        else:
            fr_dict[module.name+"_"+str(output.shape[-1])] = fr_dict[module.name+"_"+str(output.shape[-1])] + output.detach().mean()/ iter
            nz_dict[module.name+"_"+str(output.shape[-1])] = nz_dict[module.name+"_"+str(output.shape[-1])] + z.shape[0] / all_neural/ iter
        # print(module.name)
        # print(output.detach())

    for n, m in runner.model.named_modules():
        if isinstance(m, MultiStepLIFNode) or isinstance(m,torch.nn.Identity):
            print(n)
            m.name = n
            m.register_forward_hook(forward_hook_fn)

    outputs = get_model_complexity_info(runner.model, input_shape=(3,224,224), show_table=False, show_arch=False)
    flops = _format_size(outputs['flops'])
    params = _format_size(outputs['params'])
    print("flops:{}".format(flops))
    print("params:{}".format(params))

    runner.test()
    

    print("fr_dict",fr_dict)
    print("nz_dict",nz_dict)



if __name__ == '__main__':
    main()

