"""This is the entry point of the code when running using submitit"""

import os
import sys
from pathlib import Path
import glob

import submitit
from codes.grad_student.checkpointable_grad_student import CheckpointableGradStudent
from codes.utils.argument_parser import argument_parser
from codes.utils.util import timing


class SubmititRunner(submitit.helpers.Checkpointable):
    def __init__(self):
        """This is the main entry point of the code when using the submitit tool"""
        print("create")
        self.grad_student = None

    def __call__(self, config_id: str) -> None:

        self.grad_student = CheckpointableGradStudent(config_id=config_id)
        self.grad_student.logbook.write_message_logs("call")
        self.grad_student.run()

    def checkpoint(self, config_id: str) -> submitit.helpers.DelayedSubmission:
        self.grad_student.logbook.write_message_logs("checkpoint")
        self.grad_student.experiment.save_model()
        submitit_runner = SubmititRunner()
        return submitit.helpers.DelayedSubmission(submitit_runner, config_id)


config_id = argument_parser()
# config_id = "multitask/signature_learn_1_hyp_*"

# We are using Slurm Executor as it allows us to set the gres flag
executor = submitit.SlurmExecutor(folder="~/logs/lgw/%j")

if config_id[-1] == "*":
    # launch parallel jobs
    executor.update_parameters(
        partition="dev",
        comment="ICML",
        job_name=config_id,
        time=4000,
        num_gpus=1,
        cpus_per_task=2,
        signal_delay_s=60,
        array_parallelism=8,
    )
    submitit_runner = SubmititRunner()
    # get all config files starting with the prefix
    base_path = os.path.dirname(os.path.realpath(__file__)).split("/codes")[0]
    config_path = os.path.join(base_path, "config")
    files = glob.glob(os.path.join(config_path, config_id))
    config_ids = ["/".join(fl.split("/")[-2:]).split(".yaml")[0] for fl in files]
    jobs = executor.map_array(submitit_runner, config_ids)
    with open(os.path.expanduser("~/logs/lgw/all_jobs.csv"), "a") as fp:
        for ji, job in enumerate(jobs):
            fp.write("{},{}\n".format(config_ids[ji], job.job_id))


else:
    # launch single job
    executor.update_parameters(
        partition="dev",
        job_name=config_id,
        time=360,
        num_gpus=1,
        cpus_per_task=2,
        signal_delay_s=60,
    )
    submitit_runner = SubmititRunner()
    job = executor.submit(submitit_runner, config_id)
    print(job.job_id)  # ID of your job
    with open(os.path.expanduser("~/logs/lgw/all_jobs.csv"), "a") as fp:
        fp.write("{},{}\n".format(config_id, job.job_id))
