from collections import OrderedDict

import sys

import os

if __name__ == "__main__":
    save_path = "../Data/Supervised_Clustering/"

    with open(save_path + "remaining.txt", "r") as f:
        remaining_ids = [int(line.strip()) for line in f]

    MaxArraySize = 10000

    # for remaining ids greater than MaxArraySize, split id into quotient and remainder from division by MaxArraySize
    remaining_dict = OrderedDict()
    for id in remaining_ids:
        q, r = divmod(id, MaxArraySize)
        if q not in remaining_dict:
            remaining_dict[q] = []
        remaining_dict[q].append(r)

    # process remaining ids so that consecutive ids are represented as e.g. 1-10
    for q, r in remaining_dict.items():
        remaining_dict[q] = []
        r.sort()
        start = r[0]
        end = r[0]
        for i in range(1, len(r)):
            if r[i] == end + 1:
                end = r[i]
            else:
                if start == end:
                    remaining_dict[q].append(str(start))
                else:
                    remaining_dict[q].append(str(start) + "-" + str(end))
                start = r[i]
                end = r[i]
        if start == end:
            remaining_dict[q].append(str(start))
        else:
            remaining_dict[q].append(str(start) + "-" + str(end))

    for q, r in remaining_dict.items():
        with open("run.sh", "w") as f:
            f.write("#!/bin/bash\n")
            f.write("#SBATCH -J run\n")
            f.write("#SBATCH -p standard\n")
            f.write("#SBATCH -c 16\n")
            f.write("#SBATCH --mem=128G\n")
            f.write("#SBATCH --gres=gpu:1\n")
            f.write("#SBATCH --tmp=100G\n")
            f.write("#SBATCH --array=" + ",".join(map(str, r)) + "\n")

            f.write(
                "export APPTAINERENV_CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}\n"
            )

            f.write(
                "srun apptainer run --nv docker://nvcr.io/nvidia/pyg:24.09-py3 python3 run.py --slurm_id=$((${SLURM_ARRAY_TASK_ID} + "
                + str(q * MaxArraySize)
                + "))\n"
            )

        with open("run.sh", "r") as f:
            print("\n", f.read())

        os.system("sbatch run.sh")

    os.remove("run.sh")
