import os
import json
import argparse
import random
import numpy as np
import pandas as pd
from datasets import load_dataset
from typing import List, Dict, Any, Tuple

def check_num_experts_per_tok(path):
    """Check and the num_experts_per_tok value in the config.json file."""
    config_path = os.path.join(path, "config.json")
    
    try:
        # Read the existing config file
        with open(config_path, 'r') as file:
            config = json.load(file)
        
        # 8 is the standard value for Qwen
        if config["num_experts_per_tok"] != 8:
            print(f"WARNING: config has been changed without recovery!")
            # Update the num_experts_per_tok value
            config["num_experts_per_tok"] = 8
        
        # Write the updated config back to the file
        with open(config_path, 'w') as file:
            json.dump(config, file, indent=2)
            
        return True
        
    except Exception as e:
        print(f"Error updating config file: {e}")
        return False


def set_args():
    """Set command line arguments for the benchmark."""
    parser = argparse.ArgumentParser(description="Benchmarking script for MoESD") 
    
    parser.add_argument("--model_name", type=str, default="../models/Qwen2-57B-A14B-Instruct",
                        help="Model name")
    parser.add_argument("--speculative_model_name", type=str, default="../models/Qwen2-0.5B-Instruct",
                        help="Speculative model name")

    # args to be swept
    parser.add_argument("--dataset", type=str, choices=["humaneval", "mtbench"], default="humaneval",
                        help="Dataset to use for benchmarking")
    parser.add_argument("--temperature", type=float, default=0.8,
                        help="Sampling temperature (default: 0.8)")
    parser.add_argument("--seqlen", type=int, default=512,
                        help="Maximum number of tokens to generate (default: 128)")
    parser.add_argument("--tensor_parallel_size", type=int, default=2,
                        help="Tensor parallel size for target model (default: 2)")
    parser.add_argument("--num_prompts", type=int, 
                        help="Number of prompts in a batch")
    parser.add_argument("--num_speculative_tokens", type=int, default=2,
                        help="Number of speculative tokens (default: 2)")
    parser.add_argument("--seed", type=int, default=0,
                        help="Random seed (default: 0)")
    parser.add_argument("--csv_dir", type=str, default="csv_dir",
                        help="Directory to save results (default: benchmark_results)")

    # recoding policy
    parser.add_argument("--num_iter", type=int, default=10,
                        help="Number of iterations for profiling (default: 10)")
    parser.add_argument("--num_k", type=int, default=5,
                        help="Number of last runs to calculate statistics (default: 3)")

    # other engine hyperparameters
    parser.add_argument("--speculative_draft_tp_size", type=int, default=1,
                        help="Tensor parallel size for draft model (default: 1)")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.8,
                        help="GPU memory utilization (default: 0.8 for speculative decode)")
    parser.add_argument("--max_model_len", type=int, default=1024,
                        help="Maximum model length (default: 1024)")

    args = parser.parse_args()

    return args


def set_seed(seed: int = 0):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)


def load_humaneval_prompts(num_prompts: int = 100) -> List[str]:
    """Load HumanEval dataset and extract prompts."""
    dataset = load_dataset("../datasets/openai_humaneval")
    prompts = [sample["prompt"] for sample in dataset["test"]]
    if num_prompts > len(prompts):
        diff = num_prompts - len(prompts)
        prompts += prompts[:diff]
    return prompts[:num_prompts]


def load_mtbench_prompts(num_prompts: int = 100) -> List[str]:
    """Load MT-Bench dataset and extract prompts."""
    dataset = load_dataset("../datasets/mt_bench_prompts")
    # MT-Bench prompts are pairs, we'll use the first prompt from each pair
    prompts = []
    for sample in dataset["train"]:
        prompts.append(sample["prompt"][0])
        prompts.append(sample["prompt"][1])
    if num_prompts > len(prompts):
        diff = num_prompts - len(prompts)
        prompts += prompts[:diff]
    return prompts[:num_prompts]


def calculate_stats(times: List[Tuple[float, float, float]], k) -> Dict[str, float]:
    """Calculate statistics from the last k runs."""
    # Extract only the last 5 runs
    last_k_runs = times[-k:]
    
    # Calculate average prefill and decode times
    avg_prefill = sum(run[0] for run in last_k_runs) / k
    avg_decode = sum(run[1] for run in last_k_runs) / k
    avg_total_time = sum(run[2] for run in last_k_runs) / k
    
    return {
        "avg_prefill_speed": avg_prefill,
        "avg_decode_speed": avg_decode,
        "avg_total_time": avg_total_time
    }


def save_results(df: pd.DataFrame, output_file: str):
    """Save results to a CSV file."""
    df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help="target_model_path")
    args = parser.parse_args()

    check_num_experts_per_tok(args.model)

