set +e
set -x


export CXX=g++
export OMP_NUM_THREADS=20
export TRANSFORMERS_OFFLINE=1


min_port=1024
max_port=65535
range=$((max_port - min_port + 1))
random_port=$((RANDOM % range + min_port))
NUM_GPUS=8
BATCH_SIZE_PER_GPU=8
TOTAL_BATCH_SIZE=512
GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU/$NNODES))


PT_PATH=meta-llama/Meta-Llama-3-8B

topk=10
epoch=1
lr=1e-5
ratio=0.85

INST_TYPE="pure_no_inst"
PROMPT_FORMAT="question_only_sol"
DATA_NAME="train_data_sampled"
DATA_PATH=data/$DATA_NAME
MODEL_NAME=llama3-$DATA_NAME
CKPT_PATH=$CACHE_PATH/results/$MODEL_NAME-$lr-$ratio-wsd-no_shuffle
torchrun --nproc_per_node=8 \
    train_pack.py \
    --model_name_or_path $PT_PATH \
    --data_path $DATA_PATH \
    --output_dir $CKPT_PATH \
    --bf16 True \
    --num_train_epochs $epoch \
    --model_max_length 2048 \
    --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 500 \
    --learning_rate $lr \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --logging_steps 2 \
    --deepspeed ./configs/stage_2.json \
    --tf32 True \
    --gradient_checkpointing True \
    --report_to none \
    --lr_scheduler_type "linear" \
    --inst_type "pure_no_inst" \
    --flash_attention \
    --preprocessing_num_workers 96 \
    --save_on_each_node False \
    --use_wsd \
    --no_shuffle \
    --stable_ratio $ratio
