#!/bin/bash -l

ENV_DIR=/env
MLM_DIR=/ML-Megatron

DISTRIBUTED_ARGS="
    --nnodes $SLURM_NNODES \
    --nproc_per_node $SLURM_GPUS_ON_NODE \
    --rdzv_endpoint $1:$2 
    --rdzv_id $SLURM_JOB_ID 
    --rdzv-backend c10d
"

source ../conf.sh

cmd="torchrun $DISTRIBUTED_ARGS $MLM_DIR/tasks/main.py \
            $DATA_ARGS \
            $NETWORK_ARGS \
            $MOE_ARGS \
            $MODEL_PARALLEL_ARGS \
            $PERFORMANCE_ARGS \
            $MIXED_PRECISION_ARGS \
            $LEARNING_RATE_ARGS \
            $INITIALIZATION_ARGS \
            $REGULARIZATION_ARGS \
            --log-interval 50 \
            --task ALL-QA \
            --valid-data \
                piqa_validation.jsonl \
                hellaswag_validation.jsonl \
                arc_e_validation.jsonl \
            --micro-batch-size 40 \
            --global-batch-size 40 \
            --use-mcore-models \
            --load ../checkpoint \
            --no-load-optim \
            --no-load-rng \
    "

source $ENV_DIR/bin/activate

export PYTHONPATH=$MLM_DIR
export CUDA_DEVICE_MAX_CONNECTIONS=1

echo $cmd &> logs/evaluation_cmd_$SLURM_NODEID.log
eval $cmd &> logs/evaluation_$SLURM_NODEID.log
