#!/bin/bash

set -x

# export OMP_NUM_THREADS=4
# export MKL_NUM_THREADS=4
# export NUMEXPR_NUM_THREADS=4
# export OPENBLAS_NUM_THREADS=4
# export RAYON_NUM_THREADS=20
# export TOKENIZERS_PARALLELISM=False

DEV=0,1,2,3,4,5,6,7
PORT=1236
OUTPUT=$1
ZERO_STAGE=2
DATA_PATH="/efs/shicheng/remax/dataset/full-hh-rlhf"
MODEL_NAME="EleutherAI/pythia-2.8b"
SEED=1234

if [ "$OUTPUT" == "" ]; then
    OUTPUT=./output/pythia-2.8b/full-hh-rlhf
fi
mkdir -p $OUTPUT


(deepspeed --include localhost:$DEV --master_port $PORT \
main.py \
   --data_path $DATA_PATH \
   --data_output_path "/tmp/data_files/pythia" \
   --data_split 2,4,4 \
   --model_name_or_path $MODEL_NAME \
   --per_device_train_batch_size 16 \
   --per_device_eval_batch_size 16 \
   --max_seq_len 512 \
   --learning_rate 1e-5 \
   --weight_decay 0.1 \
   --num_padding_at_beginning 1 \
   --num_train_epochs 4  \
   --gradient_accumulation_steps 1 \
   --lr_scheduler_type cosine \
   --num_warmup_steps 0 \
   --seed 1234 \
   --gradient_checkpointing \
   --zero_stage $ZERO_STAGE \
   --deepspeed \
   --output_dir $OUTPUT \
   --enable_tensorboard \
   --print_loss \
   --deepspeed) 2>&1 | tee "$OUTPUT/training.log"