# Copyright (c) OpenMMLab. All rights reserved.
import os
import random

import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, load_checkpoint,
                         Fp16OptimizerHook, OptimizerHook, build_runner,
                         get_dist_info)
from mmdet.core import DistEvalHook, EvalHook, build_optimizer
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.utils import (build_ddp, build_dp, compat_cfg,
                         find_latest_checkpoint, get_root_logger)
from mmdet.apis.test import single_gpu_test, multi_gpu_test
from torch.fx import Tracer
import torch.fx as fx
from mqb_general_process import *
from mqbench.utils.state import *
from mqbench.advanced_ptq import ptq_reconstruction
import mmcv
import torch
import torch.distributed as dist
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info
import global_placeholder
from mmdet.core import encode_mask_results
from copy import deepcopy

def init_random_seed(seed=None, device='cuda'):
    """Initialize random seed.

    If the seed is not set, the seed will be automatically randomized,
    and then broadcast to all processes to prevent some potential bugs.

    Args:
        seed (int, Optional): The seed. Default to None.
        device (str): The device where the seed will be put on.
            Default to 'cuda'.

    Returns:
        int: Seed to be used.
    """
    if seed is not None:
        return seed

    # Make sure all ranks share the same random seed to prevent
    # some potential bugs. Please refer to
    # https://github.com/open-mmlab/mmdetection/issues/6339
    rank, world_size = get_dist_info()
    seed = np.random.randint(2**31)
    if world_size == 1:
        return seed

    if rank == 0:
        random_num = torch.tensor(seed, dtype=torch.int32, device=device)
    else:
        random_num = torch.tensor(0, dtype=torch.int32, device=device)
    dist.broadcast(random_num, src=0)
    return random_num.item()


def set_random_seed(seed, deterministic=False):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def auto_scale_lr(cfg, distributed, logger):
    """Automatically scaling LR according to GPU number and sample per GPU.

    Args:
        cfg (config): Training config.
        distributed (bool): Using distributed or not.
        logger (logging.Logger): Logger.
    """
    # Get flag from config
    if ('auto_scale_lr' not in cfg) or \
            (not cfg.auto_scale_lr.get('enable', False)):
        logger.info('Automatic scaling of learning rate (LR)'
                    ' has been disabled.')
        return

    # Get base batch size from config
    base_batch_size = cfg.auto_scale_lr.get('base_batch_size', None)
    if base_batch_size is None:
        return

    # Get gpu number
    if distributed:
        _, world_size = get_dist_info()
        num_gpus = len(range(world_size))
    else:
        num_gpus = len(cfg.gpu_ids)

    # calculate the batch size
    samples_per_gpu = cfg.data.train_dataloader.samples_per_gpu
    batch_size = num_gpus * samples_per_gpu
    logger.info(f'Training with {num_gpus} GPU(s) with {samples_per_gpu} '
                f'samples per GPU. The total batch size is {batch_size}.')

    if batch_size != base_batch_size:
        # scale LR with
        # [linear scaling rule](https://arxiv.org/abs/1706.02677)
        scaled_lr = (batch_size / base_batch_size) * cfg.optimizer.lr
        logger.info('LR has been automatically scaled '
                    f'from {cfg.optimizer.lr} to {scaled_lr}')
        cfg.optimizer.lr = scaled_lr
    else:
        logger.info('The batch size match the '
                    f'base batch size: {base_batch_size}, '
                    f'will not scaling the LR ({cfg.optimizer.lr}).')

def calibrate(model, data_loader, cali_num, retain_cali_data_flag = False):  # TODO 这里给修改成达到一定次数后就break；然后只用eval_data
    from global_placeholder import logger
    model.eval()
    cali_data =  []
    if isinstance(data_loader, list):
        dataset = data_loader[0].dataset
        data_loader = data_loader[0]
    else:
        dataset = data_loader.dataset
    PALETTE = getattr(dataset, 'PALETTE', None)
    
    rank, world_size = get_dist_info()
    
    if rank == 0:
        prog_bar = mmcv.ProgressBar(cali_num)
    # prog_bar = mmcv.ProgressBar(cali_num)
    steped_num = 0
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            data['img'] = data['img'].data
            data['img'][0] = data['img'][0].to('cuda:0')
            data['img_metas'] = data['img_metas'].data
            data.pop('gt_bboxes')
            data.pop('gt_labels')
            if retain_cali_data_flag:
                cali_data.append(data)
            
            
            result = model(return_loss=False, rescale=True, **data)
        
        
        # batch_size = len(result)
        # steped_num += batch_size * world_size
        steped_num += 1
        # dist.barrier()
        if rank == 0:
            # for _ in range(batch_size * world_size):
            prog_bar.update()
        if steped_num >= cali_num:
            logger.info(f'Truly calibrate num {steped_num}')
            if dist.is_initialized():
                dist.barrier()
            return cali_data

        
def ptq_detector(model,
                   dataset,
                   cfg,
                   quant_config,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):

    cfg = compat_cfg(cfg)
    logger = global_placeholder.logger

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']

    train_dataloader_default_args = dict(
        samples_per_gpu=1,
        workers_per_gpu=1,
        # `num_gpus` will be ignored if distributed
        num_gpus=len(cfg.gpu_ids),
        dist=distributed,
        seed=cfg.seed,
        runner_type=runner_type,
        persistent_workers=False)

    train_loader_cfg = {
        **train_dataloader_default_args,
        **cfg.data.get('train_dataloader', {})
    }
    

    data_loaders = [build_dataloader(ds, **train_loader_cfg, pin_memory=True) for ds in dataset]
    
    # cali loader的设置
    cali_loader_cfg = deepcopy(train_loader_cfg)
    cali_loader_cfg['samples_per_gpu'] = 1  # NOTE 注意，这里走的是batch=2！！
    cali_loader_cfg['workers_per_gpu'] = 1
    cali_data_loader = [build_dataloader(ds, **cali_loader_cfg, pin_memory=True) for ds in dataset][0]
    
    resume_from = None
    if cfg.resume_from is None and cfg.get('auto_resume'):
        resume_from = find_latest_checkpoint(cfg.work_dir)
    if resume_from is not None:
        cfg.resume_from = resume_from

    elif cfg.load_from:  # 这个应该可以解决所谓的quant baseline问题
        checkpoint = load_ckpt_with_revise_keys(model, cfg.model.type, cfg.load_from, map_location='cpu')
        # checkpoint = load_checkpoint(model, cfg.load_from, map_location='cpu')  # TODO 测一下这样的evaluate对不对
    else:
        checkpoint = {}
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint.get('meta', {}):
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset[0].CLASSES
        
        
    # print('\nGet FakeQuant model\n')
    # model.backbone = get_quantize_model(model.backbone, quant_config, cfg.trace_config.backbone_detail)  # QAT时，这个需要eval还是train
    # model.neck = get_quantize_model(model.neck, quant_config, cfg.trace_config.neck_detail)  # QAT时，这个需要eval还是train
    # temp = get_quantize_model(model.bbox_head, quant_config, cfg.trace_config.bbox_head_detail)  # QAT时，这个需要eval还是train
    # model.bbox_head.forward = temp.forward  # 太傻蛋勒
    # model.bbox_head = temp
    fp32_model = deepcopy(model)
    model.train()  # prepare 前一定得是train模式
    fp32_model.train()
    model = make_qmodel_for_mmd(model, quant_config, cfg.trace_config)
    fp32_model = make_qmodel_for_mmd(fp32_model, quant_config, cfg.trace_config)  # NOTE 这么做是因为不可deepcopy。因为model里夹杂着动态绑定的方法。
    if not hasattr(cfg, 'tune_from'):
        cfg.tune_from = False
    if cfg.tune_from:  # 加载进行tune
        raise NotImplementedError
        _ = load_checkpoint(model, cfg.tune_from, map_location='cpu')
    
    # # put model on gpus
    # if distributed:
    #     find_unused_parameters = cfg.get('find_unused_parameters', False)
    #     # Sets the `find_unused_parameters` parameter in
    #     # torch.nn.parallel.DistributedDataParallel
    #     model = build_ddp(  # 细节不给看
    #         model,
    #         cfg.device,
    #         device_ids=[int(os.environ['LOCAL_RANK'])],
    #         broadcast_buffers=False,
    #         find_unused_parameters=find_unused_parameters)
    # else:
    #     model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
    
    # model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
    # fp32_model = build_dp(fp32_model, cfg.device, device_ids=cfg.gpu_ids)
    
    model = model.to('cuda:0')
    fp32_model = fp32_model.to('cuda:0')
    
    
    # build optimizer
    auto_scale_lr(cfg, distributed, logger)
    optimizer = build_optimizer(model, cfg.optimizer)

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))

    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is None and cfg.get('device', None) == 'npu':
        fp16_cfg = dict(loss_scale='dynamic')
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

    if distributed:
        if isinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        val_dataloader_default_args = dict(
            samples_per_gpu=1,
            workers_per_gpu=1,
            dist=distributed,
            shuffle=False,
            persistent_workers=False)

        val_dataloader_args = {
            **val_dataloader_default_args,
            **cfg.data.get('val_dataloader', {})
        }
        # Support batch_size > 1 in validation

        if val_dataloader_args['samples_per_gpu'] > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.val.pipeline = replace_ImageToTensor(
                cfg.data.val.pipeline)
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))

        val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
        # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
        runner.register_hook(
            eval_hook(val_dataloader, **eval_cfg), priority='LOW')

    # cali_data_loader = deepcopy(val_dataloader)  # TODO 这里改成train的
    if not cfg.resume_from and not cfg.tune_from:
        print('\nCalibrating\n')
        enable_calibration(model)
        retain_cali_data_flag = False
        
        if quant_config.quantize.quantize_type == 'advanced_ptq':
            retain_cali_data_flag = True
            
        # enable_calibration(model.module.backbone)
        # enable_calibration(model.module.neck)
        # model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
        # if model_general_architecture == 'FasterRCNN':
        #     enable_calibration(model.module.rpn_head)
        #     enable_calibration(model.module.roi_head.bbox_head)
        # else:
        #     enable_calibration(model.module.bbox_head)
        
        cali_data = calibrate(model, cali_data_loader
                    # , 1
                    , quant_config.quantize.cali_batchnum
                    ,retain_cali_data_flag
                    )  # NOTE 训练集的data确实全带container
    # 清除cali loader，释放内存
    del cali_data_loader
    
    enable_quantization(model)
    
    # enable_quantization(model.module.backbone)
    # enable_quantization(model.module.neck)
    # model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
    # if model_general_architecture == 'FasterRCNN':
    #     enable_quantization(model.module.rpn_head)
    #     enable_quantization(model.module.roi_head.bbox_head)
    # else:
    #     enable_quantization(model.module.bbox_head)
        
    # 从这里开始，加入advance ptq
    # config = {
    #         'pattern': 'block' ,
    #         'scale_lr': 4.0e-5 ,
    #         'warm_up': 0.2 ,
    #         'weight': 0.01 ,
    #         'max_count': 20000,
    #         'b_range': [20,2] ,
    #         'keep_gpu': True ,
    #         'round_mode': 'learned_hard_sigmoid' ,
    #         'prob': 0.5 
    #     }
    if quant_config.quantize.quantize_type == 'advanced_ptq':
        logger.info('Advance PTQ Mode On!')
        model = ptq_reconstruction(fp32_model, model, cali_data=cali_data, config=quant_config['quantize'].reconstruction, graph_module_list=['backbone', 'neck', 'bbox_head'])
    elif quant_config.quantize.quantize_type == 'naive_ptq':
        pass
    else:
        raise ValueError('未实现该ptq模式！！')
    
    # if cfg.resume_from:
    #     runner.resume(cfg.resume_from)
    # runner.val(val_dataloader)
    
    # if distributed:
    #     outputs = multi_gpu_test(
    #         model, val_dataloader)
    # else:
    #     outputs = single_gpu_test(model, val_dataloader)
    
    if False:
        print('Warn!!!!!!!!!!!!!!! Switch on the MSE Q loss!')
        cali_data_loader = [build_dataloader(ds, **cali_loader_cfg, pin_memory=True) for ds in dataset][0]

        # 绑定hook

        class Save_act_hooker(object):
            record_statistics_dict = {}
            record_data_dict = {}

            def __init__(self, module_name, max_share_count=1, mode='fp', save_dir='activation_data') -> None:
                self.module_name = 'module.' + module_name
                self.max_share_count = max_share_count
                self.count = 0
                self.batch_count_init()
                self.mode = mode  # fp or q
                self.save_dir = save_dir
                pass

            def __call__(self, module, input, output):
                self.count = self.count + 1
                self.batch_count_increase()
                # 造title
                if self.max_share_count == 1:
                    title_out = self.module_name + f'_out'
                    title_in = self.module_name + f'_in'
                else:
                    title_out = self.module_name + f'_out_{self.count}'
                    title_in = self.module_name + f'_in_{self.count}'

                # 确保保存目录存在
                title_list = [title_out, title_in]
                data_list = [output, input[0]]
                # title_list = [title_in]
                # data_list = [input[0]]
                for title, data in zip(title_list, data_list):
                    os.makedirs(os.path.join(save_dir, title), exist_ok=True)
                    # 保存数据到.npy文件
                    file_path = os.path.join(self.save_dir, title, f"batch_{self.get_batch_count()}_{self.mode}.npy")
                    if not os.path.exists(file_path):
                        data = data.cpu().numpy().astype(np.float32)
                        np.save(file_path, data)
                    # 记录文件路径
                    if title not in self.record_data_dict:
                        self.record_data_dict[title] = {self.mode: [file_path]}
                    else:
                        if self.mode not in self.record_data_dict[title]:
                            self.record_data_dict[title][self.mode] = [file_path]
                        else:
                            self.record_data_dict[title][self.mode].append(file_path)

                if self.count == self.max_share_count:
                    # 那么清零
                    self.count = 0
            
            def batch_count_init(self):
                if self.max_share_count == 1:
                    self.batch_count = 0
                else:
                    self.batch_count = [0] * (self.max_share_count + 1)
                
            
            def batch_count_increase(self):
                if self.max_share_count == 1:
                    self.batch_count += 1
                else:
                    self.batch_count[self.count] += 1
            
            def get_batch_count(self):
                if self.max_share_count == 1:
                    return self.batch_count
                else:
                    return self.batch_count[self.count]
                

        # 创建保存目录
        task_name = cfg.work_dir.split('/')[-1]
        save_dir = f'/workspace/whole_world/rdata/long.huang/exp_out/temp/{task_name}'
        os.makedirs(save_dir, exist_ok=True)
        mode = 'q'
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):  # 确实绑定conv就行了；然后最后一层务必保存输出
                if 'head' in name:
                    module.register_forward_hook(Save_act_hooker(name, max_share_count=5, mode=mode, save_dir=save_dir))
                else:
                    module.register_forward_hook(Save_act_hooker(name, mode=mode, save_dir=save_dir))

        mode = 'fp'
        for name, module in fp32_model.named_modules():
            if isinstance(module, torch.nn.Conv2d):  # 确实绑定conv就行了；然后最后一层务必保存输出
                if 'head' in name:
                    module.register_forward_hook(Save_act_hooker(name, max_share_count=5, mode=mode, save_dir=save_dir))
                else:
                    module.register_forward_hook(Save_act_hooker(name, mode=mode, save_dir=save_dir))
        # import ipdb; ipdb.set_trace()

        with torch.no_grad():
            # 前传推理
            calibrate(model, cali_data_loader
                    , 1
                    #   , quant_config.quantize.cali_batchnum
                      , False)
            calibrate(fp32_model, cali_data_loader
                    , 1
                    #   , quant_config.quantize.cali_batchnum
                      , False)

        # 加载数据并计算MSE Loss
        record_statistics_dict = {}
        print('Quant MSE---->')
        for name in Save_act_hooker.record_data_dict.keys():
            fp_files = Save_act_hooker.record_data_dict[name]['fp']
            q_files = Save_act_hooker.record_data_dict[name]['q']
            mse_loss_list = []
            for fp_file, q_file in zip(fp_files, q_files):
                # print('fp_file', fp_file)
                # print('q_file', q_file)
                fp_t = np.load(fp_file)
                q_t = np.load(q_file)
                mse_loss = ((fp_t - q_t) ** 2).mean()
                mse_loss_list.append(mse_loss)
            record_statistics_dict[name] = {'mse_loss': sum(mse_loss_list) / len(mse_loss_list)}
            print(f'{name}: 处理了{len(mse_loss_list)}个样本')

        # 保存结果
        mmcv.dump(
            record_statistics_dict,
            f'{cfg.work_dir}/q4b_fuse_True.json',
        )
        return
    
    model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)  # NOTE 是这个的问题，少一个都不行啊，会导致eval报错，终于解决了
    outputs = single_gpu_test(model, val_dataloader)
        
    rank, _ = get_dist_info()
    if rank == 0:
        out_path = os.path.join(cfg.work_dir, 'results.pkl')
        print(f'\nwriting results to {out_path}')
        mmcv.dump(outputs, out_path)
        # eval_kwargs = cfg.get('evaluation', {}).copy()
        # hard-code way to remove EvalHook args
        # for key in [
        #         'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
        #         'rule', 'dynamic_intervals'
        # ]:
        #     eval_kwargs.pop(key, None)
        eval_kwargs = {}
        eval_kwargs.update(dict(metric=['bbox'],))
        metric = val_dataset.evaluate(outputs, **eval_kwargs)
        logger.info(f'{metric}')
        print(metric)
        runner.save_checkpoint(cfg.work_dir, 'ptq.pth')

def qat_detector(model,
                   dataset,
                   cfg,
                   quant_config,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):

    cfg = compat_cfg(cfg)
    logger = global_placeholder.logger

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']

    train_dataloader_default_args = dict(
        samples_per_gpu=2,
        workers_per_gpu=2,
        # `num_gpus` will be ignored if distributed
        num_gpus=len(cfg.gpu_ids),
        dist=distributed,
        seed=cfg.seed,
        runner_type=runner_type,
        persistent_workers=False)

    train_loader_cfg = {
        **train_dataloader_default_args,
        **cfg.data.get('train_dataloader', {})
    }
    

    data_loaders = [build_dataloader(ds, **train_loader_cfg, pin_memory=True) for ds in dataset]
    
    # cali loader的设置
    cali_loader_cfg = deepcopy(train_loader_cfg)
    cali_loader_cfg['samples_per_gpu'] = 1
    cali_loader_cfg['workers_per_gpu'] = 1
    cali_data_loader = [build_dataloader(ds, **cali_loader_cfg, pin_memory=True) for ds in dataset][0]
    
    resume_from = None
    if cfg.resume_from is None and cfg.get('auto_resume'):
        resume_from = find_latest_checkpoint(cfg.work_dir)
    if resume_from is not None:
        cfg.resume_from = resume_from

    elif cfg.load_from:  # 这个应该可以解决所谓的quant baseline问题
        checkpoint = load_ckpt_with_revise_keys(model, cfg.model.type, cfg.load_from, map_location='cpu')
        # checkpoint = load_checkpoint(model, cfg.load_from, map_location='cpu')  # TODO 测一下这样的evaluate对不对
    else:
        checkpoint = {}
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint.get('meta', {}):
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset[0].CLASSES
        
        
    # print('\nGet FakeQuant model\n')
    # model.backbone = get_quantize_model(model.backbone, quant_config, cfg.trace_config.backbone_detail)  # QAT时，这个需要eval还是train
    # model.neck = get_quantize_model(model.neck, quant_config, cfg.trace_config.neck_detail)  # QAT时，这个需要eval还是train
    # temp = get_quantize_model(model.bbox_head, quant_config, cfg.trace_config.bbox_head_detail)  # QAT时，这个需要eval还是train
    # model.bbox_head.forward = temp.forward  # 太傻蛋勒
    # model.bbox_head = temp
    model.train()  # prepare 前一定得是train模式
    model = make_qmodel_for_mmd(model, quant_config, cfg.trace_config)
    
    if not hasattr(cfg, 'tune_from'):
        cfg.tune_from = False
    if cfg.tune_from:  # 加载进行tune
        _ = load_ckpt_with_revise_keys(model, cfg.model.type, cfg.tune_from, map_location='cpu')
        # _ = load_checkpoint(model, cfg.tune_from, map_location='cpu')
    
    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = build_ddp(  # 细节不给看
            model,
            cfg.device,
            device_ids=[int(os.environ['LOCAL_RANK'])],
            broadcast_buffers=True,
            find_unused_parameters=find_unused_parameters)
    else:
        model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

    # build optimizer
    auto_scale_lr(cfg, distributed, logger)
    optimizer = build_optimizer(model, cfg.optimizer)

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))

    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is None and cfg.get('device', None) == 'npu':
        fp16_cfg = dict(loss_scale='dynamic')
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

    if distributed:
        if isinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        val_dataloader_default_args = dict(
            samples_per_gpu=1,
            workers_per_gpu=2,
            dist=distributed,
            shuffle=False,
            persistent_workers=False)

        val_dataloader_args = {
            **val_dataloader_default_args,
            **cfg.data.get('val_dataloader', {})
        }
        # Support batch_size > 1 in validation

        if val_dataloader_args['samples_per_gpu'] > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.val.pipeline = replace_ImageToTensor(
                cfg.data.val.pipeline)
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))

        val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
        # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
        runner.register_hook(
            eval_hook(val_dataloader, **eval_cfg), priority='LOW')

    # cali_data_loader = deepcopy(val_dataloader)  # TODO 这里改成train的
    if not cfg.resume_from and not cfg.tune_from:
        print('\nCalibrating\n')
        enable_calibration(model)
        
        # enable_calibration(model.module.backbone)
        # enable_calibration(model.module.neck)
        # model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
        # if model_general_architecture == 'FasterRCNN':
        #     enable_calibration(model.module.rpn_head)
        #     enable_calibration(model.module.roi_head.bbox_head)
        # else:
        #     enable_calibration(model.module.bbox_head)
        
        calibrate(model, cali_data_loader
                #   , 1
                  , quant_config.quantize.cali_batchnum
                  )  # NOTE 训练集的data确实全带container
    # 清除cali loader，释放内存
    del cali_data_loader
    
    enable_quantization(model)
    
    # enable_quantization(model.module.backbone)
    # enable_quantization(model.module.neck)
    # model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
    # if model_general_architecture == 'FasterRCNN':
    #     enable_quantization(model.module.rpn_head)
    #     enable_quantization(model.module.roi_head.bbox_head)
    # else:
    #     enable_quantization(model.module.bbox_head)
        
    
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    runner.run(data_loaders, cfg.workflow)
    # with torch.autograd.set_detect_anomaly(True):
    #     runner.run(data_loaders, cfg.workflow)
    
def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):

    cfg = compat_cfg(cfg)
    logger = global_placeholder.logger

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']

    train_dataloader_default_args = dict(
        samples_per_gpu=2,
        workers_per_gpu=2,
        # `num_gpus` will be ignored if distributed
        num_gpus=len(cfg.gpu_ids),
        dist=distributed,
        seed=cfg.seed,
        runner_type=runner_type,
        persistent_workers=False)

    train_loader_cfg = {
        **train_dataloader_default_args,
        **cfg.data.get('train_dataloader', {})
    }

    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = build_ddp(  # 细节不给看
            model,
            cfg.device,
            device_ids=[int(os.environ['LOCAL_RANK'])],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
    else:
        model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

    # build optimizer
    auto_scale_lr(cfg, distributed, logger)
    optimizer = build_optimizer(model, cfg.optimizer)

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))

    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is None and cfg.get('device', None) == 'npu':
        fp16_cfg = dict(loss_scale='dynamic')
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

    if distributed:
        if isinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        val_dataloader_default_args = dict(
            samples_per_gpu=1,
            workers_per_gpu=2,
            dist=distributed,
            shuffle=False,
            persistent_workers=False)

        val_dataloader_args = {
            **val_dataloader_default_args,
            **cfg.data.get('val_dataloader', {})
        }
        # Support batch_size > 1 in validation

        if val_dataloader_args['samples_per_gpu'] > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.val.pipeline = replace_ImageToTensor(
                cfg.data.val.pipeline)
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))

        val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
        # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
        runner.register_hook(
            eval_hook(val_dataloader, **eval_cfg), priority='LOW')

    resume_from = None
    if cfg.resume_from is None and cfg.get('auto_resume'):
        resume_from = find_latest_checkpoint(cfg.work_dir)
    if resume_from is not None:
        cfg.resume_from = resume_from

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow)
