import os
from argparse import ArgumentParser, Namespace
from pathlib import Path

from dotenv import load_dotenv

# Load .env file into environment variables
load_dotenv()

# Constants
SLURM_ACCOUNT = os.getenv('SLURM_ACCOUNT')
USER_ID = os.getenv('USER_ID')
PARTITION = os.getenv('PARTITION')
PYTHON_VERSION = 'python/3.11.7'
CUDA_VERSION = 'cuda/12.2'

def build_parser() -> ArgumentParser:
    '''Build the argument parser.'''
    parser = ArgumentParser(description='SLURM Job Submission')
    
    ####################
    # Script arguments #
    ####################

    parser.add_argument(
        '--sh-dir', type=str, default='./scripts/auto_generated', 
        help='Directory to save the generated script.'
    )

    parser.add_argument(
        '--sh-name', type=str, default='slurm_train.sh', 
        help='File name for the script.'
    )

    parser.add_argument(
        '--py-path', type=str, required=True, 
        help='File name for the script.'
    )

    ###################
    # SLURM arguments #
    ###################

    parser.add_argument(
        '--job-name', type=str, default='Injectivity', 
        help='Job name for the SLURM cluster.'
    )

    parser.add_argument(
        '--output', type=str, default='default', 
        help='Output file name for SLURM to redirect stdout/stderr.'
    )

    parser.add_argument(
        '--nodes', type=int, default=1, 
        help='Number of nodes for SLURM job.'
    )

    parser.add_argument(
        '--gpus-per-node', type=int, default=1, 
        help='GPUs per node for the SLURM job.'
    )

    parser.add_argument(
        '--cpus-per-task', type=int, default=2, 
        help='CPUs per task for the SLURM job.'
    )

    parser.add_argument(
        '--time', type=str, default='24:00:00', 
        help='Maximum wall time for the SLURM job.'
    )

    parser.add_argument(
        '--partition', type=str, default=PARTITION, 
        help='Partition to use for the SLURM job.'
    )

    parser.add_argument(
        '--qos', type=str, default='normal', 
        help='Quality of service to use for the SLURM job.'
    )

    parser.add_argument(
        '--env-path', type=str, default='../../envs/.sipit_py3_11_7', 
        help='Path to Python virtual environment to be loaded.'
    )

    return parser

def construct_sbatch_command(
    args: Namespace, 
    unknown_args: list[str], 
    py_to_run: str
):
    '''Construct the SLURM batch job command.'''
    os.makedirs(os.path.abspath(f'out'), exist_ok=True)
    out_file = os.path.abspath(f'out/{args.output}.out')

    # Dynamic SLURM script generation
    boilerplate = f'''\
#!/bin/bash -l

#SBATCH -A {SLURM_ACCOUNT}
#SBATCH --job-name={args.job_name}
#SBATCH --output={out_file}
#SBATCH --time={args.time}
#SBATCH --partition={args.partition}
#SBATCH --nodes={args.nodes}
#SBATCH --gres=gpu:{args.gpus_per_node}
#SBATCH --ntasks-per-node={args.gpus_per_node}
#SBATCH --cpus-per-task={args.cpus_per_task}
#SBATCH --qos={args.qos}

module load {PYTHON_VERSION}
module load {CUDA_VERSION}

nvidia-smi

export HF_HOME=$SCRATCH/hf_cache
export HF_HUB_OFFLINE=1
export TRANSFORMERS_OFFLINE=1

unset TRANSFORMERS_CACHE
unset HF_ENDPOINT
unset HUGGINGFACE_CO_RESOLVE_ENDPOINT

source {os.path.abspath(args.env_path)}/bin/activate

'''
    command = (
        f'srun --unbuffered python3.11 -u {py_to_run} ' +\
        ' '.join(unknown_args)
    )

    return boilerplate + f'\n{command}\n', out_file

def write_and_submit_sbatch_script(sbatch_script: str, script_path: Path, out_file: str):
    '''Write the SBATCH script, set permissions, and submit.'''
    with script_path.open('w') as script_file:
        script_file.write(sbatch_script)

    os.chmod(script_path, 0o755)

    # Clear the output file and submit the SLURM job
    os.system(f'> {out_file}')
    os.system(f'clear ; sbatch {script_path}')
    os.system(f'sleep 1 ; squeue -u {USER_ID}')
    os.system(f'tail -f {out_file}')

def main():
    parser = build_parser()
    args, unknown_args = parser.parse_known_args()
    
    sbatch_script, out_file = construct_sbatch_command(args, unknown_args, args.py_path)
    
    # Define the path to the SLURM script file
    script_dir  = Path(args.sh_dir)
    script_name = script_dir / args.sh_name

    script_dir.mkdir(parents=True, exist_ok=True)
    
    write_and_submit_sbatch_script(sbatch_script, script_name, out_file)

if __name__ == '__main__':
    main()
