'''
export OMP_NUM_THREADS=16 && torchrun --nproc_per_node=8 -m src.run.orchestrate.experiment.exp_02_realistic_maxent
'''

from src.run.orchestrate.config import RealisticBaseArgs, calc_realistic_model_params
from src.run.utils import get_timestamp
from src.run.main import run

from pathlib import Path
from copy import deepcopy

if __name__ == '__main__':
    
    NUM_RUNS = 1
    base_args = deepcopy(RealisticBaseArgs)
    configs = []

    root_dir = Path("src").absolute()
    res_root = root_dir / f"results/realistic/02/combined_{get_timestamp()}"

    base_args['aux_labels'] = ["bigcode", "biology", "nuclear", "cyber"]

    checkpoint_paths = [
        "src/results/realistic/02/combined_2026-01-11_05-40-00/results_2026-01-11_05-40-00/baseline/baseline_model.pth",
        "src/results/realistic/02/combined_2026-01-11_05-40-00/results_2026-01-13_15-48-37/baseline/baseline_model.pth",
        "src/results/realistic/02/combined_2026-01-11_05-40-00/results_2026-01-16_01-43-52/baseline/baseline_model.pth",
    ]

    model_params = calc_realistic_model_params(700e6)
    base_args.update(model_params)

    for seed, checkpoint_path in enumerate(checkpoint_paths):
        run_config = deepcopy(base_args)
        # Use stage-level checkpoint config: load baseline from checkpoint, skip train & eval
        run_config['stages'] = [
            {"name": "baseline", "checkpoint": checkpoint_path, "do_train": False, "do_eval": False, "ft_forget": False},
            {"name": "maxent", "ft_forget": True, "me_alpha_retain": 100, "me_steps": 400, "me_lr": 5e-5},
        ]
        run_config["seed"] = seed
        configs.append(run_config)

    for i, config in enumerate(configs):
        timestamp = get_timestamp()
        config['timestamp'] = timestamp
        config['res_dir'] = res_root / f"results_{timestamp}"
        # Only cleanup distributed on the last run (reuse process group between runs)
        config['do_cleanup_distributed'] = (i == len(configs) - 1)
        run(**config)
