export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATASET_NAME=""

input_arg="$1"
seed=0



_basedir="BASE_DIR"
# ex: ./experiments/
base_output_dir="PATH_TO_SAVE_RESULT_FILES"
test_output_dir="PATH_FOR_RESULT"

# ex : dataset/qas
base_prompt_path="${_basedir}/PROMPT_PATH"

rtype=pickscore

overwrite=0
gpu_ids="4,5"
port_num=25601
num_imgs_per_prompt=4

# TEST

# List of models to test
# Refer to the name of the model used in the training script
model_list=("MODEL_NAME")
# List of checkpoints of diffusion modesl to test
checkpoint_list=("100")
# List of test files to run
# E.g.: ("_test_filtered" "_parti_test")
data_list=("TEST_PROMPT_FILE")





# train and test

for data in "${data_list[@]}"; do
    if [ -z "$data" ]; then
        # Handle the case for an empty suffix
        prompt_path="${base_prompt_path}.json"
        dataset="${DATASET_NAME}"
    else
        prompt_path="${base_prompt_path}${data}.json"
        dataset="${DATASET_NAME}${data}"
    fi

    for model_type in "${model_list[@]}" ; do

        base_model_path="${base_output_dir}/${model_type}/checkpoint"

        for checkpoint in "${checkpoint_list[@]}" ; do

            pyfile="${_basedir}/test.py"
            model_path="${base_model_path}-${checkpoint}" 

            echo "Running model with prompt_path=${prompt_path} and dataset=${dataset}"

            if [ -z "$checkpoint" ]; then
                CUDA_VISIBLE_DEVICES=${gpu_ids} accelerate launch \
                    --mixed_precision=fp16 --main_process_port=${port_num} ${pyfile} \
                    --prompts_path ${prompt_path} \
                    --version "${model_type}_${checkpoint}" \
                    --dataset ${dataset} --reward_type=${rtype} \
                    --pretrained_model_name_or_path=${MODEL_NAME} \
                    --output-dir=${test_output_dir} 
                    --num_imgs_per_prompt=${num_imgs_per_prompt} \
                    --overwrite ${overwrite}
            else
                CUDA_VISIBLE_DEVICES=${gpu_ids} accelerate launch \
                    --mixed_precision=fp16 --main_process_port=${port_num} ${pyfile} \
                    --prompts_path ${prompt_path} \
                    --version "${model_type}_${checkpoint}" \
                    --model-path "${model_path}" \
                    --dataset ${dataset} --reward_type=${rtype} \
                    --pretrained_model_name_or_path=${MODEL_NAME} \
                    --output-dir=${test_output_dir} \
                    --num_imgs_per_prompt=${num_imgs_per_prompt} \
                    --overwrite ${overwrite}
            fi
        done
    done
done


