import os
from nesim.utils.folder import get_filenames_in_a_folder
import argparse

parser = argparse.ArgumentParser(description="Example script with --slurm argument")

# Add the --slurm argument as a boolean flag with a default value of False
parser.add_argument(
    "--slurm",
    action="store_true",
    help="Set this flag to generate slurm commands and put them into the ./slurm directory",
)
args = parser.parse_args()
slurm_shell_files_folder = "./slurm"

slurm_prefix = """#!/bin/sh
#SBATCH -n 32
#SBATCH --ntasks-per-node=32
#SBATCH --job-name=pyramid-loss-param-sweep
#SBATCH -t 48:00:0
#SBATCH --mem=32G
#SBATCH --gres=gpu:a100:1
source /mindhive/nklab3/users/XXXX-1/conda_stuff/bin/activate
conda activate nesim
"""

nesim_configs = get_filenames_in_a_folder(folder="./nesim_configs")

apply_every_n_steps_values = [1, 10, 30, 50]

num_epochs = 4

"""
1. we are training from scratch
2. not aplying bimt
3. not appplying sorted weights init
4. no cross layer correlation loss
"""
from_pretrained_values = [
    False,
]

bimt_scale_values = [
    None,
]

sorted_weights_init_values = [
    None,
]

cross_layer_correlation_loss_configs = [None]

num_epochs_arg = f"--num-epochs {num_epochs}"

all_commands = []

for sorted_weights_init_file in sorted_weights_init_values:
    for cross_layer_correlation_loss_config in cross_layer_correlation_loss_configs:
        for nesim_config in nesim_configs:
            for from_pretrained in from_pretrained_values:
                for bimt_scale in bimt_scale_values:

                    pretrained_arg = (
                        "--pretrained" if from_pretrained is True else "--no-pretrained"
                    )
                    bimt_scale_arg = (
                        f"--bimt-scale {bimt_scale}" if bimt_scale is not None else ""
                    )
                    cross_layer_correlation_arg = (
                        f"--cross-layer-correlation-loss-config {cross_layer_correlation_loss_config}"
                        if cross_layer_correlation_loss_config is not None
                        else ""
                    )
                    sorted_weights_init_arg = (
                        f"--apply-sorted-weights-init-filename {sorted_weights_init_file}"
                        if sorted_weights_init_file is not None
                        else ""
                    )
                    ## if not baseline then try different apply n step values
                    if "baseline" not in nesim_config:
                        for apply_every_n_steps in apply_every_n_steps_values:
                            command = f"python3 train.py --nesim-config {nesim_config} --nesim-apply-after-n-steps {apply_every_n_steps} {bimt_scale_arg} {pretrained_arg} {cross_layer_correlation_arg} {sorted_weights_init_arg} {num_epochs_arg} --wandb-log"
                            all_commands.append(command)
                    else:
                        ## apply_every_n_steps is hardcoded intentionally
                        command = f"python3 train.py --nesim-config {nesim_config} --nesim-apply-after-n-steps {30} {bimt_scale_arg} {pretrained_arg} {cross_layer_correlation_arg} {sorted_weights_init_arg} {num_epochs_arg} --wandb-log"
                        all_commands.append(command)


slurm_shell_files = []
os.system(f"rm {slurm_shell_files_folder}/*.sh")
for command_index, command in enumerate(all_commands):
    shell_commands = slurm_prefix + command
    filename = os.path.join(slurm_shell_files_folder, f"{command_index}.sh")
    with open(filename, "w") as file:
        file.write(shell_commands)
    slurm_shell_files.append(filename)

print(f"Will run {len(all_commands)} commands via slurm letsgo")
