import os
import argparse

parser = argparse.ArgumentParser(
    description="Create batch scripts for training Brax environments with different hyperparameters."
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="./evarl_continuous_scripts/",
    help="Directory to save the batch scripts.",
)

# Create hyperparameter combinations
# A2C
ENV_NAME = ["ant", "halfcheetah", "reacher"]
SEED = list(range(5))
NUM_ENVS = [2048]
NUM_STEPS = [10]
TOTAL_TIMESTEPS = [2e7]

# Data collection
driver_traj_len = [25]
general_traj_len = [1000]
num_drivertest_states = [5]

# Pred Transformer
batch_size = [256]
learning_rate = [1e-3]
num_epochs = [100]
num_heads = [4]
num_layers = [4]
hidden_dim = [16]
wandb_project="evarl_continuous"
PREDICTABILITY_COEF=[0, 1e-3, 1e-2]
NUM_PRED_UPDATES=[5] # Number of times the predictability transformer would be updated per a2c update
PRED_LR=[0, 1e-5]
use_pretrained_transformer = [0, 1]

# num_heads = [4]
# num_layers = [8]
# hidden_dim = [128]

# Batch script template
batch_template = """#!/bin/bash
#SBATCH -p gpu-preempt
#SBATCH -t 01:00:00
#SBATCH --gpus=1
#SBATCH --mem=32g
#SBATCH --constraint=vram23
#SBATCH --job-name="evarl_continuous_{experiment_name}"
#SBATCH --output="./runs/evarl_continuous/training-%A_%a.out"

## Load the Python environment
eval "$(conda shell.bash hook)"   # initialise conda
conda activate pred310

python train_evarl_continuous.py \
  --experiment_name {experiment_name} \
  --wandb_project {wandb_project} \
  --PREDICTABILITY_COEF {PREDICTABILITY_COEF} \
  --PRED_LR {PRED_LR} \
  --NUM_PRED_UPDATES {NUM_PRED_UPDATES} \
  --use_pretrained_transformer {use_pretrained_transformer} \
"""


# Create combinations using the list of hyperparameters
from itertools import product
combinations = product(
    ENV_NAME,
    NUM_ENVS,
    NUM_STEPS,
    TOTAL_TIMESTEPS,
    SEED,
    driver_traj_len,
    general_traj_len,
    num_drivertest_states,
    batch_size,
    learning_rate,
    num_epochs,
    num_heads,
    num_layers,
    hidden_dim,
    PREDICTABILITY_COEF,
    PRED_LR,
    use_pretrained_transformer,
    NUM_PRED_UPDATES
)

# If the output directory exists, delete the files in it
output_dir = parser.parse_args().output_dir
if os.path.exists(output_dir):
    for file in os.listdir(output_dir):
        file_path = os.path.join(output_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Create a common runs directory with subdirectory for this script
run_dir = "./runs/evarl_continuous"
if os.path.exists(run_dir):
    for file in os.listdir(run_dir):
        file_path = os.path.join(run_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the run directory if it doesn't exist
os.makedirs(run_dir, exist_ok=True)


exp_names = []
print("Creating batch files for the following combinations:")
for i, combo in enumerate(combinations):
    print(combo)
    
    experiment_name = (
        f"{combo[0]}_{combo[1]}envs_{combo[2]}steps_"
        f"{combo[3]}ts_{combo[4]}seed_"
        f"{combo[5]}dtr_{combo[6]}gtr_"
        f"{combo[7]}dt_{combo[0]}_"
        f"{combo[8]}bs_{combo[9]}lr_"
        f"{combo[10]}ep_{combo[11]}h_"
        f"{combo[12]}l_{combo[13]}hd"
    )
    exp_names.append(experiment_name)
    
    batch_template_filled = batch_template.format(
        experiment_name=experiment_name,
        wandb_project=wandb_project,
        PREDICTABILITY_COEF=combo[14],
        PRED_LR=combo[15],
        use_pretrained_transformer=combo[16],
        NUM_PRED_UPDATES=combo[17],
    )
    # print(batch_template_filled)
    # Define the filename based on the environment and seed
    filename = os.path.join(output_dir, f"sbatch_{i}.sh")
    # Write the filled template to the file
    with open(filename, "w") as f:
        f.write(batch_template_filled)
    print(f"Created {filename}")
# 

# Write the experiment names to a file
with open(os.path.join(output_dir, "experiment_names.txt"), "w") as f:
    for name in exp_names:
        f.write(name + "\n")