

from hierarchical_pous.config_setup import parse_args, init_distrbutive_env


def main_experiment():
    args = parse_args()
    from fixed_params import  set_dirs
    set_dirs(args)


    # This call HAS to happen before any other call to JAX, otherwise
    # jax.distributed.initialize will not initialize correctly and the code will
    # default to serial execution mode.
    init_var = init_distrbutive_env(args)


    from fixed_params import get_training_params, get_problem_params, get_model_setup_params
    net_setup_params_getter = get_model_setup_params
    training_params_getter = get_training_params
    problem_params_getter = get_problem_params

    from hierarchical_pous.training_runner import run_training_session
    run_training_session(
        init_var,
        runtime_args=args,
        net_setup_params_getter=net_setup_params_getter,
        training_params_getter=training_params_getter,
        problem_params_getter=problem_params_getter,
        base_seed=42
    )


if __name__ == "__main__":
    main_experiment()
