import os


def get_filenames_in_a_folder(folder: str):
    """
    returns the list of paths to all the files in a given folder
    """

    if folder[-1] == "/":
        folder = folder[:-1]

    files = os.listdir(folder)
    files = [f"{folder}/" + x for x in files]
    return files

nesim_configs = get_filenames_in_a_folder(folder="./nesim_configs")
commands_folder = "./train_commands"

os.system(f"rm -rf {commands_folder}")
os.system(f"mkdir -p {commands_folder}")

"""
Reference:
XXXX
"""
from nesim.utils.json_stuff import load_json_as_dict
training_config = load_json_as_dict(filename = "training_config.json")

effective_batch_size = training_config["effective_batch_size"]
checkpoint_every_n_steps = training_config["checkpoint_every_n_steps"]
num_warmup_steps = training_config["num_warmup_steps"]
batch_size = training_config["batch_size"]
apply_nesim_every_n_steps = training_config["apply_nesim_every_n_steps"]
dataset_name = training_config["dataset_name"]
num_train_epochs = training_config["num_train_epochs"]
learning_rate = training_config["learning_rate"]
context_length = training_config["context_length"]

assert effective_batch_size % batch_size == 0 , f"Expected batch size ({batch_size}) to be divisible by effective_batch_size ({effective_batch_size}). This is required for gradient accumulation to work properly"
gradient_accumulation_steps = effective_batch_size//batch_size


all_commands = []
for nesim_config in nesim_configs:
    command = f"python3 train.py "
    command += f"--num-train-epochs {num_train_epochs} "
    command += f"--learning-rate {learning_rate} "
    command += f"--nesim-config {nesim_config} "
    command += f"--checkpoint-every-n-steps {checkpoint_every_n_steps} "
    command += f"--num-warmup-steps {num_warmup_steps} "
    command += f"--batch-size {batch_size} "
    command += f"--context-length {context_length} "
    command += f"--gradient-accumulation-steps {gradient_accumulation_steps} "
    command += f"--dataset-name {dataset_name} "
    command += f"--apply-nesim-every-n-steps {apply_nesim_every_n_steps}"
    all_commands.append(command)

shell_files = []
os.system(f"rm {commands_folder}/*.sh")
for nesim_config, command in zip(nesim_configs, all_commands):

    run_name = os.path.basename(nesim_config).replace(f".json", "")
    filename = os.path.join(commands_folder, f"{run_name}.sh")
    with open(filename, "w") as file:
        file.write(command)
    shell_files.append(filename)

print(f"Saved {len(shell_files)} shell files")
