"""
For interactive session,
srun --partition a3 --nodes 1 --gpus-per-node=8 --ntasks-per-node=8 --pty bash
srun --partition=a3 --nodes=1 --gres=gpu:4 --ntasks-per-node=4 --cpus-per-task=8 --pty bash
"""

import os 
import datetime 
import subprocess 
import argparse 


def get_gpu_string (num_nodes :int =None ,num_gpus :int =None ):
    if num_nodes is not None :
        assert num_gpus is None ,"You cannot specify both num_nodes and num_gpus"
        return f"""#SBATCH --nodes {num_nodes }
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=8
"""
    elif num_gpus is not None :
        return f"""#SBATCH --gpus {num_gpus }
"""
    else :
        raise ValueError ("You must specify either num_nodes or num_gpus")


def get_device_string (num_nodes :int ,num_gpus :int ):
    if num_nodes is not None :
        return f"--devices 0,1,2,3,4,5,6,7 --num_nodes {num_nodes }"
    elif num_gpus is not None :
        return "--devices "+",".join (str (i )for i in range (num_gpus ))
    else :
        raise ValueError ("You must specify either num_nodes or num_gpus")


def generate_sbatch (
base :str ,
env_path :str ,
work_dir :str ,
output_path :str ,
logdir :str ,
partition :str ,
num_nodes :int ,
num_gpus :int ,
job_name :str ,
projectname :str ,
resume :str ,
debug :str ,
extra :str ="",
):
    return f"""#!/bin/bash
#SBATCH --job-name={job_name }
###SBATCH --time=0:30:00
#SBATCH --partition={partition }
{get_gpu_string (num_nodes =num_nodes ,num_gpus =num_gpus )}
#SBATCH --output={output_path }/%x-%j.out
###SBATCH --error=outputs/%x-%j.out

# module load
# module load cuda/12.1
# module load cudnn/8.9.7
# module load nccl/2.20.5
# module load hpcx/2.19

export MASTER_ADDR=$(scontrol show hostname "$SLURM_JOB_NODELIST" | head -n1)
export MASTER_PORT=$((10000 + ($SLURM_JOBID % 50000)))


export NUM_GPU_PER_NODE=8
NODE_TYPE="H100"


NUM_NODES=$SLURM_JOB_NUM_NODES
NUM_GPUS=$(($NUM_NODES * $NUM_GPU_PER_NODE))

#### Efficient multi-node training ####
UDS_PATH="/run/tcpx-$SLURM_JOB_ID"

# Only use TCPX for multi-node jobs.
[[ "$SLURM_JOB_NUM_NODES" -gt 1 ]] && export USE_TCPX=yes || export USE_TCPX=no

# Only use TCPX for multi-node jobs.
if [[ $USE_TCPX = "yes" ]]; then
    # Set up NCCL Environment variables
    export NCCL_NET=GPUDirectTCPX_v7
    export NCCL_SOCKET_IFNAME=enp0s12
    export NCCL_GPUDIRECTTCPX_CTRL_DEV=enp0s12
    export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=enp6s0,enp12s0,enp134s0,enp140s0
    export NCCL_CROSS_NIC=0
    export NCCL_ALGO=Ring
    export NCCL_PROTO=Simple
    export NCCL_NSOCKS_PERTHREAD=4
    export NCCL_SOCKET_NTHREADS=1
    export NCCL_MAX_NCHANNELS=12
    export NCCL_MIN_NCHANNELS=12
    export NCCL_DYNAMIC_CHUNK_SIZE=524288
    export NCCL_P2P_NET_CHUNKSIZE=524288
    export NCCL_P2P_PCI_CHUNKSIZE=524288
    export NCCL_P2P_NVL_CHUNKSIZE=1048576
    export NCCL_BUFFSIZE=4194304
    export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
    export NCCL_NET_GDR_LEVEL=PIX
    export NCCL_P2P_PXN_LEVEL=0
    export NCCL_GPUDIRECTTCPX_UNIX_CLIENT_PREFIX=$UDS_PATH
    export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
    export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
    export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000

    export LD_LIBRARY_PATH=/var/lib/tcpx/lib64:$LD_LIBRARY_PATH
else
    unset NCCL_NET
fi

# debugging flags (optional)
# export NCCL_DEBUG=INFO
# export PYTHONFAULTHANDLER=1
# 30 mins to 1 hour
export DEEPSPEED_TIMEOUT=60

export HF_HUB_ENABLE_HF_TRANSFER=1
source {env_path .rstrip ("/")}/bin/activate
# source  ~/miniconda3/etc/profile.d/conda.sh
# conda activate {env_path .rstrip ("/")}

cd {work_dir }
srun python main.py --base {base } {get_device_string (num_nodes ,num_gpus )} --projectname {projectname } --logdir {logdir } {resume } {debug } {extra }
"""


def main (
base :str ,
env_path :str =None ,
projectname :str ="lltm",
logdir :str ="logs",
cache_dir :str =".slurm_cache",
partition :str ="a3",
num_nodes :int =1 ,
num_gpus :int =None ,
resume :str =None ,
debug :bool =False ,
job_name :str =None ,
extra :str ="",
):
    if num_gpus is not None :

        num_nodes =None 
    env_path =os .environ .get ("ENV_PATH",env_path )
    assert env_path ,"You must set env_path!"
    work_dir =os .path .abspath (os .path .join (os .path .dirname (__file__ ),".."))
    cache_dir =os .path .abspath (os .path .realpath (cache_dir ))
    if not os .path .exists (cache_dir ):
        os .makedirs (cache_dir )
    assert os .path .isdir (cache_dir )

    resume =f"--resume {resume }"if resume is not None else ""
    debug ="--debug"if debug else ""

    cfg_path =os .path .split (base )[0 ].split (os .sep )[
    os .path .split (base )[0 ].split (os .sep ).index ("configs")+1 :
    ]
    cfg_name =os .path .splitext (os .path .split (base )[-1 ])[0 ]
    cfg_name ="-".join (cfg_path )+f"-{cfg_name }"
    now =datetime .datetime .now ().strftime ("%Y-%m-%dT%H-%M-%S")

    sbatch_script =os .path .join (cache_dir ,f"{now }-{cfg_name }.sh")
    print (f"Writing sbatch script to {sbatch_script }...")
    print ("#"*100 )
    print (f"partition: {partition }")
    print (f"num_nodes: {num_nodes }")

    print (f"cache_dir: {cache_dir }")
    print (f"env_path: {env_path }")
    print (f"base: {base }")
    print (f"logdir: {logdir }")
    if len (resume )>0 :
        print (f"resume: {resume }")
    if len (debug )>0 :
        print (f"debug: {debug }")
    with open (sbatch_script ,"w",encoding ="utf-8")as fw :
        fw .write (
        generate_sbatch (
        base =base ,
        env_path =env_path ,
        work_dir =work_dir ,
        output_path =cache_dir ,
        logdir =logdir ,
        partition =partition ,
        num_nodes =num_nodes ,
        num_gpus =num_gpus ,
        job_name =cfg_name if job_name is None else job_name ,
        projectname =projectname ,
        resume =resume ,
        debug =debug ,
        extra =extra ,
        )
        )

    print ("Submitting the job...")
    submit_command =f"sbatch {sbatch_script }"
    try :
        print (submit_command )
        output =subprocess .run (
        submit_command .split (),capture_output =True ,text =True ,check =True 
        ).stdout 
        print (output )
    except subprocess .CalledProcessError as e :
        print (e .stdout )
        raise e 


if __name__ =="__main__":
    parser =argparse .ArgumentParser ()
    parser .add_argument (
    "--base",type =str ,required =True ,help ="Path to the config file"
    )
    parser .add_argument ("--env_path",type =str ,help ="Path to the conda environment")
    parser .add_argument ("--projectname",type =str ,default ="lltm",help ="Project name")
    parser .add_argument ("--logdir",type =str ,default ="logs",help ="Log directory")
    parser .add_argument (
    "--cache_dir",type =str ,default =".slurm_cache",help ="Cache directory"
    )
    parser .add_argument ("--partition",type =str ,default ="a3",help ="Partition")
    parser .add_argument ("--num_nodes",type =int ,default =1 ,help ="Number of nodes")
    parser .add_argument ("--num_gpus",type =int ,default =None ,help ="Number of GPUs")
    parser .add_argument (
    "--resume",type =str ,default =None ,help ="Resume from the checkpoint"
    )
    parser .add_argument ("--debug",action ="store_true",help ="Debug mode")
    parser .add_argument ("--job_name",type =str ,default =None ,help ="Job name")
    parser .add_argument ("--extra",type =str ,default ="",help ="Extra arguments")
    args =parser .parse_args ()
    main (**vars (args ))
