#!/bin/bash -l

cd $SLURM_SUBMIT_DIR

ENV_DIR=/env
MLM_DIR=/ML-Megatron

DISTRIBUTED_ARGS="
    --nnodes $SLURM_NNODES \
    --nproc_per_node $SLURM_GPUS_ON_NODE \
    --rdzv_endpoint $1:$2 
    --rdzv_id $SLURM_JOB_ID 
    --rdzv-backend c10d
"

source ./conf.sh

cmd="torchrun $DISTRIBUTED_ARGS $MLM_DIR/pretrain_gpt.py \
              $TRAINING_ARGS \
              $DATA_ARGS \
              $NETWORK_ARGS \
              $MOE_ARGS \
              $MODEL_PARALLEL_ARGS \
              $PERFORMANCE_ARGS \
              $MIXED_PRECISION_ARGS \
              $LEARNING_RATE_ARGS \
              $INITIALIZATION_ARGS \
              $REGULARIZATION_ARGS \
              $CHECKPOINTING_ARGS \
              $LOGGING_ARGS
    "

source $ENV_DIR/bin/activate

export PYTHONPATH=$MLM_DIR
export CUDA_DEVICE_MAX_CONNECTIONS=1

echo "Running torchrun"
echo $cmd &> logs/pretrain_cmd_$SLURM_NODEID.log
eval $cmd &> logs/pretrain_$SLURM_NODEID.log

