"""Execute the Gram matrix computation locally or submit to SLURM."""

import subprocess
from itertools import product
from os import makedirs, path

from submitit import AutoExecutor
from submitit.helpers import CommandFunction

from cases import CASES, CASES_VARY_DATA

if __name__ == "__main__":
    USE_SLURM = True
    # USE_SLURM = False

    HERE = path.abspath(__file__)
    HEREDIR = path.dirname(HERE)
    LOGDIR = path.join(HEREDIR, "slurm_log")

    if USE_SLURM:
        makedirs(LOGDIR, exist_ok=True)

    commands = []

    for case in CASES + CASES_VARY_DATA:
        data_name = case["data_name"]
        model_name = case["model_name"]
        widths = [int(w) for w in case["widths"]]
        num_initializations = case["num_initializations"]
        epsilon = case.get("epsilon", 0.0)
        tol = case.get("tol", 0.0)

        for width, model_seed in product(widths, range(num_initializations)):
            cmd = [
                "python",
                "experiment_gram_condition.py",
                f"--data_name={data_name}",
                f"--model_name={model_name}",
                f"--width={width}",
                f"--model_seed={model_seed}",
                f"--epsilon={epsilon}",
                f"--tol={tol}",
                "--skip_exists",
            ]
            if cmd not in commands:
                commands.append(cmd)

    if USE_SLURM:
        executor = AutoExecutor(folder=LOGDIR)
        executor.update_parameters(
            slurm_gres="gpu:1",
            slurm_time="96:00:00",
            slurm_array_parallelism=12,  # maximum number of jobs from the array running at a time
            slurm_job_name="gram_condition",
            cpus_per_task=4,
            slurm_account="deadline",
            slurm_qos="deadline",
        )
        with executor.batch():
            for cmd in commands:
                cmd_func = CommandFunction(cmd, verbose=True)
                job = executor.submit(cmd_func)
                print(f"Sumitting to SLURM: {' '.join(cmd)}")
    else:
        for cmd in commands:
            print(f"Running command: {' '.join(cmd)}")
            subprocess.run(cmd, check=True)
