#!/bin/bash
#
#SBATCH --mem=200G
#SBATCH -N 1
#SBATCH -t 1-00:00
#SBATCH -o ./log/%j.out
#SBATCH -e ./log/%j.err

source ./env.sh

if [ -z $ACCELERATE_PORT ]; then
    ACCELERATE_PORT=29500
fi

if [ -z $EPOCHS ]; then
    EPOCHS=2
fi

if [ -z $EPS ]; then
    EPS=0
fi

if [ -z $LR ]; then
    LR=1e-5
fi

if [ -z $REMOVE_SUBSET ]; then
    REMOVE_SUBSET=""
else
    REMOVE_SUBSET="--subset_to_remove $REMOVE_SUBSET"
fi

if [ -z $DIST_FN ]; then
    DIST_FN=""
else
    DIST_FN="--dist_fn $DIST_FN"
fi

if [ -z $SEED ]; then
    SEED=""
else
    SEED="--seed $SEED"
fi

accelerate launch --config_file accelerate_configs/deepspeed_zero2-$GPU_COUNT.yaml \
    --main_process_port $ACCELERATE_PORT \
    reward_uf.py $SEED \
    --eps $EPS \
    $DIST_FN \
    --dataset_version 400k \
    $REMOVE_SUBSET \
    --dataset_num_proc 8 \
    --output_dir None \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
    --torch_dtype bfloat16 \
    --bf16 True \
    --bf16_full_eval True \
    --center_rewards_coefficient 0 \
    --attn_implementation flash_attention_2 \
    --per_device_train_batch_size 16 \
    --lr_scheduler_type cosine \
    --warmup_ratio 0.03 \
    --per_device_eval_batch_size 8 \
    --num_train_epochs $EPOCHS \
    --gradient_checkpointing True \
    --load_best_model_at_end True \
    --metric_for_best_model "eval_accuracy" \
    --learning_rate $LR \
    --weight_decay 0 \
    --report_to wandb \
    --remove_unused_columns False \
    --optim adamw_hf \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 200 \
    --save_steps 200 \
    --save_total_limit 3 \
    --logging_first_step True \
    --eval_on_start True \
    --max_length 1024 \
    --use_peft True \
    --lora_task_type SEQ_CLS \
    --lora_r 32 \
    --lora_alpha 64 \
    --lora_dropout 0.05
