from src.run.main import run
from src.run.orchestrate.config import StoriesBaseArgs
from src.run.utils import get_timestamp

from copy import deepcopy
from pathlib import Path
import json

def run_experiment():

    base_args = deepcopy(StoriesBaseArgs)

    root_dir = Path("src").absolute()
    base_args['epochs'] = 1
    base_args['timestamp'] = get_timestamp()
    base_args['res_dir'] = root_dir / f"results/stories/10/results_{get_timestamp()}"
    
    base_args['aux_labels'] = [
        'alien-encounters',
        'space-exploration',
        'dinosaurs',
        'gardens',
        'pirates',
        'sports',
        'superheroes',
        'time-travel',
        'virtual-worlds',
    ]

    base_args['stages'] = [
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": False, 
            "aux_route_prc": 0.5, "robust_prc": 0.25, 
            "expert_dist": "prc_one", #core + one aux sum to baseline
            "aux_exp_prc": 0.05,
            "gen_samples": True,
        },
        {"name": "baseline", "ft_forget": False, "gen_samples": True},
        {"name": "maxent", "ft_forget": False, "me_alpha_retain": 30, "gen_samples": True},
        {"name": "filtering", "ft_forget": False, "targets": [tuple(sorted(base_args['aux_labels']))], "gen_samples": True}
    ]
        
    
    base_args['log_level'] = "DEBUG"
    base_args['do_cleanup_distributed'] = True
    base_args['seed'] = 0

    run(**base_args)

if __name__ == '__main__':
    run_experiment()