import os
import argparse

parser = argparse.ArgumentParser(
    description="Create batch scripts for training Brax environments with different hyperparameters."
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="./ope_continuous_long_run_scripts/",
    help="Directory to save the batch scripts.",
)

experiment_file_path = "./complete_continuous_long_run_scripts/experiment_names.txt"

with open(experiment_file_path, 'r') as f:
    experiment_names = f.read().splitlines()

ENV_NAME = ["ant", "halfcheetah", "reacher"]
SEED = list(range(5))
NUM_ENVS = [2048]
NUM_STEPS = [10]
TOTAL_TIMESTEPS = [5e7]

# Data collection
driver_traj_len = [25]
general_traj_len = [1000]
num_drivertest_states = [5]

# Pred Transformer
batch_size = [256]
learning_rate = [1e-3]
num_epochs = [100]
num_heads = [4]
num_layers = [4]
hidden_dim = [16]
PREDICTABILITY_COEF=[0, 5e-4, 5e-3]
NUM_PRED_UPDATES=[5] # Number of times the predictability transformer would be updated per a2c update
PRED_LR=[0, 1e-5]
use_pretrained_transformer = [0, 1]

# Batch script template
batch_template = """#!/bin/bash
#SBATCH -p superpod-a100
#SBATCH -t 01:00:00
#SBATCH --gpus=1
#SBATCH --mem=32g
#SBATCH --constraint=a100
#SBATCH --job-name="ope_continuous_longrun_{experiment_name}"
#SBATCH --output="./runs/ope_continuous_longrun/training-%A_%a.out"

## Load the Python environment
eval "$(conda shell.bash hook)"   # initialise conda
conda activate pred310

python ope_continuous.py \
  --experiment_name {experiment_name} \
  --experiment_name_file {experiment_file_path} \
  --PRED_LR {PRED_LR} \
  --PREDICTABILITY_COEF {PREDICTABILITY_COEF} \
  --use_pretrained_transformer {use_pretrained_transformer} \
"""

# Create combinations using the list of hyperparameters
from itertools import product
combinations = product(
    ENV_NAME,
    NUM_ENVS,
    NUM_STEPS,
    TOTAL_TIMESTEPS,
    SEED,
    driver_traj_len,
    general_traj_len,
    num_drivertest_states,
    batch_size,
    learning_rate,
    num_epochs,
    num_heads,
    num_layers,
    hidden_dim,
    PREDICTABILITY_COEF,
    PRED_LR,
    use_pretrained_transformer,
    NUM_PRED_UPDATES
)

# If the output directory exists, delete the files in it
output_dir = parser.parse_args().output_dir
if os.path.exists(output_dir):
    for file in os.listdir(output_dir):
        file_path = os.path.join(output_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Create a common runs directory with subdirectory for this script
run_dir = "./runs/ope_continuous_longrun"
if os.path.exists(run_dir):
    for file in os.listdir(run_dir):
        file_path = os.path.join(run_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the run directory if it doesn't exist
os.makedirs(run_dir, exist_ok=True)

print("Creating batch files for the following combinations:")
file_counter = 0
for experiment_name in experiment_names:
    # if "128bs" in experiment_name: # Skip the experiment if it has 128 batch size
    #     continue
    
    env_name = experiment_name.split("_")[0]
    if env_name not in ENV_NAME or "3dt" in experiment_name:
        continue
    
    for pred_coef in PREDICTABILITY_COEF:
        for pred_lr in PRED_LR:
            for use_pretrained in use_pretrained_transformer:
                # Create a unique experiment name
                batch_template_filled = batch_template.format(
                    experiment_name=experiment_name,
                    experiment_file_path=experiment_file_path,
                    PREDICTABILITY_COEF=pred_coef,
                    PRED_LR=pred_lr,
                    use_pretrained_transformer=use_pretrained
                )
                
                filename = os.path.join(output_dir, f"sbatch_{file_counter}.sh")
                file_counter += 1
                # Write the filled template to the file
                with open(filename, "w") as f:
                    f.write(batch_template_filled)
                print(f"Created {filename}")
