import os


def get_filenames_in_a_folder(folder: str):
    """
    returns the list of paths to all the files in a given folder
    """

    if folder[-1] == "/":
        folder = folder[:-1]

    files = os.listdir(folder)
    files = [f"{folder}/" + x for x in files]
    return files

nesim_configs = get_filenames_in_a_folder(folder="./nesim_configs")
commands_folder = "./train_commands"

os.system(f"rm -rf {commands_folder}")
os.system(f"mkdir -p {commands_folder}")

"""
Reference:
XXXX
"""
effective_batch_size = 512
checkpoint_every_n_steps = 1
num_warmup_steps = 3000
batch_size = 8
assert effective_batch_size % batch_size == 0 , f"Expected batch size ({batch_size}) to be divisible by effective_batch_size ({effective_batch_size}). This is required for gradient accumulation to work properly"
context_length = 128
gradient_accumulation_steps = effective_batch_size//batch_size
apply_nesim_every_n_steps = 1
dataset_name = "openwebtext"
num_train_epochs = 5
learning_rate = 1e-4


all_commands = []
for nesim_config in nesim_configs:
    command = f"accelerate launch train.py "
    command += f"--num-train-epochs {num_train_epochs} "
    command += f"--learning-rate {learning_rate} "
    command += f"--nesim-config '{nesim_config}' "
    command += f"--checkpoint-every-n-steps {checkpoint_every_n_steps} "
    command += f"--num-warmup-steps {num_warmup_steps} "
    command += f"--batch-size {batch_size} "
    command += f"--context-length {context_length} "
    command += f"--gradient-accumulation-steps {gradient_accumulation_steps} "
    command += f"--dataset-name {dataset_name} "
    command += f"--apply-nesim-every-n-steps {apply_nesim_every_n_steps}"
    all_commands.append(command)

shell_files = []
os.system(f"rm {commands_folder}/*.sh")
for nesim_config, command in zip(nesim_configs, all_commands):

    run_name = os.path.basename(nesim_config).replace(f".json", "")
    filename = os.path.join(commands_folder, f"{run_name}.sh")
    with open(filename, "w") as file:
        file.write(command)
    shell_files.append(filename)

print(f"Saved {len(shell_files)} shell files")