#!/bin/bash

# for rerun the task
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python

set -x

timestamp=$(date +%Y%m%d_%H%M%S)

NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
    HAS_NVLINK=1
else
    HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"


source "./slime/scripts/models/qwen3-30B-A3B.sh"

SAVE_PATH=YOUR_SAVE_PATH_${timestamp}

WANDB_NAME=${timestamp}

CKPT_ARGS=(
   --hf-checkpoint YOUR_MODEL_PATH
   --ref-load YOUR_REF_LOAD_PATH
   --load ${SAVE_PATH}
   --save ${SAVE_PATH}
   --save-interval 2000
)

SFT_ARGS=(
   --prompt-data ./ReForm/sft_data/sft_data.jsonl
   --input-key messages
   --rollout-shuffle
   --num-epoch 3
   --rollout-batch-size 512
   --global-batch-size 512

   --loss-type sft_loss
   --loss-mask-type qwen3
   --calculate-per-token-loss
   --disable-compute-advantages-and-returns
   --debug-train-only
)

PERF_ARGS=(
   --tensor-model-parallel-size 8
   --sequence-parallel
   --pipeline-model-parallel-size 1
   --context-parallel-size 1
   --expert-model-parallel-size 1
   --expert-tensor-parallel-size 1

   --recompute-granularity full
   --recompute-method uniform
   --recompute-num-layers 1

   # --micro-batch-size 1
   --use-dynamic-batch-size
   --max-tokens-per-gpu 9216
)

OPTIMIZER_ARGS=(
   --optimizer adam
   --lr 1e-5
   --lr-decay-style cosine
   --min-lr 1e-6
   --lr-warmup-fraction 0.03
   --weight-decay 0.1
   --adam-beta1 0.9
   --adam-beta2 0.95
)

WANDB_ARGS=(
   --use-wandb
   --wandb-project YOUR_PROJECT_NAME
   --wandb-group ${WANDB_NAME}
   --wandb-key ${WANDB_API_KEY}
)

MISC_ARGS=(
   # default dropout in megatron is 0.1
   --attention-dropout 0.0
   --hidden-dropout 0.0
   # should be good for model performance
   --accumulate-allreduce-grads-in-fp32
   --attention-softmax-in-fp32
   # need to comment this when using model with MLA
   --attention-backend flash
)

if [ $RANK -eq 0 ]; then
   mkdir -p ${SAVE_PATH}
    # launch the master node of ray in container
    export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
    export no_proxy="127.0.0.1,${MASTER_ADDR}"
    ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NPROC_PER_NODE} --port=6379 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265


    # Build the runtime environment JSON with proper variable substitution
    RUNTIME_ENV_JSON="{
    \"env_vars\": {
        \"PYTHONPATH\": \"/root/Megatron-LM/\",
        \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
        \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
         \"no_proxy\": \"${no_proxy}\",
         \"MASTER_ADDR\": \"${MASTER_ADDR}\"
    }
    }"

    ray job submit --address="http://127.0.0.1:8265" \
    --runtime-env-json="${RUNTIME_ENV_JSON}" \
    -- python3 ./slime/train_async.py \
    --actor-num-nodes ${WORLD_SIZE} \
    --actor-num-gpus-per-node ${NPROC_PER_NODE} \
    ${MODEL_ARGS[@]} \
    ${CKPT_ARGS[@]} \
    ${SFT_ARGS[@]} \
    ${OPTIMIZER_ARGS[@]} \
    ${DISTRIBUTED_ARGS[@]} \
    ${WANDB_ARGS[@]} \
    ${PERF_ARGS[@]} \
    ${EVAL_ARGS[@]} \
    ${MISC_ARGS[@]} 2>&1 | tee ${SAVE_PATH}/log_${timestamp}.log

else
   sleep 60

   pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --block --address=${MASTER_ADDR}:6379 --num-gpus ${NPROC_PER_NODE} --disable-usage-stats
fi
