import os
import argparse

parser = argparse.ArgumentParser(
    description="Create batch scripts for training Brax environments with different hyperparameters."
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="./ope_discrete_long_run_scripts/",
    help="Directory to save the batch scripts.",
)

experiment_file_path = "./complete_discrete_long_run_scripts/experiment_names.txt"

with open(experiment_file_path, 'r') as f:
    experiment_names = f.read().splitlines()

# Create hyperparameter combinations
# A2C
ENV_NAME = ["Freeway-MinAtar", "SpaceInvaders-MinAtar", "Asterix-MinAtar"]
SEED = list(range(5))
NUM_ENVS = [64]
NUM_STEPS = [100]
TOTAL_TIMESTEPS = [1e7]

# Data collection
driver_traj_len = [10]
general_traj_len = [200]
num_drivertest_states = [2, 5, 10]


# Pred Transformer
batch_size = [256]
learning_rate = [1e-3]
num_epochs = [100]
num_heads = [4]
num_layers = [4]
hidden_dim = [16]
PREDICTABILITY_COEF=[0, 1e-2, 1e-1]
NUM_PRED_UPDATES=[5] # Number of times the predictability transformer would be updated per a2c update
PRED_LR=[0, 1e-5]
use_pretrained_transformer = [0, 1]
# num_heads = [4]
# num_layers = [8]
# hidden_dim = [128]

# Batch script template
batch_template = """#!/bin/bash
#SBATCH -p gpu-preempt
#SBATCH -t 01:00:00
#SBATCH --gpus=1
#SBATCH --mem=32g
#SBATCH --constraint=vram40
#SBATCH --job-name="evarl_discrete_{experiment_name}"
#SBATCH --output="./run_evarl_discrete/training-%A_%a.out"

## Load the Python environment
eval "$(conda shell.bash hook)"   # initialise conda
conda activate pred310

python ope_discrete.py \
  --experiment_name {experiment_name} \
  --PREDICTABILITY_COEF {PREDICTABILITY_COEF} \
  --PRED_LR {PRED_LR} \
  --use_pretrained_transformer {use_pretrained_transformer} \
"""


# Create combinations using the list of hyperparameters
from itertools import product
combinations = product(
    ENV_NAME,
    NUM_ENVS,
    NUM_STEPS,
    TOTAL_TIMESTEPS,
    SEED,
    driver_traj_len,
    general_traj_len,
    num_drivertest_states,
    batch_size,
    learning_rate,
    num_epochs,
    num_heads,
    num_layers,
    hidden_dim,
    PREDICTABILITY_COEF,
    PRED_LR,
    use_pretrained_transformer,
    NUM_PRED_UPDATES
)

# Batch script template
batch_template = """#!/bin/bash
#SBATCH -p gpu-preempt
#SBATCH -t 01:00:00
#SBATCH --gpus=1
#SBATCH --mem=32g
#SBATCH --constraint=vram40
#SBATCH --job-name="ope_discrete_{experiment_name}"
#SBATCH --output="./runs/ope_discrete_long_run/training-%A_%a.out"

## Load the Python environment
eval "$(conda shell.bash hook)"   # initialise conda
conda activate pred310

python ope_discrete.py \
  --experiment_name {experiment_name} \
  --experiment_name_file {experiment_file_path} \
  --PRED_LR {PRED_LR} \
  --PREDICTABILITY_COEF {PREDICTABILITY_COEF} \
  --use_pretrained_transformer {use_pretrained_transformer} \
"""


# Create combinations using the list of hyperparameters
from itertools import product

# If the output directory exists, delete the files in it
output_dir = parser.parse_args().output_dir
if os.path.exists(output_dir):
    for file in os.listdir(output_dir):
        file_path = os.path.join(output_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Create a common runs directory with subdirectory for this script
run_dir = "./runs/ope_discrete_long_run"
if os.path.exists(run_dir):
    for file in os.listdir(run_dir):
        file_path = os.path.join(run_dir, file)
        if os.path.isfile(file_path):
            os.remove(file_path)
# Create the run directory if it doesn't exist
os.makedirs(run_dir, exist_ok=True)


print("Creating batch files for the following combinations:")
file_counter = 0
for experiment_name in experiment_names:
    for pred_coef in PREDICTABILITY_COEF:
        for pred_lr in PRED_LR:
            for use_pretrained in use_pretrained_transformer:
                # Create a unique experiment name
                batch_template_filled = batch_template.format(
                    experiment_name=experiment_name,
                    experiment_file_path=experiment_file_path,
                    PREDICTABILITY_COEF=pred_coef,
                    PRED_LR=pred_lr,
                    use_pretrained_transformer=use_pretrained
                )
                
                filename = os.path.join(output_dir, f"sbatch_{file_counter}.sh")
                file_counter += 1
                # Write the filled template to the file
                with open(filename, "w") as f:
                    f.write(batch_template_filled)
                print(f"Created {filename}")

    