import os
import time
import json
import argparse
from pathlib import Path
from tqdm import tqdm

import torch
import numpy as np
import pandas as pd
import torch.distributed as dist
from torch.utils.data import DataLoader
from datasets import Dataset, concatenate_datasets, load_from_disk
from transformers import AutoTokenizer

from spec_benchmark.Engine.utils import init_dist, setup_seed, sample
from spec_benchmark.Engine.backends import StandardLMBackend, MTPLMBackend
from spec_benchmark.Engine.models.base import LoRAConfig
from spec_benchmark.profiler import Profiler, ProfilerConfig, register_active_decode_profiler


def parse_args():
    parser = argparse.ArgumentParser(description='spec_benchmark Unified Benchmark Runner')

    group = parser.add_argument_group('Common Arguments')
    group.add_argument('--backend', type=str, default="standard", help='Backend name (standard, mtp).')
    group.add_argument('--model_name', type=str, required=True, help='Model name (e.g., meta-llama/Meta-Llama-3.1-8B).')
    group.add_argument('--target', type=Path, required=True, help='Path to the target model weights.')
    group.add_argument('--tokenizer_path', type=Path, required=True, help='Path to the tokenizer.')
    group.add_argument('--dataset', type=str, default="AIME2025", help='Dataset name (AIME2025, GSM8K, MATH-500, LiveMathBench, GPQA-Diamond, LiveCodeBench-lite, CodeForces).')
    group.add_argument('--num_questions_in_prompt', type=int, default=1, help='Number of questions in the prompt.')

    group.add_argument('--batch_size', type=int, default=16, help='Batch size for inference.')
    group.add_argument('--prefix_len', type=int, default=2048, help='Length of the input prompt sequence.')
    group.add_argument('--max_gen_len', type=int, default=30720, help='Maximum number of tokens to generate.')
    
    group.add_argument('--dtype', type=str, default="bfloat16", help='Data type for model execution (bfloat16, float16).')
    group.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
    group.add_argument('--compile', action='store_true', help='Enable torch.compile() for the model.')
    
    group.add_argument('--printoutput', action='store_true', help='Print the generated output of each sequence.')
    group.add_argument('--force_budget', action='store_true', help='Force the generation until the budget (max_gen_len) is exhausted.')
    group.add_argument('--num_total_runs', type=int, default=10, help='Number of total runs for profiling.')

    group = parser.add_argument_group('Distributed Training Arguments')
    group.add_argument('--rank_group', nargs='+', type=int, help='Tensor parallel group ranks for the target model.')

    group = parser.add_argument_group('Sampling Arguments')
    group.add_argument('--temperature', type=float, default=0.6, help='Temperature for sampling. 0 means greedy decoding.')
    group.add_argument('--top_p', type=float, default=0.95, help='Top-p (nucleus) sampling.')
    group.add_argument('--top_k', type=int, default=20, help='Top-k sampling.')

    group = parser.add_argument_group('MTP Arguments')
    group.add_argument('--lora_adapter', type=Path, help='Path to the lora adapter weights.')
    group.add_argument('--lora_rank', type=int, default=16, help='Rank for the lora adapter.')
    group.add_argument('--lora_alpha', type=int, default=32, help='Alpha for the lora adapter.')
    group.add_argument('--lora_bias', action='store_true', help='Bias for the lora adapter.')
    group.add_argument('--lora_use_rslora', action='store_true', help='Use RSLoRA for the lora adapter.')
    group.add_argument('--draft_length', nargs='+', type=int, default=[4], help='Number of draft tokens to generate in one step. If multiple values are provided, the draft length will be switched between them in each step.')

    group = parser.add_argument_group('Profiler')
    group.add_argument('--profiling', action='store_true', help='Enable profiling.')
    group.add_argument('--warmup_runs', type=int, default=1, help='Number of warmup runs for profiling.')
    group.add_argument('--model_profiling', action='store_true', help='Enable model profiling.')
    group.add_argument('--backend_profiling', action='store_true', help='Enable backend profiling.')
    
    group.add_argument('--profiler_out', type=str, default='profiler_out/e2e', help='Output directory for the profiler.')
    group.add_argument('--profiler_run_name', type=str, default=None, help='Name for the profiler output directory.')
    group.add_argument('--no_prof_strict_sync', dest='prof_strict_sync', action='store_false', help='Disable strict synchronization at step boundary.')
    group.add_argument('--prof_dist_barrier', action='store_true', help='Enable distributed barrier at step boundary.')
    group.add_argument('--prof_seq_len_bins', nargs='+', type=int, default=[0, 1024, 2048, 4096, 8192, 16384, 24576, 32768], help='Sequence length bins for profiling.')
    group.add_argument('--prof_seq_len_reduce', type=str, default='mean', help='How to reduce a per-batch 1D length tensor to a single scalar for binning: one of {"mean","max","p50","p90","sum"}.')
    group.set_defaults(prof_strict_sync=True)

    args = parser.parse_args()
    
    args.max_len = args.prefix_len + args.max_gen_len
    if torch.cuda.is_available():
        args.device = 'cuda'
        args.dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
    else:
        args.device = 'cpu'
        args.dtype = torch.float32

    return args

class Runner:
    def __init__(self, args, engine, tokenizer, dataloader):
        self.args = args
        self.engine = engine
        self.tokenizer = tokenizer
        self.dataloader = dataloader
        self.device = engine.device
        self.eot_token = self.tokenizer.eos_token_id

        self.loop_map = {
            "standard": self._run_standard_loop,
            "mtp": self._run_mtp_loop,
            "mtp_rectangular": self._run_mtp_rectangular_loop,
        }
        
        if self.args.backend == "mtp" and len(self.engine.draft_lengths) > 1:
            self.args.backend = "mtp_rectangular"

    def _load_batch(self, batch):
        input_ids = batch['input_ids'].to(self.device)
        query_lens = batch['attention_mask'].to(self.device).sum(dim=-1).to(torch.int32)
        return input_ids, query_lens
    

    def run(self):
        total_gen_tokens = 0
        total_model_steps = 0
        
        num_total_runs = min(self.args.num_total_runs, len(self.dataloader))
        for run_idx, batch in tqdm(enumerate(self.dataloader), total=num_total_runs):
            if run_idx >= num_total_runs:
                break
            
            input_ids, query_lens = self._load_batch(batch)
            
            generated_token_ids, num_generated_tokens, num_total_tokens, model_steps = self.loop_map[self.args.backend](input_ids, query_lens)

            if self.args.printoutput:
                print("\n" + "="*50 + f" Run {run_idx} Output " + "="*50)
                for i in range(input_ids.shape[0]):
                    print(f"########## Sequence {i} ########## (Total generated tokens: {num_generated_tokens[i]}, mean generated tokens per step: {(num_generated_tokens[i] / model_steps).item():.2f})")
                    print(self.tokenizer.decode(generated_token_ids[i, query_lens[i]:num_total_tokens[i]], skip_special_tokens=True))

            total_gen_tokens += num_generated_tokens.sum().item()
            total_model_steps += model_steps

            if self.args.rank_group and len(self.args.rank_group) > 1:
                dist.barrier()

        print(f"Total generated tokens: {total_gen_tokens}")
        print(f"Total model steps (batch_size): {total_model_steps} (batch_size: {input_ids.shape[0]})")
        print(f"➡️  Mean generated tokens: {total_gen_tokens / (total_model_steps * input_ids.shape[0]):.2f}")

    def _run_standard_loop(self, input_ids, query_lens):
        """ Standard Autoregressive Decoding """
        bsz = input_ids.shape[0]
        device = self.device
        max_len = self.args.max_len
        model_steps = 0
        max_query_len = query_lens.max()
        num_total_tokens = query_lens.clone()
        
        batch_indices = torch.arange(bsz, device=device)
        output = torch.zeros(bsz, max_len+1, device=device, dtype=torch.long)
        output[:, :max_query_len] = input_ids[:, :max_query_len]
        
        next_tokens = self.engine.encode(input_ids=input_ids, query_lens=query_lens)
        output[batch_indices, num_total_tokens] = next_tokens[:, 0]
        num_total_tokens += 1
        model_steps += 1

        if hasattr(self, 'prof') and self.prof:
            self.prof.begin_run(bsz=bsz, label="decode")
        
        terminal = False
        while num_total_tokens.max() < max_len and not terminal:
            # Define the profiler context
            profiler_ctx = self.prof.time_decode() if hasattr(self, 'prof') and self.prof else None
            if profiler_ctx:
                profiler_ctx.__enter__()
                self.prof.set_step_seq_len(num_total_tokens)
            
            try:
                next_tokens = self.engine.decode(input_ids=next_tokens)
                if profiler_ctx: self.prof.set_step_tokens(bsz)

                output[batch_indices, num_total_tokens] = next_tokens[:, 0]
                num_total_tokens += 1
                model_steps += 1

            finally:
                if profiler_ctx: profiler_ctx.__exit__(None, None, None)
            
            if (next_tokens[:, 0] == self.eot_token).any(): terminal = True
        
        if hasattr(self, 'prof') and self.prof: self.prof.end_run()
        
        num_generated_tokens = num_total_tokens - query_lens
        return output, num_generated_tokens, num_total_tokens, model_steps


    def _run_mtp_loop(self, input_ids, query_lens):
        """ Self-Speculative Decoding with MTP """
        
        bsz = input_ids.shape[0]
        k = self.engine.draft_lengths[0]
        device = self.device
        greedy = self.args.temperature == 0.0
        max_len = self.args.max_len
        model_steps = 0
        max_query_len = query_lens.max()
        num_total_tokens = query_lens.clone()
        
        batch_indices = torch.arange(bsz, device=device)
        batch_indices_2d = torch.arange(bsz, device=device)[:, None]

        output = torch.zeros(bsz, max_len+k+1, device=device, dtype=torch.long)
        output[:, :max_query_len] = input_ids[:, :max_query_len]
        
        tokens_buffer = torch.zeros(bsz, 1+k, device=device, dtype=torch.long) # one is the prev step's next token, k for draft tokens
        next_tokens = self.engine.encode(input_ids=input_ids, query_lens=query_lens)
        output[batch_indices, num_total_tokens] = next_tokens[:, 0]
        num_total_tokens += 1
        model_steps += 1

        first_draft = True
        terminal = False
        if hasattr(self, 'prof') and self.prof:
            self.prof.begin_run(bsz=bsz, label="decode")
        
        while num_total_tokens.max() < max_len and not terminal:
            # Define the profiler context
            profiler_ctx = self.prof.time_decode() if hasattr(self, 'prof') and self.prof else None
            if profiler_ctx:
                profiler_ctx.__enter__()
                self.prof.set_step_seq_len(num_total_tokens)
            
            try:
                if first_draft:
                    input_ids, gate_mask = self.engine.interleave_mask_tokens(input_ids=next_tokens) # [bsz, 1+k], [bsz, 1+k]
                    logits, hidden_states = self.engine.draft(input_ids=input_ids, gate_mask=gate_mask) # [bsz, 1+k, vocab_size], [bsz, 1+k, hidden_size]
                    
                    # tokens_buffer[:, 0] : next_tokens / tokens_buffer[:, 1:] : draft_tokens
                    tokens_buffer[:, :1] = sample(logits[:, 0], top_p=self.args.top_p, top_k=self.args.top_k, temperature=self.args.temperature)
                    tokens_buffer[:, 1:] = self.engine.sampler_draft(tokens_buffer[:, :1], hidden_states[:, 1:]) # [bsz, k], [bsz, k, vocab_size]

                    # register accept_nums to Profiler
                    if profiler_ctx: self.prof.set_step_tokens(bsz)
                                    
                    output[batch_indices, num_total_tokens] = tokens_buffer[:, 0]
                    num_total_tokens += 1
                    first_draft = False
                else:
                    assert torch.all(num_total_tokens-1 == self.engine.cachelens), "The number of total tokens must be equal to the cachelens+1."
                    input_ids, gate_mask = self.engine.interleave_mask_tokens(input_ids=tokens_buffer) # [bsz, (k+1)^2], [bsz, (k+1)^2]
                    target_logits, hidden_states = self.engine.draft_and_verify(input_ids=input_ids, gate_mask=gate_mask) # [bsz, (k+1)^2, vocab_size], [bsz, (k+1)^2, hidden_size]
                    
                    if greedy:
                        bonus_tokens, accept_nums, eot_accepted = self.engine.evaluate_posterior(tokens_buffer[:, 1:], target_logits.argmax(dim=-1), self.eot_token)
                    else:
                        bonus_tokens, accept_nums, eot_accepted = self.engine.evaluate_posterior(tokens_buffer[:, 1:], target_logits, self.eot_token)
                    self.engine.collate_accepted_kv_cache(accept_nums, num_total_tokens-1)

                    # register accept_nums to Profiler
                    if profiler_ctx: self.prof.set_step_tokens(int(accept_nums.sum().item()))
                    
                    # Write the accepted tokens to the output
                    write_indices = num_total_tokens[:, None] + torch.arange(k + 1, device=device)[None, :] # [B, k+1]
                    output[batch_indices_2d, write_indices] = tokens_buffer
                    num_total_tokens += accept_nums
                    
                    # Prepare for next iteration
                    tokens_buffer[:, :1] = bonus_tokens
                    hidden_states = hidden_states.reshape(bsz, k+1, k+1, -1) # [bsz, k+1, k+1, hidden_size]
                    selected_hidden_states = hidden_states[batch_indices, accept_nums-1, 1:, :] # [bsz, k, hidden_size]
                    tokens_buffer[:, 1:] = self.engine.sampler_draft(tokens_buffer[:, :1], selected_hidden_states) # [bsz, k], [bsz, k, vocab_size]
                    if (eot_accepted).any(): terminal = True

            finally:
                if profiler_ctx: profiler_ctx.__exit__(None, None, None)
                    
            model_steps += 1
            if (tokens_buffer[:, 0] == self.eot_token).any(): terminal = True
        
        if hasattr(self, 'prof') and self.prof: self.prof.end_run()
        
        num_generated_tokens = num_total_tokens - query_lens
        return output, num_generated_tokens, num_total_tokens, model_steps


def load_dataset(tokenizer, dataset_name, seq_len=256, num_samples=None, num_questions_in_prompt=1):

    DEFAULT_DATASET_PATH_DICT = {}

    PROMPT_KEY_DICT = {
        "GSM8K": "question",
        "MATH-500": "problem",
        "AIME2025": "question",
        "LiveMathBench": "question",
        "LiveCodeBench": "question_content",
        "CodeForces": "prompt",
        "GPQA-Diamond": "Question",
    }

    if dataset_name not in PROMPT_KEY_DICT:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    if dataset_name == "AIME2025":
        ds = load_from_disk(DEFAULT_DATASET_PATH_DICT[dataset_name+"-I"])
        ds_2 = load_from_disk(DEFAULT_DATASET_PATH_DICT[dataset_name+"-II"])
        ds = concatenate_datasets([ds, ds_2])
    else:
        ds = load_from_disk(DEFAULT_DATASET_PATH_DICT[dataset_name])
    prompt_key = PROMPT_KEY_DICT[dataset_name]
    
    def tokenize_fn(examples):
        def apply_chat_template(tokenizer, text: str) -> str:
            templated = tokenizer.apply_chat_template(
                [{"role": "user", "content": text}],
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=True,
            )
            return templated
        texts = [apply_chat_template(tokenizer, q) for q in examples[prompt_key]]
        return tokenizer(texts, return_tensors="pt", max_length=seq_len, padding="max_length", truncation=True)
    
    if num_questions_in_prompt > 1:
        all_questions = ds[prompt_key]
        new_prompts = []
        for i in range(0, len(all_questions), num_questions_in_prompt):
            new_prompts.append("\n\n".join(
                [f"Question {j+1}: {q}" for j, q in enumerate(all_questions[i:i+num_questions_in_prompt])]
                )
            )
        ds = Dataset.from_dict({f"{prompt_key}": new_prompts})
    
    ds = ds.map(tokenize_fn, batched=True, remove_columns=[prompt_key])
    ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

    if num_samples is not None:
        n = len(ds)
        num_samples = int(num_samples)
        
        if num_samples == n:
            return ds
        elif num_samples < n:
            ds = ds.select(range(num_samples))
        else:
            times, rem = divmod(num_samples, n)
            parts = []
            if times > 0:
                parts.extend([ds] * times)         # repeat full copies
            if rem > 0:
                parts.append(ds.select(range(rem))) # tail
            ds = concatenate_datasets(parts)

        ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

    return ds


def setup_engine(engine, args, process_group):
    load_params = {
        'model_name': args.model_name, 'target_checkpoint': args.target,
        'use_tp': len(args.rank_group) > 1 if args.rank_group else False,
        'rank_group': args.rank_group, 'group': process_group
    }

    if args.backend == "mtp":
        lora_config = LoRAConfig(rank=args.lora_rank, alpha=args.lora_alpha, lora_bias=args.lora_bias, use_rslora=args.lora_use_rslora)
        load_params.update({'lora_checkpoint': args.lora_adapter, 'lora_config': lora_config})
    engine.load_model(**load_params)

    if args.compile:
        engine.compile()

    prefill_chunk_size = {2: 128, 4: 128, 8: 128, 16: 128, 32: 128, 64: 64, 128: 16, 256: 8}
    cache_params = {'max_batch_size': args.batch_size, 'max_seq_length': args.max_len, 'page_size': 16,
                    'prefill_chunk_size': prefill_chunk_size[args.batch_size]}
    engine.setup_caches(**cache_params)
    engine.setup_sampling_params(temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)


def main():
    args = parse_args()
    
    rank, process_group = 0, None
    use_tp = len(args.rank_group) > 1 if args.rank_group else False
    if use_tp:
        rank, process_group = init_dist()
        
        if rank != args.rank_group[0]:
            import sys
            os.makedirs("logs/system_logs", exist_ok=True)
            log_file = open(f"logs/system_logs/rank_{rank}.log", 'w', buffering=1)
            sys.stdout = log_file
            sys.stderr = log_file
    
    try:
        setup_seed(args.seed)
        print(f"Initializing end-to-end generation with FlashInfer...")
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right", local_files_only=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Prepare the dataset
        dataset = load_dataset(tokenizer=tokenizer, dataset_name=args.dataset, seq_len=args.prefix_len, num_samples=args.num_total_runs*args.batch_size, num_questions_in_prompt=args.num_questions_in_prompt)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)

        # Initialize the engine
        if args.backend == "standard":
            engine = StandardLMBackend(dtype=args.dtype, device=args.device)
        elif args.backend == "mtp":
            engine = MTPLMBackend(dtype=args.dtype, device=args.device, draft_length=args.draft_length, tokenizer=tokenizer)
        else:
            raise ValueError(f"Unsupported backend: {args.backend}")
        setup_engine(engine, args, process_group)

        # Initialize the runner
        runner = Runner(args, engine, tokenizer, dataloader)
        
        if args.profiling:
            # Initialize the profiler
            prof_cfg = ProfilerConfig(
                output_dir=args.profiler_out,
                warmup_runs=args.warmup_runs,
                num_total_runs=args.num_total_runs,
                model_profiling=args.model_profiling,
                backend_profiling=args.backend_profiling,
                run_name=args.profiler_run_name,
                strict_sync=args.prof_strict_sync,
                dist_barrier=args.prof_dist_barrier,
                kv_bins=args.prof_seq_len_bins,
                kv_len_reduce=args.prof_seq_len_reduce,
            )

            prof = Profiler(runner_args=args, cfg=prof_cfg)
            register_active_decode_profiler(prof)
            if args.model_profiling:
                prof.attach_model(engine.model, use_gated_lora=args.backend == "mtp")
            if args.backend_profiling:
                prof.attach_backend(engine)
            runner.prof = prof

        # Run the benchmark
        runner.run()
        
        if args.profiling:
            # Save the profiling results and unregister the profiler
            prof.save_config()
            prof.save_all()
            register_active_decode_profiler(None)

    except Exception as e:
        print(f"[Rank {rank}] Exception occurred: {e}")
        import traceback
        traceback.print_exc()

    finally:
        if use_tp and dist.is_initialized():
            print(f"[Rank {rank}] Cleaning up distributed process group ...")
            dist.destroy_process_group()

if __name__ == "__main__":
    main()