#!/bin/bash
# MultiEditRewardBench Evaluation Script (Full Parameter Checkpoint)
# Loads all pair files at once for efficient vllm inference

# Set the script directory
SHELL_FOLDER=$(cd "$(dirname "$0")";pwd)
cd $SHELL_FOLDER


# Arguments: checkpoint_path [temperature]
CHECKPOINT_PATH=$1
TEMPERATURE=${2:-0.7}

if [ -z "$CHECKPOINT_PATH" ]; then
    echo "Usage: bash run.sh <checkpoint_path> [temperature]"
    echo "Example: bash run.sh /path/to/checkpoint 0.7"
    exit 1
fi

# Score aggregation method
SCORE_AGGREGATION="weighted_power"
WEIGHTED_POWER_PARAMS="0.6 0.4 0.5 0.5 0.8"

# Format temperature for output naming
temp_str="${TEMPERATURE/./}"
temp_suffix="t${temp_str}"

# Extract checkpoint name for output directory
CKPT_NAME=$(basename ${CHECKPOINT_PATH})
OUTPUT_DIR="results/multiedit-${CKPT_NAME}_${temp_suffix}"

# Configuration
TENSOR_PARALLEL_SIZE=8
MAX_NUM_SEQS=256
MAX_MODEL_LEN=6144
MAX_NUM_BATCHED_TOKENS=32768
GPU_MEMORY_UTILIZATION=0.85
BATCH_SIZE=512
SCORE_RANGE=25
INTERLEAVED="--interleaved"


MAX_PIXELS=589824   
MIN_PIXELS=3136     

# Create directories
mkdir -p ${OUTPUT_DIR}
mkdir -p logs

echo "=============================================="
echo "MultiEditRewardBench Evaluation (Full Checkpoint)"
echo "Checkpoint: ${CHECKPOINT_PATH}"
echo "Output: ${OUTPUT_DIR}"
echo "Temperature: ${TEMPERATURE}"
echo "Score Aggregation: ${SCORE_AGGREGATION}"
[[ -n "$INTERLEAVED" ]] && echo "Interleaved reasoning enabled"
echo "=============================================="

# Run inference on ALL pair types at once (pass data directory)
echo ""
echo "[Running inference on all pair types...]"

CMD="CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python torch_offline.py \
    --data_path data \
    --output_path ${OUTPUT_DIR}/all_results.json \
    --model ${CHECKPOINT_PATH} \
    --tensor_parallel_size ${TENSOR_PARALLEL_SIZE} \
    --max_num_seqs ${MAX_NUM_SEQS} \
    --max_model_len ${MAX_MODEL_LEN} \
    --max_num_batched_tokens ${MAX_NUM_BATCHED_TOKENS} \
    --gpu_memory_utilization ${GPU_MEMORY_UTILIZATION} \
    --batch_size ${BATCH_SIZE} \
    --score_range ${SCORE_RANGE} \
    --score_aggregation ${SCORE_AGGREGATION} \
    --dataset_type multiedit \
    --temperature ${TEMPERATURE} \
    --top_p 0.9 \
    --top_k 20 \
    --max_tokens 4096 \
    --num_workers 16 \
    --max_pixels ${MAX_PIXELS} \
    --min_pixels ${MIN_PIXELS} \
    --with_region \
    ${INTERLEAVED}"

if [ "${SCORE_AGGREGATION}" = "weighted_power" ]; then
    CMD="${CMD} --weighted_power_params ${WEIGHTED_POWER_PARAMS}"
fi

eval ${CMD} 2>&1 | tee -a logs/${CKPT_NAME}_${temp_suffix}.log

if [ $? -ne 0 ]; then
    echo "❌ Inference failed!"
    exit 1
fi

echo ""
echo "=============================================="
echo "Calculating Accuracy..."
echo "=============================================="

python calculate_accuracy.py \
    --result_file ${OUTPUT_DIR}/all_results.json \
    --output ${OUTPUT_DIR}/accuracy_report.json 2>&1 | tee -a logs/${CKPT_NAME}_${temp_suffix}_accuracy.log

if [ $? -ne 0 ]; then
    echo "❌ Accuracy calculation failed!"
    exit 1
fi

echo ""
echo "=============================================="
echo "✅ Pipeline completed successfully!"
echo "=============================================="
echo "📁 Results saved to: ${OUTPUT_DIR}/"
echo "📄 Files created:"
echo "   - all_results.json (all pair types)"
echo "   - accuracy_report.json"
echo "=============================================="
