'''
export OMP_NUM_THREADS=16 && torchrun --nproc_per_node=8 -m src.run.orchestrate.experiment.exp_03_scaling
'''

from src.run.orchestrate.config import RealisticBaseArgs, calc_realistic_model_params
from src.run.utils import get_timestamp
from src.run.main import run

from pathlib import Path
from copy import deepcopy
import gc
import torch

if __name__ == '__main__':

    NUM_RUNS = 3
    base_args = deepcopy(RealisticBaseArgs)
    configs = []

    root_dir = Path("src").absolute()
    res_root = root_dir / f"results/realistic/03/combined_{get_timestamp()}"
    
    base_args['aux_labels'] = ["bigcode", "biology", "nuclear", "cyber"]

    base_args['stages'] = [
        {
            "name": "baseline", "ft_forget": True, "do_checkpoint": True,
        },
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": True,
            "aux_route_prc": 0.75, "robust_prc": 0.5, "expert_dist": "prc_one",
            "do_checkpoint": True,
        },
        {
            "name": "filtering", "ft_forget": True,
            "targets": [("bigcode", "cyber", "nuclear")],
            "do_checkpoint": True,
        },
    ]

    model_sizes = [50e6, 96.7e6, 187.6e6, 363.9e6, 700e6, 2e9]
    batch_sizes = [64, 64, 32, 32, 16, 16]
    acc_steps = [1, 1, 1, 1, 1, 1]

    for seed in range(NUM_RUNS):
        for model_size, batch_size, acc_step in zip(model_sizes, batch_sizes, acc_steps):

            if model_size >= 2e9 and seed != 0:
                continue #only run a single seed for >=2B model

            run_config = deepcopy(base_args)
            run_config["seed"] = seed
            run_config["batch_size"] = batch_size
            run_config["accumulation_steps"] = acc_step
            model_params = calc_realistic_model_params(model_size)
            run_config.update(model_params)
            configs.append(run_config)

    for i, config in enumerate(configs):
        timestamp = get_timestamp()
        config['log_level'] = "DEBUG"
        config['timestamp'] = timestamp
        config['res_dir'] = res_root / f"results_{timestamp}"
        config['do_cleanup_distributed'] = (i == len(configs) - 1)
        run(**config)

        #clear cuda cache
        torch.cuda.empty_cache()
        gc.collect()