#!/usr/bin/env python3
"""
benchmark_infer.py
简单的推理速度基准脚本：
- 使用 HuggingFace 的 AutoModelForCausalLM / AutoTokenizer 加载模型
- 在指定 GPU 上进行若干次 warmup 后，执行多次生成并记录耗时
- 输出平均、最小、最大、P50、P90 等统计信息，并可选择保存为 JSON

设计目标：能快速验证 run+.sh 下模型在给定 GPU/精度下的生成速度。
"""
import os
import time
import json
import argparse
import statistics
from typing import List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from types import SimpleNamespace
from utils.common import *

# 为了可以调用 main.main 来加载按 run+.sh 配置处理后的模型
try:
    from main import main as prepare_model
    # 不直接导入 main.parse_args（会解析当前 CLI），如果需要可导入
    parse_main_args = None
except Exception:
    prepare_model = None
    parse_main_args = None


def parse_args():
    p = argparse.ArgumentParser()
    # Inference Benchmark
    p.add_argument('--gpu', type=str, default='0', help='GPU id, 会设置 CUDA_VISIBLE_DEVICES')
    p.add_argument('--model', type=str, required=True, help='模型路径或 HuggingFace 名称')
    p.add_argument('--prompt', type=str, default='Hello, introduce yourself.', help='测试用 prompt')
    p.add_argument('--prefill_wikitext', action='store_true', help='使用 wikitext 数据做 prefill（前向）速度测试')
    p.add_argument('--seqlen', type=int, default=2048, help='prefill 序列长度（token）')
    p.add_argument('--nsamples', type=int, default=5, help='用于 prefill 测试的样本数量（切片数量）')
    p.add_argument('--repeat', type=int, default=10, help='计时的重复次数（不包括 warmup）')
    p.add_argument('--warmup', type=int, default=2, help='预热次数')
    p.add_argument('--batch_size', type=int, default=1, help='batch size for generation')
    # p.add_argument('--max_new_tokens', type=int, default=32, help='生成的最大新 token 数')
    p.add_argument('--use_generate', action='store_true', help='使用 model.generate（默认是使用 generate）')
    p.add_argument('--use_mxfp4', action='store_true', help='使用 main.main 来加载模型')
    p.add_argument('--gen_lengths', type=str, default='', help='逗号分隔的多个生成长度，例如 "8,16,32"')
    p.add_argument('--torch_dtype', type=str, default='auto',
                   choices=['auto', 'bf16', 'fp16', 'float32'], help='加载模型时的 dtype')
    p.add_argument('--trust_remote_code', action='store_true', help='从 remote repo 使用 trust_remote_code=True')
    p.add_argument('--out', type=str, default=None, help='可选：将结果保存为 JSON 文件')
    return p.parse_args()


def dtype_from_str(s: str):
    if s == 'bf16':
        return torch.bfloat16
    if s == 'fp16':
        return torch.float16
    if s == 'float32':
        return torch.float32
    return None


def make_inputs(tokenizer, prompt: str, batch_size: int, device: torch.device):
    enc = tokenizer(prompt, return_tensors='pt')
    input_ids = enc.input_ids
    if batch_size > 1:
        input_ids = input_ids.repeat(batch_size, 1)
    return input_ids.to(device)


def run_once_generate(model, input_ids, max_new_tokens: int):
    # 使用 generate 的典型方式
    out = model.generate(input_ids, max_new_tokens=max_new_tokens, eos_token_id=None)
    return out


def stats(times: List[float]):
    t = sorted(times)
    if len(t) <= 1:
        mean = statistics.mean(t) if t else 0.0
    else:
        mean = statistics.mean(t[:-1])  # 排除最大值计算平均
    return {
        'count': len(t),
        'mean': mean,
        'min': min(t) if t else 0.0,
        'max': max(t) if t else 0.0,
        'p50': t[int(len(t) * 0.5)] if t else 0.0,
        'p90': t[int(len(t) * 0.9)] if t else 0.0,
        'all': t,
    }


def _convert_time_stats_to_ms(s: dict):
    """Convert time fields in a stats dict from seconds to milliseconds in-place.
    Only adjusts numeric time fields: mean, min, max, p50, p90 and the 'all' list.
    """
    if not s:
        return s
    # convert scalar entries if present
    for k in ('mean', 'min', 'max', 'p50', 'p90'):
        if k in s and isinstance(s[k], (int, float)):
            s[k] = s[k] * 1000.0
    # convert list of samples
    if 'all' in s and isinstance(s['all'], list):
        s['all'] = [float(x) * 1000.0 for x in s['all']]
    return s


def main():
    args = parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}, CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")

    # 选 dtype
    td = dtype_from_str(args.torch_dtype)
    load_dtype = td if td is not None else None

    if args.use_mxfp4:
        if prepare_model is None:
            raise RuntimeError('Cannot import main.main from main.py in this workspace')
        # 构造一个 minimal args 传给 main.main，避免 parse_args() 解析当前脚本的 argv
        main_args = SimpleNamespace()
        # 将一些常用字段传入 main.main（可根据需要扩展）
        main_args.model = args.model
        main_args.seed = getattr(args, 'seed', 0)
        main_args.pre_smooth = None
        main_args.pre_smooth_alpha = 0.5
        main_args.rotate = True
        main_args.rotate_mode = 'group_hadamard'
        main_args.rotate_kv = False
        main_args.group_rotate_kv = False
        main_args.kv_quant_only = False
        main_args.kv_tokenwise = False
        main_args.online_partial_had = True
        main_args.r_path = None
        main_args.spinquant = False
        # quantization defaults (覆盖为常见测试值)
        main_args.w_elem_format_linear = 'fp4_e2m1'
        main_args.a_elem_format_linear = 'fp4_e2m1'
        main_args.double_quant_linear = False
        main_args.w_dq_only = False
        main_args.scale_bits_linear = 8
        main_args.block_size_linear = 32
        main_args.w_scale_mode = 0
        main_args.a_scale_mode = 0
        main_args.A_elem_format_matmul = 'none'
        main_args.B_elem_format_matmul = 'none'
        main_args.double_quant_matmul = False
        main_args.scale_bits_matmul = 8
        main_args.block_size_matmul = 32
        main_args.w_elem_format_ln = 'none'
        main_args.a_elem_format_ln = 'none'
        main_args.scale_bits_ln = 8
        main_args.block_size_ln = 32
        main_args.w_elem_format_head = 'none'
        main_args.a_elem_format_head = 'none'
        main_args.scale_bits_head = 8
        main_args.block_size_head = 32
        main_args.auto_dtype = False
        main_args.custom_cuda = True
        main_args.a_scale_mode = 0
        main_args.w_scale_mode = 0
        main_args.A_scale_mode = 0
        main_args.B_scale_mode = 0
        main_args.per_tensor = False
        main_args.gptq = False
        main_args.gptq_percdamp = 0.01
        main_args.gptq_cal_dataset = 'wikitext2'
        main_args.gptq_cal_nsamples = 1024
        main_args.gptq_cal_seqlen = 256
        main_args.kurtail = False

        # 调用 main.main 返回 model, tokenizer
        model, tokenizer = prepare_model(main_args)
        # model.eval()
        # model.to(device)
        # 这里忽略 torch_dtype 选项，因为 main.main 会用传入参数决定
    else:
        print(
            f"Loading tokenizer and model: {args.model} (trust_remote_code={args.trust_remote_code}, dtype={args.torch_dtype})")
        tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code, use_fast=False)

        model = AutoModelForCausalLM.from_pretrained(
            args.model, torch_dtype=load_dtype, trust_remote_code=args.trust_remote_code)
        # model.eval()

        # 将模型移动到设备（若 dtype 是 bf16/fp16，to(device) 将会保留 dtype）
        # model.to(device)
    from utils.common import distribute_model
    model.eval()
    distribute_model(model)

    # Only prepare prompt input_ids for generate tests. For prefill_wikitext tests
    # we do not use manual prompt input.
    input_ids = None
    if not args.prefill_wikitext:
        input_ids = make_inputs(tokenizer, args.prompt, args.batch_size, device)

    # parse gen_lengths
    gen_lengths = []
    if args.gen_lengths:
        try:
            gen_lengths = [int(x) for x in args.gen_lengths.split(',') if x.strip()]
        except Exception:
            raise ValueError('Invalid --gen_lengths format, expected comma-separated ints')

    # 如果选择使用 wikitext 做 prefill 测试，准备数据（从 cache 读取或构建）
    wikitext_inputs = None
    if args.prefill_wikitext:
        print(f"Building wikitext dataset via utils.calib.get_wikitext2_test ...")
        from utils.calib import get_wikitext2_test
        ds = get_wikitext2_test(seed=0, seqlen=args.seqlen, model=args.model)
        enc_input_ids = ds.input_ids

        # 切片出 nsamples 段，每段长度为 seqlen
        total_tokens = enc_input_ids.numel()
        max_samples = max(1, total_tokens // args.seqlen)
        n = min(args.nsamples, max_samples)
        wikitext_inputs = []
        for i in range(n):
            start = i * args.seqlen
            end = start + args.seqlen
            sample = enc_input_ids[:, start:end].to(device)
            if args.batch_size > 1:
                sample = sample.repeat(args.batch_size, 1)
            wikitext_inputs.append(sample)

    # 预热
    with torch.no_grad():
        # 如果使用 wikitext prefill，额外预热前向
        if args.prefill_wikitext and wikitext_inputs:
            print(f"Warmup prefill with {len(wikitext_inputs)} samples")
            for i in range(max(1, args.warmup)):
                model.config.use_cache = False
                for sample in wikitext_inputs:
                    _ = model(sample)
                    break
                model.config.use_cache = True

    # 计时：两种模式
    times = []
    # store memory traces parallel to times when using CUDA
    mems = []
    device_is_cuda = (device.type == 'cuda')
    with torch.no_grad():
        if args.prefill_wikitext and wikitext_inputs:
            print(f"Running prefill wikitext speed test on {len(wikitext_inputs)} samples (repeat={args.repeat})")
            # 对每个样本进行 repeat 次测量
            model.config.use_cache = False
            for idx, sample in enumerate(wikitext_inputs):
                sample_times = []
                sample_mems = []
                for i in range(args.repeat):
                    if device_is_cuda:
                        torch.cuda.reset_peak_memory_stats()
                    t0 = time.perf_counter()
                    _ = model(sample)
                    t1 = time.perf_counter()
                    elapsed = t1 - t0
                    if device_is_cuda:
                        peak = torch.cuda.max_memory_allocated()
                    else:
                        peak = 0
                    sample_times.append(elapsed)
                    sample_mems.append(peak)
                times.append(sample_times)
                mems.append(sample_mems)
                # print peak memory in GB
                print(
                    f"sample {idx+1}/{len(wikitext_inputs)} mean={statistics.mean(sample_times):.4f}s min={min(sample_times):.4f}s max={max(sample_times):.4f}s peak_mem={max(sample_mems) / (1024**3):.4f}GB")
            model.config.use_cache = True
        else:
            # Run generate tests for either --max_new_tokens or --gen_lengths list
            gen_targets = gen_lengths  # if gen_lengths else [args.max_new_tokens]
            gen_results = {}
            for gl in gen_targets:
                print(f"Running generate speed test for max_new_tokens={gl} (repeat={args.repeat})")
                gl_times = []
                gl_mems = []
                for i in range(args.repeat):
                    if device_is_cuda:
                        torch.cuda.reset_peak_memory_stats()
                    t0 = time.perf_counter()
                    run_once_generate(model, input_ids, gl)
                    t1 = time.perf_counter()
                    elapsed = t1 - t0
                    if device_is_cuda:
                        peak = torch.cuda.max_memory_allocated()
                    else:
                        peak = 0
                    gl_times.append(elapsed)
                    gl_mems.append(peak)
                    print(f"iter {i+1}/{args.repeat}: {elapsed:.4f}s peak_mem={peak/(1024**3):.4f}GB")
                s = stats(gl_times)
                m = stats(gl_mems)
                # merge memory stats into time stats dict (store in GB)
                s['peak_memory_gb_mean'] = m['mean'] / (1024**3)
                s['peak_memory_gb_min'] = m['min'] / (1024**3)
                s['peak_memory_gb_max'] = m['max'] / (1024**3)
                gen_results[gl] = s
            times = gen_results

    # 汇总输出
    result = {
        'model': args.model,
        'gpu': args.gpu,
        'batch_size': args.batch_size,
        'torch_dtype': args.torch_dtype,
        'repeat': args.repeat,
        'warmup': args.warmup,
        'prefill_wikitext': args.prefill_wikitext,
        'seqlen': args.seqlen,
        'nsamples': args.nsamples,
        'prefill_summary': None,
        'generate_summary': None,
    }

    if args.prefill_wikitext and wikitext_inputs:
        # times 是每个 sample 的 list
        summary = [stats(st) for st in times]
        # 将时间单位转换为 ms 并计算吞吐量：tokens/sec（考虑 batch_size）
        for s in summary:
            _convert_time_stats_to_ms(s)
            mean_sec = s['mean'] / 1000.0 if s.get('mean', 0) > 0 else 1e-12
            s['tokens_per_sec'] = (args.seqlen * args.batch_size) / mean_sec
        overall_tps = sum(s['tokens_per_sec'] for s in summary) / len(summary) if summary else 0.0
        result['prefill_summary'] = {'samples': summary, 'overall_tokens_per_sec': overall_tps}
        print('\nPrefill samples summary:')
        for i, s in enumerate(summary):
            print(f"sample {i+1}: count={s['count']}, mean={s['mean']:.2f}ms, tokens/s={s['tokens_per_sec']:.2f}")
        print(
            f"\nOverall avg tokens/sec: {overall_tps:.2f} tokens/s (seqlen={args.seqlen}, batch_size={args.batch_size})")
    else:
        # generate summary can be dict (gen_len->stats) or simple stats
        if isinstance(times, dict):
            # Ensure each entry is a stats dict
            gen_summary = {gl: stats_list if isinstance(stats_list, dict) else stats(stats_list)
                           for gl, stats_list in times.items()}
            # Add tokens/sec (throughput) for each gen length
            for gl, s in gen_summary.items():
                # convert time stats to ms (if they aren't already)
                _convert_time_stats_to_ms(s)
                mean_sec = s['mean'] / 1000.0 if s.get('mean', 0) > 0 else 1e-12
                # gl is the number of newly generated tokens
                try:
                    gen_tokens = int(gl)
                except Exception:
                    gen_tokens = float(gl)
                s['tokens_per_sec'] = (gen_tokens * args.batch_size) / mean_sec

            result['generate_summary'] = gen_summary
            print('\nGenerate summary:')
            for gl, s in gen_summary.items():
                # mean and p50 are in ms now
                print(f"gen_len={gl}: mean={s['mean']:.2f}ms, p50={s['p50']:.2f}ms, tokens/s={s['tokens_per_sec']:.2f}")
        else:
            s = stats(times)
            # convert to ms and store
            _convert_time_stats_to_ms(s)
            result['generate_summary'] = s
            print('\nSummary:')
            print(
                f"count={s['count']}, mean={s['mean']:.2f}ms, min={s['min']:.2f}ms, max={s['max']:.2f}ms, p50={s['p50']:.2f}ms, p90={s['p90']:.2f}ms")

    if args.out:
        with open(args.out, 'w', encoding='utf-8') as f:
            json.dump(result, f, indent=2)
        print(f"Saved results to {args.out}")


if __name__ == '__main__':
    main()
