import os
import time
import torch
import logging
from typing import List, Dict, Any
import pandas as pd
from vllm import LLM, SamplingParams
from simulate_utils import (
    set_seed, 
    load_humaneval_prompts, 
    load_mtbench_prompts, 
    calculate_stats, 
    save_results, 
    set_args_sim,
)

def run_simulate_benchmark(
    model_name: str,
    dataset_name: str,
    temperature: float,
    max_tokens: int,
    tensor_parallel_size: int,
    num_prompts: int,
    seed: int,
    num_iter: int,
    num_k: int,
    gpu_memory_utilization: float,
    max_model_len: int,
    exp_num: int,
) -> List[Dict[str, Any]]:
    """Run benchmarks with varying number of prompts using the same model instance."""
    # Set seed for reproducibility
    set_seed(seed)

    # Load prompts
    if dataset_name == "humaneval":
        prompts = load_humaneval_prompts(num_prompts)
    elif dataset_name == "mtbench":
        prompts = load_mtbench_prompts(num_prompts)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Setup sampling parameters
    assert max_tokens == 1, "max_tokens should be 1 for prefill time"
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, seed=seed, ignore_eos=True)

    # Initialize the model
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tensor_parallel_size,
        gpu_memory_utilization=gpu_memory_utilization,
        max_model_len=max_model_len,
        seed=seed,
        disable_log_stats=False
    )
    
    # Run benchmark
    times = []
    # print(f" >>> Starting benchmark for {num_prompts} prompts with {num_iter} iterations")

    for i in range(num_iter):
        print(f" >>> Iteration {i+1}/{num_iter}")

        start_time = time.perf_counter()
        outputs = llm.generate(prompts, sampling_params)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        
        # record timing
        total_prompt_tokens = 0
        total_output_tokens = 0
        for o in outputs:
            total_prompt_tokens += len(o.prompt_token_ids)
            total_output_tokens += len(o.outputs[0].token_ids)
        prefill_speed = total_prompt_tokens / total_time
        decode_speed = total_output_tokens / total_time
        
        times.append((prefill_speed, decode_speed, total_time))
        
        print(f" >>> Iteration {i+1} - Prefill: {prefill_speed:.4f}tok/s ({total_prompt_tokens}toks), Decode: {decode_speed:.4f}tok/s ({total_output_tokens}toks), elapse_time: {total_time:.4f}s")
    
    # Calculate statistics from the last k runs
    stats = calculate_stats(times, num_k)
    
    # Prepare results
    result = {
        "dataset": dataset_name,
        "temperature": temperature,
        "seqlen": max_tokens,
        "tensor_parallel_size": tensor_parallel_size,
        "num_prompts": num_prompts,
        "total_prefill_tokens": total_prompt_tokens,
        "total_decode_tokens": total_output_tokens,
        "prefill_speed": stats["avg_prefill_speed"],
        "decode_speed": stats["avg_decode_speed"],
        "total_time": stats["avg_total_time"],
        "seed": seed,
        "exp_num": exp_num
    }
    
    return result

def main():
    args = set_args_sim()
    args.seqlen = 1     # prefill, only one token is generated

    # output file
    os.makedirs(args.csv_dir, exist_ok=True)
    output_file_name = os.path.join(
        args.csv_dir,
        f"{args.dataset}_temp{args.temperature}_tp{args.tensor_parallel_size}_np{args.num_prompts}_seed{args.seed}_expnum{args.n}.csv"
    )
    
    # run experiments
    results = run_simulate_benchmark(
        model_name=args.model_name,
        dataset_name=args.dataset,
        temperature=args.temperature,
        max_tokens=args.seqlen,
        tensor_parallel_size=args.tensor_parallel_size,
        num_prompts=args.num_prompts,
        seed=args.seed,
        exp_num=args.n,
        num_iter=args.num_iter,
        num_k=args.num_k,
        gpu_memory_utilization=args.gpu_memory_utilization,
        max_model_len=args.max_model_len
    )
    
    df = pd.DataFrame([results])
    save_results(df, output_file_name)
    print(f" >>> All results saved to {output_file_name}")

if __name__ == "__main__":
    main()