#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_LAUNCH_BLOCKING=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$PYTHONPATH:grouter_ep_optimizer

#CHECKPOINT_PATH="model_home/qwen3-235b-converted"
CHECKPOINT_PATH="checkpoints/3b_grt"
TOKENIZER_MODEL="model_home/qwen3-30b-a3b"
TOKENIZER_TYPE="HuggingFaceTokenizer"
C4_HOME="/workspace/Megatron-LM-router/qwen3_dataset"
DATA_PATH=""
for i in {0040..199}; do 
    DATA_PATH="${DATA_PATH} 0.01 ${C4_HOME}/qwen3-c4-${i}_text_document"
done


ROUTER_ARGS="\
    --use-grouter-weight \
    --moe-router-load-balancing-type none \
    --moe-use-grouter \
    --use-single-grouter \
    --grouter-checkpoint-path grouter_ep_optimizer/grouter/grouter_ft_128.pth \
    --grouter-config-path grouter_ep_optimizer/grouter/config_128.json \
"

# distrubuted training setting
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6002"}
NNODES=${WORLD_SIZE:-"1"}
NODE_RANK=${RANK:-"0"}


# torchrun parameter
DISTRIBUTED_ARGS=(
    --nproc_per_node $GPUS_PER_NODE
    --nnodes $NNODES
    --node_rank $NODE_RANK
    --master_addr $MASTER_ADDR
    --master_port $MASTER_PORT
)


MODEL_ARGS=(
    --use-mcore-models
    --disable-bias-linear
    --seq-length 4096
    --max-position-embeddings 40960
    --num-layers 16
    --hidden-size 1024
    --ffn-hidden-size 3072
    --num-attention-heads 16
    --group-query-attention
    --num-query-groups 4
    --kv-channels 128
    --init-method-std 0.02
    --attention-dropout 0.0
    --hidden-dropout 0.0
    --normalization RMSNorm
    --norm-epsilon 1e-6
    --position-embedding-type rope
    --rotary-base 1000000
    --rotary-percent 1.0
    --swiglu
    --untie-embeddings-and-output-weights
    --no-masked-softmax-fusion
    --use-flash-attn
    --vocab-size 151936
    --qk-layernorm
    --vocab-size 151936
)

MOE_ARGS=(
    --num-experts 128
    --moe-router-topk 8
    --moe-grouped-gemm
    --moe-permute-fusion
    --overlap-param-gather
    --overlap-grad-reduce
    --moe-token-dispatcher-type alltoall
    --moe-ffn-hidden-size 384
    --moe-router-topk-scaling-factor 1.0
    --moe-layer-freq 1
    #--moe-router-load-balancing-type aux_loss
    #--moe-aux-loss-coeff 0.001
    #--moe-router-score-function softmax
    #--moe-router-enable-expert-bias
    #--moe-router-bias-update-rate 1e-3
    ${ROUTER_ARGS}
)

DATA_ARGS=(
    --tokenizer-model ${TOKENIZER_MODEL}
    --tokenizer-type ${TOKENIZER_TYPE}
    --data-path "$DATA_PATH"
    --split 990,8,2
)

TRAINING_ARGS=(
    --micro-batch-size 8
    --global-batch-size 256
    --recompute-granularity full
    --recompute-method uniform
    --recompute-num-layers 1
    --lr 1e-4
    --train-iters 50000
    --lr-decay-iters 45000
    --lr-decay-style cosine
    --min-lr 1.0e-5
    --lr-warmup-iters 5000
    --clip-grad 1.0
    --bf16
)

MODEL_PARALLEL_ARGS=(
    --tensor-model-parallel-size 1
    --pipeline-model-parallel-size 1
    --expert-model-parallel-size 8
    --use-distributed-optimizer
    --sequence-parallel 
)

LOGGING_ARGS=(
    --log-interval 1
    --log-throughput
    --save-interval 10
    --eval-interval 1000
    --eval-iters 10
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard"
    --no-load-optim
    --no-load-rng
)

torchrun ${DISTRIBUTED_ARGS[@]} Megatron-LM/pretrain_gpt.py \
    ${MODEL_ARGS[@]} \
    ${MOE_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${LOGGING_ARGS[@]} > logs/3b_grt.log 
