import os
import time
import json
import argparse
import logging
from typing import List, Dict, Any
import pandas as pd
from vllm import LLM, SamplingParams
from benchmark_utils import (
    set_seed, 
    load_humaneval_prompts, 
    load_mtbench_prompts, 
    calculate_stats, 
    save_results, 
    set_args,
)

def run_benchmark(
    model_name: str,
    speculative_model_name: str,
    dataset_name: str,
    temperature: float,
    max_tokens: int,
    tensor_parallel_size: int,
    num_prompts: int,
    num_speculative_tokens: int,
    gpu_memory_utilization: float,
    speculative_draft_tp_size: int,
    max_model_len: int,
    seed: int,
    num_iter: int,
    num_k: int
) -> Dict[str, Any]:
    """Run benchmark with specified parameters."""
    # Set seed for reproducibility
    set_seed(seed)
    
    # Load prompts
    if dataset_name == "humaneval":
        prompts = load_humaneval_prompts(num_prompts)
    elif dataset_name == "mtbench":
        prompts = load_mtbench_prompts(num_prompts)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Setup sampling parameters
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, seed=seed, ignore_eos=True)

    # Initialize the model with speculative decoding
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tensor_parallel_size,
        speculative_model=speculative_model_name,
        speculative_draft_tensor_parallel_size=speculative_draft_tp_size,
        num_speculative_tokens=num_speculative_tokens,
        gpu_memory_utilization=gpu_memory_utilization,
        max_model_len=max_model_len,
        seed=seed,
        disable_log_stats=False
    )
    
    # Run benchmark
    times = []

    # 10 runs for stable results
    for i in range(num_iter):
        print(f" >>> Iteration {i+1}/10")

        start_time = time.perf_counter()
        outputs = llm.generate(prompts, sampling_params)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        
        # Extract timing information
        total_prompt_tokens = 0
        total_output_tokens = 0
        for o in outputs:
            total_prompt_tokens += len(o.prompt_token_ids)
            total_output_tokens += len(o.outputs[0].token_ids)
        prefill_speed = total_prompt_tokens / total_time    # toks/s
        decode_speed = total_output_tokens / total_time     # toks/s
        
        times.append((prefill_speed, decode_speed, total_time))
        
        print(f" >>> Iteration {i+1} - Prefill: {prefill_speed:.4f}tok/s ({total_prompt_tokens}toks), Decode: {decode_speed:.4f}tok/s ({total_output_tokens}toks), elapse_time: {total_time:.4f}s")
    
    # Calculate statistics from the last k runs
    stats = calculate_stats(times, num_k)
    
    # Prepare results
    results = {
        "dataset": dataset_name,
        "temperature": temperature,
        "seqlen": max_tokens,
        "tensor_parallel_size": tensor_parallel_size,
        "num_prompts": num_prompts,
        "num_speculative_tokens": num_speculative_tokens,
        "total_prefill_tokens": total_prompt_tokens,
        "total_decode_tokens": total_output_tokens,
        "prefill_speed": stats["avg_prefill_speed"],
        "decode_speed": stats["avg_decode_speed"],
        "total_time": stats["avg_total_time"],
        "seed": seed
    }
    

    return results

def main():
        
    args = set_args()

    # get file names
    os.makedirs(args.csv_dir, exist_ok=True)
    output_file_name = os.path.join(
        args.csv_dir,
        f"{args.dataset}_temp{args.temperature}_tp{args.tensor_parallel_size}_np{args.num_prompts}_st{args.num_speculative_tokens}_seed{args.seed}.csv"
    )
    
    # Run benchmark
    result = run_benchmark(
        model_name=args.model_name,
        speculative_model_name=args.speculative_model_name,
        dataset_name=args.dataset,
        temperature=args.temperature,
        max_tokens=args.seqlen,
        tensor_parallel_size=args.tensor_parallel_size,
        num_prompts=args.num_prompts,
        num_speculative_tokens=args.num_speculative_tokens,
        seed=args.seed,
        speculative_draft_tp_size=1,
        gpu_memory_utilization=args.gpu_memory_utilization,
        max_model_len=args.max_model_len,
        num_iter=args.num_iter,
        num_k=args.num_k,
    )
    
    df = pd.DataFrame([result])
    save_results(df, output_file_name)

if __name__ == "__main__":
    main()