import os
import json
import argparse
import random
import numpy as np
import pandas as pd
from datasets import load_dataset
from typing import List, Dict, Any, Tuple


def set_args():
    """Set command line arguments for the benchmark."""
    parser = argparse.ArgumentParser(description="Benchmarking script for MoESD") 
    
    parser.add_argument("--model_name", type=str, default="../models/opt-30b",
                        help="Model name")
    parser.add_argument("--speculative_model_name", type=str, default="../models/opt-350m",
                        help="Speculative model name")

    # args to be swept
    parser.add_argument("--dataset", type=str, choices=["humaneval", "mtbench"], default="humaneval",
                        help="Dataset to use for benchmarking")
    parser.add_argument("--temperature", type=float, default=0.8,
                        help="Sampling temperature (default: 0.8)")
    parser.add_argument("--seqlen", type=int, default=512,
                        help="Maximum number of tokens to generate (default: 128)")
    parser.add_argument("--tensor_parallel_size", type=int, default=2,
                        help="Tensor parallel size for target model (default: 2)")
    parser.add_argument("--num_prompts", type=int, 
                        help="Number of prompts in a batch")
    parser.add_argument("--num_speculative_tokens", type=int, default=2,
                        help="Number of speculative tokens (default: 2)")
    parser.add_argument("--seed", type=int, default=0,
                        help="Random seed (default: 0)")
    parser.add_argument("--csv_dir", type=str, default="csv_dir",
                        help="Directory to save results (default: benchmark_results)")

    # recoding policy
    parser.add_argument("--num_iter", type=int, default=10,
                        help="Number of iterations for profiling (default: 10)")
    parser.add_argument("--num_k", type=int, default=5,
                        help="Number of last runs to calculate statistics (default: 3)")

    # other engine hyperparameters
    parser.add_argument("--speculative_draft_tp_size", type=int, default=1,
                        help="Tensor parallel size for draft model (default: 1)")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.8,
                        help="GPU memory utilization (default: 0.8 for speculative decode)")
    parser.add_argument("--max_model_len", type=int, default=1024,
                        help="Maximum model length (default: 1024)")

    args = parser.parse_args()

    return args


def set_seed(seed: int = 0):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)


def load_humaneval_prompts(num_prompts: int = 100) -> List[str]:
    """Load HumanEval dataset and extract prompts."""
    dataset = load_dataset("../datasets/openai_humaneval")
    prompts = [sample["prompt"] for sample in dataset["test"]]
    if num_prompts > len(prompts):
        diff = num_prompts - len(prompts)
        prompts += prompts[:diff]
    return prompts[:num_prompts]


def load_mtbench_prompts(num_prompts: int = 100) -> List[str]:
    """Load MT-Bench dataset and extract prompts."""
    dataset = load_dataset("../datasets/mt_bench_prompts")
    # MT-Bench prompts are pairs, we'll use the first prompt from each pair
    prompts = []
    for sample in dataset["train"]:
        prompts.append(sample["prompt"][0])
        prompts.append(sample["prompt"][1])
    if num_prompts > len(prompts):
        diff = num_prompts - len(prompts)
        prompts += prompts[:diff]
    return prompts[:num_prompts]


def calculate_stats(times: List[Tuple[float, float, float]], k) -> Dict[str, float]:
    """Calculate statistics from the last k runs."""
    # Extract only the last 5 runs
    last_k_runs = times[-k:]
    
    # Calculate average prefill and decode times
    avg_prefill = sum(run[0] for run in last_k_runs) / k
    avg_decode = sum(run[1] for run in last_k_runs) / k
    avg_total_time = sum(run[2] for run in last_k_runs) / k
    
    return {
        "avg_prefill_speed": avg_prefill,
        "avg_decode_speed": avg_decode,
        "avg_total_time": avg_total_time
    }


def save_results(df: pd.DataFrame, output_file: str):
    """Save results to a CSV file."""
    df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")


