#!/bin/bash
if [ -z "$NVME_PATH" ]; then
  NVME_PATH="/mnt/mdinf"
fi

DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
#mkdir -p $DIR/logs
#mkdir -p /tmp/logs


#DATASET_1="<PATH TO THE FIRST DATASET>"
#DATASET_2="<PATH TO THE SECOND DATASET>"
#DATASET_3="<PATH TO THE THIRD DATASET>"
#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"

BASE_DATA_PATH=${DIR}/data
DUMMY_VOCAB_PATH=${BASE_DATA_PATH}/tokenizer.model
DATASET=${BASE_DATA_PATH}/c4_llama3_00000_text_document


script_path=$(realpath $0)
script_dir=$(dirname $script_path)
#CONFIG_JSON="$script_dir/ds_config.json"
CONFIG_JSON=${DIR}/ds_config.json

USE_DEEPSPEED=1
ZERO_STAGE=3


# Debug
#TP=4
#PP=4
#LAYERS=8
#HIDDEN=512
#SEQ=1024
#GLOBAL_BATCH=128
#WORKER_STR="-i worker-0"

PP=1
TP=2

# llama3 debug model
# HIDDEN=256
# LAYERS=8
# HEADS=16
# MULTIPLE_OF=256
# FFN_DIM_MULTIPLIER=1
# SEQ=128
# GLOBAL_BATCH=1
# WORKER_STR=""
# MICRO_BATCH=1
# ROPE_THETA=500000

# Granite fp32 training (because torchtitan only support bf16 when FSDP is enabled)
HIDDEN=4096
LAYERS=36
HEADS=32
N_KV_HEADS=8
MULTIPLE_OF=2048
FFN_DIM_MULTIPLIER=1.4
SEQ=4096
GLOBAL_BATCH_SIZE=64
WORKER_STR=""
MICRO_BATCH_SIZE=2
ROPE_THETA=10000

LOG_DIR="/tmp/tensorboard/tp${TP}_pp${PP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_${DTYPE}_fix3"
mkdir -p $LOG_DIR

while [[ $# -gt 0 ]]
do
key="$1"
case $key in
    --no-deepspeed)
    USE_DEEPSPEED=0;
    shift
    ;;
    -z|--zero-stage)
    ZERO_STAGE=$2;
    shift
    ;;
    *)
    echo "Unknown argument(s)"
    usage
    exit 1
    shift
    ;;
esac
done


options=" \
	--tensor-model-parallel-size $TP \
        --num-layers $LAYERS \
        --hidden-size $HIDDEN \
        --num-attention-heads $HEADS \
        --num-kv-heads $N_KV_HEADS  \
        --rope-theta $ROPE_THETA  \
        --multiple-of $MULTIPLE_OF \
        --ffn-dim-multiplier $FFN_DIM_MULTIPLIER \
        --seq-length $SEQ \
        --loss-scale 12 \
        --max-position-embeddings $SEQ \
	--micro-batch-size $MICRO_BATCH_SIZE \
	--global-batch-size $GLOBAL_BATCH_SIZE \
	--train-iters 4 \
        --lr 6.0e-5 \
	--min-lr 6.0e-6 \
        --lr-decay-style cosine \
        --log-interval 1 \
        --eval-iters 40 \
        --eval-interval 1000 \
	--data-path ${DATASET} \
	--vocab-file ${DUMMY_VOCAB_PATH} \
	--save-interval 10000 \
        --split 98,2,0 \
        --clip-grad 1.0 \
	--weight-decay 0.1 \
	--adam-beta1 0.9 \
	--adam-beta2 0.95 \
	--init-method-std 0.006 \
	--exit-interval 10000 \
	--tensorboard-dir $LOG_DIR \
        --cpu-optimizer
        "


if [[ ${USE_DEEPSPEED} -eq 1 ]]; then
	echo "Using DeepSpeed"
	options="${options} \
		--deepspeed \
		--deepspeed_config=${CONFIG_JSON} \
		--zero-stage=${ZERO_STAGE} \
	"
fi


cat <<EOT > $CONFIG_JSON
{
  "train_batch_size" : $GLOBAL_BATCH_SIZE,
  "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
  "steps_per_print": 1,
  "stage3_max_reuse_distance": 0,
  "stage3_prefetch_bucket_size": 0,
  "zero_optimization": {
    "stage": $ZERO_STAGE,
    "offload_optimizer": {
      "device": "nvme",
      "nvme_path": "$NVME_PATH",
      "buffer_count": 4,
      "pipeline_read": false,
      "pipeline_write": false,
      "ratio": 1,
      "pin_memory": true
    },
    "offload_param": {
      "device": "nvme",
      "nvme_path": "$NVME_PATH",
      "pin_memory": true,
      "buffer_count": 18,
      "buffer_size": 300000000,
      "max_in_cpu": 0
    }
  },

  "bf16": {
    "enabled": false
  },

  "fp16": {
    "enabled": false,
    "loss_scale": 0,
    "loss_scale_window": 500,
    "hysteresis": 2,
    "min_loss_scale": 1,
    "initial_scale_power": 12
  },

  "wall_clock_breakdown" : true
}
EOT

WORKER_STR="-i localhost:0,1"
#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}"
#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}"
run_cmd="deepspeed $WORKER_STR ${DIR}/pretrain_llama.py $@ ${options}"


echo ${run_cmd}
eval ${run_cmd}

set +x
