import os
from tqdm import tqdm
from nesim.utils.tmux import launch_job_in_tmux

def divide_list(lst, n):
    """
    Divide a list into n roughly equal parts.

    Parameters:
    lst (list): The list to be divided.
    n (int): The number of parts to divide the list into.

    Returns:
    list of lists: A list containing n sublists.
    """
    if n <= 0:
        raise ValueError("The number of parts must be a positive integer.")
    
    # Calculate the size of each part
    avg = len(lst) // n
    remainder = len(lst) % n

    result = []
    start = 0

    for i in range(n):
        end = start + avg + (1 if i < remainder else 0)
        result.append(lst[start:end])
        start = end

    return result

def save_list_to_file(strings_list, filename):
    """
    Save a list of strings to a file, with each string on a new line.

    Parameters:
    strings_list (list of str): The list of strings to save.
    filename (str): The name of the file to save the strings to.
    """
    try:
        with open(filename, 'w') as file:
            for string in strings_list:
                file.write(f"{string}\n")
        print(f"File '{filename}' saved successfully.")
    except Exception as e:
        print(f"An error occurred while saving the file: {e}")


target_categories = [
    "politics",
    "science",
    "history",
    "technology",
]
modes = ["dprime"]
gpus = [0, 1, 2, 0, 1, 2]

commands = []

for mode in modes: 
    for c in target_categories:
        command = f"python3 generate_map.py --dataset-filename dataset.json --target-class {c} --mode {mode}"
        commands.append(command)

commands_divided = divide_list(
    lst = commands,
    n = len(gpus)
)

commands_divided_per_gpu = []
count = 0

for commands_for_single_gpu, gpu in zip(commands_divided, gpus):
    single_worker_commands = [
            f"conda activate nesim && CUDA_VISIBLE_DEVICES={gpu} {c}"
            for c in commands_for_single_gpu
        ]
    
    save_list_to_file(
        strings_list=single_worker_commands,
        filename=f"{count}.sh"
    )

    launch_job_in_tmux(
        command = f"source {count}.sh",
        session_name=f"nesim-maps-{count}"
    )
    count += 1