"""
Test script to check if the model loading and generation process is working properly before running the benchmark.
"""

import os
import sys
import time
import argparse
import warnings

from tqdm import tqdm
from pathlib import Path
from typing import Optional
from itertools import permutations

import torch
import torch.distributed as dist

from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

from .backend import FlashInferBackend
from .utils import init_dist, setup_seed, sample


def print_on_rank0(message):
    if dist.is_initialized():
        if dist.get_rank() == 0:
            print(message)
    else:
        print(message)


def apply_chat_template(tokenizer, text: str, system_prompt: Optional[str], enable_thinking: bool) -> str:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": text})

    # Qwen3 supports enable_thinking flag; others will just ignore extra kw.
    # Ref: Qwen docs about apply_chat_template & thinking mode.
    try:
        templated = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking,
        )
    except TypeError:
        print_on_rank0("Tokenizer does not support enable_thinking flag")
        templated = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
    return templated


def load_GSM8k_dataset(tokenizer, prefix_len=256):
    ds = load_dataset("openai/gsm8k", "main", split="test")
    
    def tokenize_fn(examples):
        texts = [apply_chat_template(tokenizer, q, None, True) for q in examples["question"]]
        return tokenizer(texts, return_tensors="pt", max_length=prefix_len, padding="max_length", truncation=True)
    ds = ds.map(tokenize_fn, batched=True, remove_columns=["question"])
    ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

    return ds


def parse_args():
    parser = argparse.ArgumentParser(description='FlashInfer Benchmark')

    group = parser.add_argument_group('Model Arguments')
    group.add_argument('--model_name', type=str, required=True, help='Model name (e.g., Qwen/Qwen3-8B). Assumed to be a repo_id from HuggingFace.')
    group.add_argument('--model_path', type=Path, default=None, help='Model path (e.g., Qwen3-8B/model.pth).')
    group.add_argument('--dtype', type=str, default="float16", choices=["float16", "bfloat16"], help='Data type for model execution (bfloat16, float16).')
    group.add_argument('--seed', type=int, default=42, help='Seed for random number generator.')

    group = parser.add_argument_group('Benchmark Arguments')
    group.add_argument('--batch_size', type=int, default=16, help='Batch size for inference.')
    group.add_argument('--prefix_len', type=int, default=256, help='Length of the input prompt sequence.')
    group.add_argument('--max_gen_len', type=int, default=1024, help='Maximum number of tokens to generate.')
    group.add_argument('--dataset', type=str, default="GSM8K", help='Dataset name (GSM8K).')
    
    group = parser.add_argument_group('Distributed Inference Arguments')
    group.add_argument('--tp_size', type=int, default=1, help='Tensor parallel size.')

    group = parser.add_argument_group('Sampling Arguments')
    group.add_argument('--temperature', type=float, default=0.0, help='Temperature for sampling.')
    group.add_argument('--top_p', type=float, default=0.95, help='Top-p sampling.')
    group.add_argument('--top_k', type=int, default=0, help='Top-k sampling.')

    args = parser.parse_args()

    if torch.cuda.is_available():
        args.device = 'cuda'
        args.dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
    else:
        raise ValueError('CUDA is not available. Please check your environment.')

    if args.tp_size > 1:
        args.rank_group = list(range(args.tp_size))
    else:
        args.rank_group = None
    
    return args


def generate(args, engine, tokenizer, input_ids, query_lens, bsz):
    max_gen_len = args.max_gen_len
    eot_token = args.eot_token
    temperature = args.temperature
    top_k = args.top_k
    top_p = args.top_p
    device = engine.device

    batch_indices = torch.arange(bsz, device=device)
    output = torch.zeros(bsz, max_gen_len + 1, device=device, dtype=torch.long)
    output[:, :query_lens[0]] = input_ids[:, :query_lens[0]]
    num_total_tokens = query_lens.clone()
    model_steps = 0
    
    next_logits = engine.prefill(input_ids=input_ids, query_lens=query_lens)
    if temperature == 0.0:
        next_tokens = torch.argmax(next_logits, dim=-1)
    else:
        next_tokens = sample(next_logits, top_p=top_p, top_k=top_k, temperature=temperature)
    
    output[batch_indices, num_total_tokens] = next_tokens[:, 0]
    num_total_tokens += 1
    model_steps += 1
    
    terminal = False
    while num_total_tokens.max() < max_gen_len and not terminal:
        logits = engine.decode(input_ids=next_tokens)
        if temperature == 0.0:
            next_tokens = torch.argmax(logits, dim=-1)
        else:
            next_tokens = sample(logits, top_p=top_p, top_k=top_k, temperature=temperature)
        
        output[batch_indices, num_total_tokens] = next_tokens[:, 0]
        num_total_tokens += 1
        model_steps += 1

        print_on_rank0(f"{tokenizer.decode(output[0, :num_total_tokens[0]])}")

        if (next_tokens[:, 0] == eot_token).any():
            terminal = True

    num_generated_tokens = (num_total_tokens - query_lens).sum().item()
    return num_generated_tokens, model_steps, output, num_total_tokens


def main():
    args = parse_args()

    rank, process_group = 0, None
    use_tp = args.tp_size > 1
    if use_tp:
        rank, process_group = init_dist()
        if rank != 0:
            os.makedirs("logs/system_logs", exist_ok=True)
            log_file = open(f"logs/system_logs/rank_{rank}.log", 'w', buffering=1)
            sys.stdout = log_file
            sys.stderr = log_file
        device = torch.device(f"cuda:{rank}")
    else:
        device = torch.device("cuda")

    setup_seed(args.seed)

    print_on_rank0("Loading dataset ...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, padding_side='right')
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    args.eot_token = tokenizer.eos_token_id

    dataset = load_GSM8k_dataset(tokenizer, prefix_len=args.prefix_len)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)

    print_on_rank0("Initializing engine ...")
    engine = FlashInferBackend(dtype=args.dtype, device=device)
    engine.load_model(model_name=args.model_name, checkpoint_path=args.model_path, use_tp=use_tp, rank_group=args.rank_group, group=process_group)
    engine.setup_caches(max_batch_size=args.batch_size, max_seq_length=args.max_gen_len+256, dec_len=1)
    engine.setup_sampling_params(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p)

    total_time = 0.0
    total_model_steps = 0
    num_gen_tokens = 0
    num_eval_steps = min(5, len(dataloader))
    for step, batch in tqdm(enumerate(dataloader), total=num_eval_steps):
        if step >= num_eval_steps:
            break

        input_ids = batch['input_ids'].to(device)
        query_lens = batch['attention_mask'].to(device).sum(dim=-1).to(torch.int32)
        current_batch_size = input_ids.shape[0]

        torch.cuda.synchronize(device)
        start_time = time.perf_counter()
        
        generated_tokens, model_steps, final_output, num_total_tokens = generate(args, engine, tokenizer, input_ids, query_lens, current_batch_size)
        
        torch.cuda.synchronize(device)
        end_time = time.perf_counter()

        for i in range(current_batch_size):
            print_on_rank0(f"=================== Sequence {i} ===================")
            print_on_rank0(f"{tokenizer.decode(final_output[i, :num_total_tokens[i]])}")
            print_on_rank0("="*50)
        
        total_time += (end_time - start_time)
        num_gen_tokens += generated_tokens
        total_model_steps += model_steps

        if use_tp and dist.is_initialized():
            dist.barrier()

    print_on_rank0(f"Total time: {total_time:.2f}s")
    print_on_rank0(f"Total generated tokens: {num_gen_tokens}")
    print_on_rank0(f"Total model steps: {total_model_steps}")
    print_on_rank0(f"Throughput: {num_gen_tokens / total_time:.2f} tokens/s")
    print_on_rank0(f"Mean Generated Length: {num_gen_tokens / total_model_steps:.2f} tokens")
        
    if use_tp and dist.is_initialized():
        print(f"[RANK {rank}] Cleaning up distributed process group ...")
        dist.destroy_process_group()


if __name__ == "__main__":
    main()