# sh scripts/exps/main/sfilter.sh --models Qwen2-VL-7B --datasets animals
source scripts/exps/model_configs.sh

# Default configuration
rank_thresholds="1,2,3"
per_device_eval_batch_size=8
quotatype="reserved"

datasets_to_run=(
    "imagenet_animals_v2"
    "food_v3"
    "landmarks"
    # "animals"
)

models_to_run=(
    "Qwen2-VL-2B"
    "Qwen2-VL-2B-Instruct"
    "Qwen2-VL-7B"
    "Qwen2-VL-7B-Instruct"
    # "Qwen2-VL-72B"
    # "Qwen2-VL-72B-Instruct"
    "Qwen2.5-VL-3B-Instruct"
    "Qwen2.5-VL-7B-Instruct"
    # "Qwen2.5-VL-32B-Instruct"
    # "Qwen2.5-VL-72B-Instruct"
    "gemma-3-4b-pt"
    "gemma-3-4b-it"
    "gemma-3-12b-pt"
    "gemma-3-12b-it"
    # "gemma-3-27b-pt"
    # "gemma-3-27b-it"
    "Llama-3.2-11B-Vision"
    "Llama-3.2-11B-Vision-Instruct"
    # "Llama-3.2-90B-Vision"
    # "Llama-3.2-90B-Vision-Instruct"
    "llava-1.5-7b-hf"
    "llava-1.5-13b-hf"
    "llava-v1.6-vicuna-7b-hf"
    "llava-v1.6-vicuna-13b-hf"
    # "llava-v1.6-34b-hf"
    "InternVL3-1B"
    "InternVL3-8B"
    "InternVL3-14B"
)

# Parse command line arguments
while [[ "$#" -gt 0 ]]; do
    case $1 in
        --rank_thresholds) rank_thresholds="$2"; shift ;;
        --per_device_eval_batch_size) per_device_eval_batch_size="$2"; shift ;;
        --quotatype) quotatype="$2"; shift ;;
        --models)
            shift
            _models_to_run=()
            IFS=',' read -ra _models_to_run <<< "$1"
            # Validate models
            for model in "${_models_to_run[@]}"; do
                if [[ -z "${model_configs[$model]}" ]]; then
                    echo "Error: Unknown model '$model'. Valid options are:"
                    printf '  %s\n' "${!model_configs[@]}"
                    exit 1
                fi
            done
            models_to_run=("${_models_to_run[@]}")
            ;;
        --datasets)
            shift
            _datasets_to_run=()
            IFS=',' read -ra _datasets_to_run <<< "$1"
            # Validate datasets
            for dataset in "${_datasets_to_run[@]}"; do
                if [[ ! " ${datasets_to_run[@]} " =~ " ${dataset} " ]]; then
                    echo "Error: Unknown dataset '$dataset'. Valid options are:"
                    printf '  %s\n' "${datasets_to_run[@]}"
                    exit 1
                fi
            done
            datasets_to_run=("${_datasets_to_run[@]}")
            ;;
        *) echo "Unknown parameter passed: $1"; exit 1 ;;
    esac
    shift
done

export PYTHONPATH=.

for dataset in "${datasets_to_run[@]}"; do

    # Loop over all model configurations
    for model_name in "${models_to_run[@]}"; do

        # Remove any trailing comma from model name
        model_name=$(echo "$model_name" | sed 's/,$//')
        
        # Verify model exists in config
        if [[ -z "${model_configs[$model_name]}" ]]; then
            echo "Error: Unknown model '$model_name'"
            continue
        fi

        # Split the configuration string into parts
        IFS='|' read -r model_name_or_path accelerate_config nodes gpu <<< "${model_configs[$model_name]}"

        ### ✅ Determine GPU count based on node count
        if [[ "$nodes" == "1" ]]; then
            gpu=1
        else
            gpu=8
        fi


        for patch_ratio in 4 8; do

            echo "===================================================================="
            echo "Filtering for dataset: \"$dataset\" with model: \"$model_name\""
            echo "Config:"
            echo "  gpu: ${gpu}"
            echo "  Quota type: ${quotatype}"
            echo "  Patch ratio: ${patch_ratio}"
            echo "===================================================================="

            # Skip if stats.json already exists
            [[ -f "tmp/data/${dataset}/files/others/sfilter/${patch_ratio}x${patch_ratio}/$(basename ${model_name_or_path})/unrecognizable/stats.json" ]] && continue

            (
                # Run the main filtering script
                srun -p mllm_safety --quotatype=${quotatype} --gres=gpu:${gpu} --cpus-per-task=16 --time=30000 \
                    python src/tools/patches_sfilter.py \
                    --model_name_or_path "${model_name_or_path}" \
                    --data_config_path "data/${dataset}/test.yaml" \
                    --data_overwrite_args "data.eval[0].images_dirs[0]=tmp/data/${dataset}/files/${patch_ratio}x${patch_ratio}" \
                    --output_dir "tmp/data/${dataset}/files/others/sfilter/${patch_ratio}x${patch_ratio}/$(basename ${model_name_or_path})/unrecognizable" \
                    --per_device_eval_batch_size ${per_device_eval_batch_size} \
                    --rank_thresholds "${rank_thresholds}"

                # For each threshold, run stitching
                # Convert comma-separated list to array
                IFS=',' read -ra rank_threshold_list <<< "${rank_thresholds}"
                for threshold in "${rank_threshold_list[@]}"; do
                    srun -p p-cpu-new --quotatype=reserved --cpus-per-task=8 --time=30000 \
                        python src/tools/patches_stitch.py \
                        --src_patches_dir "tmp/data/${dataset}/files/others/sfilter/${patch_ratio}x${patch_ratio}/$(basename ${model_name_or_path})/unrecognizable/threshold${threshold}" \
                        --tgt_images_dir "tmp/data/${dataset}/files/others/sfilter/${patch_ratio}x${patch_ratio}/$(basename ${model_name_or_path})/unrecognizable_stitched/threshold${threshold}" \
                        --patch_ratio ${patch_ratio}
                done
            ) &

            sleep 1

        done


    done

done