#!/bin/bash

set -e
set -o pipefail

model_base=lmsys/vicuna-7b-v1.5
#model_path=microsoft/llava-rad
# todo
#model_path=/path/to/trained_llava_ta_checkpoints/

model_base="${1:-$model_base}"
model_path="${2:-$model_path}"
prediction_dir="${3:-results/topic_seg/llavarad_IUXRay}"
prediction_file=$prediction_dir/test

run_name="${4:-llava_topic_seg}"


# query_file=/PATH_TO/physionet.org/files/llava-rad-mimic-cxr-annotation/1.0.0/chat_test_MIMIC_CXR_all_gpt4extract_rulebased_v1.json
# todo
#query_file=/path/to/IU-Xray-report.json

# image_folder=/PATH_TO/physionet.org/files/mimic-cxr-jpg/2.0.0/files
# todo
#image_folder=/path/to/IU-Xray-image-folder
# todo
#mask_path=/path/to/IU-Xray-segmented-image-folder
loader="iuxray_test_topic_reason_findings"
conv_mode="v1"

CHUNKS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)

for (( idx=0; idx<$CHUNKS; idx++ ))
do
    CUDA_VISIBLE_DEVICES=$idx python -m llava.eval.model_mimic_cxr_topicSeg \
        --query_file ${query_file} \
        --loader ${loader} \
        --image_folder ${image_folder} \
        --mask_path ${mask_path} \
        --conv_mode ${conv_mode} \
        --prediction_file ${prediction_file}_${idx}.jsonl \
        --temperature 0 \
        --model_path ${model_path} \
        --model_base ${model_base} \
        --chunk_idx ${idx} \
        --num_chunks ${CHUNKS} \
        --batch_size 4 \
        --group_by_length &
done

wait

cat ${prediction_file}_*.jsonl > mimic_cxr_preds.jsonl

pushd llava/eval/rrg_eval
WANDB_PROJECT="llava_topic_seg" WANDB_RUN_ID="llava-eval-$(date +%Y%m%d%H%M%S)" WANDB_RUN_GROUP=evaluate CUDA_VISIBLE_DEVICES=0 \
    python run.py ../../../mimic_cxr_preds.jsonl --run_name ${run_name} --output_dir ../../../${prediction_dir}/eval
popd

rm mimic_cxr_preds.jsonl