#!/bin/bash
#
#SBATCH --mem=200G
#SBATCH -N 1
#SBATCH -t 1-00:00
#SBATCH -o ./log/%j.out
#SBATCH -e ./log/%j.err

source ./env.sh

if [ -z $ACCELERATE_PORT ]; then
    ACCELERATE_PORT=29500
fi

if [ -z $EPS ]; then
    EPS=0
fi

if [ -z $LR ]; then
    LR=1e-5
fi

if [ -z $KL ]; then
    KL=0.1
fi

if [ -z $REMOVE_SUBSET ]; then
    REMOVE_SUBSET=""
else
    REMOVE_SUBSET="--subset_to_remove $REMOVE_SUBSET"
fi

if [ -z $SEED ]; then
    SEED=""
else
    SEED="--seed $SEED"
fi

accelerate launch --config_file accelerate_configs/deepspeed_zero2-$GPU_COUNT.yaml \
    --main_process_port $ACCELERATE_PORT \
    dpo.py $SEED \
    --eps $EPS \
    $REMOVE_SUBSET \
    --dataset_num_proc 20 \
    --dataset_version 400k \
    --beta $KL \
    --max_length 1024 \
    --max_prompt_length 128 \
    --output_dir None \
    --model_name_or_path google/gemma-2b-it \
    --torch_dtype bfloat16 \
    --bf16 True \
    --bf16_full_eval True \
    --attn_implementation flash_attention_2 \
    --optim adamw_hf \
    --lr_scheduler_type cosine \
    --warmup_ratio 0.03 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --gradient_accumulation_steps 1 \
    --num_train_epochs 1 \
    --learning_rate $LR \
    --gradient_checkpointing True \
    --weight_decay 0 \
    --report_to wandb \
    --remove_unused_columns False \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 150 \
    --save_steps 150 \
    --save_total_limit 2 \
    --logging_first_step True \
    --eval_on_start True \
    --no_remove_unused_columns \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 64 \
    --lora_dropout 0.05

    # # Does not work with DPOTrainer
    # --save_only_model False \
    # --load_best_model_at_end True \
    # --metric_for_best_model "eval_rewards/accuracies" \
