import os
import subprocess
import time

import numpy as np

ALGOS = ["custom_sac"]
ENVS = ["HalfCheetahBulletEnv-v0", "AntBulletEnv-v0", "HopperBulletEnv-v0", "Walker2DBulletEnv-v0"]
# ENVS = [ENVS[0]]
N_SEEDS = 10
# N_SEEDS = 1
SAMPLE_FREQS = [1000]
EVAL_FREQ = 10000
N_EVAL_EPISODES = 20
N_EVAL_ENVS = 5
np.random.seed(8)
SEEDS = np.random.randint(2 ** 20, size=(N_SEEDS,))
N_TIMESTEPS = int(1e6)

os.makedirs(os.path.join("logs", "slurm"), exist_ok=True)

for algo in ALGOS:
    for env_id in ENVS:
        for noise_sample_freq in SAMPLE_FREQS:
            log_folder = f"logs/param_noise/param_noise_{noise_sample_freq}"
            # Episodic sampling
            if noise_sample_freq == 1000:
                noise_sample_freq = -1
            for seed in SEEDS:
                args = [
                    "--algo",
                    algo,
                    "--env",
                    env_id,
                    "--hyperparams",
                    f"param_noise_sample_freq:{noise_sample_freq}",
                    "use_sde:False",
                    "deterministic_exploration:True",
                    "use_param_noise:True",
                    "policy_kwargs:'dict(net_arch=[400, 300], layer_norm=True)'",
                    "--eval-episodes",
                    N_EVAL_EPISODES,
                    "--eval-freq",
                    EVAL_FREQ,
                    "--n-eval-envs",
                    N_EVAL_ENVS,
                    "-f",
                    log_folder,
                    "--seed",
                    seed,
                    "--log-interval",
                    10,
                    "--num-threads",
                    2,
                    "-n",
                    N_TIMESTEPS,
                    "-uuid",
                ]
                args = list(map(str, args))

                command = " ".join(["python", "-u", "train.py"] + args)

                ok = subprocess.call(["sbatch", "cluster_torchy.sh", algo, env_id, "ablation", command])
                time.sleep(0.05)
