from src.run.orchestrate.parallel import run_parallel
from src.run.orchestrate.config import StoriesBaseArgs
from src.run.utils import get_timestamp

from copy import deepcopy
from pathlib import Path
import json

def run_experiment():

    NUM_RUNS = 3

    configs = []

    root_dir = Path("src").absolute()
    res_root = root_dir / f"results/stories/08/combined_{get_timestamp()}"

    data_dir = root_dir / "data/stories"
    base_args = deepcopy(StoriesBaseArgs)
    base_args['data_dirs'] = [data_dir]
    dataset_metadata = json.load(open(data_dir / "metadata.json"))
    all_labels = sorted(dataset_metadata["all"]["labels"])
    num_labels = 4

    base_args['aux_labels'] = all_labels[:num_labels]
    base_args['test_ood'] = True

    base_args['stages'] = [
        {"name": "baseline",  "ft_forget": False, "do_checkpoint": True},
        {"name": "filtering", "ft_forget": True},
        {"name": "coreftaux", "ft_forget": True, "alpha": 1.0, "beta": 0.5},
        {
            "name": "routed", "arch": "lora", "ordered": True, "ft_forget": True, 
            "alpha": 1.0, "beta": 0.5, "lora_attn": True, "lora_mlp": True, "lora_rank": 32
        },
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": True,
            "aux_route_prc": 0.75, "robust_prc": 0.5, "expert_dist": "prc_one"
        },
    ]

    for seed in range(NUM_RUNS):
        run_config = deepcopy(base_args)
        run_config["seed"] = seed
        configs.append(run_config)

    run_parallel(
        configs=configs,
        res_root=res_root,
    )

if __name__ == '__main__':
    run_experiment()