#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_LAUNCH_BLOCKING=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$PYTHONPATH:grouter_ep_optimizer

CHECKPOINT_PATH="checkpoints/350m_grt_ft_score"
TOKENIZER_MODEL="model_home/qwen3-30b-a3b"
TOKENIZER_TYPE="HuggingFaceTokenizer"

C4_HOME="/workspace/Megatron-LM-router/qwen3_dataset"
DATA_PATH=""
for i in {0040..0599}; do
   DATA_PATH="${DATA_PATH} 0.01 ${C4_HOME}/qwen3-c4-${i}_text_document"
done

ROUTER_ARGS="\
    --use-grouter-weight \
    --moe-router-load-balancing-type none \
    --moe-use-grouter \
    --use-single-grouter \
    --grouter-checkpoint-path grouter_ep_optimizer/grouter/grouter_ft_32.pth \
    --grouter-config-path grouter_ep_optimizer/grouter/config_32.json \
"

GPUS_PER_NODE=8
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6003"}
NNODES=${WORLD_SIZE:-"1"}
NODE_RANK=${RANK:-"0"}

DISTRIBUTED_ARGS=(
    --nproc_per_node $GPUS_PER_NODE
    --nnodes $NNODES
    --node_rank $NODE_RANK
    --master_addr $MASTER_ADDR
    --master_port $MASTER_PORT
)

MODEL_ARGS=(
    --use-mcore-models
    --disable-bias-linear
    --seq-length 4096
    --max-position-embeddings 131072
    --num-layers 8
    --hidden-size 360
    --ffn-hidden-size 720
    --num-attention-heads 32
    --group-query-attention
    --num-query-groups 4
    --kv-channels 32
    --init-method-std 0.02
    --attention-dropout 0.0
    --hidden-dropout 0.0
    --normalization RMSNorm
    --norm-epsilon 1e-5
    --position-embedding-type rope
    #--rope-type yarn
    --rope-type rope
    --rotary-base 150000
    --rotary-percent 1.0
    --rotary-scaling-factor 32.0
    --swiglu
    --untie-embeddings-and-output-weights
    #--masked-softmax-fusion
    --vocab-size 151936
    #--bias-activation-fusion
    #--apply-rope-fusion
    --mscale 1.0
    --mscale-all-dim 1.0
)

MOE_ARGS=(
    --num-experts 32
    --moe-router-topk 4
    --moe-grouped-gemm
    --moe-permute-fusion
    --moe-token-dispatcher-type alltoall
    --moe-ffn-hidden-size 720
    #--moe-router-score-function softmax
    #--moe-router-load-balancing-type none
    --moe-layer-freq 1
    #--moe-router-bias-update-rate 0.001
    ${ROUTER_ARGS}
)

DATA_ARGS=(
    --tokenizer-model ${TOKENIZER_MODEL}
    --tokenizer-type ${TOKENIZER_TYPE}
    --data-path "$DATA_PATH"
    --split 990,8,2
)

TRAINING_ARGS=(
    --micro-batch-size 8
    --global-batch-size 256
    #--recompute-granularity full
    #--recompute-method uniform
    #--recompute-num-layers 1
    --lr 1e-4
    --train-iters 30000
    --lr-decay-iters 25000
    --lr-decay-style cosine
    --min-lr 1.0e-5
    --adam-beta1 0.9
    --adam-beta2 0.999
    --adam-eps 1e-8
    --weight-decay 0.01
    --lr-warmup-iters 5000
    --clip-grad 1.0
    --bf16
)


MODEL_PARALLEL_ARGS=(
    --tensor-model-parallel-size 1
    --pipeline-model-parallel-size 1
    --expert-model-parallel-size 4
    --use-distributed-optimizer
    --sequence-parallel
)

# Logging and checkpoint arguments
LOGGING_ARGS=(
    --log-interval 1
    --log-throughput
    --save-interval 5000
    --eval-interval 10000
    --eval-iters 10
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard"
)

# Launch training
torchrun ${DISTRIBUTED_ARGS[@]} Megatron-LM/pretrain_gpt.py \
    ${MODEL_ARGS[@]} \
    ${MOE_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${LOGGING_ARGS[@]} > logs/350m_grt_ft_score.log 2>&1
