#!/bin/bash
pkill -9 python3
set -ex
source /usr/local/Ascend/ascend-toolkit/set_env.sh

export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$ASCEND_OPP_PATH/vendors/customize/op_api/lib/:$LD_LIBRARY_PATH
export HCCL_CONNECT_TIMEOUT=1200
export COMBINED_ENABLE=1
export HCCL_DETERMINISTIC=True   
export USE_DET_REDUCE_SCATTER_PIPELINE_MT=true
export HCCL_RDMA_TC=100
export HCCL_RDMA_SL=3
export HCCL_ASYNC_ERROR_HANDLING=1
export HCCL_IF_BASE_PORT=65100
export HCCL_ALGO="allgather=level0:NA;level1:pipeline/allreduce=level0:NA;level1:NHR/reducescatter=level0:NA;level1:NHR"
export HCCL_OP_EXPANSION_MODE="AIV"
export HCCL_EXEC_TIMEOUT=6000
export CUDA_DEVICE_MAX_CONNECTIONS=1
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export DETERMINISTIC_MODE=True
export TORCH_DIST_INIT_BARRIER=0
export HCCL_INTRA_PCIE_ENABLE=0 
export HCCL_INTRA_ROCE_ENABLE=1
export A2A_EP_HCCL_BUFF_SIZE=1000
export HCCL_BUFFSIZE=1000

pip3 install tracking_utils==3.6.8 -i http://pypi.example.com/simple/ --trusted-host pypi.example.com --force-reinstall

DIR=`pwd`
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
cd $script_dir

# Data and tokenizer paths
DATA_PATH="/path/to/training/data/data.index"
VOCAB_PATH="/path/to/tokenizer/tokenizer.json"
TOKENIZER_TYPE=BloomTokenizer
TRAIN_TOKENS=30_000_000_000

# Learning rate configuration
LR_WARMUP_TOKENS=0
INIT_LR=1e-4
LR_DECAY_TOKENS=1_000_000
LR_DECAY_STYLE="wsd"
MIN_LR=1.0e-5

TRAIN_TOKENS_CLEANED=${TRAIN_TOKENS//_/}
LR_DECAY_TOKENS_CLEANED=${LR_DECAY_TOKENS//_/}

DECAY_TOKENS=$(echo "$TRAIN_TOKENS_CLEANED - $LR_DECAY_TOKENS_CLEANED" | bc)
DECAY_FRAC_FOR_WSD=$(bc -l <<< "scale=10; $DECAY_TOKENS / $LR_DECAY_TOKENS_CLEANED")

echo "Decay Status param: TRAIN_TOKENS=${TRAIN_TOKENS_CLEANED}; LR_DECAY_TOKENS=${LR_DECAY_TOKENS_CLEANED};"
echo "DECAY_TOKENS=${DECAY_TOKENS}; DECAY_FRAC_FOR_WSD=${DECAY_FRAC_FOR_WSD}"

# Batch and parallelism configuration
BATCH_RAMPUP_TOKENS=100_000_000_000
RAMPUP_START=320
RAMPUP_INCR=320
SEQ=8192
GLOBAL_BATCH=1600
MICRO_BATCH=1
TP=4
PP=9

# Model architecture configuration
DTYPE="bf16"
HF_TYPE="llama"
LAYERS=34
HIDDEN=3072
INTERMEDIATE=8192
NHEADS=24
MQA_KV_HEADS=8
ROPE_THETA=1000000.0
NUM_EXPERTS="1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64 1 64"
CAPACITY_FACTOR=64
USE_LOAD_BALANCE_LOSS=true 
LOSS_TYPE="both" 
MOE_LOSS_COEFF=0.001
ONLY_Z_LOSS_COEFF=0.1
ROUTER_TYPE="fast_token_choose"
TOP_K=3

LOG_EXPERTS_TOKENS=true
LOG_EXPERTS_TOKENS_INTERVAL=10

# MuP (Maximal Update Parametrization) configuration
USE_MUP=true
USE_MUP_ADD=1
BASE_HIDDEN=768
CUSTOM_INIT_STD=0.02

# Checkpoint paths
CHECKPOINT_PATH="/path/to/checkpoint/save"
LOAD_CHECKPOINT_PATH="/path/to/checkpoint/load"
SAVE_INTERVAL=500  
CURRENT_TIME=$(date "+%Y%m%d%H%M%S")

EXP_DIR="/path/to/experiment/logs"

# Generate log directory name based on configuration
if [ "$NUM_EXPERTS" != "1" ] && [ "$NUM_EXPERTS" != "0" ]; then
LOG_DIR="${EXP_DIR}/${MODEL_NAME}_tp${TP}_pp${PP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_LR_${INIT_LR}_${MIN_LR}_${DTYPE}_exp_iter_cf${CAPACITY_FACTOR}_moetype${LOSS_TYPE}_moeloss${MOE_LOSS_COEFF}_onlyzlosscoeff${ONLY_Z_LOSS_COEFF}_top${TOP_K}_router_${ROUTER_TYPE}_tokens_${TRAIN_TOKENS}"
else
LOG_DIR="${EXP_DIR}/${MODEL_NAME}_tp${TP}_pp${PP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_LR_${INIT_LR}_${MIN_LR}_${DTYPE}_exp${NUM_EXPERTS}"
fi
mkdir -p $LOG_DIR

# Debug mode configuration
DEBUG_MODE=${1:-'run'}
if [[ $DEBUG_MODE == 'debug' ]]; then
  GPUS=${2:-4}
        LAYERS=32
        HIDDEN=4096
        INTERMEDIATE=11008
        SEQ=8192
        NHEADS=32
        EXIT_INTERVAL=10
        LOG_INTERVAL=1
        TB_LOG_INTERVAL=1 
        TIMING_LOG_LEVEL=0
        TP=4
        PP=1
        DP=$((GPUS/TP/PP))
else
  GPUS=${2:-96} 
        HIDDEN=$HIDDEN
        INTERMEDIATE=$INTERMEDIATE
        LAYERS=$LAYERS
        SEQ=$SEQ
        NHEADS=$NHEADS
        EXIT_INTERVAL=10
        LOG_INTERVAL=1
        TB_LOG_INTERVAL=1
        TIMING_LOG_LEVEL=1
        TP=$TP
        PP=$PP
        DP=$((GPUS/TP/PP))
fi

WORLD_SIZE=$((TP*PP*DP))
SEED=48

# Convert token counts to sample counts
TRAIN_TOKENS=${TRAIN_TOKENS//_/}
LR_DECAY_TOKENS=${LR_DECAY_TOKENS//_/}
LR_WARMUP_TOKENS=${LR_WARMUP_TOKENS//_/}
BATCH_RAMPUP_TOKENS=${BATCH_RAMPUP_TOKENS//_/}
TRAIN_SAMPLES=$(($TRAIN_TOKENS / $SEQ))
LR_DECAY_SAMPLES=$(($LR_DECAY_TOKENS / $SEQ))
LR_WARMUP_SAMPLES=$(($LR_WARMUP_TOKENS / $SEQ))
BATCH_RAMPUP_SAMPLES=$(($BATCH_RAMPUP_TOKENS / $SEQ))

# Configure MuP options
if [ $USE_MUP == 'true' ]; then
  INIT_STD=0.02
  ATTN_MULT=4.0
  EMB_MULT=10.0

  if [ "$USE_MUP_ADD" == "1" ]; then
    mup_options=" \
      --mup-lr-scale-cond \
    "
  fi

  if [ "$USE_MUP_ADD" == "2" ]; then
    mup_options=" \
      --attn-mult $ATTN_MULT \
      --emb-mult $EMB_MULT \
    "
    INIT_STD=$CUSTOM_INIT_STD
    BASE_HIDDEN=$HIDDEN
  fi

  if [ "$USE_MUP_ADD" == "3" ]; then
    mup_options=" \
      --o-zero-init \
      --q-zero-init \
    "
    INIT_STD=$CUSTOM_INIT_STD
    BASE_HIDDEN=$HIDDEN
  fi

else
  BASE_HIDDEN=$HIDDEN
  INIT_STD=$CUSTOM_INIT_STD
  mup_options=""
fi

# Model configuration options
model_options=" \
    --num-layers $LAYERS \
    --ffn-hidden-size $INTERMEDIATE \
    --hidden-size $HIDDEN \
    --base-size $BASE_HIDDEN \
    --rope-theta $ROPE_THETA \
    --num-attention-heads $NHEADS \
    --seq-length $SEQ \
    --max-position-embeddings $SEQ \
    --disable-bias-linear \
    --tokenizer-type $TOKENIZER_TYPE \
    --untie-embeddings-and-output-weights \
    --output-logits-chunk-sum-cnt 32 \
    --make-vocab-size-divisible-by 1 \
    --attention-dropout 0.0 \
    --hidden-dropout 0.0 \
    --hf-type $HF_TYPE \
    --swiglu
"

# Multi-query attention configuration
if [ $MQA_KV_HEADS -lt $NHEADS ]; then
    model_options+="\
    --multi-query-attention \
    --multi-query-group-num $MQA_KV_HEADS \
    "
fi

# Mixture of Experts configuration
moe_options=" \
    --num-experts $NUM_EXPERTS \
    --capacity-factor $CAPACITY_FACTOR \
    --moe-topk $TOP_K \
    --expert-interval 1 \
    --router $ROUTER_TYPE \
"
if [ $USE_LOAD_BALANCE_LOSS == 'true' ]; then
    moe_options+=" \
        --use-load-balance-loss \
        --load-balance-loss-type $LOSS_TYPE \
        --moe-loss-coeff $MOE_LOSS_COEFF \
        --only-z-loss-coeff $ONLY_Z_LOSS_COEFF \
    "
else
    echo "MoE does not use load_balance_loss!"
fi

if [ $LOG_EXPERTS_TOKENS == 'true' ]; then
    moe_options+=" \
        --log-experts-passed-token-to-tensorboard \
        --log-experts-passed-token-interval $LOG_EXPERTS_TOKENS_INTERVAL \
    "
else
    echo "MoE does not log the number of tokens passed to every expert!"
fi

# Optimizer configuration
optimizer_options=" \
    --optimizer adam \
    --adam-beta1 0.9 \
    --adam-beta2 0.95 \
    --adam-eps 1e-8 \
    --lr $INIT_LR \
    --min-lr $MIN_LR \
    --lr-decay-style $LR_DECAY_STYLE \
    --lr-decay-samples $LR_DECAY_SAMPLES \
    --lr-warmup-samples $LR_WARMUP_SAMPLES \
    --clip-grad 1.0 \
    --weight-decay 1e-1 \
    --decay-frac-for-wsd $DECAY_FRAC_FOR_WSD
"

# Training configuration
train_options=" \
    --override-opt_param-scheduler \
    --use-distributed-optimizer \
    --tensor-model-parallel-size $TP \
    --pipeline-model-parallel-size $PP \
    --max-position-embeddings $SEQ \
    --dataloader-type lamp \
    --micro-batch-size $MICRO_BATCH \
    --split 100 \
    --train-samples $TRAIN_SAMPLES \
    --save-interval $SAVE_INTERVAL \
    --data-path ${DATA_PATH} \
    --data-impl mmap \
    --vocab-file ${VOCAB_PATH} \
    --init-method-std $INIT_STD \
    --${DTYPE} \
    --attention-softmax-in-fp32 \
    --seed ${SEED} \
    --recompute-activations \
    --recompute-granularity selective \
    --save ${CHECKPOINT_PATH} \
    --load ${LOAD_CHECKPOINT_PATH} \
    --dataset-type pretrain_processed \
    --distributed-timeout-minutes 30 \
    --load-inequal-bucket-checkpoint \
    --bucket-size 500000000 \
    --eod-mask-loss \
    --use-flash-attn \
    --legacy-version 0.30 \
    --varlen-attention \
    --use-grouped-gemm \
    --grouped-gemm-init-same-with-separate-expserts \
    --sequence-parallel \
    --f1b2 \
    --num-workers 2 \
    --prefetch-factor 32 \
    --async-checkpointing \
    --deterministic-mode \
    --distributed-checkpointing \
    --no-load-rng
"

train_options+="\
    --global-batch-size $GLOBAL_BATCH \
"

# Configure training options based on expert configuration
if [ "$NUM_EXPERTS" != "1" ] && [ "$NUM_EXPERTS" != "0" ]; then
    train_options+="--use-gigaflops-pre-token "
else
    train_options+=" \
      --num-layers-per-virtual-pipeline-stage 2 \
      --overlap-p2p-communication \
    "
fi

# Check if layers can be evenly distributed across pipeline stages
if [ $((($LAYERS+2) % $PP)) -eq 0 ]; then
    train_options+=" \
        --balance-embedding-stage \
    "
else
    echo "num_layers $LAYERS cannot be divided by pp $PP, balance embedding stage is disabled"
fi

# Logging configuration
log_options="\
    --tensorboard-dir $LOG_DIR \
    --tensorboard-queue-size 5 \
    --log-timers-to-tensorboard \
    --log-batch-size-to-tensorboard \
    --log-memory-to-tensorboard \
    --log-interval $LOG_INTERVAL \
    --tensorboard-log-interval $TB_LOG_INTERVAL \
    --simultaneous-writing-native-tensorboard \
    --timing-log-level $TIMING_LOG_LEVEL \
    --timing-log-option all \
    --tracking_utils.board True \
    --param_report 50 \
"

# Validation configuration
EVAL_DATA_PATH='/path/to/validation/data'

valid_options="\
   --eval-data-impl mmap \
   --eval-all-data \
   --eval-dataloader-type single\
   --eval-dataset-type chat \
   --valid-data-path ${EVAL_DATA_PATH} \
   --eval-interval 500 \
"

# Save configuration
save_options="\
  --distributed-checkpointing \
  --async-checkpointing \
  --save-inner-interval 100 \
"

# Combine all options
options="
  ${mup_options} \
  ${model_options} \
  ${optimizer_options} \
  ${train_options} \
  ${log_options} \
  ${valid_options} \
  ${save_options} \
"

# Add MoE options if using experts
if [ "$NUM_EXPERTS" != "1" ] && [ "$NUM_EXPERTS" != "0" ]; then
    options+=" ${moe_options}"
else
    :
fi

# Execute training based on debug mode
if [[ $DEBUG_MODE == 'debug' ]]; then
  run_cmd="python3 -m torch.distributed.launch ${extra_torch_launch_options} --nproc_per_node=$WORLD_SIZE ${DIR}/pretrain_gpt.py ${options}"
  echo ${run_cmd}
  eval ${run_cmd}
else
  run_cmd="pretrain_gpt.py ${options}"
  launcher=$(python3 submit/distributed_launch.py)
  echo $launcher ${extra_torch_launch_options} ${run_cmd}
  ${launcher} ${extra_torch_launch_options} ${run_cmd}
fi