#!/bin/bash
# MMRB2 Image Editing Benchmark Evaluation Script
#
# Usage:
#   bash run_lora.sh <checkpoint_path> [temperature]
# Example:
#   bash run_lora.sh /path/to/checkpoint 0.7

# Set the script directory
SHELL_FOLDER=$(cd "$(dirname "$0")";pwd)
cd $SHELL_FOLDER

# Uncomment and modify to activate your conda environment
# source "$(dirname $(which conda))/../etc/profile.d/conda.sh"
# conda activate your_environment

# Arguments
CHECKPOINT_PATH=$1
TEMPERATURE=${2:-0.7}

if [ -z "$CHECKPOINT_PATH" ]; then
    echo "Usage: bash run_lora.sh <checkpoint_path> [temperature]"
    echo "Example: bash run_lora.sh /path/to/checkpoint 0.7"
    exit 1
fi

# Score aggregation method
# - "min": Take minimum of SC dimensions
# - "mean": Take average of SC dimensions
# - "weighted_power": Use weighted power formula
SCORE_AGGREGATION="min"

# Weighted Power parameters (only used when SCORE_AGGREGATION="weighted_power")
# Formula: ((w1*s1+w2*s2)**a) * ((w3*s3+w4*s4)**(1-a))
WEIGHTED_POWER_PARAMS="0.5 0.5 0.5 0.5 0.5"

# Temperature formatting
temp_str="${TEMPERATURE/./}"
temp_suffix="t${temp_str}"

# Configuration - modify these paths for your environment
BENCHMARK_FILE="./data/edit.json"  # Path to benchmark file
MODEL_PATH="Qwen/Qwen3-VL-8B-Instruct"  # Base model
LORA_PATH="${CHECKPOINT_PATH}"  # LoRA checkpoint path

# Format checkpoint name for output directory
CKPT_NAME=$(basename ${CHECKPOINT_PATH})
OUTPUT_DIR="results/mmrb2-${CKPT_NAME}_${temp_suffix}"

# Inference configuration
TENSOR_PARALLEL_SIZE=8
MAX_NUM_SEQS=256
MAX_MODEL_LEN=12240
MAX_NUM_BATCHED_TOKENS=65536
GPU_MEMORY_UTILIZATION=0.6
BATCH_SIZE=512
SCORE_RANGE=25
INTERLEAVED="--interleaved"

# Add timestamp to output filename
ADD_TIMESTAMP=true

# Single image only mode
SINGLE_IMAGE_ONLY=false

# Create output directory and logs
mkdir -p ${OUTPUT_DIR}
mkdir -p logs

# Build inference command
CMD="CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python torch_offline.py \
    --data_path ${BENCHMARK_FILE} \
    --output_path ${OUTPUT_DIR}/results.json \
    --model ${MODEL_PATH}"

# Add LoRA path only if it's not empty
if [ -n "${LORA_PATH}" ]; then
    CMD="${CMD} --lora_path ${LORA_PATH}"
    echo "Using LoRA: ${LORA_PATH}"
else
    echo "Using base model without LoRA"
fi

CMD="${CMD} \
    --tensor_parallel_size ${TENSOR_PARALLEL_SIZE} \
    --max_num_seqs ${MAX_NUM_SEQS} \
    --max_model_len ${MAX_MODEL_LEN} \
    --max_num_batched_tokens ${MAX_NUM_BATCHED_TOKENS} \
    --gpu_memory_utilization ${GPU_MEMORY_UTILIZATION} \
    --batch_size ${BATCH_SIZE} \
    --score_range ${SCORE_RANGE} \
    --score_aggregation ${SCORE_AGGREGATION} \
    --temperature ${TEMPERATURE} \
    --top_p 0.9 \
    --top_k 20 \
    --max_tokens 4096 \
    --num_workers 16 \
    --with_region \
    ${INTERLEAVED}"

# Add weighted_power_params if using weighted_power aggregation
if [ "${SCORE_AGGREGATION}" = "weighted_power" ]; then
    CMD="${CMD} --weighted_power_params ${WEIGHTED_POWER_PARAMS}"
fi

# Add timestamp flag if enabled
if [ "${ADD_TIMESTAMP}" = true ]; then
    CMD="${CMD} --add_timestamp"
fi

# Add single_image_only flag if enabled
if [ "${SINGLE_IMAGE_ONLY}" = true ]; then
    CMD="${CMD} --single_image_only"
    echo "Only evaluating single-image editing tasks"
else
    echo "Evaluating all tasks (single-image + multi-image fusion)"
fi

echo "Starting MMRB2 inference..."
echo "Model: ${MODEL_PATH}"
echo "LoRA: ${LORA_PATH}"
echo "Benchmark: ${BENCHMARK_FILE}"
echo "Output: ${OUTPUT_DIR}"
echo "Temperature: ${TEMPERATURE}"
echo "=========================================="

eval ${CMD} 2>&1 | tee -a logs/mmrb2_${CKPT_NAME}_${temp_suffix}.log

if [ $? -ne 0 ]; then
    echo "Inference failed!"
    exit 1
fi

echo ""
echo "Inference completed successfully!"
echo "Calculating accuracy..."
    
# Find the results file
if [ "${ADD_TIMESTAMP}" = true ]; then
    RESULT_FILE=$(ls -t ${OUTPUT_DIR}/results_*.json 2>/dev/null | head -1)
    if [ -z "$RESULT_FILE" ]; then
        echo "No timestamped results file found!"
        exit 1
    fi
else
    RESULT_FILE="${OUTPUT_DIR}/results.json"
fi

python calculate_accuracy.py \
    --result_file "${RESULT_FILE}" \
    --benchmark_file ${BENCHMARK_FILE} \
    --output_file "${RESULT_FILE%.json}_evaluation.json" 2>&1 | tee -a logs/mmrb2_${CKPT_NAME}_${temp_suffix}.log

if [ $? -ne 0 ]; then
    echo "Accuracy calculation failed!"
    exit 1
fi

echo ""
echo "=========================================="
echo "Pipeline completed successfully!"
echo "=========================================="
echo "Results saved to: ${OUTPUT_DIR}/"
echo "Files created:"
echo "   - results*.json: Raw inference results"
echo "   - results*_evaluation.json: Accuracy metrics"
echo "=========================================="
