#!/bin/bash


CKPT="$1"  
# ckpts=("ming1.8b-4x1-topk-openmath01-r32" "ming1.8b-molora-4x1-topk-openmath01")
# CKPT="ming1.8b-4x1-topk-openmath01-r32"
# the ckpt starts with ming{SIZE}b, extract the SIZE (maybe 1.8 or 7) and save it to SIZE variable
SIZE=$(echo $CKPT | grep -oP 'ming\K[0-9.]+(?=b)')

if [[ $CKPT == *"llama7b"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/llama2_7b_chat
    conv_mode="llama2"
elif [[ $CKPT == *"llama8b"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/Meta-Llama-3-8B-Instruct
    conv_mode="llama3"
else
    MODEL_BASE=/mnt/petrelfs/usr/models/models--Qwen--Qwen1.5-${SIZE}B-Chat
    conv_mode="qwen"
fi
LOGS_BASE_PATH="./logs/diverse"

MODEL_PATH=/mnt/petrelfs/usr/checkpoints/${CKPT}

while [ ! -f "${MODEL_PATH}/adapter_model.safetensors" ]; do
    echo "Waiting for ${MODEL_PATH}/adapter_model.safetensors to appear..."
    sleep 60
done


domains=("logiqa_en" "commonsense_qa" "svamp" "mmlu" "mmedbench_en" "bbh" "math")
# domain="commonsense_qa"
version="-woffn"


for domain in "${domains[@]}"; do
# for CKPT in "${ckpts[@]}"; do
    sleep 1
    (
        echo "Processing $domain"
        mkdir -p ${LOGS_BASE_PATH}/${domain}
        # if domain==bbh, run bash bbh.sh $CKPT
        if [ $domain == "bbh" ]; then
            bash scripts/v1/eval/bbh_lora_woffn.sh $CKPT
            continue
        fi
        if [ $domain == "math" ]; then
            bash scripts/v1/eval/math_lora_woffn.sh $CKPT
            continue
        fi
        

        srun -p partition --gres=gpu:1  --quotatype=auto --output=${LOGS_BASE_PATH}/${domain}/${CKPT}${version}.infer.log python -m ming.eval.model_diverse_gen \
            --model-path ${MODEL_PATH} \
            --model-base ${MODEL_BASE} \
            --question-file s3://bucket/datasets/diverse_domain/test/${domain}.json \
            --answers-file ${LOGS_BASE_PATH}/${domain}/${CKPT}${version}.jsonl \
            --s3-answers-file s3://bucket/logs/diverse/${domain}/${CKPT}${version}.jsonl \
            --temperature 0 \
            --max-tokens 1024 \
            --keep-local \
            --conv-mode ${conv_mode} \
            --infer-answer \
            --unload-ffn \
            --use-logit-bias \
            --resume
        


        echo "Evaluating $domain"

        srun -p partition --output=${LOGS_BASE_PATH}/${domain}/${CKPT}${version}.eval.log python -m ming.eval.eval_em \
            --input_file ${LOGS_BASE_PATH}/${domain}/${CKPT}${version}.jsonl 

        # fi

    ) &
done


wait
echo "All processes are done."