import ray


@ray.remote(num_gpus=1)
def ray_run_exp_one(config):
    from . import run_exp_one

    run_exp_one(config)


def ray_run_all(*configs):
    config_list = []

    def add_config(config):
        if isinstance(config, dict):
            config_list.append(config)
        elif isinstance(config, list):
            for c in config:
                add_config(c)

    for config in configs:
        add_config(config)

    ray.init()
    obj_refs = [ray_run_exp_one.remote(config) for config in config_list]
    # finish all tasks
    ray.get(obj_refs, timeout=None)



def run_all_exps():
    from . import config, run_exp

    # run_exp(config.debug_exp1)
    # run_exp(config.debug_exp2)
    # run_exp(config.debug_exp3)
    # run_exp(config.debug_exp4)
    # run_exp(config.debug_exp5)
    # run_exp(config.debug_exp6)
    # run_exp(config.debug_exp7)
    # run_exp(config.debug_exp8)
    # run_exp(config.debug_exp9)
    # run_exp(config.debug_exp10)
    # run_exp(config.debug_exp11)
    # run_exp(config.debug_exp12)
    # run_exp(config.debug_exp13)
    # run_exp(config.main_exp)
    # run_exp(config.ablation_exp1)
    # run_exp(config.ablation_exp2)
    # run_exp(config.ablation_exp3)
    ray_run_all(
        config.ablation_exp3,
        config.ablation_exp2,
        config.ablation_exp2_a1,
        config.main_exp,
        # config.ablation_exp1,
    )


if __name__ == "__main__":
    run_all_exps()

