from src.run.orchestrate.parallel import run_parallel
from src.run.orchestrate.config import StoriesBaseArgs
from src.run.utils import get_timestamp
# from src.run.main import run

from copy import deepcopy
from pathlib import Path
import json

def run_experiment():

    configs = []
    base_args = deepcopy(StoriesBaseArgs)

    root_dir = Path("src").absolute()
    res_root = root_dir / f"results/stories/01/combined_{get_timestamp()}_DEBUG"

    dataset_metadata = json.load(open(base_args["data_dirs"][0] / "metadata.json"))
    all_labels = sorted(dataset_metadata["all"]["labels"])
    num_labels = 4
    base_args['aux_labels'] = all_labels[:num_labels]
    base_args['seed'] = 0
    ft_forget = False

    #TODO rerun with compute-equivalent LoRA
    #TODO rerun with lora and coreftaux parameterized via core_prc and aux_prc

    stages = [
        {"name": "baseline", "ft_forget": ft_forget, "do_checkpoint": True},
        {"name": "maxent", "ft_forget": ft_forget, "me_alpha_retain": 15},
        {"name": "maxent", "ft_forget": ft_forget, "me_alpha_retain": 30},
        {"name": "coreftaux", "ft_forget": ft_forget, "alpha": 1.0, "beta": 0.5},
        {
            "name": "routed", "arch": "lora", "ordered": True, "ft_forget": ft_forget, 
            "lora_attn": True, "lora_mlp": True, "lora_rank": 32,
            "alpha": 1.0, "beta": 0.5,
            "equal_compute": True,
        },
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": ft_forget, 
            "aux_route_prc": 0.75, "robust_prc": 0.5, "expert_dist": "prc_one",
        },
        {
            "name": "routed", "arch": "demix", "ft_forget": ft_forget, 
            "expert_dist": "equal_one",
        },
    ]

    for stage in stages:
        config = deepcopy(base_args)
        config['stages'] = [stage]
        configs.append(config)

    run_parallel(
        configs=configs,
        res_root=res_root,
    )

    # for i, config in enumerate(configs):
    #     timestamp = get_timestamp()
    #     config['timestamp'] = timestamp
    #     config['res_dir'] = res_root / f"results_{timestamp}"
    #     # Only cleanup distributed on the last run (reuse process group between runs)
    #     config['do_cleanup_distributed'] = (i == len(configs) - 1)
    #     run(**config)

if __name__ == '__main__':
    run_experiment()