#!/usr/bin/env python
import argparse
import datetime
import os
import shlex
import uuid
from pathlib import Path
import shlex
import colorama
import submitit

parser = argparse.ArgumentParser()
parser.add_argument("--gpu_type", type=str, help="GPU type", default="any")
parser.add_argument("--n_gpus", type=int, help="Number of GPUs", default=1)
parser.add_argument("--partition", type=str, help="Partition")
parser.add_argument("--exclusive", action=argparse.BooleanOptionalAction)
parser.add_argument(
    "--exclude",
    type=str,
    help="Exclude nodes",
    default="",
)
parser.add_argument("--mem", type=str, help="Memory (include the G)", default="32")
parser.add_argument("--use_srun", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--script", type=str, help="Script to run")

parser.add_argument("--time", type=int, help="Time in hours", default=5)
parser.add_argument("--n_cpus", type=int, help="Number of CPUs", default=8)
parser.add_argument("--cuda-launch-blocking", action=argparse.BooleanOptionalAction)
parser.add_argument("--batch", type=str, help="list of comamnds")
parser.add_argument("args", nargs=argparse.REMAINDER, help="All command-line arguments")
args = parser.parse_args()


# commands must be provided as a list!
jobname = args.script if args.script is not None else 'batch'
print(f"Job name: {jobname}")
base_dir = Path(__file__).parent.resolve()
today = datetime.datetime.now(None)  # noqa: DTZ005
folder = (
    base_dir
    / "submitit"
    / today.strftime("%Y-%m-%d")
    / today.strftime("%H-%M-%S")
) 
print(f"Folder: {folder}")
executor = submitit.SlurmExecutor(folder=folder)
sap = {}
if args.partition is not None:
    sap["partition"] = args.partition
gres = "gpu"
if args.gpu_type is not None and args.gpu_type != "any":
    gres += f":{args.gpu_type}"
if args.n_gpus is not None and args.n_gpus > 0:
    gres += f":{args.n_gpus}"
if args.exclude is not None:
    sap["exclude"] = args.exclude
executor.update_parameters(
    additional_parameters=sap,
    cpus_per_task=args.n_cpus,
    exclusive=args.exclusive,
    gres=gres,
    job_name=f"{jobname}",
    mem=args.mem,
    time=args.time * 60,
    use_srun=args.use_srun,
    comment="script",
)

cmd_env = os.environ.copy()
cmd_env["TORCH_SHOW_CPP_STACKTRACES"] = "1"
cmd_env["TORCH_CPP_LOG_LEVEL"] = "INFO"
cmd_env["OMP_NUM_THREADS"] = "1"
cmd_env["TQDM_DISABLE"] = "1"
# cmd_env["PYTHONPROFILEIMPORTTIME"] = "1"
if args.cuda_launch_blocking:
    cmd_env["CUDA_LAUNCH_BLOCKING"] = "1"
    print(colorama.Back.GREEN + "launch blocking" + colorama.Style.RESET_ALL)
jobs: list[submitit.Job] = []
with executor.batch():
    coms = []
    if args.batch is None:
        script_path = Path(base_dir / args.script)
        assert script_path.exists(), f"Script {script_path} not found"

        # args.args is a list of all remaining arguments, but they are combined into one string
        # We need to further split them using shlex.split to handle spaces and quotes correctly
        parsed_args = []
        for arg in args.args:
            if arg == "--":
                continue
            parsed_args.extend(shlex.split(arg))
        coms.append(
            [
                "uv run",
                str(script_path),
                *parsed_args,
            ]
        )

    else:
        assert len(args.args) == 0
        assert args.script is None
        with open(args.batch) as fd:
            coms.extend(
                map(
                    shlex.split,
                    filter(lambda i: len(i) > 0, map(str.strip, fd.readlines())),
                )
            )

    for cmd in coms:
        function = submitit.helpers.CommandFunction(cmd, env=cmd_env)
        jobs.append(executor.submit(function))

print(jobs[0].job_id)
