import os

slurm_prefix = """#!/bin/sh
#SBATCH --ntasks-per-node=32
#SBATCH --job-name=cifar10-resnet18
#SBATCH -t 24:00:0
#SBATCH --mem=32G
#SBATCH --gres=gpu:a100:1
source /mindhive/nklab3/users/XXXX-1/conda_stuff/bin/activate
conda activate nesim
"""

slurm_shell_files_folder = "./slurm"

nesim_configs = [
    "nesim_configs/apply_laplacian_pyramid_loss.json",
    'nesim_configs/baseline.json',
]

from_pretrained_values = [
    False,
]

apply_every_n_steps_values = [
    1,
    10,
    30,
    50,
    70,
    90
]

bimt_scale_values = [
    None,
]

baseline_apply_every_n_steps = 10

all_commands = []
for nesim_config in nesim_configs:
    for from_pretrained in from_pretrained_values:
        for bimt_scale in bimt_scale_values:
            pretrained_arg = (
                "--pretrained" if from_pretrained is True else "--no-pretrained"
            )

            bimt_scale_arg = (
                f"--bimt-scale {bimt_scale}" if bimt_scale is not None else ""
            )

            ## if not baseline then try different apply n step values
            if nesim_config != "nesim_configs/baseline.json":
                for apply_every_n_steps in apply_every_n_steps_values:
                    command = f"python3 train.py --nesim-config {nesim_config} --nesim-apply-after-n-steps {apply_every_n_steps} {bimt_scale_arg} {pretrained_arg} --wandb-log"
                    all_commands.append(command)
            else:
                ## apply_every_n_steps is hardcoded intentionally
                command = f"python3 train.py --nesim-config {nesim_config} --nesim-apply-after-n-steps {baseline_apply_every_n_steps} {bimt_scale_arg} {pretrained_arg} --wandb-log"
                all_commands.append(command)

print(f"Will run {len(all_commands)} commands letsgo")

slurm_shell_files = []
os.system(f"rm {slurm_shell_files_folder}/*.sh")
for command_index, command in enumerate(all_commands):
    shell_commands = slurm_prefix + command
    filename = os.path.join(slurm_shell_files_folder, f"{command_index}.sh")
    with open(filename, "w") as file:
        file.write(shell_commands)
    slurm_shell_files.append(filename)
