#!/bin/bash
#
#SBATCH --mem=24G
#SBATCH -N 1
#SBATCH -t 0-01:00
#SBATCH -o ./log/%j.out
#SBATCH -e ./log/%j.err

source ./env.sh

if [ -z $EPS ]; then
    EPS=""
else
    EPS="_dr_eps${EPS}"
fi

if [ -z $BEST ]; then
    if [[ "$REMOVE_SUBSET" == "hh-rlhf" ]]; then
        BEST="/checkpoint-10426"
    elif [[ "$REMOVE_SUBSET" == "chatbot_arena_conversations" ]]; then
        BEST="/checkpoint-12540"
    else
        BEST="/checkpoint-12886"
    fi
else
    BEST=""
fi

if [ -z $REMOVE_SUBSET ]; then
    REMOVE_SUBSET=""
else
    REMOVE_SUBSET="_no-${REMOVE_SUBSET}"
fi

if [ -z $DIST_FN ]; then
    DIST_FN=""
else
    DIST_FN="_${DIST_FN}dist"
fi

if [ -z $SEED ]; then
    SEED=""
else
    SEED="_seed${SEED}"
fi

MODEL_NAME="models/reward_uf_400k${REMOVE_SUBSET}_lr1e-05${EPS}${DIST_FN}_google_gemma-2b-it${SEED}${BEST}"

python3 eval_reward_bench.py --torch_dtype bfloat16 --attn_implementation flash_attention_2 --batch_size 16 --not_quantized --model google/gemma-2b-it --peft_name $MODEL_NAME

python3 eval_yang.py --peft_name $MODEL_NAME --task unified
python3 eval_yang.py --peft_name $MODEL_NAME --task hhh
python3 eval_yang.py --peft_name $MODEL_NAME --task mtbench
