from slurmpy import Slurm

from scripts.project_location import ADDITIONAL_MOUNTS, BASE_PATH_PROJECT, BASE_PROJECT_IMAGE


def run_job(
    job_name,
    job_cmd,
    num_cpus=4,
    partition="gpu-2d",
    apptainer=True,
    container_path=BASE_PROJECT_IMAGE,
    slurm_args=None,
    log_dir="./logs",
    num_jobs_in_array=1,
    mem=32,
):
    if apptainer:
        mounts = " -B ".join([BASE_PATH_PROJECT, *ADDITIONAL_MOUNTS])
        submit_cmd = f'apptainer run --nv --writable-tmpfs {mounts} {container_path} /bin/bash -c "cd $PATH_TO_REPO && {job_cmd}"'
    else:
        submit_cmd = job_cmd

    if not isinstance(mem, int):
        raise TypeError("The variable mem needs to be a (positive) integer.")

    if not isinstance(mem, int):
        raise ValueError("Argument mem must be an int")

    slurm_options = {
        "partition": partition,
        "cpus-per-task": num_cpus,
        "nodes": 1,
        "chdir": "./",
        "output": f"{log_dir}/run_%A/%a.out",
        "error": f"{log_dir}/run_%A/%a.err",
        "array": f"0-{num_jobs_in_array - 1}" if num_jobs_in_array > 1 else "0",
        "mem": f"{mem}G",
    }

    if slurm_args is not None:
        for key, val in slurm_args.items():
            slurm_options[key] = val

    device, time = partition.split("-")

    if device == "gpu":
        slurm_options["gres"] = "gpu:1"
        if partition != "gpu-test":
            slurm_options["constraint"] = "'[80gb|40gb|h100]'"
        slurm_options["exclude"] = "head075,head074,head073"

    time_mapping = {
        "test": "00-00:15:00",
        "9m": "00-00:09:00",
        "2h": "00-02:00:00",
        "5h": "00-05:00:00",
        "2d": "02-00:00:00",
        "7d": "07-00:00:00",
    }
    slurm_options["time"] = time_mapping[time]
    s = Slurm(job_name, slurm_options)
    s.run(submit_cmd)
