#!/bin/bash

CKPT="$1" 
# the ckpt starts with ming{SIZE}b, extract the SIZE (maybe 1.8 or 7) and save it to SIZE variable
SIZE=$(echo $CKPT | grep -oP 'qwen\K[0-9.]+(?=b)')

if [[ $CKPT == *"llama7b"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/llama2_7b_chat
    if [[ $CKPT == *"aoa"* ]]; then 
        conv_mode="llama2_harm2"
    elif [[ $CKPT == *"redteam"* ]]; then  
        conv_mode="llama2_harm"
    else 
        conv_mode="llama2"
    fi
elif [[ $CKPT == *"llama3"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/Meta-Llama-3-8B-Instruct
    conv_mode="llama3"
else
    MODEL_BASE=/mnt/petrelfs/usr/models/models--Qwen--Qwen1.5-${SIZE}B-Chat
    conv_mode="qwen"
fi
domain=advbench
LOGS_BASE_PATH="./logs/diverse"

MODEL_PATH=/mnt/petrelfs/usr/checkpoints/${CKPT}

version="-woffn"


echo "Processing $domain"
mkdir -p ${LOGS_BASE_PATH}/${domain}


srun -p partition --gres=gpu:1  --quotatype=auto --output=${LOGS_BASE_PATH}/${domain}/${CKPT}${version}.infer.log python -m ming.eval.model_advbench_gen \
    --model-base ${MODEL_BASE} \
    --model-path ${MODEL_PATH} \
    --answers-file ${LOGS_BASE_PATH}/${domain}/${CKPT}${version}.jsonl \
    --s3-answers-file s3://bucket/logs/diverse/${domain}/${CKPT}${version}.jsonl \
    --question-file /mnt/petrelfs/usr/datasets/harmful_behaviors.csv \
    --temperature 0 \
    --max-tokens 1024 \
    --keep-local \
    --conv-mode ${conv_mode} \
    --infer-answer \
    --use-logit-bias \
    --unload-ffn \
    --resume &

wait 

echo "Evaluating $domain"

srun -p partition python -m ming.eval.eval_advbench --input_file ${LOGS_BASE_PATH}/${domain}/${CKPT}${version}.jsonl
