import os
import json
import argparse
from pathlib import Path
from tqdm import tqdm

import torch
import torch.distributed as dist
from datasets import Dataset
from transformers import AutoTokenizer

from spec_benchmark.Engine.utils import init_dist, setup_seed, sample
from spec_benchmark.Engine.backends import StandardLMBackend, MTPLMBackend
from spec_benchmark.Engine.models.base import LoRAConfig
from spec_benchmark.profiler import Profiler, ProfilerConfig, register_active_decode_profiler
from spec_benchmark.benchmark_utils import BatchSampler


def parse_args():
    parser = argparse.ArgumentParser(description='spec_benchmark Unified Benchmark Runner')

    group = parser.add_argument_group('Common Arguments')
    group.add_argument('--backend', type=str, default="standard", help='Backend name (standard, mtp).')
    group.add_argument('--model_name', type=str, required=True, help='Model name (e.g., meta-llama/Meta-Llama-3.1-8B).')
    group.add_argument('--target', type=Path, required=True, help='Path to the target model weights.')
    group.add_argument('--tokenizer_path', type=Path, required=True, help='Path to the tokenizer.')
    group.add_argument('--dataset', type=str, default="AIME2025", help='Dataset name (AIME2025, GSM8K, MATH-500, LiveMathBench, GPQA-Diamond, LiveCodeBench-lite, CodeForces).')

    group.add_argument('--batch_size', type=int, default=16, help='Batch size for inference.')
    group.add_argument('--prefix_len_list', nargs='+', type=int, default=[1024, 2048, 4096, 8192, 12288, 16384, 20480, 24576, 28672], help='List of input prompt sequence lengths.')
    group.add_argument('--max_gen_len', type=int, default=128, help='Maximum number of tokens to generate per sequence.')
    
    group.add_argument('--dtype', type=str, default="bfloat16", help='Data type for model execution (bfloat16, float16).')
    group.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
    group.add_argument('--compile', action='store_true', help='Enable torch.compile() for the model.')
    
    group.add_argument('--printoutput', action='store_true', help='Print the generated output of each sequence.')
    group.add_argument('--force_budget', action='store_true', help='Force the generation until the budget (max_gen_len) is exhausted.')
    group.add_argument('--num_total_runs', type=int, default=10, help='Number of total runs for profiling.')

    group = parser.add_argument_group('Distributed Training Arguments')
    group.add_argument('--rank_group', nargs='+', type=int, help='Tensor parallel group ranks for the target model.')

    group = parser.add_argument_group('Sampling Arguments')
    group.add_argument('--temperature', type=float, default=0.6, help='Temperature for sampling. 0 means greedy decoding.')
    group.add_argument('--top_p', type=float, default=0.95, help='Top-p (nucleus) sampling.')
    group.add_argument('--top_k', type=int, default=20, help='Top-k sampling.')

    group = parser.add_argument_group('MTP Arguments')
    group.add_argument('--lora_adapter', type=Path, help='Path to the lora adapter weights.')
    group.add_argument('--lora_rank', type=int, default=16, help='Rank for the lora adapter.')
    group.add_argument('--lora_alpha', type=int, default=32, help='Alpha for the lora adapter.')
    group.add_argument('--lora_bias', action='store_true', help='Bias for the lora adapter.')
    group.add_argument('--lora_use_rslora', action='store_true', help='Use RSLoRA for the lora adapter.')
    group.add_argument('--draft_length', nargs='+', type=int, default=[4], help='Number of draft tokens to generate in one step. If multiple values are provided, the draft length will be switched between them in each step.')

    group = parser.add_argument_group('Profiler')
    group.add_argument('--profiling', action='store_true', help='Enable profiling.')
    group.add_argument('--warmup_runs', type=int, default=1, help='Number of warmup runs for profiling.')
    group.add_argument('--model_profiling', action='store_true', help='Enable model profiling.')
    group.add_argument('--backend_profiling', action='store_true', help='Enable backend profiling.')
    
    group.add_argument('--profiler_out', type=str, default='profiler_out/benchmark', help='Output directory for the profiler.')
    group.add_argument('--profiler_run_name', type=str, default=None, help='Name for the profiler output directory.')
    group.add_argument('--no_prof_strict_sync', dest='prof_strict_sync', action='store_false', help='Disable strict synchronization at step boundary.')
    group.add_argument('--prof_dist_barrier', action='store_true', help='Enable distributed barrier at step boundary.')
    group.add_argument('--prof_seq_len_bins', nargs='+', type=int, default=[0, 1024, 2048, 4096, 8192, 16384, 24576, 32768], help='Sequence length bins for profiling.')
    group.add_argument('--prof_seq_len_reduce', type=str, default='mean', help='How to reduce a per-batch 1D length tensor to a single scalar for binning: one of {"mean","max","p50","p90","sum"}.')
    group.set_defaults(prof_strict_sync=True)

    args = parser.parse_args()
    if torch.cuda.is_available():
        args.device = 'cuda'
        args.dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
    else:
        args.device = 'cpu'
        args.dtype = torch.float32

    return args

class Runner:
    def __init__(self, args, engine, tokenizer, batch_sampler):
        self.args = args
        self.engine = engine
        self.tokenizer = tokenizer
        self.batch_sampler = batch_sampler
        self.device = engine.device
        self.eot_token = self.tokenizer.eos_token_id

        self.loop_map = {
            "standard": self._run_standard_loop,
            "mtp": self._run_mtp_loop,
        }
    

    def run(self):
        total_gen_tokens = 0
        total_model_steps = 0
        
        for run_idx in tqdm(range(self.args.num_total_runs), total=self.args.num_total_runs):
            input_ids = self.batch_sampler.sample_batch().to(self.device)
            query_lens = torch.ones(self.args.batch_size, device=self.device, dtype=torch.int32) * input_ids.shape[1]
            
            generated_token_ids, num_generated_tokens, num_total_tokens, model_steps = self.loop_map[self.args.backend](input_ids, query_lens)

            if self.args.printoutput:
                print("\n" + "="*50 + f" Run {run_idx} Output " + "="*50)
                for i in range(input_ids.shape[0]):
                    print(f"########## Sequence {i} ########## (Total generated tokens: {num_generated_tokens[i]}, mean generated tokens per step: {(num_generated_tokens[i] / model_steps).item():.2f})")
                    print(self.tokenizer.decode(generated_token_ids[i, query_lens[i]:num_total_tokens[i]], skip_special_tokens=True))

            total_gen_tokens += num_generated_tokens.sum().item()
            total_model_steps += model_steps

            if self.args.rank_group and len(self.args.rank_group) > 1:
                dist.barrier()

        print(f"Total generated tokens: {total_gen_tokens}")
        print(f"Total model steps (batch_size): {total_model_steps} (batch_size: {input_ids.shape[0]})")
        print(f"➡️  Mean generated tokens: {total_gen_tokens / (total_model_steps * input_ids.shape[0]):.2f}")

    def _run_standard_loop(self, input_ids, query_lens):
        """ Standard Autoregressive Decoding """
        bsz = input_ids.shape[0]
        device = self.device
        max_len = self.args.max_len
        model_steps = 0
        max_query_len = query_lens.max()
        num_total_tokens = query_lens.clone()
        
        batch_indices = torch.arange(bsz, device=device)
        output = torch.zeros(bsz, max_len+1, device=device, dtype=torch.long)
        output[:, :max_query_len] = input_ids[:, :max_query_len]
        
        next_tokens = self.engine.encode(input_ids=input_ids, query_lens=query_lens)
        output[batch_indices, num_total_tokens] = next_tokens[:, 0]
        num_total_tokens += 1
        model_steps += 1

        if hasattr(self, 'prof') and self.prof:
            self.prof.begin_run(bsz=bsz, label="decode")
        
        terminal = False
        while num_total_tokens.max() < max_len and not terminal:
            # Define the profiler context
            profiler_ctx = self.prof.time_decode() if hasattr(self, 'prof') and self.prof else None
            if profiler_ctx:
                profiler_ctx.__enter__()
                self.prof.set_step_seq_len(num_total_tokens)
            
            try:
                next_tokens = self.engine.decode(input_ids=next_tokens)
                if profiler_ctx: self.prof.set_step_tokens(bsz)

                output[batch_indices, num_total_tokens] = next_tokens[:, 0]
                num_total_tokens += 1
                model_steps += 1

            finally:
                if profiler_ctx: profiler_ctx.__exit__(None, None, None)
            
            if (next_tokens[:, 0] == self.eot_token).any(): terminal = True
        
        if hasattr(self, 'prof') and self.prof: self.prof.end_run()
        
        num_generated_tokens = num_total_tokens - query_lens
        return output, num_generated_tokens, num_total_tokens, model_steps


    def _run_mtp_loop(self, input_ids, query_lens):
        """ Self-Speculative Decoding with MTP """
        
        bsz = input_ids.shape[0]
        k = self.engine.draft_lengths[0]
        device = self.device
        greedy = self.args.temperature == 0.0
        max_total_tokens = self.args.max_total_tokens
        model_steps = 0
        max_query_len = query_lens.max()
        num_total_tokens = query_lens.clone()
        
        batch_indices = torch.arange(bsz, device=device)
        batch_indices_2d = torch.arange(bsz, device=device)[:, None]

        output = torch.zeros(bsz, self.args.max_len+k+1, device=device, dtype=torch.long)
        output[:, :max_query_len] = input_ids[:, :max_query_len]
        
        tokens_buffer = torch.zeros(bsz, 1+k, device=device, dtype=torch.long) # one is the prev step's next token, k for draft tokens
        next_tokens = self.engine.encode(input_ids=input_ids, query_lens=query_lens)
        output[batch_indices, num_total_tokens] = next_tokens[:, 0]
        num_total_tokens += 1
        model_steps += 1

        first_draft = True
        terminal = False
        if hasattr(self, 'prof') and self.prof:
            self.prof.begin_run(bsz=bsz, label="decode")
        
        while (num_total_tokens - query_lens).sum() < max_total_tokens and not terminal:
            # Define the profiler context
            profiler_ctx = self.prof.time_decode() if hasattr(self, 'prof') and self.prof else None
            if profiler_ctx:
                profiler_ctx.__enter__()
                self.prof.set_step_seq_len(num_total_tokens)
            
            try:
                if first_draft:
                    input_ids, gate_mask = self.engine.interleave_mask_tokens(input_ids=next_tokens) # [bsz, 1+k], [bsz, 1+k]
                    logits, hidden_states = self.engine.draft(input_ids=input_ids, gate_mask=gate_mask) # [bsz, 1+k, vocab_size], [bsz, 1+k, hidden_size]
                    
                    # tokens_buffer[:, 0] : next_tokens / tokens_buffer[:, 1:] : draft_tokens
                    tokens_buffer[:, :1] = sample(logits[:, 0], top_p=self.args.top_p, top_k=self.args.top_k, temperature=self.args.temperature)
                    tokens_buffer[:, 1:] = self.engine.sampler_draft(tokens_buffer[:, :1], hidden_states[:, 1:]) # [bsz, k], [bsz, k, vocab_size]

                    # register accept_nums to Profiler
                    if profiler_ctx: self.prof.set_step_tokens(bsz)
                                    
                    output[batch_indices, num_total_tokens] = tokens_buffer[:, 0]
                    num_total_tokens += 1
                    first_draft = False
                else:
                    assert torch.all(num_total_tokens-1 == self.engine.cachelens), "The number of total tokens must be equal to the cachelens+1."
                    input_ids, gate_mask = self.engine.interleave_mask_tokens(input_ids=tokens_buffer) # [bsz, (k+1)^2], [bsz, (k+1)^2]
                    target_logits, hidden_states = self.engine.draft_and_verify(input_ids=input_ids, gate_mask=gate_mask) # [bsz, (k+1)^2, vocab_size], [bsz, (k+1)^2, hidden_size]
                    
                    if greedy:
                        bonus_tokens, accept_nums, eot_accepted = self.engine.evaluate_posterior(tokens_buffer[:, 1:], target_logits.argmax(dim=-1), self.eot_token)
                    else:
                        bonus_tokens, accept_nums, eot_accepted = self.engine.evaluate_posterior(tokens_buffer[:, 1:], target_logits, self.eot_token)
                    self.engine.collate_accepted_kv_cache(accept_nums, num_total_tokens-1)

                    # register accept_nums to Profiler
                    if profiler_ctx: self.prof.set_step_tokens(int(accept_nums.sum().item()))
                    
                    # Write the accepted tokens to the output
                    write_indices = num_total_tokens[:, None] + torch.arange(k + 1, device=device)[None, :] # [B, k+1]
                    output[batch_indices_2d, write_indices] = tokens_buffer
                    num_total_tokens += accept_nums
                    
                    # Prepare for next iteration
                    tokens_buffer[:, :1] = bonus_tokens
                    hidden_states = hidden_states.reshape(bsz, k+1, k+1, -1) # [bsz, k+1, k+1, hidden_size]
                    selected_hidden_states = hidden_states[batch_indices, accept_nums-1, 1:, :] # [bsz, k, hidden_size]
                    tokens_buffer[:, 1:] = self.engine.sampler_draft(tokens_buffer[:, :1], selected_hidden_states) # [bsz, k], [bsz, k, vocab_size]
                    if (eot_accepted).any(): terminal = True

            finally:
                if profiler_ctx: profiler_ctx.__exit__(None, None, None)
                    
            model_steps += 1
            if (tokens_buffer[:, 0] == self.eot_token).any(): terminal = True
        
        if hasattr(self, 'prof') and self.prof: self.prof.end_run()
        
        num_generated_tokens = num_total_tokens - query_lens
        return output, num_generated_tokens, num_total_tokens, model_steps


def setup_engine(engine, args, process_group):
    load_params = {
        'model_name': args.model_name, 'target_checkpoint': args.target,
        'use_tp': len(args.rank_group) > 1 if args.rank_group else False,
        'rank_group': args.rank_group, 'group': process_group
    }

    if args.backend == "mtp":
        lora_config = LoRAConfig(rank=args.lora_rank, alpha=args.lora_alpha, lora_bias=args.lora_bias, use_rslora=args.lora_use_rslora)
        load_params.update({'lora_checkpoint': args.lora_adapter, 'lora_config': lora_config})
    engine.load_model(**load_params)

    if args.compile:
        engine.compile()
    prefill_chunk_size = {2: 128, 4: 128, 8: 128, 16: 128, 32: 128, 64: 128, 128: 64, 256: 32}
    cache_params = {'max_batch_size': args.batch_size, 'max_seq_length': args.max_len, 'page_size': 16, 'prefill_chunk_size': prefill_chunk_size[args.batch_size]}
    engine.setup_caches(**cache_params)
    engine.setup_sampling_params(temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)


def load_benchmark_dataset(tokenizer, dataset_name):
    filepath = Path(f"{DEFAULT_BASE_DATASET_PATH}/{dataset_name}_1000_responses.json")
    with Path(filepath).open("r", encoding="utf-8") as f:
        data = json.load(f)
    return Dataset.from_list(data.get("results", []))


def main():
    args = parse_args()
    
    rank, process_group = 0, None
    use_tp = len(args.rank_group) > 1 if args.rank_group else False
    if use_tp:
        rank, process_group = init_dist()
        
        if rank != args.rank_group[0]:
            import sys
            os.makedirs("logs/system_logs", exist_ok=True)
            log_file = open(f"logs/system_logs/rank_{rank}.log", 'w', buffering=1)
            sys.stdout = log_file
            sys.stderr = log_file
    
    try:
        setup_seed(args.seed)
        print(f"Initializing benchmark ...")
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right", local_files_only=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Prepare the dataset
        dataset = load_benchmark_dataset(tokenizer=tokenizer, dataset_name=args.dataset)

        # Initialize the engine
        if args.backend == "standard":
            engine = StandardLMBackend(dtype=args.dtype, device=args.device)
        elif args.backend == "mtp":
            engine = MTPLMBackend(dtype=args.dtype, device=args.device, draft_length=args.draft_length, tokenizer=tokenizer)
        else:
            raise ValueError(f"Unsupported backend: {args.backend}")

        # Initialize the Caches with the maximum prefix length
        print(f"Initializing the Caches with the maximum prefix length: {max(args.prefix_len_list)}")
        args.max_len = max(args.prefix_len_list) + (args.max_gen_len * 3)
        setup_engine(engine, args, process_group)

        for prefix_len in args.prefix_len_list:
            args.prefix_len = prefix_len
            if args.backend == "mtp":
                args.max_total_tokens = args.batch_size * args.max_gen_len
                args.max_len = args.prefix_len + (args.max_gen_len * 3) # Give some margin for the fastest sequence
            else:
                args.max_len = args.prefix_len + args.max_gen_len
            batch_sampler = BatchSampler(dataset=dataset, tokenizer=tokenizer, 
                                         batch_size=args.batch_size, seq_len=args.prefix_len, margin_before_eos=5 * args.max_gen_len,
                                         pretokenize=True, seed=args.seed)

            # Initialize the runner
            runner = Runner(args, engine, tokenizer, batch_sampler)
            
            if args.profiling:
                # Initialize the profiler
                prof_cfg = ProfilerConfig(
                    output_dir=args.profiler_out,
                    warmup_runs=args.warmup_runs,
                    num_total_runs=args.num_total_runs,
                    model_profiling=args.model_profiling,
                    backend_profiling=args.backend_profiling,
                    run_name=args.profiler_run_name,
                    strict_sync=args.prof_strict_sync,
                    dist_barrier=args.prof_dist_barrier,
                    kv_bins=args.prof_seq_len_bins,
                    kv_len_reduce=args.prof_seq_len_reduce,
                )

                prof = Profiler(runner_args=args, cfg=prof_cfg)
                register_active_decode_profiler(prof)
                if args.model_profiling:
                    prof.attach_model(engine.model, use_gated_lora=args.backend == "mtp")
                if args.backend_profiling:
                    prof.attach_backend(engine)
                runner.prof = prof

            # Run the benchmark
            runner.run()
            
            if args.profiling:
                # Save the profiling results and unregister the profiler
                prof.save_config()
                prof.save_all()
                register_active_decode_profiler(None)

    except Exception as e:
        print(f"[Rank {rank}] Exception occurred: {e}")
        import traceback
        traceback.print_exc()

    finally:
        if use_tp and dist.is_initialized():
            print(f"[Rank {rank}] Cleaning up distributed process group ...")
            dist.destroy_process_group()

if __name__ == "__main__":
    main()