import os


def get_filenames_in_a_folder(folder: str):
    """
    returns the list of paths to all the files in a given folder
    """

    if folder[-1] == "/":
        folder = folder[:-1]

    files = os.listdir(folder)
    files = [f"{folder}/" + x for x in files]
    return files

GPU = "H100"
# GPU = "A100-80GB"
NUM_GPU = 1
HOURS = 72
RAM_GB = 64
PACE_ACCOUNT = "gts-XXXX-6-paid" # or "gts-XXXX-6"

slurm_prefix =  f"""#!/bin/bash
#SBATCH -A {PACE_ACCOUNT}
#SBATCH -q inferno
#SBATCH -N 1
#SBATCH --ntasks-per-node=1
#SBATCH --job-name=topo-gpt
#SBATCH --gres=gpu:{NUM_GPU} -C {GPU}
#SBATCH -t {HOURS}:00:00
#SBATCH --mem={RAM_GB}G

source /storage/home/hcoda1/4/XXXX-4/p-XXXX-6-0/miniconda3/bin/activate
conda activate nesim
"""

nesim_configs = get_filenames_in_a_folder(folder="./nesim_configs")
slurm_commands_folder = "./slurm"

"""
Reference:
XXXX
"""
effective_batch_size = 512
checkpoint_every_n_steps = 100
num_warmup_steps = 3000
batch_size = 32
assert effective_batch_size % batch_size == 0 , f"Expected batch size ({batch_size}) to be divisible by effective_batch_size ({effective_batch_size}). This is required for gradient accumulation to work properly"
context_length = 128
gradient_accumulation_steps = effective_batch_size//batch_size
apply_nesim_every_n_steps = 1
dataset_name = "openwebtext"
num_train_epochs = 5
learning_rate = 1e-4

all_commands = []
for nesim_config in nesim_configs:
    command = f"python3 train.py "
    command += f"--num-train-epochs {num_train_epochs} "
    command += f"--learning-rate {learning_rate} "
    command += f"--nesim-config {nesim_config} "
    command += f"--checkpoint-every-n-steps {checkpoint_every_n_steps} "
    command += f"--num-warmup-steps {num_warmup_steps} "
    command += f"--batch-size {batch_size} "
    command += f"--context-length {context_length} "
    command += f"--gradient-accumulation-steps {gradient_accumulation_steps} "
    command += f"--dataset-name {dataset_name} "
    command += f"--apply-nesim-every-n-steps {apply_nesim_every_n_steps}"
    all_commands.append(command)

slurm_shell_files = []
os.system(f"rm {slurm_commands_folder}/*.sh")
for command_index, command in enumerate(all_commands):
    shell_commands = slurm_prefix + command
    filename = os.path.join(slurm_commands_folder, f"{command_index}.sh")
    with open(filename, "w") as file:
        file.write(shell_commands)
    slurm_shell_files.append(filename)

print(f"Saved {len(slurm_shell_files)} shell files")
