'''
export OMP_NUM_THREADS=16 && torchrun --nproc_per_node=8 -m src.run.orchestrate.experiment.exp_07_arbsub
'''
from src.run.main import run
from src.run.orchestrate.config import RealisticBaseArgs, calc_realistic_model_params
from src.run.utils import get_timestamp

from pathlib import Path
from copy import deepcopy

def run_experiment():

    NUM_RUNS = 3

    configs = []

    base_args = deepcopy(RealisticBaseArgs)
    base_args['arbsub'] = True

    root_dir = Path("src").absolute()
    res_root = root_dir / f"results/realistic/07/combined_{get_timestamp()}"

    base_args['aux_labels'] = ["bigcode", "biology", "nuclear", "cyber"]
    base_args['stages'] = [
        {
            "name": "routed", "arch": "lora", "ordered": True, "ft_forget": False, 
            "core_prc": 0.9, "aux_prc": 0.5, 
            "lora_attn": True, "lora_mlp": True, 
            "expert_dist": "prc_one", "equal_compute": True, "train_arbsub": False,
        },
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": False,
            "aux_route_prc": 0.75, "robust_prc": 0.5, 
            "expert_dist": "prc_one", "train_arbsub": False,
        },
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": False,
            "aux_route_prc": 0.75, "core_robust_prc": 0.5, "aux_robust_prc": 1.0, 
            "expert_dist": "prc_one", "train_arbsub": True,
        },
    ]

    model_params = calc_realistic_model_params(700e6)
    base_args.update(model_params)

    for seed in range(NUM_RUNS):
        run_config = deepcopy(base_args)
        run_config["seed"] = seed
        configs.append(run_config)

    for i, config in enumerate(configs):
        timestamp = get_timestamp()
        config['log_level'] = "DEBUG"
        config['timestamp'] = timestamp
        config['res_dir'] = res_root / f"results_{timestamp}"
        config['do_cleanup_distributed'] = (i == len(configs) - 1)
        run(**config)

if __name__ == "__main__":
    run_experiment()