declare -A filter_prompts
filter_prompts["animals"]="Does this patch show part of a {animal_name}? Answer with only Yes or No."
filter_prompts["imagenet_animals"]="Does this patch show part of a {animal_name}? Answer with only Yes or No."
filter_prompts["food_v2"]="Does this patch show part of a {food_name}? Answer with only Yes or No."
filter_prompts["landmarks"]="Does this patch show part of the {landmark_name}? Answer with only Yes or No."

# Default configuration
filter_class_name="Qwen_VL_Instruct_Filter"
filter_model_name_or_path="/mnt/lustrenew/mllm_safety-shared/models/huggingface/Qwen/Qwen2.5-VL-72B-Instruct"
gpu=4
world_size=1
datasets_to_run=("${!filter_prompts[@]}")

# Parse command line arguments
while [[ "$#" -gt 0 ]]; do
    case $1 in
        --filter_class_name) filter_class_name="$2"; shift ;;
        --filter_model_name_or_path) filter_model_name_or_path="$2"; shift ;;
        --gpu) gpu="$2"; shift ;;
        --world_size) world_size="$2"; shift ;;
        --datasets)
            shift
            _datasets_to_run=()
            IFS=',' read -ra _datasets_to_run <<< "$1"
            # Validate that all specified datasets are valid
            for dataset in "${_datasets_to_run[@]}"; do
                if [[ ! " ${datasets_to_run[@]} " =~ " ${dataset} " ]]; then
                    echo "Error: Unknown dataset '$dataset'. Valid options are:"
                    printf '  %s\n' "${datasets_to_run[@]}"
                    exit 1
                fi
            done
            datasets_to_run=("${_datasets_to_run[@]}")
            ;;
        *) echo "Unknown parameter passed: $1"; exit 1 ;;
    esac
    shift
done

for dataset in "${datasets_to_run[@]}"; do

    filter_prompt="${filter_prompts[$dataset]}"

    # data filtering and stitching
    for patch_ratio in 4 8; do
        (
            for rank in $(seq 1 ${world_size}); do
                PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:${gpu} --cpus-per-task=8 --time=30000 \
                python src/tools/patches_filter.py \
                --rank ${rank} --world_size ${world_size} \
                --src_patches_dir "tmp/data/${dataset}/files/${patch_ratio}x${patch_ratio}" \
                --tgt_patches_dir "tmp/data/${dataset}/files/${patch_ratio}x${patch_ratio}_unrecognizable" \
                --data_config_path "data/${dataset}/config_image.yaml" \
                --filter_class_name "${filter_class_name}" --filter_kwargs "--prompt '${filter_prompt}' --model_name_or_path ${filter_model_name_or_path}" &
                sleep 1
            done
            wait

            PYTHONPATH=. srun -p p-cpu-new --quotatype=reserved --cpus-per-task=8 --time=30000 \
            python src/tools/patches_stitch.py \
            --src_patches_dir "tmp/data/${dataset}/files/${patch_ratio}x${patch_ratio}_unrecognizable" \
            --tgt_images_dir "tmp/data/${dataset}/files/${patch_ratio}x${patch_ratio}_unrecognizable_stitched" \
            --patch_ratio "${patch_ratio}"
        ) &
        sleep 10
    done

done
