import os
import subprocess
import time

import numpy as np

ALGOS = ["custom_sac"]
ENVS = ["HalfCheetahBulletEnv-v0", "AntBulletEnv-v0", "HopperBulletEnv-v0", "Walker2DBulletEnv-v0"]
N_SEEDS = 10
EVAL_FREQ = 10000
N_EVAL_EPISODES = 20
N_EVAL_ENVS = 5
np.random.seed(8)
# SAMPLE_FREQS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1000]
SAMPLE_FREQS = [16]
SEEDS = np.random.randint(2 ** 20, size=(N_SEEDS,))
N_TIMESTEPS = int(1e6)
log_std_init = -3


os.makedirs(os.path.join("logs", "slurm"), exist_ok=True)

for algo in ALGOS:
    for env_id in ENVS:
        for noise_sample_freq in SAMPLE_FREQS:
            log_folder = f"logs/paper/gsde_{noise_sample_freq}"
            for seed in SEEDS:
                args = [
                    "--algo",
                    algo,
                    "--env",
                    env_id,
                    "--hyperparams",
                    f"sde_sample_freq:{noise_sample_freq}",
                    "use_sde:True",
                    f'policy_kwargs:"dict(log_std_init={log_std_init}, net_arch=[400, 300])"',
                    "--eval-episodes",
                    N_EVAL_EPISODES,
                    "--eval-freq",
                    EVAL_FREQ,
                    "--n-eval-envs",
                    N_EVAL_ENVS,
                    "-f",
                    log_folder,
                    "--seed",
                    seed,
                    "--log-interval",
                    10,
                    "--num-threads",
                    2,
                    "-n",
                    N_TIMESTEPS,
                    "-uuid",
                ]
                args = list(map(str, args))

                command = " ".join(["python", "-u", "train.py"] + args)

                ok = subprocess.call(["sbatch", "cluster_torchy.sh", algo, env_id, "ablation", command])
                time.sleep(0.05)
