#! /bin/bash

MASTER_ADDR=localhost
MASTER_PORT=2025
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=2

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH=${1-"/home/MiniLLM"}
CKPT_NAME="1.3B-init"
CKPT="${BASE_PATH}/results/opt/train/minillm_init/opt-1.3B"
TEACHER_CKPT_NAME="6.7B-sft"
TEACHER_CKPT="${BASE_PATH}/results/opt/train/sft/teacher_opt/"
MP_SIZE=2
# data
PROMPT_DATA_DIR="${BASE_PATH}/processed_data1/dolly/prompt/opt/"
LM_DATA_DIR="${BASE_PATH}/processed_data/roberta/opt/512/20M/"
# runtime
SAVE_PATH="${BASE_PATH}/results/opt/train/projector/seed2/"
# hp
GRAD_ACC=2
BATCH_SIZE=4
CHUNK_SIZE=8


OPTS=""
# model
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${CKPT}"
OPTS+=" --teacher-model-path ${TEACHER_CKPT}"
OPTS+=" --ckpt-name ${CKPT_NAME}"
OPTS+=" --teacher-ckpt-name ${TEACHER_CKPT_NAME}"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
OPTS+=" --n-nodes ${NNODES}"
OPTS+=" --model-type opt"
OPTS+=" --teacher-model-fp16"
OPTS+=" --gradient-checkpointing"
#OPTS+=" --resume ${CKPT}"
OPTS+=" --model-parallel"
OPTS+=" --model-parallel-size ${MP_SIZE}"
# data
OPTS+=" --prompt-data-dir ${PROMPT_DATA_DIR}"
OPTS+=" --lm-data-dir ${LM_DATA_DIR}"
OPTS+=" --dev-num 1000"
OPTS+=" --num-workers 0"
# hp
OPTS+=" --epochs 10"
OPTS+=" --total-iters 5000"
OPTS+=" --kd-ratio 0.0"
OPTS+=" --lm-coef 0.0"
OPTS+=" --rl-coef 1.0"
OPTS+=" --batch-size ${BATCH_SIZE}"
OPTS+=" --lr 5e-6"
OPTS+=" --lr-min 5e-6"
OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
OPTS+=" --max-length 512"
OPTS+=" --max-prompt-length 256"
#OPTS+=" --warmup-iters 100"
# runtime
OPTS+=" --save ${SAVE_PATH}"
OPTS+=" --seed 20"
OPTS+=" --seed-ppo 42"
OPTS+=" --seed-lm 7"
OPTS+=" --save-interval 500"
OPTS+=" --eval-interval 500"
OPTS+=" --log-interval 16"
OPTS+=" --mid-log-num 1"
OPTS+=" --student-layer-indices [-1]"
OPTS+=" --teacher-layer-indices [-1]"
OPTS+=" --student-mlp-size 3072"
OPTS+=" --student-hidd-size 2048"
OPTS+=" --teacher-hidd-size 4096"
OPTS+=" --alpha-corr 0.0"
OPTS+=" --alpha-cos 0.00"
OPTS+=" --alpha-mse 1.0"
OPTS+=" --alpha-cka 0.0"
#OPTS+=" --mlp"
OPTS+=" --hidden"
OPTS+=" --do-projector"
# ppo
OPTS+=" --type minillm"
OPTS+=" --ppo-epochs 4"
OPTS+=" --num-rollouts 256"
OPTS+=" --chunk-size ${CHUNK_SIZE}"
OPTS+=" --num-rollouts 256"
OPTS+=" --chunk-size ${CHUNK_SIZE}"

# minillm
OPTS+=" --length-norm"
OPTS+=" --single-step-reg"
OPTS+=" --teacher-mixed-alpha 0.2"
# reward
OPTS+=" --reward-scaling 0.5"
OPTS+=" --cliprange-reward 100"
# gen
OPTS+=" --do-sample"
OPTS+=" --top-k 0"
OPTS+=" --top-p 1.0"
OPTS+=" --temperature 1.0"
# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config_zero1_fp16_1.json"

export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export PYTHONPATH=${BASE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/train_minillm.py ${OPTS} $@"

echo ${CMD}
echo "PYTHONPATH=${PYTHONPATH}"
mkdir -p ${SAVE_PATH}
${CMD}
