from src.run.orchestrate.parallel import run_parallel
from src.run.orchestrate.config import StoriesBaseArgs
from src.run.utils import get_timestamp

from copy import deepcopy
import json
from pathlib import Path

def run_experiment():

    NUM_RUNS = 3

    configs = []

    root_dir = Path("src").absolute()
    res_root = root_dir / f"results/stories/06/combined_{get_timestamp()}"

    dataset_metadata = json.load(open(StoriesBaseArgs["data_dirs"][0] / "metadata.json"))
    all_labels = sorted(dataset_metadata["all"]["labels"])
    num_labels = 4
    base_args = deepcopy(StoriesBaseArgs)
    base_args['aux_labels'] = all_labels[:num_labels]
    base_args['stages'] = [
        {"name": "routed", "arch": "moe", "ordered": False, "ft_forget": True, "expert_dist": "prc_one"}
    ]

    for seed in range(NUM_RUNS):

        # for aux_route_prc in [0.0, 0.25, 0.5, 0.75, 1.0]:
        #     run_config = deepcopy(base_args)
        #     run_config["seed"] = seed
        #     run_config["stages"][0]["aux_route_prc"] = aux_route_prc
        #     run_config["stages"][0]["robust_prc"] = 0.5
        #     configs.append(run_config)

        # for robust_prc in [0.0, 0.125, 0.25, 0.375, 0.50, 0.75, 1.0]:
        for robust_prc in [0.125, 0.375]:
            run_config = deepcopy(base_args)
            run_config["seed"] = seed + NUM_RUNS
            run_config["stages"][0]["aux_route_prc"] = 0.75
            run_config["stages"][0]["robust_prc"] = robust_prc
            configs.append(run_config)

    run_parallel(
        configs=configs,
        res_root=res_root,
        log_level="DEBUG",
    )

if __name__ == '__main__':
    run_experiment()