import json
import argparse
import random
import os
import numpy as np
import pandas as pd
from datasets import load_dataset
from typing import List, Dict, Any, Tuple

def update_num_experts_per_tok(new_value, model):
    """
    Updates the 'num_experts_per_tok' value in the config.json file.
    Args: new_value (int): The new value to set for 'num_experts_per_tok'
    Returns: True if update was successful, False otherwise
    """
    config_path = os.path.join(model, "config.json")
    
    try:
        # Read the existing config file
        with open(config_path, 'r') as file:
            config = json.load(file)
        
        # Update the num_experts_per_tok value
        config["num_experts_per_tok"] = new_value
        
        # Write the updated config back to the file
        with open(config_path, 'w') as file:
            json.dump(config, file, indent=2)
            
        print(f"Successfully updated 'num_experts_per_tok' to {new_value}")
        return True
        
    except Exception as e:
        print(f"Error updating config file: {e}")
        return False
    

def set_args_sim():
    """Set command line arguments for the benchmark."""
    parser = argparse.ArgumentParser(description="Simulating script for MoESD")

    # model name
    parser.add_argument("--model_name", type=str, default="../models/Qwen2-57B-A14B-Instruct",
                        help="Model name")
    parser.add_argument("--speculative_model_name", type=str, default="../models/Qwen2-0.5B-Instruct",
                        help="Speculative model name")
    
    # args to be swept (namely, appeared in the shell script)
    parser.add_argument("--dataset", type=str, choices=["humaneval", "mtbench"], default="humaneval",
                        help="Dataset to use for benchmarking")
    parser.add_argument("--temperature", type=float, default=0,
                        help="Sampling temperature (default: 0)")
    parser.add_argument("--seqlen", type=int, default=512,
                        help="Maximum number of tokens to generate (default: 512)")
    parser.add_argument("--tensor_parallel_size", type=int, default=2,
                        help="Tensor parallel size for target model (default: 2)")
    # --n for different number of experts per token
    parser.add_argument("--n", type=int, help="num_experts_per_tok")
    parser.add_argument("--num_prompts", type=int, 
                        help="Number of prompts to use")
    parser.add_argument("--num_speculative_tokens", type=int, default=2,
                        help="Number of speculative tokens (default: 2)")
    parser.add_argument("--seed", type=int, default=0,
                        help="Random seed (default: 0)")
    parser.add_argument("--csv_dir", type=str, default="csv_dir",
                        help="Directory to save results (default: csv_dir)")
    
    # recoding policy
    parser.add_argument("--num_iter", type=int, default=10,
                        help="Number of iterations for profiling (default: 10)")
    parser.add_argument("--num_k", type=int, default=5,
                        help="Number of last runs to calculate statistics (default: 3)")
    
    
    # other engine hyperparameters
    parser.add_argument("--speculative_draft_tp_size", type=int, default=1,
                        help="Tensor parallel size for draft model (default: 1)")
    parser.add_argument("--max_model_len", type=int, default=1024,
                        help="Maximum model length (default: 1024)")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.8,
                        help="GPU memory utilization (default: 0.8 for speculative decode)")


    args = parser.parse_args()

    return args


# The following codes are the same as the code in benchmark_utils.py

def set_seed(seed: int = 0):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)

def load_humaneval_prompts(num_prompts: int = 100) -> List[str]:
    """Load HumanEval dataset and extract prompts."""
    dataset = load_dataset("../datasets/openai_humaneval")
    prompts = [sample["prompt"] for sample in dataset["test"]]
    if num_prompts > len(prompts):
        diff = num_prompts - len(prompts)
        prompts += prompts[:diff]
    return prompts[:num_prompts]

def load_mtbench_prompts(num_prompts: int = 100) -> List[str]:
    """Load MT-Bench dataset and extract prompts."""
    dataset = load_dataset("../datasets/mt_bench_prompts")
    # MT-Bench prompts are pairs, we'll use the first prompt from each pair
    prompts = []
    for sample in dataset["train"]:
        prompts.append(sample["prompt"][0])
        prompts.append(sample["prompt"][1])
    if num_prompts > len(prompts):
        diff = num_prompts - len(prompts)
        prompts += prompts[:diff]
    return prompts[:num_prompts]

def calculate_stats(times: List[Tuple[float, float, float]], k) -> Dict[str, float]:
    """Calculate statistics from the last k runs."""
    # Extract only the last 5 runs
    last_k_runs = times[-k:]
    
    # Calculate average prefill and decode times
    avg_prefill = sum(run[0] for run in last_k_runs) / k
    avg_decode = sum(run[1] for run in last_k_runs) / k
    avg_total_time = sum(run[2] for run in last_k_runs) / k
    
    return {
        "avg_prefill_speed": avg_prefill,
        "avg_decode_speed": avg_decode,
        "avg_total_time": avg_total_time
    }

def save_results(df: pd.DataFrame, output_file: str):
    """Save results to a CSV file."""
    df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")

    
if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Update num_experts_per_tok in config.json")
    parser.add_argument('--n', type=int, help="New value for num_experts_per_tok")
    parser.add_argument('--model', type=str, default="../models/Qwen2-57B-A14B-Instruct",)
    args = parser.parse_args()

    update_num_experts_per_tok(args.n, args.model)