#!/bin/bash
TEMPERATURE=1.0
NUM_TEST_SAMPLES=1000
DATA_PATH="./coco/train2017"
SAVE_BASE="./result/pe"
DEVICE="cuda:0"

echo "DATA_PATH: $DATA_PATH"
echo "SAVE_BASE: $SAVE_BASE"

model_list=( )

run_experiment() {
    local MODEL_PATH=$1
    local q_type=$2

    if [[ "$MODEL_PATH" == *"llava"* ]]; then
        ASK_CONV_MODE="conv_llama_v2_vqa"
        ANS_CONV_MODE="llava_llama_2"
    elif [[ "$MODEL_PATH" == *"Qwen"* ]]; then
        ASK_CONV_MODE="chatml_direct_q"
        ANS_CONV_MODE="chatml_direct"
    else
        ASK_CONV_MODE="conv_llama_v2_vqa"
        ANS_CONV_MODE="llava_llama_2"
    fi

    echo "================================================================"
    echo "model path: $MODEL_PATH"
    echo "problem type: $q_type"
    echo "device: $DEVICE"
    echo "ask model: $ASK_CONV_MODE"
    echo "ans model: $ANS_CONV_MODE"
    echo "================================================================"

    python ./main/AskPrompt.py \
        --model_path "$MODEL_PATH" \
        --device "$DEVICE" \
        --data_path "$DATA_PATH" \
        --temperature "$TEMPERATURE" \
        --num_test_samples "$NUM_TEST_SAMPLES" \
        --ask_conv_mode "$ASK_CONV_MODE" \
        --ans_conv_mode "$ANS_CONV_MODE" \
        --problem_type "$q_type" \
        --save_base "$SAVE_BASE"
}

for MODEL_PATH in "${model_list[@]}"; do
    for q_type in "answerable" "unanswerable"; do
        run_experiment "$MODEL_PATH" "$q_type"
    done
done

echo
