# Refer to bash_test_sd15.sh for more details.

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME=""

input_arg="$1"
seed=0


_basedir="BASE_DIR"


base_output_dir="PATH_TO_SAVE_RESULT"
test_output_dir="PATH_FOR_RESULT" 

base_prompt_path="BASE_PROMPT_PATH"

rtype=pickscore 


overwrite=0
gpu_ids="4,5"
port_num=25700
num_imgs_per_prompt=4

# TEST

# List of models to test
# Refer to the name of the model used in the training script
model_list=("MODEL_NAME")
# List of checkpoints of diffusion modesl to test
checkpoint_list=("100")
# List of test files to run
# E.g.: ("_test_filtered" "_parti_test")
data_list=("TEST_PROMPT_FILE")







for data in "${data_list[@]}"; do
    if [ -z "$data" ]; then
        prompt_path="${base_prompt_path}.json"
        dataset="${DATASET_NAME}"
    else
        prompt_path="${base_prompt_path}${data}.json"
        dataset="${DATASET_NAME}${data}"
    fi

    for model_type in "${model_list[@]}" ; do

        base_model_path="${base_output_dir}/${model_type}/checkpoint"

        for checkpoint in "${checkpoint_list[@]}" ; do

            pyfile="${_basedir}/test.py"
            model_path="${base_model_path}-${checkpoint}" 

            echo "Running model with prompt_path=${prompt_path} and dataset=${dataset}"

            if [ -z "$checkpoint" ]; then

                CUDA_VISIBLE_DEVICES=${gpu_ids} accelerate launch \
                    --main_process_port=${port_num} ${pyfile} \
                    --pretrained_model_name_or_path=$MODEL_NAME \
                    --prompts_path ${prompt_path} \
                    --dataset ${dataset} --reward_type=${rtype} \
                    --pretrained_model_name_or_path=${MODEL_NAME} \
                    --output-dir=${test_output_dir} \
                    --num_imgs_per_prompt=${num_imgs_per_prompt} \
                    --overwrite ${overwrite} 

            else
                CUDA_VISIBLE_DEVICES=${gpu_ids} accelerate launch \
                    --main_process_port=${port_num} ${pyfile} \
                    --version "${model_type}_${checkpoint}" \
                    --model-path "${model_path}" \
                    --pretrained_model_name_or_path=$MODEL_NAME \
                    --prompts_path ${prompt_path} \
                    --dataset ${dataset} --reward_type=${rtype} \
                    --pretrained_model_name_or_path=${MODEL_NAME} \
                    --output-dir=${test_output_dir} \
                    --num_imgs_per_prompt=${num_imgs_per_prompt} \
                    --overwrite ${overwrite}
            fi
        done
    done
done
