import subprocess
import os
import itertools
from pathlib import Path
# Configuration
CACHE_DIR   = Path("cache")                       # TTSRouter/cache/
DATA_PATH   = CACHE_DIR / "combined_dataset_210.jsonl"
OFFLINE_EMBEDDINGS = CACHE_DIR / "query_embeddings_combined.json"
OFFLINE_RESULTS = CACHE_DIR / "pruned_question_trials_combined.csv"
ACTION_EMB_FILE  = CACHE_DIR / "action_embeddings_semantic.json"

MAX_PROBLEMS = 210
WARM_UP = 50

# Reward Weights
# Format: [w_acc, w_ver, w_cost, bias]
REWARD_WEIGHTS_LIST = [
    [0.1, 0.1, 0.8, 0.8], # Cost-Sensitive
    [0.4, 0.4, 0.2, 0.2]  # Quality-Priority
]

# Action Spaces
ACTION_SPACES = [
    "uis"
]

# Methods Configuration
# Format: (name, algorithm, param_name, param_value)
# Note: For Random/Oracle/kNN, param is ignored or fixed.
METHODS = [
    {"name": "LinUCB_alpha10", "algo": "lin_ucb", "param": ("alpha", 10)},
]

# Cost Metric (woDA means without Difficulty Awareness, so standard Normalized_EFLOPS)
COST_METRIC = "Normalized_EFLOPS"

def run_experiment(method_conf, action_space, weights):
    """
    Run a single experiment configuration.
    """
    method_name = method_conf["name"]
    algo = method_conf["algo"]
    
    # Construct weights string
    w_str = f"{weights[0]}_{weights[1]}_{weights[2]}"
    
    # Construct Experiment Label
    # Format: {action_space}_{method_name}_w{w_str}
    exp_label = f"{action_space}_{method_name}_w{w_str}"
    
    cmd = [
        "python", "main.py",
        "--algorithm", algo,
        "--virtual_dataset",
        "--data_path", DATA_PATH,  # Pass the new data path explicitly
        "--offline_embeddings", OFFLINE_EMBEDDINGS,
        "--offline_results", OFFLINE_RESULTS,
        "--max_problems", str(MAX_PROBLEMS),
        "--warm_up", str(WARM_UP),
        "--allowed_action_space", action_space,
        "--exp_label", exp_label,
        "--action_embedding_file", ACTION_EMB_FILE,
        "--cost_metric", COST_METRIC,
        "--reward_weights"
    ] + [str(w) for w in weights]
    
    # Add algorithm specific parameters
    if method_conf["param"]:
        p_name, p_val = method_conf["param"]
        cmd.extend([f"--{p_name}", str(p_val)])
        
    print(f"\n[Experiment] Running: {exp_label}")
    # print(f"Command: {' '.join(cmd)}")
    
    try:
        subprocess.run(cmd, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Error running experiment {exp_label}: {e}")

def main():
    tasks = []
    
    # Generate all combinations
    for weights in REWARD_WEIGHTS_LIST:
        for action_space in ACTION_SPACES:
            for method in METHODS:
                tasks.append({
                    "method": method,
                    "action_space": action_space,
                    "weights": weights
                })
    
    print(f"Total experiments to run: {len(tasks)}")
    
    for i, task in enumerate(tasks):
        print(f"\n--- Progress: {i+1}/{len(tasks)} ---")
        run_experiment(
            task["method"],
            task["action_space"],
            task["weights"]
        )

if __name__ == "__main__":
    main()
