import os


def get_filenames_in_a_folder(folder: str):
    """
    returns the list of paths to all the files in a given folder
    """

    if folder[-1] == "/":
        folder = folder[:-1]

    files = os.listdir(folder)
    files = [f"{folder}/" + x for x in files]
    return files


slurm_prefix = """#!/bin/sh
#SBATCH --ntasks-per-node=16
#SBATCH --job-name=lora-finetuning
#SBATCH -t 8:00:0
#SBATCH --mem=32G
#SBATCH --gres=gpu:a100:1

source /mindhive/nklab3/users/XXXX-1/conda_stuff/bin/activate
conda activate neuro
export filepath=/mindhive/nklab3/users/XXXX-1/repos/nesim/training/gpt_neo_125m/train.py
export gpu=0
"""

config_filenames = get_filenames_in_a_folder(folder="./configs")
slurm_commands_folder = "./slurm"

all_commands = []
for f in config_filenames:
    command = f"python3 train.py "
    command += f"--config-filename {f}"
    all_commands.append(command)

slurm_shell_files = []
os.system(f"rm {slurm_commands_folder}/*.sh")
for command_index, command in enumerate(all_commands):
    shell_commands = slurm_prefix + command
    filename = os.path.join(slurm_commands_folder, f"{command_index}.sh")
    with open(filename, "w") as file:
        file.write(shell_commands)
    slurm_shell_files.append(filename)

print(f"Saved {len(slurm_shell_files)} shell files")
