import os
import argparse

parser = argparse.ArgumentParser(
    description="Create batch scripts for training Brax environments with different hyperparameters."
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="./complete_discrete_long_run_scripts/",
    help="Directory to save the batch scripts.",
)


# Create hyperparameter combinations
# A2C
ENV_NAME = ["Freeway-MinAtar", "SpaceInvaders-MinAtar", "Asterix-MinAtar"]
SEED = list(range(5))
NUM_ENVS = [64]
NUM_STEPS = [100]
TOTAL_TIMESTEPS = [1e7]

# Data collection
driver_traj_len = [10]
general_traj_len = [200]
num_drivertest_states = [1, 2, 5, 10]

# Pred Transformer
batch_size = [128, 256]
learning_rate = [1e-3]
num_epochs = [100]
num_heads = [4]
num_layers = [4]
hidden_dim = [16]
wandb_project="complete-discrete"

# Batch script template
batch_template = """#!/bin/bash
#SBATCH -p gpu-preempt
#SBATCH -t 02:00:00
#SBATCH --gpus=1
#SBATCH --mem=32g
#SBATCH --constraint=vram40
#SBATCH --job-name="complete_discrete_{ENV_NAME}"
#SBATCH --output="./runs/complete_discrete_long_run/training-%A_%a.out"

## Load the Python environment
eval "$(conda shell.bash hook)"   # initialise conda
conda activate pred310

python generate_discrete_action_offline_data_evarl_pred_transformer.py \
  --ENV_NAME {ENV_NAME}\
  --NUM_ENVS {NUM_ENVS} \
  --NUM_STEPS {NUM_STEPS} \
  --TOTAL_TIMESTEPS {TOTAL_TIMESTEPS} \
  --SEED {SEED} \
  --driver_traj_len {driver_traj_len} \
  --general_traj_len {general_traj_len} \
  --num_drivertest_states {num_drivertest_states} \
  --env_name {ENV_NAME} \
  --batch_size {batch_size} \
  --learning_rate {learning_rate} \
  --num_epochs {num_epochs} \
  --num_heads {num_heads} \
  --num_layers {num_layers} \
  --hidden_dim {hidden_dim} \
  --experiment_name "{experiment_name}" \
  --wandb_project {wandb_project} \
"""


# Create combinations using the list of hyperparameters
from itertools import product
combinations = product(
    ENV_NAME,
    NUM_ENVS,
    NUM_STEPS,
    TOTAL_TIMESTEPS,
    SEED,
    driver_traj_len,
    general_traj_len,
    num_drivertest_states,
    batch_size,
    learning_rate,
    num_epochs,
    num_heads,
    num_layers,
    hidden_dim,
)

# If the output directory exists, delete the files in it
output_dir = parser.parse_args().output_dir
if os.path.exists(output_dir):
    for file in os.listdir(output_dir):
        file_path = os.path.join(output_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Create a common runs directory with subdirectory for this script
run_dir = "./runs/complete_discrete_long_run"
if os.path.exists(run_dir):
    for file in os.listdir(run_dir):
        file_path = os.path.join(run_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the run directory if it doesn't exist
os.makedirs(run_dir, exist_ok=True)

exp_names = []
print("Creating batch files for the following combinations:")
for i, combo in enumerate(combinations):
    print(combo)

    experiment_name = (
        f"{combo[0]}_{combo[1]}envs_{combo[2]}steps_"
        f"{combo[3]}ts_{combo[4]}seed_"
        f"{combo[5]}dtr_{combo[6]}gtr_"
        f"{combo[7]}dt_{combo[0]}_"
        f"{combo[8]}bs_{combo[9]}lr_"
        f"{combo[10]}ep_{combo[11]}h_"
        f"{combo[12]}l_{combo[13]}hd"
    )
    exp_names.append(experiment_name)
    
    batch_template_filled = batch_template.format(
        ENV_NAME=str(combo[0]),
        NUM_ENVS=int(combo[1]),
        NUM_STEPS=int(combo[2]),
        TOTAL_TIMESTEPS=int(combo[3]),
        SEED=int(combo[4]),
        driver_traj_len=int(combo[5]),
        general_traj_len=int(combo[6]),
        num_drivertest_states=int(combo[7]),
        batch_size=int(combo[8]),
        learning_rate=float(combo[9]),
        num_epochs=int(combo[10]),
        num_heads=int(combo[11]),
        num_layers=int(combo[12]),
        hidden_dim=int(combo[13]),
        experiment_name=experiment_name,
        wandb_project=wandb_project
    )
    # print(batch_template_filled)
    # Define the filename based on the environment and seed
    filename = os.path.join(output_dir, f"sbatch_{i}.sh")
    # Write the filled template to the file
    with open(filename, "w") as f:
        f.write(batch_template_filled)
    print(f"Created {filename}")
# 

# Write the experiment names to a file
with open(os.path.join(output_dir, "experiment_names.txt"), "w") as f:
    for name in exp_names:
        f.write(name + "\n")
