"""Execute the close-to-linearity condition computation locally or submit to SLURM."""

import subprocess
from itertools import product
from os import makedirs, path

from submitit import AutoExecutor
from submitit.helpers import CommandFunction

from cases import CASES, CASES_VARY_DATA

if __name__ == "__main__":
    USE_SLURM = True
    # USE_SLURM = False

    HERE = path.abspath(__file__)
    HEREDIR = path.dirname(HERE)
    LOGDIR = path.join(HEREDIR, "slurm_log")

    if USE_SLURM:
        makedirs(LOGDIR, exist_ok=True)

    commands = []

    # NOTE Running both in parallel leads to error (maximum job array size exceeded)
    # RUNS = CASES
    RUNS = CASES_VARY_DATA

    for case in RUNS:
        data_name = case["data_name"]
        model_name = case["model_name"]
        widths = [int(w) for w in case["widths"]]
        num_initializations = case["num_initializations"]
        num_perturbations = case["num_perturbations"]

        for width, model_seed, perturbation_seed in product(
            widths, range(num_initializations), range(num_perturbations)
        ):
            cmd = [
                "python",
                "experiment_linearity_condition.py",
                f"--data_name={data_name}",
                f"--model_name={model_name}",
                f"--width={width}",
                f"--model_seed={model_seed}",
                f"--perturbation_seed={perturbation_seed}",
                "--skip_exists",
            ]
            if cmd not in commands:
                commands.append(cmd)

    if USE_SLURM:
        executor = AutoExecutor(folder=LOGDIR)
        executor.update_parameters(
            slurm_gres="gpu:1",
            slurm_time="96:00:00",
            slurm_array_parallelism=12,  # maximum number of jobs from the array running at a time
            slurm_job_name="linearity_condition",
            cpus_per_task=4,
            slurm_account="deadline",
            slurm_qos="deadline",
        )
        with executor.batch():
            for cmd in commands:
                cmd_func = CommandFunction(cmd, verbose=True)
                job = executor.submit(cmd_func)
                print(f"Sumitting to SLURM: {' '.join(cmd)}")
    else:
        for cmd in commands:
            print(f"Running command: {' '.join(cmd)}")
            subprocess.run(cmd, check=True)
