#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_LAUNCH_BLOCKING=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$PYTHONPATH:grouter_ep_optimizer

#CHECKPOINT_PATH="model_home/deepseek-v2-lite-converted"
CHECKPOINT_PATH="checkpoint/1b_dsv2_grt"
TOKENIZER_MODEL="model_home/qwen3-30b-a3b"
TOKENIZER_TYPE="HuggingFaceTokenizer"
C4_HOME="/workspace/Megatron-LM-router/qwen3_dataset"
DATA_PATH=""
for i in {0040..0599}; do 
    DATA_PATH="${DATA_PATH} 0.01 ${C4_HOME}/qwen3-c4-${i}_text_document"
done

ROUTER_ARGS="\
    --use-grouter-weight \
    --moe-router-load-balancing-type none \
    --moe-use-grouter \
    --use-single-grouter \
    --grouter-checkpoint-path grouter_ep_optimizer/grouter/grouter_ft_64.pth \
    --grouter-config-path grouter_ep_optimizer/grouter/config_64.json \
"

# distributed training setting
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6002"}
NNODES=${WORLD_SIZE:-"1"}
NODE_RANK=${RANK:-"0"}


# torchrun parameter
DISTRIBUTED_ARGS=(
    --nproc_per_node $GPUS_PER_NODE
    --nnodes $NNODES
    --node_rank $NODE_RANK
    --master_addr $MASTER_ADDR
    --master_port $MASTER_PORT
)

MODEL_ARGS=(
    --use-mcore-models
    --disable-bias-linear
    --seq-length 4096
    --max-position-embeddings 4096
    --num-layers 12
    --hidden-size 512
    --ffn-hidden-size 2736
    --num-attention-heads 8
    --multi-latent-attention
    --kv-lora-rank 256
    --qk-head-dim 64
    --qk-pos-emb-head-dim 32
    --v-head-dim 64
    --qk-layernorm
    --init-method-std 0.02
    --attention-dropout 0.0
    --hidden-dropout 0.0
    --normalization RMSNorm
    --norm-epsilon 1e-6
    --position-embedding-type rope
    --rope-type yarn
    --rotary-base 10000
    --rotary-scaling-factor 40
    --mscale 0.707
    --mscale-all-dim 0.707
    --swiglu
    --untie-embeddings-and-output-weights
    --no-masked-softmax-fusion
    --use-flash-attn
    --no-rope-fusion
    --vocab-size 151936
)

MOE_ARGS=(
    --num-experts 64
    --moe-router-topk 6
    --moe-grouped-gemm
    --moe-permute-fusion
    --overlap-param-gather
    --overlap-grad-reduce
    --moe-token-dispatcher-type alltoall
    --moe-ffn-hidden-size 704
    --moe-shared-expert-intermediate-size 1408
    --moe-router-topk-scaling-factor 1.0
    #--moe-layer-freq "([0]+[1]*26)"
    --moe-layer-freq "([0]+[1]*11)"
    ${ROUTER_ARGS}
    #--moe-router-load-balancing-type seq_aux_loss
    #--moe-aux-loss-coeff 0.001
    #--moe-router-score-function softmax
    #--moe-router-enable-expert-bias
    #--moe-router-bias-update-rate 1e-3
)

DATA_ARGS=(
    --tokenizer-model ${TOKENIZER_MODEL}
    --tokenizer-type ${TOKENIZER_TYPE}
    --data-path "$DATA_PATH"
    --split 990,8,2
)

TRAINING_ARGS=(
    --micro-batch-size 8
    --global-batch-size 256
    --recompute-granularity full
    --recompute-method uniform
    --recompute-num-layers 1
    --lr 1e-4
    --train-iters 40000
    --lr-decay-iters 25000
    --lr-decay-style cosine
    --min-lr 1.0e-5
    --lr-warmup-iters 5000
    --clip-grad 1.0
    --bf16
)

MODEL_PARALLEL_ARGS=(
    --tensor-model-parallel-size 1
    --pipeline-model-parallel-size 1
    --expert-model-parallel-size 8
    --use-distributed-optimizer
    --sequence-parallel 
)

LOGGING_ARGS=(
    --log-interval 1
    --log-throughput
    --save-interval 5000
    --eval-interval 1000
    --eval-iters 10
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard"
    --no-load-optim
    --no-load-rng
)

torchrun ${DISTRIBUTED_ARGS[@]} Megatron-LM/pretrain_gpt.py \
    ${MODEL_ARGS[@]} \
    ${MOE_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${DISTILLATION_ARGS[@]} \
    ${LOGGING_ARGS[@]}  > logs/1b_dsv2_grt.log 2>&1

