import os
import sys
import time
import argparse
import warnings

from tqdm import tqdm
from pathlib import Path
from typing import Optional

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, DynamicCache, Qwen3Config,
)

from FlashInfer.utils import setup_seed, sample
from PyTorch.models.modeling_qwen3 import Qwen3ForCausalLM

AutoModelForCausalLM.register(Qwen3Config, Qwen3ForCausalLM, exist_ok=True)


def apply_chat_template(tokenizer, text: str, system_prompt: Optional[str], enable_thinking: bool) -> str:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": text})
    try:
        templated = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking,
        )
    except TypeError:
        print("Tokenizer does not support enable_thinking flag")
        templated = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
    return templated


def load_GSM8k_dataset(tokenizer):
    ds = load_dataset("openai/gsm8k", "main", split="test")

    def tokenize_fn(examples):
        texts = [apply_chat_template(tokenizer, q, None, True) for q in examples["question"]]
        return tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

    ds = ds.map(tokenize_fn, batched=True, remove_columns=["question"])
    ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
    return ds


def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Benchmark (FA2 + TP enabled)')

    # Model
    group = parser.add_argument_group('Model Arguments')
    group.add_argument('--model_name', type=str, required=True,
                       help='HF repo_id (e.g., Qwen/Qwen2.5-7B-Instruct)')
    group.add_argument('--model_path', type=Path, default=None,
                       help='Local path (optional). If set, overrides --model_name')
    group.add_argument('--dtype', type=str, default="bfloat16", choices=["float16", "bfloat16"],
                       help='Model dtype')
    group.add_argument('--seed', type=int, default=42, help='Random seed')

    # Inference
    group = parser.add_argument_group('Benchmark Arguments')
    group.add_argument('--batch_size', type=int, default=16, help='Batch size')
    group.add_argument('--max_gen_len', type=int, default=1024, help='Max new tokens to generate')
    group.add_argument('--dataset', type=str, default="GSM8K", help='Dataset name (GSM8K)')

    # Sampling
    group = parser.add_argument_group('Sampling Arguments')
    group.add_argument('--temperature', type=float, default=0.0, help='Temperature')
    group.add_argument('--top_p', type=float, default=0.95, help='Top-p')
    group.add_argument('--top_k', type=int, default=0, help='Top-k')

    # FlashAttention / TP
    group = parser.add_argument_group('Performance')
    group.add_argument('--attn_impl', type=str, default="flash_attention_2",
                       choices=["flash_attention_2", "sdpa", "eager", "paged_attention"],
                       help='Attention backend to use inside the model')
    group.add_argument('--tp_size', type=int, default=1,
                       help='Tensor parallel world size. If >1, launch with torchrun and tp_plan auto.')
    group.add_argument('--tp_plan', type=str, default="auto",
                       choices=["auto", "none"],
                       help='TP partitioning plan. Use "auto" for supported models; "none" disables TP even if world_size>1.')

    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError('CUDA is not available. Please check your environment.')

    args.dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16

    return args


def dist_is_initialized():
    return dist.is_available() and dist.is_initialized()


def init_dist_if_needed():
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    if world_size > 1 and not dist_is_initialized():
        dist.init_process_group(backend="nccl", timeout=torch.timedelta(minutes=30))
        local_rank = int(os.environ.get("LOCAL_RANK", "0"))
        torch.cuda.set_device(local_rank)


def get_rank_world():
    if dist_is_initialized():
        return dist.get_rank(), dist.get_world_size()
    return 0, 1


def bcast_batch_tensors(input_ids, attention_mask, src=0):
    if not dist_is_initialized():
        return input_ids, attention_mask

    rank, _ = get_rank_world()
    if rank == src:
        shape = torch.tensor([input_ids.size(0), input_ids.size(1)], device=input_ids.device, dtype=torch.int64)
    else:
        shape = torch.empty(2, device=torch.device("cuda"), dtype=torch.int64)
    dist.broadcast(shape, src=src)

    B, S = int(shape[0].item()), int(shape[1].item())
    if rank != src:
        input_ids = torch.empty((B, S), dtype=torch.long, device=torch.device("cuda"))
        attention_mask = torch.empty((B, S), dtype=torch.long, device=torch.device("cuda"))

    dist.broadcast(input_ids, src=src)
    dist.broadcast(attention_mask, src=src)
    return input_ids, attention_mask


def generate(args, model, tokenizer, input_ids, attention_mask, is_main_process: bool):
    max_gen_len = args.max_gen_len
    eot_token = args.eot_token
    temperature = args.temperature
    top_k = args.top_k
    top_p = args.top_p

    model_device = getattr(model, "device", torch.device("cuda"))
    input_ids = input_ids.to(model_device)
    attention_mask = attention_mask.to(model_device)

    bsz = input_ids.shape[0]
    query_lens = attention_mask.sum(dim=-1).to(torch.int32)

    batch_indices = torch.arange(bsz, device=model_device)
    output = torch.zeros(bsz, max_gen_len + 1, device=model_device, dtype=torch.long)
    output[:, :query_lens[0]] = input_ids[:, -query_lens[0]:]
    num_total_tokens = query_lens.clone()
    model_steps = 0

    with torch.inference_mode():
        prefill_output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=None,
            use_cache=True,
        )
    past_key_values = prefill_output.past_key_values

    new_past_key_values = tuple(
        (k[..., -query_lens[0]:, :].contiguous(), v[..., -query_lens[0]:, :].contiguous())
        for k, v in past_key_values
    )
    past_key_values = DynamicCache.from_legacy_cache(past_key_values=new_past_key_values)

    next_logits = prefill_output.logits[:, -1, :]
    if temperature == 0.0:
        next_tokens = torch.argmax(next_logits, dim=-1, keepdim=True)  # [B, 1]
    else:
        next_tokens = sample(next_logits, top_p=top_p, top_k=top_k, temperature=temperature)

    output[batch_indices, num_total_tokens] = next_tokens[:, 0]
    num_total_tokens += 1
    model_steps += 1

    terminal = False
    while num_total_tokens.max() < max_gen_len and not terminal:
        with torch.inference_mode():
            decode_output = model(
                input_ids=next_tokens,
                past_key_values=past_key_values,
                use_cache=True,
            )
        past_key_values = decode_output.past_key_values
        logits = decode_output.logits
        if temperature == 0.0:
            next_tokens = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            next_tokens = sample(logits, top_p=top_p, top_k=top_k, temperature=temperature)

        output[batch_indices, num_total_tokens] = next_tokens[:, 0]
        num_total_tokens += 1
        model_steps += 1

        if (next_tokens[:, 0] == eot_token).any():
            terminal = True

    num_generated_tokens = (num_total_tokens - query_lens).sum().item()
    return num_generated_tokens, model_steps, output, num_total_tokens


def main():
    args = parse_args()
    setup_seed(args.seed)

    init_dist_if_needed()
    rank, world_size = get_rank_world()
    is_main = (rank == 0)

    if is_main:
        print("Loading tokenizer ...")
    
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name if args.model_path is None else str(args.model_path),
        padding_side='left'
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    args.eot_token = tokenizer.eos_token_id

    if is_main:
        if args.dataset.upper() != "GSM8K":
            warnings.warn(f"Unknown dataset {args.dataset}; falling back to GSM8K")
        dataset = load_GSM8k_dataset(tokenizer)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)
        num_eval_steps = min(5, len(dataloader))
    else:
        dataloader = None
        num_eval_steps = 0

    num_eval_steps_t = torch.tensor([num_eval_steps], device=torch.device("cuda"), dtype=torch.int64)
    if dist_is_initialized():
        dist.broadcast(num_eval_steps_t, src=0)
    num_eval_steps = int(num_eval_steps_t.item())

    if is_main:
        print("Initializing model ...")

    pretrained_id = str(args.model_path) if args.model_path is not None else args.model_name
    attn_impl = args.attn_impl  # "flash_attention_2" | "sdpa" | "eager"

    tp_kw = {}
    if args.tp_size > 1:
        if world_size != args.tp_size:
            raise ValueError(f"--tp_size ({args.tp_size}) must equal WORLD_SIZE ({world_size}) when using TP.")
        if args.tp_plan == "auto":
            tp_kw["tp_plan"] = "auto"
        else:
            tp_kw["tp_plan"] = None

    model = AutoModelForCausalLM.from_pretrained(
        pretrained_id,
        dtype=args.dtype,
        attn_implementation=attn_impl,
        **tp_kw
    )

    if args.tp_size == 1:
        model = model.to(torch.device("cuda"))

    if is_main:
        print(f"[rank0] attn_impl={attn_impl}, tp_plan={tp_kw.get('tp_plan', None)}, world_size={world_size}")

    total_time = 0.0
    total_model_steps = 0
    num_gen_tokens = 0

    if is_main:
        data_iter = iter(dataloader)

    for step in tqdm(range(num_eval_steps), disable=not is_main):
        if is_main:
            batch = next(data_iter)
            input_ids = batch['input_ids'].to(torch.device("cuda"))
            attention_mask = batch['attention_mask'].to(torch.device("cuda"))
        else:
            input_ids = torch.empty((1, 1), dtype=torch.long, device=torch.device("cuda"))
            attention_mask = torch.empty((1, 1), dtype=torch.long, device=torch.device("cuda"))

        input_ids, attention_mask = bcast_batch_tensors(input_ids, attention_mask, src=0)

        torch.cuda.synchronize()
        start_time = time.perf_counter()

        generated_tokens, model_steps, final_output, num_total_tokens = generate(
            args, model, tokenizer, input_ids, attention_mask, is_main_process=is_main
        )

        torch.cuda.synchronize()
        end_time = time.perf_counter()

        if is_main:
            print(f"=================== Sequence {0} ===================")
            print(f"{tokenizer.decode(final_output[0, :num_total_tokens[0]].tolist())}")
            print("=" * 50)

            total_time += (end_time - start_time)
            num_gen_tokens += generated_tokens
            total_model_steps += model_steps

    if is_main:
        print(f"Total time: {total_time:.2f}s")
        print(f"Total generated tokens: {num_gen_tokens}")
        print(f"Total model steps: {total_model_steps}")
        if total_time > 0:
            print(f"Throughput: {num_gen_tokens / total_time:.2f} tokens/s")
        if total_model_steps > 0:
            print(f"Mean Generated Length: {num_gen_tokens / total_model_steps:.2f} tokens")

    if dist_is_initialized():
        dist.barrier()
        dist.destroy_process_group()


if __name__ == "__main__":
    main()
