#!/bin/bash -x
#SBATCH --nodes=2
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=1
#SBATCH --job-name=anGPT
#SBATCH --threads-per-core=1
#SBATCH --time=1:00:00
#SBATCH --output=slurm/out.%j
#SBATCH --error=slurm/err.%j


CODE_DIR=".../SM_anGPT"

export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}
export CUDA_VISIBLE_DEVICES="0,1,2,3"

MASTER_ADDR=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1`
export MASTER_PORT=20073
export NUM_NODES=$SLURM_JOB_NUM_NODES
export GPUS_PER_NODE=4
export NUM_GPUS_PER_NODE=4
export NUM_GPUS=$((NUM_GPUS_PER_NODE*SLURM_NNODES))

ACCELERATE_CONFIG_FILE="${CODE_DIR}/scripts/accelerate.yaml"
cat << EOT > $ACCELERATE_CONFIG_FILE
# WARNING: do not edit this file as this is an slurm-auto-generated file
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
main_process_ip: $MASTER_ADDR
main_process_port: $MASTER_PORT
main_training_function: main
num_machines: $SLURM_NNODES
num_processes: $NUM_GPUS
use_cpu: false
EOT

source ${CODE_DIR}/venv/bin/activate
PYTHON_SCRIPT=${CODE_DIR}/train_model.py

for (( i=0; i<$NUM_NODES; i++ )); do
    cmd="srun -lN1 -n1 -r $i accelerate launch \
        --config_file \$ACCELERATE_CONFIG_FILE \
        --rdzv_conf \"rdzv_backend=c10d,rdzv_endpoint=\$MASTER_ADDR:\$MASTER_PORT\" \
        --main_process_ip \$MASTER_ADDR \
        --main_process_port \$MASTER_PORT \
        --machine_rank $i \
         \$PYTHON_SCRIPT -c default_config.yaml \$@"

    if [[ $i -lt $(($NUM_NODES - 1)) ]]; then
        eval $cmd &
    else
        eval $cmd
    fi
done

