# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)

from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.models import build_detector
from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device,
                         replace_cfg_vals, rfnext_init_model,
                         setup_multi_processes, update_data_root)
from mqb_general_process import make_qmodel_for_mmd, prepocess
from mqbench.utils.state import *
import global_placeholder
from mqb_general_process import *
from copy import deepcopy
import numpy as np
from plot_curve_v2 import save_distribution
from tqdm import tqdm


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('quant_config', default=None, help='quant config file path')
    parser.add_argument('--aqd-mode', type=int, default=0, help='when bigger than 0 , it means switch on aqd, and equals the neck output level num')
    parser.add_argument('--range-reg', 
        action='store_true', help='introduce range regularization')
    parser.add_argument('--amplify-outputs', 
        action='store_true', help='introduce amplify outputs mode')
    parser.add_argument('--quantize', 
        action='store_true', help='quant flag')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument(
        '--dir-name',
        help='the directory to save the file')
    
    parser.add_argument(
        '--draw-distribution',
        action='store_true', help='switch the mode to draw distribution for the first img(batch=1)')
    parser.add_argument(
        '--full-eval',
        action='store_true', help='switch the mode to full-eval')
    parser.add_argument(
        '--save-act',
        action='store_true', help='switch the mode to save-act')
    
    parser.add_argument(
        '--save-num',
        type=int,
        default=-1,
        help='the number of the first k samples to save act data')
    
    parser.add_argument('--out', help='output result file in pickle format')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='id of gpu to use '
        '(only applicable to non-distributed testing)')
    parser.add_argument(
        '--format-only',
        action='store_true',
        help='Format the output results without perform evaluation. It is'
        'useful when you want to format the result to a specific format and '
        'submit it to the test server')
    parser.add_argument(
        '--eval',
        type=str,
        nargs='+',
        help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
        ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
    parser.add_argument('--show', action='store_true', help='show results')
    parser.add_argument(
        '--show-dir', help='directory where painted images will be saved')
    parser.add_argument(
        '--show-score-thr',
        type=float,
        default=0.3,
        help='score threshold (default: 0.3)')
    parser.add_argument(
        '--gpu-collect',
        action='store_true',
        help='whether to use gpu to collect results.')
    parser.add_argument(
        '--tmpdir',
        help='tmp directory used for collecting results from multiple '
        'workers, available when gpu-collect is not specified')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function (deprecate), '
        'change to --eval-options instead.')
    parser.add_argument(
        '--eval-options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    if args.options and args.eval_options:
        raise ValueError(
            '--options and --eval-options cannot be both '
            'specified, --options is deprecated in favor of --eval-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --eval-options')
        args.eval_options = args.options
    return args


def main():
    args = parse_args()
    from global_placeholder import logger
    assert args.out or args.eval or args.format_only or args.show \
        or args.show_dir, \
        ('Please specify at least one operation (save/eval/format/show the '
         'results / save the results) with the argument "--out", "--eval"'
         ', "--format-only", "--show" or "--show-dir"')

    if args.eval and args.format_only:
        raise ValueError('--eval and --format_only cannot be both specified')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

    cfg = Config.fromfile(args.config)

    # replace the ${key} with the value of cfg.key
    cfg = replace_cfg_vals(cfg)

    # update data root according to MMDET_DATASETS
    update_data_root(cfg)

    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    cfg = compat_cfg(cfg)

    # set multi-process settings
    setup_multi_processes(cfg)
    set_random_seed(args.seed)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if 'pretrained' in cfg.model:
        cfg.model.pretrained = None
    elif 'init_cfg' in cfg.model.backbone:
        cfg.model.backbone.init_cfg = None

    if cfg.model.get('neck'):
        if isinstance(cfg.model.neck, list):
            for neck_cfg in cfg.model.neck:
                if neck_cfg.get('rfp_backbone'):
                    if neck_cfg.rfp_backbone.get('pretrained'):
                        neck_cfg.rfp_backbone.pretrained = None
        elif cfg.model.neck.get('rfp_backbone'):
            if cfg.model.neck.rfp_backbone.get('pretrained'):
                cfg.model.neck.rfp_backbone.pretrained = None

    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed testing. Use the first GPU '
                      'in `gpu_ids` now.')
    else:
        cfg.gpu_ids = [args.gpu_id]
    cfg.device = get_device()
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    test_dataloader_default_args = dict(
        samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)

    # in case the test dataset is concatenated
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.test.pipeline = replace_ImageToTensor(
                cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    test_loader_cfg = {
        **test_dataloader_default_args,
        **cfg.data.get('test_dataloader', {})
    }

    rank, _ = get_dist_info()
    # allows not to create
    if args.work_dir is not None and rank == 0:
        mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
        json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')

    # build the dataloader
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(dataset, **test_loader_cfg)

    # build the model and load checkpoint
    cfg.model.train_cfg = None
    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
    # global setting
    if args.aqd_mode != 0:
        global_placeholder.modify_AQD_mode(args.aqd_mode)
    
    # 如果是QAT模型，那么就需要提前定义好量化模型结构
    if args.quantize or args.range_reg:
        quant_config = prepocess(args.quant_config)
        # copy_config_file(args.quant_config, cfg.work_dir)
        logger.info(quant_config)
        global_placeholder.modify_quant_bit(quant_config.extra_prepare_dict.extra_qconfig_dict.w_qscheme.bit)
        global_placeholder.modify_quant_algorithm(quant_config.quantize.quant_algorithm)
        global_placeholder.modify_buff_flag(quant_config.training.my_buff_flag)
        global_placeholder.modify_range_reg_flag(quant_config.training.range_reg_flag)
        global_placeholder.modify_fold_bn_flag(quant_config.training.fold_bn_flag)
        
        model.train()
        model = make_qmodel_for_mmd(model, quant_config, cfg.trace_config)
        
    # else:
    #     if 'HQOD' in cfg.work_dir:
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！\n")
    #         global_placeholder.modify_buff_flag(1) # 为mypro
    #     elif 'HarDet' in cfg.work_dir:
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！\n")
    #         global_placeholder.modify_buff_flag(2) # 为hardet
            
    # 检查 num levels 一致性
    if global_placeholder.aqd_mode != 0 and model.neck.num_outs != global_placeholder.aqd_mode:
        # 说明 num levels给的不对
        raise  ValueError(f'num levels给的不对! aqd_mode={global_placeholder.aqd_mode} 而 neck.num_outs={model.neck.num_outs}')
    
    # init rfnext if 'RFSearchHook' is defined in cfg
    rfnext_init_model(model, cfg=cfg)
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is None and cfg.get('device', None) == 'npu':
        fp16_cfg = dict(loss_scale='dynamic')
    if fp16_cfg is not None:
        wrap_fp16_model(model)
        
    model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

    # if args.quantize:
    #     # cali loader的设置
    #     cali_loader_cfg = deepcopy(test_loader_cfg)
    #     cali_loader_cfg['samples_per_gpu'] = 1
    #     cali_loader_cfg['workers_per_gpu'] = 1
    #     cali_data_loader = [build_dataloader(ds, **cali_loader_cfg, pin_memory=True) for ds in dataset][0]
    
    #     print('\nCalibrating\n')
    #     enable_calibration(model.module.backbone)
    #     enable_calibration(model.module.neck)
    #     model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
    #     if model_general_architecture == 'FasterRCNN':
    #         enable_calibration(model.module.rpn_head)
    #         enable_calibration(model.module.roi_head.bbox_head)
    #     else:
    #         enable_calibration(model.module.bbox_head)
        
    #     # calibrate(model, data_loader
    #     #           , 1
    #     #         #   , quant_config.quantize.cali_batchnum
    #     #           )  # NOTE 训练集的data确实全带container
        
    #     # 清除cali loader，释放内存
    #     del cali_data_loader
    # if cfg.model.type == 'RetinaNet':
        
    #     revise_keys = [
    #         (r'conv(\d)\.bn', r'bn\g<1>'),  # backbone.layer1.0.conv1.bn.weight -> backbone.layer1.0.bn1.weight
    #         (r'downsample\.0\.bn','downsample.1'),  # backbone.layer2.0.downsample.0.bn.weight -> backbone.layer2.0.downsample.1.weight
    #     ]
    #     # revise_keys = [
    #     #     # (r'layer(.+)conv(\d)\.bn', r'layer\g<1>bn\g<2>'),  # backbone.layer1.0.conv1.bn.weight -> backbone.layer1.0.bn1.weight
    #     #     # (r'^module\.', '')  # palceholder
    #     # ]
    # else:
    #     raise NotImplementedError(f'模型为{cfg.model.type}，请实现revise_keys！！')
    # checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu', revise_keys=revise_keys)
    
    checkpoint = load_ckpt_with_revise_keys(model, cfg.model.type, args.checkpoint, map_location='cpu', is_logging=False)
    # checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    head_list = []
    head_name_list = []
    head_data_list = []
    total_name_list = []
    total_data_list = []
    # NOTE 一般单纯测结果就注释下面这段
    # for key in checkpoint['state_dict'].keys():
    #     # if 'scale' in key:
    #     #     total_name_list.append(key)
    #     #     total_data_list.append(checkpoint['state_dict'][key].data.item())
    #     #     if 'bbox_head' in key:
    #     #         head_list.append([key, checkpoint['state_dict'][key].data.item()])
    #     #         head_name_list.append(key)
    #     #         head_data_list.append(checkpoint['state_dict'][key].data.item())
        
    #     if 'bbox_head' in key and 'weight' in key:
    #         data = checkpoint['state_dict'][key].data
    #         head_list.append([key, {'min:':data.min(), 'max:':data.max(), 'std:':data.std(), }])
    if args.fuse_conv_bn:
        model = fuse_conv_bn(model)
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint.get('meta', {}):
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES
        

    
    if args.quantize:
        enable_quantization(model)

        # enable_quantization(model.module.backbone)
        # enable_quantization(model.module.neck)
        # model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
        # if model_general_architecture == 'FasterRCNN':
        #     enable_quantization(model.module.rpn_head)
        #     enable_quantization(model.module.roi_head.bbox_head)
        # else:
        #     enable_quantization(model.module.bbox_head)

    model.eval()
    if not distributed:
        if cfg.model.type in ['RetinaNet', 'ATSS']:
            args.max_share_count=5
        elif cfg.model.type in ['SingleStageDetector', 'YOLOX']:
            args.max_share_count=1
        else:
            raise ValueError



        forward_hooker = hook_hooker(args, model)
        
        
        
        if args.draw_distribution:
            # 说明是单batch下的可视化统计----
            # 前传
            for i, data in enumerate(data_loader):
                with torch.no_grad():
                    result = model(return_loss=False, rescale=True, **data)
                break
            # 画图  保存路径
            weight_data_in_head = {
                'cls': [],
                'reg': [],
                'other': []  # centerness/objectness
                }
            record_data_dict = forward_hooker.record_data_dict
            with tqdm(total=len(record_data_dict)) as pbar:
                for name, data in record_data_dict.items():
                    pbar.update(1)
                    # 根据结构建立文件夹
                    if 'backbone' in name:
                        save_path = os.path.join(args.work_dir, args.dir_name, 'backbone')
                    elif 'neck' in name:
                        save_path = os.path.join(args.work_dir, args.dir_name, 'neck')
                    elif 'head' in name:
                        save_path = os.path.join(args.work_dir, args.dir_name, 'head')
                        if 'weight' in name:
                            if 'cls' in name:
                                weight_data_in_head['cls'].append(data)
                            elif 'reg' in name:
                                weight_data_in_head['reg'].append(data)
                            elif 'centerness' in name or 'objectness' in name:
                                weight_data_in_head['other'].append(data)
                    else:
                        save_path = os.path.join(args.work_dir, args.dir_name)
                    
                    # 根据fuse or not 建立文件夹
                    if args.fuse_conv_bn:
                        save_path = os.path.join(save_path, 'fuse')
                    else:
                        save_path = os.path.join(save_path, 'ori')
                    
                    os.makedirs(f'{save_path}', exist_ok=True)
                    save_file_name = '{}/{}.jpg'.format(save_path, name)
                    # if os.path.exists(save_file_name):
                    #     print(f'Distribution drawing of {save_file_name} exist! skip...')
                    #     continue
                    save_distribution(data, save_file_name, name)
            
            # 画weight的对比图
            pass
            compa_save_path = os.path.join(args.work_dir, args.dir_name, 'weight_compa.jpg')
            # draw_weight_box_plot(weight_data_in_head, compa_save_path)
            draw_weight_line_plot(weight_data_in_head, compa_save_path)
            
            
            record_statistics_dict = Draw_distribution_hooker.record_statistics_dict
            mmcv.dump(
                record_statistics_dict, 
                f'{os.path.join(args.work_dir, args.dir_name)}/statistics_fuse_{args.fuse_conv_bn}.json',
                )
            print('Dump statistics')
        elif args.full_eval:
            # 说明是全eval集的统计----
            # 前传
            with tqdm(total=len(data_loader)) as pbar:
                for i, data in enumerate(data_loader):
                    with torch.no_grad():
                        result = model(return_loss=False, rescale=True, **data)
                        pbar.update(1)
            # 取得统计结果
            parsed_data_dict = forward_hooker.parse_data(forward_hooker.record_statistics_dict)
            save_path = os.path.join(args.work_dir, args.dir_name)
            os.makedirs(f'{save_path}', exist_ok=True)
            mmcv.dump(
                parsed_data_dict, 
                f'{save_path}/full-eval-statistics_fuse_{args.fuse_conv_bn}.json',
                )
        elif args.save_act:
            # 说明是单batch下的可视化统计----
            # 前传
            with tqdm(total=len(data_loader)) as pbar:
                for i, data in enumerate(data_loader):
                    with torch.no_grad():
                        result = model(return_loss=False, rescale=True, **data)
                    pbar.update(1)
                    if i == args.save_num - 1:
                        break
            record_data_dict = forward_hooker.record_data_dict
            save_path = os.path.join(args.work_dir, args.dir_name)
            os.makedirs(f'{save_path}', exist_ok=True)
            import pickle
            with open(f'{save_path}/save-act_{args.save_num}_fuse_{args.fuse_conv_bn}.npz', 'wb') as tarfile:
                pickle.dump(record_data_dict, tarfile)
            # np.savez(
            #     f'{save_path}/save-act_{args.save_num}_fuse_{args.fuse_conv_bn}.npz',
            #     **record_data_dict, 
            #     )
    else:
        raise NotImplementedError
        
def hook_hooker(args, model):
    
    if args.draw_distribution:
        # 则进行单图可视化------
        forward_hooker = Draw_distribution_hooker
    elif args.full_eval:
        forward_hooker = Full_eval_hooker
    elif args.save_act:
        forward_hooker = Save_act_hooker
        # 特殊挂钩子
        target_layer_names = [
            # 'module.bbox_head.retina_cls',
            # 'module.bbox_head.retina_reg'
            'module.bbox_head.cls_convs.3.conv',
            'module.bbox_head.reg_convs.3.conv'
        ]
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):  # 确实绑定conv就行了；然后最后一层务必保存输出
                for tname in target_layer_names:
                    if tname in name:
                        module.register_forward_hook(forward_hooker(name, max_share_count=args.max_share_count, max_save_count=args.save_num))
        return forward_hooker
        
    #  正常挂钩子
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):  # 确实绑定conv就行了；然后最后一层务必保存输出
            if 'head' in name:
                module.register_forward_hook(forward_hooker(name, max_share_count=args.max_share_count))
            else:
                module.register_forward_hook(forward_hooker(name))
    return forward_hooker







def draw_weight_line_plot(weight_data_in_head, compa_save_path):
    import matplotlib.pyplot as plt
    import numpy as np
    # 重构展示顺序  cls reg 交替，最后存other
    max_sequence = [[],[]]
    precentile_sequence = [[],[]]
    color_sequence = ['blue','orange']  # cls、reg
    assert len(weight_data_in_head['reg']) == len(weight_data_in_head['cls'])
    for i in range(len(weight_data_in_head['reg'])):
        max_sequence[0].append(np.abs(weight_data_in_head['cls'][i]).max())
        max_sequence[1].append(np.abs(weight_data_in_head['reg'][i]).max())
        precentile_sequence[0].append((weight_data_in_head['cls'][i]).std())
        precentile_sequence[1].append((weight_data_in_head['reg'][i]).std())
        
        # precentile_sequence[0].append(np.abs(weight_data_in_head['cls'][i]).max() / np.percentile(weight_data_in_head['cls'][i], 99.99))
        # precentile_sequence[1].append(np.abs(weight_data_in_head['reg'][i]).max() / np.percentile(weight_data_in_head['reg'][i], 99.99))
        # precentile_sequence[0].append(np.abs(weight_data_in_head['cls'][i]).max() / weight_data_in_head['cls'][i].std())
        # precentile_sequence[1].append(np.abs(weight_data_in_head['reg'][i]).max() / weight_data_in_head['reg'][i].std())
        
    # <表示max；
    x_idces = list(range(1, 1 + len(weight_data_in_head['reg'])))
    plt.plot(
        x_idces, 
        max_sequence[0], 
        marker='<',
        color=color_sequence[0],
        alpha=0.65
        )
    plt.plot(
        x_idces, 
        max_sequence[1], 
        marker='<',
        color=color_sequence[1],
        alpha=0.65
        )
    plt.plot(
        x_idces, 
        precentile_sequence[0], 
        marker='o',
        color=color_sequence[0],
        alpha=0.65
        )
    plt.plot(
        x_idces, 
        precentile_sequence[1], 
        marker='o',
        color=color_sequence[1],
        alpha=0.65
        )
    # 添加标题和标签
    # ax.set_title('Box Plot with Alternating Colors')
    plt.xlabel('Layer Index')
    plt.ylabel('Weight range')
    # 显示图形
    plt.show()
    # plt.ylim(0,0.2)

    plt.savefig(compa_save_path)
    
def draw_weight_box_plot(weight_data_in_head, compa_save_path):
    import matplotlib.pyplot as plt
    import numpy as np
    # 重构展示顺序  cls reg 交替，最后存other
    data_sequence = []
    color_sequence = []
    assert len(weight_data_in_head['reg']) == len(weight_data_in_head['cls'])
    for i in range(len(weight_data_in_head['reg'])):
        data_sequence.append(weight_data_in_head['cls'][i].flatten())
        data_sequence.append(weight_data_in_head['reg'][i].flatten())
        color_sequence.append('blue')
        color_sequence.append('orange')
    for i in range(len(weight_data_in_head['other'])):
        data_sequence.append(weight_data_in_head['other'][i].flatten())
        color_sequence.append('black')
        
    
    # 绘制箱型图
    fig, ax = plt.subplots()
    box = ax.boxplot(data_sequence, patch_artist=True,
                 flierprops=dict(marker='*', color='black', markersize=2, alpha=0.2))
    
    # 设置颜色
    for patch, color in zip(box['boxes'], color_sequence):
        patch.set_facecolor(color)
    # 添加标题和标签
    # ax.set_title('Box Plot with Alternating Colors')
    ax.set_xlabel('Layer Index')
    ax.set_ylabel('Weight range')
    # 显示图形
    plt.show()
    # plt.ylim(0,0.2)

    plt.savefig(compa_save_path)
    
class Save_act_hooker(object):
    record_data_dict = {}
    def __init__(self, module_name, max_share_count=5, max_save_count=1) -> None:
        self.module_name = module_name
        self.max_save_count = max_save_count
        self.max_share_count = max_share_count
        self.save_count = 0
        self.share_count = 0
        self.record_data_dict[module_name+'_in'] = []
        self.record_data_dict[module_name+'_out'] = []
        
    def __call__(self, module, input, output):
        self.share_count = self.share_count + 1
        # 造title
        # title_out = self.module_name + f'_out_{self.count}'
        # title_in = self.module_name + f'_in_{self.count}'
        
        # title_list = [title_out, title_in]
        # data_list = [output, input[0]]
        # for title, data in zip(title_list, data_list):
            
        #     if title not in self.record_statistics_dict:
        #         self.record_statistics_dict[title] = [get_statistics(data, self.draw_distribution_flag)]
        #     else:
        #         self.record_statistics_dict[title].append(get_statistics(data, self.draw_distribution_flag))
        #     # self.record_statistics_dict[title_in] = get_statistics()
        #     # self.record_statistics_dict[title_out] = get_statistics()
        if self.share_count == 1:
            # head上act处在一次推理的时候只保存一次
            self.record_data_dict[self.module_name+'_in'].append(input[0].detach().cpu().numpy())
            self.record_data_dict[self.module_name+'_out'].append(output.detach().cpu().numpy())
        
        if self.share_count == self.max_share_count:
            # 那么清零
            self.share_count = 0

class Full_eval_hooker(object):
    record_statistics_dict = {}
    def __init__(self, module_name, max_share_count=1, draw_distribution_flag=False) -> None:
        self.module_name = module_name
        self.max_share_count = max_share_count
        self.count = 0
        self.draw_distribution_flag = draw_distribution_flag
        pass
    def __call__(self, module, input, output):
        self.count = self.count + 1
        # 造title
        if self.max_share_count == 1:
            title_out = self.module_name + f'_out'
            title_in = self.module_name + f'_in'
        else:
            title_out = self.module_name + f'_out_{self.count}'
            title_in = self.module_name + f'_in_{self.count}'
        
        title_list = [title_out, title_in]
        data_list = [output, input[0]]
        for title, data in zip(title_list, data_list):
            
            if title not in self.record_statistics_dict:
                self.record_statistics_dict[title] = [get_statistics(data, self.draw_distribution_flag)]
            else:
                self.record_statistics_dict[title].append(get_statistics(data, self.draw_distribution_flag))
            # self.record_statistics_dict[title_in] = get_statistics()
            # self.record_statistics_dict[title_out] = get_statistics()

        if self.count == self.max_share_count:
            # 那么清零
            self.count = 0
    @staticmethod
    def parse_data(record_statistics_dict):
        # desired_list = ['quantile9999', 'max']
        desired_list = ['max', 'std', 'max/std_3']
        
        out_dict = {}
        for title, statistics in record_statistics_dict.items():
            parsed_statistics = {}
            for key in desired_list:
                parsed_statistics[key] = []
            # parsing
            for stati in statistics:
                for key in desired_list:
                    parsed_statistics[key].append(stati[key])
            # calculate
            calculated_statistics = {}
            op_list = ['max', 'mean', 'min', 'std']
            for key in desired_list:
                statistics = parsed_statistics[key]
                statistics = np.array(statistics)
                for op_name in op_list:
                    full_title = f'{key}_{op_name}'
                    calculated_statistics[full_title] = getattr(statistics, op_name)()
                    
            out_dict[title] = calculated_statistics
                
            
        
        return out_dict

class Draw_distribution_hooker(object):
    record_data_dict = {}
    record_statistics_dict = {}
    def __init__(self, module_name, max_share_count=1) -> None:
        self.module_name = module_name
        self.max_share_count = max_share_count
        self.count = 0
        pass
    def __call__(self, module, input, output):
        self.count = self.count + 1
        # 造title
        if self.max_share_count == 1:
            title_out = self.module_name + f'_out'
            title_in = self.module_name + f'_in'
        else:
            title_out = self.module_name + f'_out_{self.count}'
            title_in = self.module_name + f'_in_{self.count}'
        # 记录
        self.record_data_dict[title_out] = output.detach().cpu().numpy()
        print(f'REC {title_out}')
        self.record_data_dict[title_in] = input[0].detach().cpu().numpy()
        print(f'REC {title_in}')
        title_weight = self.module_name + f'weight'
        self.record_data_dict[title_weight] = module.weight.detach().cpu().numpy()
        print(f'REC {title_weight}')
        
        self.record_statistics_dict[title_in] = get_statistics(input[0])
        self.record_statistics_dict[title_out] = get_statistics(output)
        if title_weight not in self.record_statistics_dict:
            self.record_statistics_dict[title_weight] = get_statistics(module.weight)

        if self.count == self.max_share_count:
            # 那么清零
            self.count = 0


def get_statistics(data, draw_distribution_flag=True):
    EPS = 1e-8
    statistics_dict = {}
    # quantile_to_mean = (torch.quantile(data,0.9999) - data.mean()).detach().cpu().numpy()
    max = data.abs().max().detach().cpu().numpy()
    std = data.std().detach().cpu().numpy()
    statistics_dict['std'] = std
    statistics_dict['max'] = max
    statistics_dict['std_3'] = std * 3
    statistics_dict['max/std_3'] = max / statistics_dict['std_3']
    
    if draw_distribution_flag:
        kurt = torch.mean((((data - data.mean()) / (data.std()+EPS)) ** 4)).detach().cpu().numpy()
        quantile = np.quantile(data.detach().cpu().numpy(),0.9999)
        statistics_dict['kurt'] = kurt
        statistics_dict['quantile9999'] = quantile
        statistics_dict['max/quantile'] = max / quantile
    
    
    return statistics_dict

def get_std(inputs):
    mean = 0
    return (torch.sum(((inputs - mean) ** 2)) / (inputs.numel() - 1)).sqrt()
    

def set_random_seed(seed, deterministic=True):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    import random

    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False



if __name__ == '__main__':
    main()
