#!/bin/bash
#
#SBATCH --mem=200G
#SBATCH -N 1
#SBATCH -t 0-12:00
#SBATCH -o ./log/%j.out
#SBATCH -e ./log/%j.err

source ./env.sh

if [ -z $ACCELERATE_PORT ]; then
    ACCELERATE_PORT=29500
fi

if [ -z $EPS ]; then
    EPS=0
fi

if [ -z $LR ]; then
    LR=1e-5
fi

if [ -z $KL ]; then
    KL=0
fi

if [ -z $LOSS_TYPE ]; then
    LOSS_TYPE="reward"
fi

if [ -z $SEED ]; then
    SEED=""
else
    SEED="--seed $SEED"
fi

accelerate launch --config_file accelerate_configs/deepspeed_zero2-${GPU_COUNT}rlhf.yaml \
    --main_process_port $ACCELERATE_PORT \
    ppo.py $SEED \
    --eps $EPS \
    --loss_type $LOSS_TYPE \
    --dataset_version 5k \
    --output_dir None \
    --dataset_num_proc 20 \
    --model_name_or_path google/gemma-2b-it \
    --sft_model_path google/gemma-2b-it \
    --reward_model_path $REW_MODEL \
    --torch_dtype bfloat16 \
    --bf16 True \
    --bf16_full_eval True \
    --attn_implementation flash_attention_2 \
    --optim adamw_hf \
    --lr_scheduler_type cosine \
    --batch_size 64 \
    --mini_batch_size 1 \
    --warmup_ratio 0.03 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --kl_coef $KL \
    --num_ppo_epochs 1 \
    --num_train_epochs 1 \
    --learning_rate $LR \
    --lam 0.95 \
    --gamma 1 \
    --cliprange 0.2 \
    --cliprange_value 0.2 \
    --local_rollout_forward_batch_size 4 \
    --missing_eos_penalty 1.0 \
    --num_sample_generations 10 \
    --gradient_checkpointing True \
    --save_only_model True \
    --load_best_model_at_end True \
    --metric_for_best_model "objective/rlhf_reward" \
    --weight_decay 0 \
    --report_to wandb \
    --remove_unused_columns False \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 250 \
    --save_steps 250 \
    --save_total_limit 1 \
    --logging_first_step True \
    --eval_on_start True \
    --use_peft True \
    --lora_r 32 \
    --lora_alpha 64 \
    --lora_dropout 0.05
#    --lora_task_type SEQ_CLS \
