
model_size=0.11
num_layers=24   
hidden_size=576
num_attn_heads=9
init_std=0.02
num_kv_nums=3

###############################################################################
### lr configs
lr_warmup_iters=1530
lr_decay_iters=153000
lr_decay_style="cosine"
###############################################################################
### Parallelism configs
## Model parallelism, 1 is no MP
mp_size=1

pp_size=1
no_pp="true"

## ZeRO stage
zero_stage=1

## Total number of GPUs. ds_ssh is from DeepSpeed library.
num_gpus=1
num_gpus_pernode=1
num_node=1
# ## Data parallel size.
dp_size=$(( ${num_gpus} / ${pp_size} / ${mp_size} ))

## Micro batch size per GPU
## Make sure that batch_size <= global_batch_size*pp_size*mp_size/num_gpus
## Below batch_size calculation assumes the case without gradient accumulation.
## Manually set it to a lower value if you hit out of memory during training.
# batch_size=$(( ${global_batch_size} / ${dp_size} ))
###############################################################################
### Misc configs
log_interval=10
eval_iters=1
eval_interval=200
# num_save controls how frequent to save checkpoint. num_save=20 means that a
# checkpoint will be saved every 5% of training. For longer training you would
# want larger num_save to save more frequently, and vice versa.
num_save=153
save_interval=$((${train_iters} / ${num_save}))

## Activation checkpointing saves GPU memory, but reduces training speed
# activation_checkpoint="true"
activation_checkpoint="false"

## Whether or not log optimizer states (norms, max abs values) to tensorboard.
## This is not required for training and might save GPU memory when turned off.
log_optimizer_state="false"
###############################################################################
### Output and data configs
current_time=$(date "+%Y.%m.%d-%H.%M.%S")

## Public the Pile dataset, see prepare_pile_data.py in the same directory
## about how to download and preprocess the data.
jobname="gebert-pile"
## For internal use. Change data_home to your own training data path.
data_home="../pile_process_2048_all"

TRAIN_DATA_PATH="${data_home}/training"
VALID_DATA_PATH="${data_home}/training"

TOKENIZER_PATH=../models/pythia

jobname="${jobname}-${model_size}B-iters-${train_iters}"
jobname="${jobname}-lr-${lr}-min-${min_lr}-wmup-${lr_warmup_iters}-dcy-${lr_decay_iters}-sty-${lr_decay_style}"
jobname="${jobname}-gbs-${global_batch_size}-mbs-${batch_size}-gpu-${num_gpus}-zero-${zero_stage}-mp-${mp_size}-pp-${pp_size}"
if [ "${no_pp}" = "true" ]; then
    jobname="${jobname}-nopp"
fi

output_home="../megatron_ckpt/gerbert_pretraining"
log_path="${output_home}/log/"
checkpoint_path="${output_home}/checkpoint/${jobname}"
## Microsoft internal constraint: because tensorboard is logged by last rank,
## it's better to put the path in NFS instead of Blob.
tensorboard_dir="${output_home}/project/bert_with_pile/tensorboard/"
tensorboard_path="${tensorboard_dir}${jobname}_${current_time}"
mkdir -p ${log_path}
mkdir -p ${checkpoint_path}
mkdir -p ${tensorboard_path}
###############################################################################
data_options=" \
    --train-data-path $TRAIN_DATA_PATH \
    --valid-data-path $VALID_DATA_PATH \
    --tokenizer-model $TOKENIZER_PATH \
    --tokenizer-type HFTokenizer \
    --data-impl mmap "

megatron_options=" \
    --override-opt_param-scheduler \
    --adam-beta1 0.9 \
    --adam-beta2 0.95 \
    --init-method-std ${init_std} \
    --tensor-model-parallel-size ${mp_size} \
    --lr-decay-iters ${lr_decay_iters} \
    --lr-warmup-iters ${lr_warmup_iters} \
    --micro-batch-size ${batch_size} \
    --global-batch-size ${global_batch_size} \
    --num-layers ${num_layers} \
    --hidden-size ${hidden_size} \
    --num-attention-heads ${num_attn_heads} \
    --num-key-value-heads ${num_kv_nums} \
    --seq-length ${seq_len} \
    --max-position-embeddings ${max_posi_emb} \
    --train-iters ${train_iters} \
    --lr ${lr} \
    --min-lr ${min_lr} \
    --lr-decay-style ${lr_decay_style} \
    --log-interval ${log_interval} \
    --eval-interval ${eval_interval} \
    --eval-iters ${eval_iters} \
    --save-interval ${save_interval} \
    --weight-decay 1e-2 \
    --clip-grad 1.0 \
    --fp16 \
    --load ${checkpoint_path} \
    --save ${checkpoint_path} \
    --tensorboard-queue-size 1 \
    --log-timers-to-tensorboard \
    --log-batch-size-to-tensorboard \
    --log-validation-ppl-to-tensorboard \
    --tensorboard-dir ${tensorboard_path} \
    --swiglu \
    --use-rotary-position-embeddings \
    --rotary-percent 0.25 \
    --attention-softmax-in-fp32 \
    --num-workers 16 \
    --distributed-backend nccl "
    # --use-flash-attn-v2"

#   --normalization rmsnorm \
if [ "${activation_checkpoint}" = "true" ]; then
megatron_options="${megatron_options} \
    --checkpoint-activations"
fi

if [ "${log_optimizer_state}" = "true" ]; then
megatron_options="${megatron_options} \
    --log-optimizer-states-to-tensorboard"
fi

template_json="../ds_config/ds_config_bert_TEMPLATE.json"
config_json="../ds_config/ds_config_gebert_bsz${global_batch_size}_mbsz${batch_size}_log${log_interval}_zero${zero_stage}.json"
if [[ $zero_stage -gt 0 ]]; then
sed "s/CONFIG_BATCH_SIZE/${global_batch_size}/" ${template_json} \
    | sed "s/CONFIG_MBSIZE/${batch_size}/" \
    | sed "s/LOG_INTERVAL/${log_interval}/" \
    | sed "s/ZERO_STAGE/${zero_stage}/" \
    | sed "s/PRESCALE_GRAD/false/" \
    | sed "s/CONFIG_FP16_ENABLED/true/" \
    | sed "s/CONFIG_BF16_ENABLED/false/" \
    | sed "s/INSTIAL_SCALE_POWER/12/" \
      > ${config_json}
else
sed "s/CONFIG_BATCH_SIZE/${global_batch_size}/" ${template_json} \
    | sed "s/CONFIG_MBSIZE/${batch_size}/" \
    | sed "s/LOG_INTERVAL/${log_interval}/" \
    | sed "s/ZERO_STAGE/${zero_stage}/" \
    | sed "s/PRESCALE_GRAD/true/" \
    | sed "s/CONFIG_FP16_ENABLED/true/" \
    | sed "s/CONFIG_BF16_ENABLED/false/" \
    | sed "s/INSTIAL_SCALE_POWER/12/" \
      > ${config_json}
fi

deepspeed_options=" \
    --deepspeed \
    --deepspeed_config ${config_json} \
    --zero-stage ${zero_stage} \
    --pipeline-model-parallel-size ${pp_size}"

if [[ "${no_pp}" = "true" ]]; then
deepspeed_options="${deepspeed_options} \
    --no-pipeline-parallel"
fi

if [ "${activation_checkpoint}" = "true" ]; then
deepspeed_options="${deepspeed_options} \
    --deepspeed-activation-checkpointing"
fi


torchrun $DISTRIBUTED_ARGS ../pretrain_gebert.py ${megatron_options} ${data_options} ${deepspeed_options} >> ${log_path}/${jobname}_${current_time}.log