#!/bin/bash
#SBATCH -J evaluate_eagle_x4_8b_average-job
#SBATCH -o evaluate_eagle_x4_8b_average-%j.out
#SBATCH -e evaluate_eagle_x4_8b_average-%j.err


set -euo pipefail



# Configuration
MODEL=${MODEL:-Eagle-X4-8B-Plus}
NUM_ENCODERS=${NUM_ENCODERS:-4}
WORK_DIR=${WORK_DIR:-outputs}
PYTHON=${PYTHON:-python}
MODIFY_SCRIPT=${MODIFY_SCRIPT:-scripts/modify_eagle_mask_index.py}
NPROC=${NPROC:-8}

# Evaluation function
run_eval() {
    echo "[INFO] Starting evaluation..."
    torchrun --nproc-per-node=${NPROC} run.py \
        --model "${MODEL}" \
        --data MME MMBench_DEV_EN SEEDBench_IMG GQA_TestDev_Balanced \
        ScienceQA_VAL MMMU_DEV_VAL MathVista_MINI AI2D_TEST \
        ChartQA_TEST OCRBench TextVQA_VAL DocVQA_VAL \
        MMVP RealWorldQA CV-Bench-2D CV-Bench-3D \
        --judge qwen-plus \
        --work-dir "${WORK_DIR}" \
        --reuse \
        --verbose
}

MODEL_OUT_DIR="${WORK_DIR}/${MODEL}"
mkdir -p "${MODEL_OUT_DIR}"

# Find the newest timestamp directory
find_timestamp_dir() {
    find "${MODEL_OUT_DIR}" -maxdepth 1 -type d -name "T20*" -printf '%T@ %p\n' 2>/dev/null | \
    sort -nr | head -1 | cut -d' ' -f2- || echo ""
}

# Clean and rename output directory
clean_and_rename() {
    local mask_bits="$1"
    
    # Remove non-directory files
    find "${MODEL_OUT_DIR}" -mindepth 1 -maxdepth 1 -type f -delete 2>/dev/null || true
    
    # Find and rename timestamp directory
    local latest_dir
    latest_dir=$(find_timestamp_dir)
    
    if [[ -n "${latest_dir}" && -d "${latest_dir}" ]]; then
        local base_name
        base_name=$(basename "${latest_dir}")
        local new_name="${MODEL_OUT_DIR}/${mask_bits}_${base_name}"
        
        echo "[INFO] Renaming ${base_name} -> ${mask_bits}_${base_name}"
        mv "${latest_dir}" "${new_name}"
    else
        echo "[WARN] No timestamp directory found to rename"
    fi
}

# Apply masking configuration
apply_mask() {
    if [[ $# -eq 0 ]]; then
        echo "[INFO] Applying NO masking"
        ${PYTHON} "${MODIFY_SCRIPT}" -m
    else
        local mask_indices=("$@")
        echo "[INFO] Masking encoders: ${mask_indices[*]}"
        ${PYTHON} "${MODIFY_SCRIPT}" -m "${mask_indices[@]}"
    fi
}

# Generate bit pattern string (e.g., "0_1_0_1")
generate_bit_pattern() {
    local mask_int=$1
    local bits=()
    
    for ((i=0; i<NUM_ENCODERS; i++)); do
        if (( (mask_int >> i) & 1 )); then
            bits+=(1)
        else
            bits+=(0)
        fi
    done
    
    (IFS='_'; echo "${bits[*]}")
}

# Convert mask integer to list of indices
mask_int_to_indices() {
    local mask_int=$1
    local indices=()
    
    for ((i=0; i<NUM_ENCODERS; i++)); do
        if (( (mask_int >> i) & 1 )); then
            indices+=("$i")
        fi
    done
    
    # Return empty string if no indices, otherwise return space-separated indices
    if [[ ${#indices[@]} -eq 0 ]]; then
        echo ""
    else
        echo "${indices[@]}"
    fi
}

check_existing_evaluation() {
    local bit_pattern="$1"
    local existing_dirs
    existing_dirs=$(find "${MODEL_OUT_DIR}" -maxdepth 1 -type d -name "${bit_pattern}_T20*" 2>/dev/null)
    
    if [[ -n "${existing_dirs}" ]]; then
        echo "[INFO] Found existing evaluation for pattern ${bit_pattern}:"
        echo "${existing_dirs}" | head -1
        return 0  # Evaluation exists
    else
        return 1  # No evaluation found
    fi
}


# Main execution
main() {
    echo "[INFO] ======================================"
    echo "[INFO] Eagle Vision Encoder Masking Evaluation"
    echo "[INFO] ======================================"
    echo "[INFO] Model: ${MODEL}"
    echo "[INFO] Number of encoders: ${NUM_ENCODERS}"
    echo "[INFO] Output directory: ${MODEL_OUT_DIR}"
    echo "[INFO] Modify script: ${MODIFY_SCRIPT}"
    echo "[INFO] ======================================"

    echo "[INFO] Enabling proxy..."
    echo "[INFO] ======================================"

    local total_combinations=$(( (1 << NUM_ENCODERS) - 1 ))  # Exclude all-masked
    echo "[INFO] Total combinations to test: ${total_combinations}"
    
    # Check for existing evaluations
    echo "[INFO] Checking for existing evaluations..."
    local existing_count=0
    for ((mask=0; mask < (1 << NUM_ENCODERS); mask++)); do
        if (( mask == (1 << NUM_ENCODERS) - 1 )); then
            continue  # Skip all-masked
        fi
        local pattern
        pattern=$(generate_bit_pattern "${mask}")
        if check_existing_evaluation "${pattern}"; then
            existing_count=$((existing_count + 1))
        fi
    done
    echo "[INFO] Found ${existing_count} existing evaluations"
    echo "[INFO] Will evaluate $((total_combinations - existing_count)) remaining combinations"
    echo "[INFO] ======================================"
    
    local combo=0
    local skipped=0
    local evaluated=0
    
    # Iterate through all possible masking combinations
    for ((mask=0; mask < (1 << NUM_ENCODERS); mask++)); do
        # Skip all-masked configuration (must keep at least one encoder)
        if (( mask == (1 << NUM_ENCODERS) - 1 )); then
            echo "[INFO] Skipping all-masked configuration"
            continue
        fi
        
        combo=$((combo + 1))
        local bit_pattern
        bit_pattern=$(generate_bit_pattern "${mask}")
        
        # Check if evaluation already exists
        if check_existing_evaluation "${bit_pattern}"; then
            skipped=$((skipped + 1))
            echo
            echo "======================================"
            echo "[INFO] Combination ${combo}/${total_combinations} - SKIPPING"
            echo "[INFO] Mask integer: ${mask}"
            echo "[INFO] Bit pattern: ${bit_pattern}"
            echo "[INFO] ⏭️  Evaluation already exists - skipping"
            echo "======================================"
            continue
        fi
        
        # Get indices as string and convert to array safely
        local indices_string
        indices_string=$(mask_int_to_indices "${mask}")
        
        local indices_array=()
        if [[ -n "${indices_string}" ]]; then
            read -a indices_array <<< "${indices_string}"
        fi
        
        evaluated=$((evaluated + 1))
        echo
        echo "======================================"
        echo "[INFO] Combination ${combo}/${total_combinations} - EVALUATING"
        echo "[INFO] Mask integer: ${mask}"
        echo "[INFO] Bit pattern: ${bit_pattern}"
        if [[ ${#indices_array[@]} -eq 0 ]]; then
            echo "[INFO] Indices to mask: (none)"
        else
            echo "[INFO] Indices to mask: ${indices_array[*]}"
        fi
        echo "[INFO] 🚀 Starting new evaluation..."
        echo "======================================"
        
        # Step 1: Apply masking configuration
        if [[ ${#indices_array[@]} -eq 0 ]]; then
            if ! apply_mask; then
                echo "[ERROR] Failed to apply mask configuration"
                continue
            fi
        else
            if ! apply_mask "${indices_array[@]}"; then
                echo "[ERROR] Failed to apply mask configuration"
                continue
            fi
        fi
        
        # Step 2: Run evaluation
        if run_eval; then
            echo "[INFO] ✓ Evaluation completed successfully"
        else
            echo "[ERROR] ✗ Evaluation failed"
            # Continue to next combination even if evaluation fails
        fi
        
        # Step 3: Clean and rename output directory
        clean_and_rename "${bit_pattern}"
        
        # Brief pause between evaluations
        echo "[INFO] Waiting 5 seconds before next combination..."
        sleep 5
    done
    
    echo
    echo "[INFO] ======================================"
    echo "[INFO] All masking evaluations completed!"
    echo "[INFO] Total combinations: ${total_combinations}"
    echo "[INFO] Already existed: ${skipped}"
    echo "[INFO] Newly evaluated: ${evaluated}"
    echo "[INFO] Results saved in: ${MODEL_OUT_DIR}"
    echo "[INFO] ======================================"
    
    # List all result directories
    echo "[INFO] All result directories:"
    find "${MODEL_OUT_DIR}" -maxdepth 1 -type d -name "*_T20*" | sort || true
    
    echo
    echo "[INFO] Summary by mask pattern:"
    for ((mask=0; mask < (1 << NUM_ENCODERS); mask++)); do
        if (( mask == (1 << NUM_ENCODERS) - 1 )); then
            continue
        fi
        local pattern
        pattern=$(generate_bit_pattern "${mask}")
        local status="❌ Missing"
        if check_existing_evaluation "${pattern}" >/dev/null 2>&1; then
            status="✅ Complete"
        fi
        echo "[INFO] ${pattern}: ${status}"
    done
}

# Run main function
main "$@"