"""
Usage
>>> uv run scripts/sbatch_crbench.py --crbench_task AdditionTaskSmall --base configs/cr/ltm-0.5b-base.yaml
>>> tail -f .slurm_cache/lltm-additiontasksmall-{job_id}.out
"""

import os 
import argparse 
import subprocess 
import importlib 
from omegaconf import OmegaConf 


def generate_sbatch (
base :str ,
env_path :str ,
work_dir :str ,
output_path :str ,
logdir :str ,
partition :str ,
job_name :str ,
resume :str |None =None ,
extra :str ="",
):
    return f"""#!/bin/bash
#SBATCH --job-name=lltm-{job_name }
#SBATCH --partition={partition }
#SBATCH --nodes=1
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=8
#SBATCH --output={output_path }/%x-%j.out


export MASTER_ADDR=$(scontrol show hostname "$SLURM_JOB_NODELIST" | head -n1)
export MASTER_PORT=$((10000 + ($SLURM_JOBID % 50000)))


export NUM_GPU_PER_NODE=8
NODE_TYPE="H100"


NUM_NODES=$SLURM_JOB_NUM_NODES
NUM_GPUS=$(($NUM_NODES * $NUM_GPU_PER_NODE))

#### Efficient multi-node training ####
UDS_PATH="/run/tcpx-$SLURM_JOB_ID"

# Only use TCPX for multi-node jobs.
[[ "$SLURM_JOB_NUM_NODES" -gt 1 ]] && export USE_TCPX=yes || export USE_TCPX=no

# Only use TCPX for multi-node jobs.
if [[ $USE_TCPX = "yes" ]]; then
    # Set up NCCL Environment variables
    export NCCL_NET=GPUDirectTCPX_v7
    export NCCL_SOCKET_IFNAME=enp0s12
    export NCCL_GPUDIRECTTCPX_CTRL_DEV=enp0s12
    export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=enp6s0,enp12s0,enp134s0,enp140s0
    export NCCL_CROSS_NIC=0
    export NCCL_ALGO=Ring
    export NCCL_PROTO=Simple
    export NCCL_NSOCKS_PERTHREAD=4
    export NCCL_SOCKET_NTHREADS=1
    export NCCL_MAX_NCHANNELS=12
    export NCCL_MIN_NCHANNELS=12
    export NCCL_DYNAMIC_CHUNK_SIZE=524288
    export NCCL_P2P_NET_CHUNKSIZE=524288
    export NCCL_P2P_PCI_CHUNKSIZE=524288
    export NCCL_P2P_NVL_CHUNKSIZE=1048576
    export NCCL_BUFFSIZE=4194304
    export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
    export NCCL_NET_GDR_LEVEL=PIX
    export NCCL_P2P_PXN_LEVEL=0
    export NCCL_GPUDIRECTTCPX_UNIX_CLIENT_PREFIX=$UDS_PATH
    export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
    export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
    export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000

    export LD_LIBRARY_PATH=/var/lib/tcpx/lib64:$LD_LIBRARY_PATH
else
    unset NCCL_NET
fi

source {env_path .rstrip ("/")}/bin/activate

cd {work_dir }
# Training
srun python main.py --base {base } --name {job_name } --logdir {logdir } {resume } {extra }
# Evaluation
CKPT={logdir }/{job_name }/checkpoints/last.ckpt
python scripts/run_crbench.py --ckpt $CKPT --num_gpus=8 --num_examples=100
"""


def get_obj_from_str (string ):
    module ,cls =string .rsplit (".",1 )
    return getattr (importlib .import_module (module ,package =None ),cls )


def main (
crbench_task :str ="AdditionTaskSmall",
name :str |None =None ,
base :str ="configs/cr/ltm-0.5b-base.yaml",
logdir :str ="crbench_logs",
cache_dir :str =".slurm_cache",
partition :str ="a3",
resume :str |None =None ,
extra :str ="",
):
    work_dir =os .path .abspath (os .path .join (os .path .dirname (__file__ ),".."))
    env_path =os .environ .get ("ENV_PATH",os .path .join (work_dir ,".venv"))
    assert os .path .exists (env_path ),f"Conda environment {env_path } does not exist"
    try :
        crbench_string =f"crbench.{crbench_task }"
        _ =get_obj_from_str (crbench_string )
    except :
        raise ImportError (f"CRBench task {crbench_task } not found")
    if name is None :
        name =crbench_task .split (".")[-1 ].lower ()

    resume =f"--resume {resume }"if resume is not None else ""

    config =OmegaConf .load (base )
    config .data .params .crbench_task_config .target =crbench_string 

    config_path =f"configs/cr/{name }-ltm-0.5b.yaml"
    OmegaConf .save (config ,config_path )

    sbatch_script =os .path .join (cache_dir ,f"{name }.sh")
    print (f"Writing sbatch script to {sbatch_script }...")
    print ("#"*100 )
    print (f"partition: {partition }")
    print (f"cache_dir: {cache_dir }")
    print (f"env_path: {env_path }")
    print (f"base: {config_path }")
    print (f"logdir: {logdir }")
    os .makedirs (cache_dir ,exist_ok =True )
    with open (sbatch_script ,"w",encoding ="utf-8")as fw :
        fw .write (
        generate_sbatch (
        base =config_path ,
        env_path =env_path ,
        work_dir =work_dir ,
        output_path =cache_dir ,
        logdir =logdir ,
        partition =partition ,
        job_name =name ,
        resume =resume ,
        extra =extra ,
        )
        )

    print ("Submitting the job...")
    submit_command =f"sbatch {sbatch_script }"
    try :
        print (submit_command )
        output =subprocess .run (
        submit_command .split (),capture_output =True ,text =True ,check =True 
        ).stdout 
        print (output )
    except subprocess .CalledProcessError as e :
        print (e .stdout )
        raise e 


if __name__ =="__main__":
    parser =argparse .ArgumentParser ()
    parser .add_argument ("--crbench_task",type =str ,required =True )
    parser .add_argument ("--name",type =str ,required =False ,default =None )
    parser .add_argument (
    "--base",
    type =str ,
    required =False ,
    default ="configs/cr/ltm-0.5b-base-h200.yaml",
    )
    parser .add_argument ("--resume",type =str ,required =False ,default =None )
    parser .add_argument ("--logdir",type =str ,required =False ,default ="crbench_logs")
    parser .add_argument ("--cache_dir",type =str ,required =False ,default =".slurm_cache")
    parser .add_argument ("--partition",type =str ,required =False ,default ="a3")
    parser .add_argument ("--extra",type =str ,required =False ,default ="")
    args =parser .parse_args ()
    main (**vars (args ))
