#! /bin/bash

BASE_PATH="/data"
BASE_CODE_PATH="${BASE_PATH}/code"
BASE_DATA_PATH="${BASE_PATH}/dataset"

MASTER_ADDR=localhost
MASTER_PORT=${2-2012}
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=${1-8}

# accelerate config file, use >>>accelerate config and answer question to generate
CONFIG_FILE="${BASE_CODE_PATH}/configs/accelerate/default_config.yaml"

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# data
DATA_DIR="${BASE_DATA_PATH}/processed_data/openwebtext/gpt2/512/10M/"

# runtime
SAVE_PATH="${BASE_CODE_PATH}/results/gpt2/train/pretrain/"
GLOBAL_SEED=17
MP_SIZE=4

OPTS+=" --data_file ${DATA_DIR}"
OPTS+=" --validation_split_percentage 5"  # Only work when validation_file is none
OPTS+=" --config_name ${BASE_CODE_PATH}/configs/learngene/gpt2/gpt2-NEmbed_768_NHead_12_NLayer_14.json"  # The settings of my custom llama model
OPTS+=" --tokenizer_name ${BASE_DATA_PATH}/minillm_official/gpt2/train/minillm/base-init-xlarge-sft"
OPTS+=" --per_device_train_batch_size 1"
OPTS+=" --block_size 1024"
OPTS+=" --max_length 1024"
OPTS+=" --per_device_eval_batch_size 1"
OPTS+=" --learning_rate 1e-4"
OPTS+=" --weight_decay 0.0"
OPTS+=" --num_train_epochs 2"
OPTS+=" --gradient_accumulation_steps 1"
OPTS+=" --lr_scheduler_type linear"
OPTS+=" --num_warmup_steps 0"
OPTS+=" --output_dir ${SAVE_PATH}"
OPTS+=" --num_workers 4"
OPTS+=" --seed ${GLOBAL_SEED}"
OPTS+=" --model_type llama"
OPTS+=" --no_keep_linebreaks"
OPTS+=" --with_tracking"
OPTS+=" --report_to all"
OPTS+=" --low_cpu_mem_usage"
OPTS+=" --gradient_checkpointing"
OPTS+=" --model-parallel"
OPTS+=" --model-parallel-size ${MP_SIZE}"
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_CODE_PATH}/configs/deepspeed/ds_config.json"


export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export ASCEND_LAUNCH_BLOCKING=1
export PYTHONPATH=${BASE_CODE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_CODE_PATH}/pretrain.py ${OPTS} $@"


echo ${CMD}
echo "PYTHONPATH=${PYTHONPATH}"
mkdir -p ${SAVE_PATH}
${CMD}
