from generate import SbatchGenerator

run_group = "reproduce"
data_root = "[]" # TODO: fill me in!

num_jobs_per_gpu = 1
gpu_limit = 50

for debug in [True, False]:
    gen = SbatchGenerator(j=num_jobs_per_gpu, limit=gpu_limit, prefix=("MUJOCO_GL=egl", "python main.py"), comment=run_group)
    if debug:
        gen.add_common_prefix(
            {"run_group": run_group + "_debug", "offline_steps": 100, "eval_episodes": 1, "eval_interval": 5, 
             "start_training": 50, "online_steps": 100, "log_interval": 25})
    else:
        gen.add_common_prefix({"run_group": run_group, "offline_steps": 1000000, "online_steps": 1000000, "save_interval": 100000})

    env_names = [
        "cube-double-play-singletask-v0",
        "cube-triple-play-singletask-v0",
        "cube-quadruple-play-singletask-v0",
        "antmaze-large-navigate-singletask-v0", 
        "antmaze-giant-navigate-singletask-v0", 
    ]
    
    for seed in [10001, 20002, 30003, 40004]:
        for env_name in env_names:
            
            if "ant" in env_name:
                horizon_length = 1
            else:
                horizon_length = 5
            
            base_kwargs = {
                "seed": seed,
                "utd_ratio": 1, 
                'agent.num_qs': 10,
                "env_name": env_name,
                "horizon_length": horizon_length,
                "agent.discount": 0.999 if "giant" in env_name else 0.99,
                "agent.action_chunking": True,
                "agent.actor_hidden_dims": '"(512, 512, 512, 512)"',
                "agent.value_hidden_dims": '"(512, 512, 512, 512)"',
                "agent.batch_size": 256,
                "agent.best_of_n": 1,
                "agent.rho": 0.5,
            }
            if "cube-quadruple" in env_name:
                base_kwargs["ogbench_dataset_dir"] = f"{data_root}/cube-quadruple-play-100m-v0/"
            
            # FBRAC
            if "cube-double-play" in env_name:
                alpha = 0.1
            elif "antmaze" in env_name:
                alpha = 10.0
            elif "cube-triple-play" in env_name:
                alpha = 0.01
            elif "cube-quadruple-play" in env_name:
                alpha = 1.0

            kwargs = {"agent": "agents/fbrac.py", "agent.alpha": alpha, 
                        "tags": f'"FBRAC,alpha={alpha}"', **base_kwargs}
            gen.add_run(kwargs)

            # CGQL
            guidance_coef = 0.1
            kwargs = {"agent": "agents/cgql.py", "agent.simple": True, 
                        "agent.target_guidance": True, "agent.guidance_coef": guidance_coef, 
                        "agent.noisy_coef": 0., "tags": f'"CGQL,gc={guidance_coef}"',
                        **base_kwargs}
            gen.add_run(kwargs)

            # adjoint matching: qam
            if "cube-double-play" in env_name:
                tau = 1.0
                clip_critic_grad = 10.0
                # rho = 0.5
            elif "antmaze-large-navigate" in env_name:
                tau = 10.0
                clip_critic_grad = 100.0
                # rho = 0.5
            elif "antmaze-giant-navigate" in env_name:
                tau = 3.0
                clip_critic_grad = 100.0
                # rho = 0.5
            elif "cube-triple-play" in env_name:
                tau = 10.0
                clip_critic_grad = 10.0
                # rho = 0.5
            elif "cube-quadruple-play" in env_name:
                tau = 3.0
                clip_critic_grad = 10.0
                # rho = 0.5
            elif "humanoidmaze-medium-navigate" in env_name:
                tau = 10.0
                clip_critic_grad = 10.0
            elif "humanoidmaze-large-navigate" in env_name:
                tau = 3.0
                clip_critic_grad = 10.0

            kwargs = {"agent": "agents/qam.py", "agent.clip_adj": True, 
                    "agent.clip_critic_grad": clip_critic_grad, 
                    "agent.target_actor": True, 
                    "agent.tau": tau, 
                    "tags": f'"QAM,ccg={clip_critic_grad},tau={tau}"', **base_kwargs}
            gen.add_run(kwargs)
            
            # FQL
            if "cube-double-play" in env_name:
                alpha = 300.0
            elif "antmaze" in env_name:
                alpha = 10.0
            elif "cube-triple-play" in env_name:
                alpha = 100.0
            elif "cube-quadruple-play" in env_name:
                alpha = 30.0

            kwargs = {"agent": "agents/fql.py", "agent.alpha": alpha, 
                        "tags": f'"FQL,alpha={alpha}"', **base_kwargs}
            gen.add_run(kwargs)

            # DSRL
            if "cube-double-play" in env_name:
                noise_scale = 0.5
            elif "antmaze" in env_name:
                noise_scale = 0.25
            elif "cube-triple-play" in env_name:
                noise_scale = 1.0
            elif "cube-quadruple-play" in env_name:
                noise_scale = 1.25
            
            kwargs = {"agent": "agents/dsrl.py", "agent.noise_scale": noise_scale, 
                        "tags": f'"DSRL,ns={noise_scale}"', **base_kwargs}
            gen.add_run(kwargs)

            # FEdit
            if "cube-double-play" in env_name:
                edit_scale = 0.25
            elif "antmaze" in env_name:
                edit_scale = 0.25
            elif "cube-triple-play" in env_name:
                edit_scale = 0.5
            elif "cube-quadruple-play" in env_name:
                edit_scale = 0.5

            kwargs = {"agent": "agents/fedit.py", "agent.edit_scale": edit_scale, 
                "tags": f'"FEDIT,edit_scale={edit_scale}"', **base_kwargs}
            gen.add_run(kwargs)
            
            # FAWAC
            tau = 10
            kwargs = {"agent": "agents/fawac.py", "agent.inv_temp": tau,
                        "tags": f'"FAWAC,tau={tau}"', **base_kwargs}
            gen.add_run(kwargs)

            # IFQL
            N = 32
            kwargs = {"agent": "agents/ifql.py", "agent.num_samples": N,
                        "tags": f'"IFQL,N={N}"', **base_kwargs}
            gen.add_run(kwargs)

    sbatch_str = gen.generate_str()
    if debug:
        with open(f"sbatch/{run_group}_debug.sh", "w") as f:
            f.write(sbatch_str)
    else:
        with open(f"sbatch/{run_group}.sh", "w") as f:
            f.write(sbatch_str)
