#!/bin/bash
# Usage: ./dataset_inf.sh <input_jsonl> <output_jsonl> [model_path]

set -e

export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-16}

DEFAULT_MODEL_PATH=YOUR_MODEL_PATH
Bench=MMAU

DEFAULT_MAX_TOKENS=1024
DEFAULT_TEMPERATURE=0.0

model_name="${DEFAULT_MODEL_PATH##*/}"
case "${Bench}" in
    MMAU)
        DEFAULT_INPUT_PATH="YOUR_INPUT_PATH"
        DEFAULT_OUTPUT_PATH="YOUR_OUTPUT_PATH/${model_name}.jsonl"
        ;;
    *)
        DEFAULT_INPUT_PATH="YOUR_DEFAULT_PATH"
        ;;
esac

SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
PYTHON_SCRIPT="$(dirname "$SCRIPT_DIR")/dataset_inf.py"

# Get the number of visible GPUs
get_visible_gpu_count() {
    if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
        echo "0"
        return
    fi
    gpu_count=$(echo "$CUDA_VISIBLE_DEVICES" | tr -d ' ' | tr ',' '\n' | wc -l)
    echo "$gpu_count"
}

INPUT_FILE="${1:-$DEFAULT_INPUT_PATH}"
OUTPUT_FILE="${2:-$DEFAULT_OUTPUT_PATH}"
MODEL_PATH="${3:-$DEFAULT_MODEL_PATH}"

MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-$DEFAULT_MAX_TOKENS}"
TEXT_TEMPERATURE="${TEXT_TEMPERATURE:-$DEFAULT_TEMPERATURE}"

if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
    echo "Error: please set CUDA_VISIBLE_DEVICES."
    echo "Example: CUDA_VISIBLE_DEVICES=0,1,2,3 $0 input.jsonl output.jsonl"
    exit 1
fi

GPU_COUNT=$(get_visible_gpu_count)
if [ "$GPU_COUNT" -eq 1 ]; then
    MODE="single"
    NPROC=1
elif [ "$GPU_COUNT" -gt 1 ]; then
    MODE="multi-gpu"
    NPROC="$GPU_COUNT"
else
    echo "Error: invalid CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
    exit 1
fi

OUTPUT_DIR=$(dirname "$OUTPUT_FILE")
if [ ! -d "$OUTPUT_DIR" ]; then
    echo "Creating output directory: $OUTPUT_DIR"
    mkdir -p "$OUTPUT_DIR"
fi

TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="${OUTPUT_DIR}/inference_${MODE}_${TIMESTAMP}.log"

echo "=================================="
echo "Dataset Inference Configuration"
echo "=================================="
echo "Mode:            $MODE"
echo "GPU count:       $GPU_COUNT"
echo "Visible GPUs:    $CUDA_VISIBLE_DEVICES"
echo "Input file:      $INPUT_FILE"
echo "Output file:     $OUTPUT_FILE"
echo "Model path:      $MODEL_PATH"
echo "Max tokens:      $MAX_NEW_TOKENS"
echo "Text temperature:$TEXT_TEMPERATURE"
if [ "$MODE" = "multi-gpu" ]; then
    echo "Processes:       $NPROC"
fi
echo "Log file:        $LOG_FILE"
echo "=================================="

# Check GPU availability
echo "Checking GPU availability..."
if command -v nvidia-smi &> /dev/null; then
    for gpu_id in $(echo "$CUDA_VISIBLE_DEVICES" | tr ',' ' '); do
        nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "GPU $gpu_id: failed to query"
    done
else
    echo "Warning: nvidia-smi not found; cannot query GPU status."
fi

echo ""
echo "Starting dataset inference..."
echo "Logs will be saved to: $LOG_FILE"
echo "Current CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"

case $MODE in
    "single")
        CUDA_VISIBLE_DEVICES="$CUDA_VISIBLE_DEVICES" python "$PYTHON_SCRIPT" \
            --input_file "$INPUT_FILE" \
            --output_file "$OUTPUT_FILE" \
            --model_path "$MODEL_PATH" \
            --max_new_tokens "$MAX_NEW_TOKENS" \
            --text_temperature "$TEXT_TEMPERATURE" \
            --resume \
            2>&1 | tee "$LOG_FILE"
        ;;
    "multi-gpu")
        CUDA_VISIBLE_DEVICES="$CUDA_VISIBLE_DEVICES" torchrun --nproc_per_node="$NPROC" "$PYTHON_SCRIPT" \
            --input_file "$INPUT_FILE" \
            --output_file "$OUTPUT_FILE" \
            --model_path "$MODEL_PATH" \
            --max_new_tokens "$MAX_NEW_TOKENS" \
            --text_temperature "$TEXT_TEMPERATURE" \
            --resume \
            --gather_on_rank0 \
            2>&1 | tee "$LOG_FILE"
        ;;
esac
