#!/bin/bash

gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
IFS=',' read -ra GPULIST <<< "$gpu_list"

CHUNKS=${#GPULIST[@]}

CKPT=$1

if [ "$2" == "" ]; then
  BASE=llava-next-interleave-qwen-0.5b
else
  BASE=$2
fi

# Add --doc_init only if BASE is llava-next-interleave-qwen-0.5b
DOC_INIT_ARG=""
if [ "$BASE" == "llava-next-interleave-qwen-0.5b" ]; then
  DOC_INIT_ARG="--doc_model_init"
fi

BASE_PATH=''

for IDX in $(seq 0 $((CHUNKS-1))); do
    CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python llava/eval/viquae/embed_query_viquae.py \
        --model-path ${BASE_PATH}/checkpoints/$CKPT \
        --model-base ${BASE_PATH}/checkpoints/$BASE \
        --save_path ${BASE_PATH}/dataset/viquae/query_embeds/${CKPT} \
        --query_path ${BASE_PATH}/dataset/viquae/test_clean.csv \
        --is_multimodal \
        --num-chunks $CHUNKS \
        --chunk-idx $IDX \
        --conv qwen_1_5 \
        --batch_size 8 \
        $DOC_INIT_ARG & 
done

wait


python llava/eval/merge_embeds.py \
        --embed_path ${BASE_PATH}/dataset/viquae/query_embeds/${CKPT}/query_embed \
