import argparse
import subprocess
import shlex

# Constants
MCTS_REPEATS = 5
TEST_REPEATS = 10

# Base command components
base_args = {
    "matrix_function": "inv",
    "distribution_name": "wishart",
    "c": 0.25,
    "d": 5000,
    "epsilon": 1e-3,
    "precision": "float",
    "custom_loss": False,
    "device": "cuda",
    "test_repeats": TEST_REPEATS,
    "make_algorithm": True,
    "mcts_budget": 150000,
    "mcts_c_ucb": 5,
    "use_recent": False,
    "mcts_alpha_pw": 0.4,
    "mcts_repeats": MCTS_REPEATS,
    "test_algorithm": True,
    "plot_y_low": 1e-6,
    "result_dir": "results/inv_d5000",
}

def build_cmd(arg_dict):
    args = [f"--{key}={value}" if not isinstance(value, bool) else f"--{key}={str(value)}" for key, value in arg_dict.items()]
    return f"python main.py {' '.join(args)}"

def modify_args(base, updates):
    new_args = base.copy()
    new_args.update(updates)
    return new_args

# Create command variants
inv_base_arg = modify_args(base_args, {"distribution_name" : "unif"})
sign_base_arg = modify_args(base_args, {"matrix_function": "sign", "result_dir": "results/sign_d5000", "mcts_budget": 80000, "distribution_name": "unif"})
sqrt_base_arg = modify_args(base_args, {"matrix_function": "sqrt", "result_dir": "results/sqrt_d5000", "mcts_budget": 150000})
proot_base_arg = modify_args(base_args, {"matrix_function": "proot", "result_dir": "results/proot_d5000", "mcts_budget": 80000})


cmd_variants = {
    # "sqrt_cifar": build_cmd(modify_args(sqrt_base_arg, {"result_dir":"results/sqrt_cifar_zca", "distribution_name":"CIFAR"})),
    # "inv": build_cmd(inv_base_arg),
    # "inv_wishart": build_cmd(modify_args(inv_base_arg, {"distribution_name": "wishart", "result_dir": "results/inv_wishart_d5000"})),
    # "sign": build_cmd(sign_base_arg),
    #"sqrt": build_cmd(sqrt_base_arg),
    "sqrt_generalize": build_cmd(modify_args(sqrt_base_arg, {"distribution_name":"wishart_unif","result_dir":"results/sqrt_wishart_unif"}))
    # "proot": build_cmd(proot_base_arg),
    # "sign_d1500": build_cmd(modify_args(sign_base_arg, {"d": 1500, "result_dir": "results/sign_d1500"})),
    # "sign_d3000": build_cmd(modify_args(sign_base_arg, {"d": 3000, "result_dir": "results/sign_d3000"})),
    # "sign_d10000": build_cmd(modify_args(sign_base_arg, {"d": 10000, "result_dir": "results/sign_d10000"})),
    # "sign_d5000_cpu": build_cmd(modify_args(sign_base_arg, {"device": "cpu", "result_dir": "results/sign_d5000_cpu"})),
    # "sign_d5000_double": build_cmd(modify_args(sign_base_arg, {"precision": "double", "result_dir": "results/sign_d5000_double"})),
    # "sign_hessian": build_cmd(modify_args(sign_base_arg, {"distribution_name": "quartic_saddle", "result_dir": "results/sign_quartic_saddle"})),
    # "sqrt_d10000": build_cmd(modify_args(sqrt_base_arg, {"d": 10000, "result_dir": "results/sqrt_d10000"})),
    # "proot_graphLaplacian": build_cmd(modify_args(proot_base_arg, {"result_dir": "results/proot_graphLaplacian", "distribution_name": "Erdos_Renyi", "c": 0.4})),
}

if __name__ == "__main__":

    import time
    for key, cmd in cmd_variants.items():

        # Print or execute
        with open("times.txt", "a") as f:
            f.write(f">>> Running: {key}\n")
        subprocess.run(shlex.split(cmd), shell=False)

        # Test and Plot code
        #subprocess.run(shlex.split(cmd.replace("make_algorithm=True","make_algorithm=False")), shell=False)

        # Test for different distribution
        #subprocess.run(shlex.split(cmd.replace("make_algorithm=True","make_algorithm=False")), shell=False)

        # Plot only code
        #subprocess.run(shlex.split(cmd.replace("make_algorithm=True","make_algorithm=False").replace("test_algorithm=True","test_algorithm=False")), shell=False)