"""
Benchmark runner for FlashInfer inference engine.
Only model forward is benchmarked.
"""

import os
import sys
import argparse
import warnings

from tqdm import tqdm
from pathlib import Path
from itertools import permutations

import torch
import torch.distributed as dist

from FlashInfer.backend import FlashInferBackend
from FlashInfer.utils import init_dist, setup_seed
from profiler import Profiler, ProfilerConfig, register_active_decode_profiler


def print_on_rank0(message):
    if dist.is_initialized():
        if dist.get_rank() == 0:
            print(message)
    else:
        print(message)
    

def parse_args():
    parser = argparse.ArgumentParser(description='FlashInfer Benchmark')

    group = parser.add_argument_group('Model Arguments')
    group.add_argument('--model_name', type=str, required=True, help='Model name (e.g., Qwen/Qwen3-8B). Assumed to be a repo_id from HuggingFace.')
    group.add_argument('--model_path', type=Path, default=None, help='Model path (e.g., Qwen3-8B/model.pth).')
    group.add_argument('--dtype', type=str, default="float16", choices=["float16", "bfloat16"], help='Data type for model execution (bfloat16, float16).')
    group.add_argument('--seed', type=int, default=42, help='Seed for random number generator.')

    group = parser.add_argument_group('Benchmark Arguments')
    group.add_argument(
        '--batch_size_list',
        type=int,
        nargs='+',
        default=[16, 32, 64, 128],
        help='A list of batch sizes to benchmark (e.g., --batch_size_list 16 32 64 128).'
    )
    group.add_argument(
        '--dec_len_list', 
        type=int, 
        nargs='+', 
        default=[1, 2, 5, 10], 
        help='A list of decoding lengths to benchmark (e.g., --dec_len_list 1 2 5 10).'
    )
    group.add_argument(
        '--seq_len_list', 
        type=int, 
        nargs='+',
        default=[1024, 2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432, 20480, 22528, 24576, 26624, 28672, 30720, 32768], 
        help='A list of sequence lengths to benchmark (e.g., --seq_len_list 1024 2048 4096).'
    )
    group.add_argument('--non_causal_mask', action='store_true', help='Use non-causal mask for inference.')

    group = parser.add_argument_group('Distributed Inference Arguments')
    group.add_argument('--tp_size', type=int, default=1, help='Tensor parallel size.')

    group = parser.add_argument_group('Profiling Arguments')
    group.add_argument('--profile_dir', type=str, default='profiler_out')
    group.add_argument('--detailed_profiling', action='store_true', help='Enable detailed profiling.')

    args = parser.parse_args()

    if torch.cuda.is_available():
        args.device = 'cuda'
        args.dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
    else:
        raise ValueError('CUDA is not available. Please check your environment.')

    if args.tp_size > 1:
        args.rank_group = list(range(args.tp_size))
    else:
        args.rank_group = None
    
    if args.non_causal_mask:
        warnings.warn('FlashInfer benchmark with non-causal mask may not be reliable since it is unclear how different shapes of custom mask will affect the performance.')

    return args


class BenchmarkRunner:
    def __init__(self, args, engine, profiler):
        self.args = args
        self.engine = engine
        self.profiler = profiler

    def run(self):
        for bsz in self.args.batch_size_list:
            for declen in self.args.dec_len_list:
                for seqlen in self.args.seq_len_list:
                    print_on_rank0(f"Running benchmark for batch_size {bsz}, dec_len {declen}, seq_len {seqlen} ...")
                    self.engine.setup_caches(max_batch_size=bsz, max_seq_length=seqlen+128, dec_len=declen, non_causal_mask=self.args.non_causal_mask)
                    self.run_with_seqlen(bsz, declen, seqlen)

    def run_with_seqlen(self, bsz, declen, seqlen):
        self.engine.clear_kv()
        self.engine.insert_kv(seqlen)

        # Warmup for compilation
        for _ in tqdm(range(10), desc="Warmup"):
            input_ids = torch.randint(0, self.engine.model.config.vocab_size, (bsz, declen), device=self.engine.device)
            if self.args.non_causal_mask:
                custom_mask = torch.randint(0, 2, (bsz, declen, declen), dtype=torch.bool, device=self.engine.device)
            else:
                custom_mask = None
            
            self.engine.decode(input_ids=input_ids, custom_mask=custom_mask)
            self.engine.delete_kv(declen)

        self.profiler.begin_run(bsz=bsz, declen=declen, seqlen=seqlen, label="decode")
        register_active_decode_profiler(self.profiler)

        for _ in tqdm(range(100), desc=f"Standard Decode (seq_len {seqlen})"):
            input_ids = torch.randint(0, self.engine.model.config.vocab_size, (bsz, declen), device=self.engine.device)
            if self.args.non_causal_mask:
                custom_mask = torch.randint(0, 2, (bsz, declen, declen), dtype=torch.bool, device=self.engine.device)
            else:
                custom_mask = None
            
            with self.profiler.time_decode():
                self.engine.decode(input_ids=input_ids, custom_mask=custom_mask)
            self.engine.delete_kv(declen)
        
        register_active_decode_profiler(None)
        self.profiler.end_run()

def main():
    args = parse_args()

    rank, process_group = 0, None
    use_tp = args.tp_size > 1
    if use_tp:
        rank, process_group = init_dist()
        if rank != 0:
            os.makedirs("logs/system_logs", exist_ok=True)
            log_file = open(f"logs/system_logs/rank_{rank}.log", 'w', buffering=1)
            sys.stdout = log_file
            sys.stderr = log_file

    setup_seed(args.seed)

    print(f"[RANK {rank}] Initializing FlashInfer engine ...")
    engine = FlashInferBackend(dtype=args.dtype, device=args.device)
    engine.load_model(model_name=args.model_name, checkpoint_path=args.model_path, use_tp=use_tp, rank_group=args.rank_group, group=process_group)

    profiler = Profiler(ProfilerConfig(output_dir=args.profile_dir, detailed_profiling=args.detailed_profiling))
    profiler.attach_model(engine.model)

    runner = BenchmarkRunner(args, engine, profiler)
    runner.run()

    profiler.save_all()

    if use_tp and dist.is_initialized():
        print(f"[RANK {rank}] Cleaning up distributed process group ...")
        dist.destroy_process_group()

if __name__ == "__main__":
    main()