# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
from functools import partial
from pathlib import Path

import numpy as np
import torch
from mmengine.config import Config, DictAction
from mmengine.logging import MMLogger
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner
from mmengine.utils import digit_version

from mmdet.registry import MODELS

try:
    from mmengine.analysis import get_model_complexity_info
    from mmengine.analysis.print_helper import _format_size
    # 检查是否有 FlopAnalyzer
    try:
        from mmengine.analysis import FlopAnalyzer
        HAS_FLOP_ANALYZER = True
    except ImportError:
        HAS_FLOP_ANALYZER = False
        print("Warning: FlopAnalyzer not available in this version of mmengine")
except ImportError:
    raise ImportError('Please upgrade mmengine >= 0.6.0')


def parse_args():
    parser = argparse.ArgumentParser(description='Get a detector flops')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--num-images',
        type=int,
        default=100,
        help='num images of calculate model flops')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    args = parser.parse_args()
    return args


def sum_by_keywords(flops_dict, *keywords):
    """根据关键字汇总特定模块的 FLOPs"""
    total = 0
    for module_name, flops_val in flops_dict.items():
        # 检查是否包含所有关键字
        if all(keyword in module_name for keyword in keywords):
            total += flops_val
    return total


def inference(args, logger):
    if digit_version(torch.__version__) < digit_version('1.12'):
        logger.warning(
            'Some config files, such as configs/yolact and configs/detectors,'
            'may have compatibility issues with torch.jit when torch<1.12. '
            'If you want to calculate flops for these models, '
            'please make sure your pytorch version is >=1.12.')

    config_name = Path(args.config)
    if not config_name.exists():
        logger.error(f'{config_name} not found.')

    cfg = Config.fromfile(args.config)
    cfg.val_dataloader.batch_size = 1
    cfg.work_dir = tempfile.TemporaryDirectory().name

    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    init_default_scope(cfg.get('default_scope', 'mmdet'))

    # TODO: The following usage is temporary and not safe
    # use hard code to convert mmSyncBN to SyncBN. This is a known
    # bug in mmengine, mmSyncBN requires a distributed environment，
    # this question involves models like configs/strong_baselines
    if hasattr(cfg, 'head_norm_cfg'):
        cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True)
        cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict(
            type='SyncBN', requires_grad=True)
        cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict(
            type='SyncBN', requires_grad=True)

    result = {}
    avg_flops = []
    avg_decoder_flops = []   # 新增：用于记录每张图的 decoder FLOPs

    data_loader = Runner.build_dataloader(cfg.val_dataloader)
    model = MODELS.build(cfg.model)
    if torch.cuda.is_available():
        model = model.cuda()
    model = revert_sync_batchnorm(model)
    model.eval()
    _forward = model.forward

    for idx, data_batch in enumerate(data_loader):
        if idx == args.num_images:
            break

        data = model.data_preprocessor(data_batch)
        result['ori_shape'] = data['data_samples'][0].ori_shape
        result['pad_shape'] = data['data_samples'][0].pad_shape
        if hasattr(data['data_samples'][0], 'batch_input_shape'):
            result['pad_shape'] = data['data_samples'][0].batch_input_shape

        # 让 forward 走和 get_flops 同样的路径
        model.forward = partial(_forward, data_samples=data['data_samples'])

        # —— 总 FLOPs（保持你原来的统计）——
        outputs = get_model_complexity_info(
            model,
            None,
            inputs=data['inputs'],
            show_table=False,
            show_arch=False)
        avg_flops.append(outputs['flops'])
        params = outputs['params']
        result['compute_type'] = 'dataloader: load a picture from the dataset'

        # —— 关键：拿"按子模块"字典并汇总 decoder —— 
        if HAS_FLOP_ANALYZER:
            fa = FlopAnalyzer(model, data['inputs'])
            by_module = fa.by_module()  # dict: {模块路径字符串: FLOPs(int)}

            # 根据模块名聚合，适配 DETR: 通常包含 'transformer.decoder'
            # 尽量严格：同时包含 'transformer' 和 'decoder'
            dec_flops = sum_by_keywords(by_module, 'transformer', 'decoder')

            # 回退：如果有些实现只有 'decoder' 也加上
            if dec_flops == 0:
                dec_flops = sum_by_keywords(by_module, 'decoder')

            avg_decoder_flops.append(dec_flops)
        else:
            # 如果 FlopAnalyzer 不可用，设置一个默认值
            avg_decoder_flops.append(0)

    del data_loader

    mean_flops = _format_size(int(np.average(avg_flops)))
    params = _format_size(params)
    
    result['flops'] = mean_flops
    result['params'] = params
    
    # 只在有 decoder FLOPs 数据时添加
    if HAS_FLOP_ANALYZER and avg_decoder_flops:
        mean_decoder_flops = _format_size(int(np.average(avg_decoder_flops)))
        result['decoder_flops'] = mean_decoder_flops
    else:
        result['decoder_flops'] = "N/A"
    
    return result


def main():
    args = parse_args()
    logger = MMLogger.get_instance(name='MMLogger')
    result = inference(args, logger)
    split_line = '=' * 30
    ori_shape = result['ori_shape']
    pad_shape = result['pad_shape']
    flops = result['flops']
    params = result['params']
    decoder_flops = result.get('decoder_flops', 'N/A')
    compute_type = result['compute_type']

    if pad_shape != ori_shape:
        print(f'{split_line}\nUse size divisor set input shape '
              f'from {ori_shape} to {pad_shape}')
    
    print(f'{split_line}\nCompute type: {compute_type}\n'
          f'Input shape: {pad_shape}\nFlops: {flops}\n'
          f'Params: {params}')
    
    if decoder_flops != 'N/A':
        print(f'Decoder Flops: {decoder_flops}')
    
    print(split_line)
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify '
          'that the flops computation is correct.')


if __name__ == '__main__':
    main()