"""
Benchmark runner for PyTorch / HuggingFace Transformers inference engine.
Only model forward is benchmarked.
"""

import os
import argparse

from tqdm import tqdm
from pathlib import Path
from datetime import timedelta

import torch
import torch.distributed as dist

from PyTorch.backend import PyTorchBackend
from FlashInfer.utils import setup_seed
from profiler import Profiler, ProfilerConfig, register_active_decode_profiler


def print_on_rank0(message):
    if dist.is_initialized():
        if dist.get_rank() == 0:
            print(message)
    else:
        print(message)


def dist_is_initialized():
    return dist.is_available() and dist.is_initialized()


def init_dist_if_needed():
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    if world_size > 1 and not dist_is_initialized():
        dist.init_process_group(backend="nccl", timeout=timedelta(minutes=30))
        local_rank = int(os.environ.get("LOCAL_RANK", "0"))
        torch.cuda.set_device(local_rank)


def get_rank_world():
    if dist_is_initialized():
        return dist.get_rank(), dist.get_world_size()
    return 0, 1


def parse_args():
    parser = argparse.ArgumentParser(description='FlashInfer Benchmark')

    group = parser.add_argument_group('Model Arguments')
    group.add_argument('--model_name', type=str, required=True, help='Model name (e.g., Qwen/Qwen3-8B). Assumed to be a repo_id from HuggingFace.')
    group.add_argument('--model_path', type=Path, default=None, help='Model path (e.g., Qwen3-8B/model.pth).')
    group.add_argument('--dtype', type=str, default="float16", choices=["float16", "bfloat16"], help='Data type for model execution (bfloat16, float16).')
    group.add_argument('--seed', type=int, default=42, help='Seed for random number generator.')
    group.add_argument('--attn_impl', type=str, default="eager", choices=["eager", "sdpa", "flash_attention_2"], help='Attention implementation (eager, sdpa, flash_attention_2).')
    group.add_argument('--fuse_weights', action='store_true', help='Fuse weights for model.')
    
    group = parser.add_argument_group('Benchmark Arguments')
    group.add_argument(
        '--batch_size_list',
        type=int,
        nargs='+',
        default=[16, 32, 64, 128],
        help='A list of batch sizes to benchmark (e.g., --batch_size_list 16 32 64 128).'
    )
    group.add_argument(
        '--dec_len_list', 
        type=int, 
        nargs='+', 
        default=[1, 2, 5, 10], 
        help='A list of decoding lengths to benchmark (e.g., --dec_len_list 1 2 5 10).'
    )
    group.add_argument(
        '--seq_len_list', 
        type=int, 
        nargs='+',
        default=[1024, 2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624, 28672, 30720, 32768], 
        help='A list of sequence lengths to benchmark (e.g., --seq_len_list 1024 2048 4096).'
    )

    group = parser.add_argument_group('Distributed Inference Arguments')
    group.add_argument('--tp_size', type=int, default=1, help='Tensor parallel size.')

    group = parser.add_argument_group('Profiling Arguments')
    group.add_argument('--profile_dir', type=str, default='profiler_out')
    group.add_argument('--detailed_profiling', action='store_true', help='Enable detailed profiling.')
    
    args = parser.parse_args()

    if torch.cuda.is_available():
        args.device = 'cuda'
        args.dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
    else:
        raise ValueError('CUDA is not available. Please check your environment.')
    
    return args


class BenchmarkRunner:
    def __init__(self, args, engine, profiler):
        self.args = args
        self.engine = engine
        self.profiler = profiler

    def run(self):
        for bsz in self.args.batch_size_list:
            for declen in self.args.dec_len_list:
                for seqlen in self.args.seq_len_list:
                    print_on_rank0(f"Running benchmark for batch_size {bsz}, dec_len {declen}, seq_len {seqlen} ...")
                    self.engine.setup_caches(max_batch_size=bsz, max_seq_length=seqlen+128)
                    self.run_with_seqlen(bsz, declen, seqlen)

    def run_with_seqlen(self, bsz, declen, seqlen):
        self.engine.clear_kv()
        self.engine.insert_kv(seqlen)

        # Warmup for compilation
        for _ in tqdm(range(10), desc="Warmup"):
            input_ids = torch.randint(0, self.engine.model.config.vocab_size, (bsz, declen), device=self.engine.device)
            self.engine.decode(input_ids=input_ids)
            self.engine.delete_kv(declen)

        self.profiler.begin_run(bsz=bsz, declen=declen, seqlen=seqlen, label="decode")
        register_active_decode_profiler(self.profiler)

        for _ in tqdm(range(100), desc=f"Standard Decode (seq_len {seqlen})"):
            input_ids = torch.randint(0, self.engine.model.config.vocab_size, (bsz, declen), device=self.engine.device)
            with self.profiler.time_decode():
                self.engine.decode(input_ids=input_ids)
            self.engine.delete_kv(declen)

        register_active_decode_profiler(None)
        self.profiler.end_run()


def main():
    args = parse_args()
    setup_seed(args.seed)

    init_dist_if_needed()
    rank, world_size = get_rank_world()
    print_on_rank0(f"attn_impl={args.attn_impl}, tp_size={args.tp_size}, world_size={world_size}")

    print(f"[rank {rank}] Initializing PyTorch engine ...")
    engine = PyTorchBackend(dtype=args.dtype, device=torch.device("cuda"))
    engine.load_model(
        model_name=args.model_name,
        checkpoint_path=args.model_path,
        attn_impl=args.attn_impl,
        tp_size=args.tp_size,
        world_size=world_size,
        fuse_weights=args.fuse_weights,
    )
    # engine.model.fuse_weights()

    profiler = Profiler(ProfilerConfig(output_dir=args.profile_dir, detailed_profiling=args.detailed_profiling))
    profiler.attach_model(engine.model)

    runner = BenchmarkRunner(args, engine, profiler)
    runner.run()

    profiler.save_all()

    if args.tp_size > 1 and dist.is_initialized():
        print(f"[rank {rank}] Cleaning up distributed process group ...")
        dist.destroy_process_group()

if __name__ == "__main__":
    main()