#!/home/ec2-user/miniconda3/bin/python3

import argparse
import os
import stat
import subprocess
from typing import Any
from pathlib import Path

# TODO ensure git is commited before job gets scheduled
# TODO add starting interactive job. command: `srun --partition=compute-g4dn --nodes=1 --gres=gpu:1 --pty bash`

EMAIL = "francesco.montagna997@gmail.com"
LOG_PATH = "/efs/slurm_logs"
CONDA_ENV = "/home/ec2-user/causal-benchmark/venv"


class SlurmJob:
    def __init__(
        self,
        script : str,
        name : str,
        time : Any,
        gpu : bool,
        ngpus : int,
        afterok : str,
        ntasks_per_node : int,
        script_args : str,
        partition : str,
        scratch_and_run_prepend : str = "",
        set_mem : bool = False,
        num_cpus : int = 1
    ):
        """
        set_mem : bool
            If True set GB memory per CPU. If False, leave as default (4GB)
        """
        self.script = script
        self.name = f"{name}"
        self.email = EMAIL
        self.time = time
        self.gpu = gpu
        self.afterok = afterok # ??
        self.ntasks_per_node = ntasks_per_node
        self.script_args = script_args
        self.scratch_and_run_prepend = scratch_and_run_prepend # include the "--" separator
        self.set_mem = set_mem
        self.num_cpus = num_cpus
        

        if not Path(LOG_PATH).exists():
            os.mkdir(LOG_PATH)

        # if gpu == "T4":
        #     self.partition = "compute-g4dn"
        #     self.avail_num_cpus = 96
        #     self.avail_ngpus = 8

        # elif gpu == "V100":
        #     if ngpus == 1 and gpu_mem == 16:
        #         self.partition = "prototyping-p3"
        #         self.avail_num_cpus = 8
        #         self.avail_ngpus = 1

        #     elif ngpus is None and gpu_mem == 16:
        #         self.partition = "compute-p3"
        #         self.avail_num_cpus = 64
        #         self.avail_ngpus = 8

        #     elif ngpus is None and gpu_mem == 32:
        #         self.partition = "compute-p3pd"
        #         self.avail_num_cpus = 96
        #         self.avail_ngpus = 8

        #     else:
        #         raise ValueError("Invalid combination of gpu, ngpus and gpu_mem")

        # else:
        #     raise ValueError("Argument `gpu` must be `T4` or `V100`.")

        # self.ngpus = ngpus if ngpus is not None else self.avail_ngpus

        if partition=="gpu": # T4
            self.partition = "gpu"
            self.avail_num_cpus = 16
            self.avail_ngpus = 0
            if gpu:
                self.avail_ngpus = 1

        else:
            self.partition = "cpu"
            self.avail_num_cpus = 36
            self.avail_ngpus = 0
        
        self.ngpus = ngpus if ngpus is not None else self.avail_ngpus


    def __call__(self):
        slurm_job_bash_file = f"./{self.name}.sh"
        slurm_job_bash_file_content = (
            "#!/bin/bash \n"
            + self.resource_config_string
            + "\n"
            + self.dependency_afterok_config_string # Dependencies on seeds: gather results over all seeds to get metrics
            + "\n"
            + self.job_info_cmd
            + "\n"
            + self.set_free_port_env_cmd
            + "\n"
            + "sleep $(shuf -i 1-60 -n 1)"
            + "\n"
            + self.run_cmd
        )

        with open(slurm_job_bash_file, "w") as f:
            f.write(slurm_job_bash_file_content)
        os.chmod(slurm_job_bash_file, stat.S_IRWXU)

        try:
            print_msg = subprocess.check_output("sbatch " + slurm_job_bash_file, shell=True)
            print(print_msg)
            job_id = int(str(print_msg).split(" ")[-1][:-3])
        finally:
            os.remove(slurm_job_bash_file)
            
        return job_id
    
    @property
    def mem_per_cpu(self):
        if self.set_mem:
            return "#SBATCH --mem-per-cpu=32GB"
        return "##########"

    @property
    def resource_config_string(self):
        # num_cpus = self.avail_num_cpus*50 // self.tot_number_jobs # 50: number of machines. Assign at least one cpu per task. 
        # num_cpus = max(1, (num_cpus // 2)*2) # Assign 1 or multiples of 2 to avoid errors.
        # num_cpus = 1 # Use self.num_cpus
        cpus_per_gpus = 0 if self.avail_ngpus == 0 else self.avail_num_cpus // self.avail_ngpus
        return f"""
#SBATCH --job-name={self.name}                   # Name of the job
##########
### if using pytorch lightning multi gpus, use this
### TODO this limits the number of cpus available when using default pytorch model training
# #SBATCH --ntasks-per-node={self.ntasks_per_node}                         # #SBATCH --ntasks=1                          # Number of tasks  self.avail_ngpus // self.ngpus
# #SBATCH --cpus-per-gpu={cpus_per_gpus}          # Number of CPU cores per task
### if using cpus only, use this  # TODO add --cpu argument to get cpu only machines
#SBATCH --ntasks-per-node=1                         # #SBATCH --ntasks=1                          # Number of tasks, i.e. how many instances of your command are executed
#SBATCH --cpus-per-task={self.num_cpus}          # Number of CPU cores per task
{self.mem_per_cpu}   # memory per cpu-core
##########
#SBATCH --nodes=1                           # Ensure that all cores are on one machine
#SBATCH --time={self.time}                       # Runtime in D-HH:MM
#SBATCH --partition={self.partition}        # Partition to submit to
#SBATCH --gres=gpu:{self.ngpus}               # Number of requested GPUs
#SBATCH --output={LOG_PATH}/std/%j-{self.name}.log              # File to which STDOUT will be written
#SBATCH --error={LOG_PATH}/err/%j-{self.name}.log               # File to which STDERR will be written
#SBATCH --open-mode=append                          # Do not overwrite output files in case job gets rescheduled
#SBATCH --mail-type=fail                     # Type of email notification- BEGIN,END,FAIL,ALL
#SBATCH --mail-user={EMAIL}                 # Email to which notifications will be sent
        """

    @property
    def dependency_afterok_config_string(self):
        if self.afterok is not None:
            afterok_str = ""
            for job_id in self.afterok.split(','):
                afterok_str += f'afterok:{job_id},'
            afterok_str = afterok_str[:-1]  # remove final comma
            return f"""
#SBATCH --dependency={afterok_str}
            """
        else:
            return ""

    @property
    def job_info_cmd(self):
        return """
scontrol show job $SLURM_JOB_ID  # print some info
date
hostname
pwd
        """

    @property
    def cp2scratch_cmd(self):
        if not self.datadir:
            return ""
        cmd = f"cp -r --dereference {self.datadir} /scratch \n"
        return f"""
echo '{cmd}'\n
{cmd}
        """

    @property
    def set_free_port_env_cmd(self):
        return """
echo "Searching random free port and setting $port variable for use by job to be started"
read -r lower_port upper_port < /proc/sys/net/ipv4/ip_local_port_range
while true; do
    port=$(shuf -i "$lower_port"-"$upper_port" -n 1)
    ss -tulpn | grep :"$port" > /dev/null || break
done
echo "$port"
        """

    @property
    def run_cmd(self):
        cmd = f"{self.scratch_and_run_prepend} conda run --no-capture-output -p {CONDA_ENV} python {self.script} {self.script_args}"
        return f"""
echo '{cmd}'\n
srun {cmd}
        """



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Running jobs on SLURM cluster")
    parser.add_argument("python_script", type=str)
    parser.add_argument(
        "--name",
        dest="name",
        action="store",
        default="noname",
        type=str,
        help="",
    )
    parser.add_argument(
        "--time",
        dest="time",
        action="store",
        default="7-00:00",
        type=str,
        help="",
    )
    parser.add_argument(
        "--gpu",
        dest="gpu",
        action="store",
        default="T4",
        type=str,
        help="",
    )
    parser.add_argument(
        "--ngpus",
        dest="ngpus",
        action="store",
        default=None,
        type=int,
        help="",
    )
    # parser.add_argument(
    #     "--cp2scratch",
    #     dest="datadir",
    #     action="store",
    #     default=None,
    #     type=str,
    #     help="dir that will be copied to /scratch",
    # )
    parser.add_argument(
        "--njobs",
        dest="njobs",
        action="store",
        default=1,
        type=int,
        help="",
    )
    parser.add_argument(
        "--afterok",
        dest="afterok",
        action="store",
        default=None,
        type=str,
        help="Start job after the job specified by --afterok=jobid1,jobid2,... is finished",
    )
    parser.add_argument(
        "--gpumem",
        dest="gpu_mem",
        action="store",
        default=16,
        type=int,
        help="GPU memory in GB",
    )
    parser.add_argument(
        "--ntasks_per_node",
        action='store',
        default=1,
        type=int,
        help="For pytorch lightning, set to number of GPUs to use; if pytorch script uses spawn, set to 1."
    )
    args = parser.parse_args()

    print(args.gpu_mem)

    for job_idx in range(args.njobs):
        job = SlurmJob(args.python_script, args.name, args.time, args.gpu, args.ngpus, args.datadir, job_idx, args.afterok, args.gpu_mem, args.ntasks_per_node)
        job()
