from src.run.main import run
from src.run.orchestrate.parallel import run_parallel
from src.run.orchestrate.config import StoriesBaseArgs
from src.run.utils import get_timestamp

from copy import deepcopy
import json, random
from pathlib import Path

def run_experiment():

    NUM_RUNS = 5

    num_core_labels = 24
    configs = []
    base_args = deepcopy(StoriesBaseArgs)
   
    root_dir = Path("src").absolute()
    res_root = root_dir / f"results/stories/04/combined_{get_timestamp()}"

    dataset_metadata = json.load(open(StoriesBaseArgs["data_dirs"][0] / "metadata.json"))
    all_labels = sorted(dataset_metadata["all"]["labels"])

    base_args['epochs'] = 2
    base_args['stages'] = [
        {
            "name": "baseline", "ft_forget": False, "do_checkpoint": True,
        },
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": True,
            "aux_route_prc": 0.75, "robust_prc": 0.5, "expert_dist": "prc_one", "aux_exp_prc": 5/105,
        },
    ]

    num_aux_labels = [4, 8, 12, 16, 20]
    aux_batch_limits = [0.05, 0.10, 0.15, 0.20, 0.25]

    random.seed(42)
    seed = 0
    for _ in range(NUM_RUNS):
        for num_aux, batch_limit in zip(num_aux_labels, aux_batch_limits):

            run_config = deepcopy(base_args)
            run_config["seed"] = seed
            seed += 1

            num_labels = num_core_labels + num_aux
            choosen_labels = random.sample(all_labels, num_labels)
            core_labels = choosen_labels[:num_core_labels]
            aux_labels = choosen_labels[num_core_labels:]

            run_config['core_labels'] = core_labels
            run_config["aux_labels"] = aux_labels
            run_config['aux_batch_limit'] = batch_limit

            configs.append(run_config)

    run_parallel(
        configs=configs,
        res_root=res_root,
    )

if __name__ == '__main__':
    run_experiment()