from nesim.utils.folder import get_filenames_in_a_folder
from nesim.utils.tmux import launch_job_in_tmux

CHECKPOINTS_FOLDER = "/research/XXXX-1/toponets_resnet50_imagenet_checkpoints"

def save_list_to_file(strings_list, filename):
    """
    Save a list of strings to a file, with each string on a new line.

    Parameters:
    strings_list (list of str): The list of strings to save.
    filename (str): The name of the file to save the strings to.
    """
    try:
        with open(filename, 'w') as file:
            for string in strings_list:
                file.write(f"{string}\n")
        print(f"File '{filename}' saved successfully.")
    except Exception as e:
        print(f"An error occurred while saving the file: {e}")

def divide_list(lst, n):
    if n <= 0:
        raise ValueError("The number of parts must be a positive integer.")
    
    # Calculate the size of each part
    avg = len(lst) // n
    remainder = len(lst) % n

    result = []
    start = 0

    for i in range(n):
        end = start + avg + (1 if i < remainder else 0)
        result.append(lst[start:end])
        start = end

    return result


nesim_configs = get_filenames_in_a_folder(folder = "./nesim_configs")
nesim_configs.sort()

def get_train_command(nesim_config_filename):
    command = f"""python train_imagenet.py \
--config-file rn50_configs/rn50_88_epochs.yaml \
--data.train_dataset=/research/datasets/imagenet_ffcv/train_500_0.50_90.ffcv \
--data.val_dataset=/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv  \
--data.num_workers=32 --data.in_memory=1 \
--logging.folder={CHECKPOINTS_FOLDER} \
--nesim.config_filename='{nesim_config_filename}'"""
    return command

commands = [
    get_train_command(nesim_config_filename=n)
    for n in nesim_configs
]

gpus = [0,1]
conda_env = "ffcv"
commands_divided_per_gpu = []
count = 0

commands_divided = divide_list(
    lst = commands,
    n = len(gpus)
)

for commands_for_single_gpu, gpu in zip(commands_divided, gpus):
    single_worker_commands = [
            f"conda activate {conda_env} && CUDA_VISIBLE_DEVICES={gpu} {c}"
            for c in commands_for_single_gpu
        ]
    
    save_list_to_file(
        strings_list=single_worker_commands,
        filename=f"{count}.sh"
    )

    launch_job_in_tmux(
        command = f"source {count}.sh",
        session_name=f"rn50-{count}"
    )
    count += 1