import os
from nesim.utils.chunking import divide_list
from nesim.utils.tmux import launch_job_in_tmux

def save_list_to_file(strings_list, filename):
    """
    Save a list of strings to a file, with each string on a new line.

    Parameters:w
    strings_list (list of str): The list of strings to save.
    filename (str): The name of the file to save the strings to.
    """
    try:
        with open(filename, 'w') as file:
            for string in strings_list:
                file.write(f"{string}\n")
        print(f"File '{filename}' saved successfully.")
    except Exception as e:
        print(f"An error occurred while saving the file: {e}")

possible_model_names = [
    ## all_topo
    # "end_topo_scale_1.0_shrink_factor_3.0",
    # "end_topo_scale_5.0_shrink_factor_3.0",
    # "end_topo_scale_10.0_shrink_factor_3.0",
    # "end_topo_scale_50.0_shrink_factor_3.0",
    # ## end_topo
    "all_topo_scale_0.5_shrink_factor_3.0",
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_5_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
    "all_topo_scale_20.0_shrink_factor_3.0",
    "all_topo_scale_50.0_shrink_factor_3.0",
    ## baseline
    "baseline_scale_None_shrink_factor_3.0",
]

"""
Apply sparsity to both all and end layers for all topo models
Apply sparsity to only end layers for end topo models
"""
layers_to_sparsify = [
    "end", "all"
]

commands = []
for name in possible_model_names:
    if name.startswith("all"):
        commands.append(
            f"python3 eval.py --model-name {name} --layers all"
        )
    elif name.startswith("end"):
        commands.append(
            f"python3 eval.py --model-name {name} --layers end"
        )
    else:
        assert name.startswith("baseline")
        for layers in layers_to_sparsify:
            commands.append(
                f"python3 eval.py --model-name {name} --layers {layers}"
            )

gpus = [0, 1]
conda_env = "ffcv"
print(f"total: {len(commands)} commands to be run over {len(gpus)} workers")
commands_divided = divide_list(
    lst = commands,
    n = len(gpus)
)

count = 0

for commands_for_single_gpu, gpu in zip(commands_divided, gpus):
    if len(commands_for_single_gpu) > 0:
        single_worker_commands = [
                f"conda activate {conda_env} && CUDA_VISIBLE_DEVICES={gpu} {c}"
                for c in commands_for_single_gpu
            ]
        
        save_list_to_file(
            strings_list=single_worker_commands,
            filename=f"{count}.sh"
        )

        # launch_job_in_tmux(
        #     command = f"source {count}.sh",
        #     session_name=f"eff-{count}"
        # )
    count += 1