from generate import SbatchGenerator

run_group = "reproduce-minimum"
dataset_root = "[my_data_root]"  # TODO: fill in the root directory of your dataset

num_jobs_per_gpu = 1
gpu_limit = 16

for debug in [True, False]:
    gen = SbatchGenerator(j=num_jobs_per_gpu, limit=gpu_limit, prefix=("MUJOCO_GL=egl", "python main.py"), comment=run_group)
    if debug:
        gen.add_common_prefix({"run_group": run_group + "_debug", "offline_steps": 100, "eval_episodes": 0, 
                               "video_episodes": 0, "eval_interval": 20, "log_interval": 10, "dataset_replace_interval": 10})
    else:
        gen.add_common_prefix({"run_group": run_group, "offline_steps": 1000000, "eval_interval": 250000, "online_steps": 0})

    domains = [ 
        "humanoidmaze-giant-navigate-oraclerep-v0",
        "puzzle-4x5-play-oraclerep-v0", 
    ]

    for seed in [1]:
        for domain in domains:

            if "cube-triple-play-oraclerep-v0" in domain:
                sizes = ["100m"]
            if "cube-quadruple-play-oraclerep-v0" in domain:
                sizes = ["100m"]
            if "cube-octuple-play-oraclerep-v0" in domain:
                sizes = ["1b"]
            if "humanoidmaze-giant-navigate-oraclerep-v0" in domain:
                sizes = [None]
            if "puzzle-4x5" in domain:
                sizes = [None]
            if "puzzle-4x6" in domain:
                sizes = ["1b"]
            
            for size in sizes:
                
                # environment-specific parameters
                if "humanoid" in domain:
                    extra_kwargs = {"agent.q_agg": "mean"}
                    distill_tau = 0.5
                    backup_horizon = 25
                    policy_chunk_size = 1
                
                elif "cube" in domain:
                    extra_kwargs = {"agent.q_agg": "min"}
                    distill_tau = 0.8
                    backup_horizon = 25
                    policy_chunk_size = 5
                
                elif "puzzle" in domain:
                    extra_kwargs = {"agent.q_agg": "mean"}
                    distill_tau = 0.5
                    backup_horizon = 25
                    policy_chunk_size = 1
                
                kwargs = {
                    "seed": seed,
                    "agent": "agents/dqc.py",
                    "agent.num_qs": 2,
                    "agent.policy_chunk_size": policy_chunk_size,
                    "agent.backup_horizon": backup_horizon,
                    "agent.distill_method": "expectile",
                    "agent.distill_tau": distill_tau,
                    "env_name": domain,
                    **extra_kwargs,
                }
                if size is not None:
                    if "puzzle-4x5" in domain:
                        kwargs["dataset_dir"] = f"{dataset_root}/puzzle-4x5-play-{size}-v0"
                    if "puzzle-4x6" in domain:
                        kwargs["dataset_dir"] = f"{dataset_root}/puzzle-4x6-play-{size}-v0"
                    if "cube-quadruple" in domain:
                        kwargs["dataset_dir"] = f"{dataset_root}/cube-quadruple-play-{size}-v0"
                    if "cube-triple" in domain:
                        kwargs["dataset_dir"] = f"{dataset_root}/cube-triple-play-{size}-v0"
                    if "cube-octuple" in domain:
                        kwargs["dataset_dir"] = f"{dataset_root}/cube-octuple-play-{size}-v0"

                if debug:
                    kwargs["agent.batch_size"] = 8

                kwargs.update({
                    "agent.implicit_backup_type": "quantile",
                    "agent.backup_tau": 0.9,
                    "tags": f'"DQC,h={backup_horizon},ha={policy_chunk_size}"',
                })
                gen.add_run(kwargs)

    sbatch_str = gen.generate_str()
    if debug:
        with open(f"sbatch/{run_group}_debug.sh", "w") as f:
            f.write(sbatch_str)
    else:
        with open(f"sbatch/{run_group}.sh", "w") as f:
            f.write(sbatch_str)
