#!/bin/bash

# VNBench Evaluation Script for RWKV-Qwen Hybrid Model
# Usage: ./eval_vnbench_rwkv_qwen.sh <model_path> <video_dir> <output_dir> <num_chunks>

# --- CONFIGURATION: Replace these paths with your own setup ---
# Navigate to evaluation script directory
cd /path/to/your/eval/eval_vnbench

# --- DEFAULT PATHS (override with command-line arguments) ---
MODEL_PATH=${1:-"/path/to/your/model/checkpoint"}
VIDEO_DIR=${2:-"/path/to/your/videos"}
OUTPUT_DIR=${3:-"${MODEL_PATH}/eval_output_vnbench"}
NUM_CHUNKS=${4:-1}

# --- AUTO-DETECT AVAILABLE GPUS ---
gpu_list=$(nvidia-smi --query-gpu=index --format=csv,noheader | tr '\n' ',' | sed 's/,$//')
read -a GPULIST <<< ${gpu_list//,/ }

# Use number of available GPUs if num_chunks is default (1)
if [ $NUM_CHUNKS -eq 1 ]; then
    NUM_CHUNKS=${#GPULIST[@]}
fi

echo "=== VNBench Evaluation with RWKV-Qwen Hybrid Model ==="
echo "Model Path: $MODEL_PATH"
echo "Video Directory: $VIDEO_DIR"
echo "Output Directory: $OUTPUT_DIR"
echo "Number of Chunks: $NUM_CHUNKS"
echo "Available GPUs: ${GPULIST[@]}"

# Create output directory
mkdir -p $OUTPUT_DIR

# Set ground truth annotation file
GT_FILE="/path/to/your/annotations/VNBench-main-4try.json"

echo "Evaluating VNBench..."
# Launch evaluation for each chunk in parallel
for IDX in $(seq 0 $((NUM_CHUNKS-1))); do
    if [ $IDX -lt ${#GPULIST[@]} ]; then
        CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python model_vnbench_qa_rwkv_qwen.py \
            --model-path $MODEL_PATH \
            --video_dir $VIDEO_DIR \
            --gt_file $GT_FILE \
            --output_dir $OUTPUT_DIR \
            --output_name pred \
            --num-chunks $NUM_CHUNKS \
            --chunk-idx $IDX  &
    else
        echo "Warning: Not enough GPUs for chunk $IDX"
    fi
done

# Wait for all background processes to finish
wait

echo "=== Evaluation Complete ==="
echo "Results saved to: $OUTPUT_DIR"
echo "Output files: $(ls -la $OUTPUT_DIR/*.jsonl 2>/dev/null || echo 'No .jsonl files yet')"
echo "Total questions processed: $(cat $OUTPUT_DIR/*.jsonl 2>/dev/null | wc -l)"

echo ""
echo "=== Calculating Final Scores ==="

# Generate scores for all chunks
python calculate_score.py \
    --output_path $OUTPUT_DIR \
    --score_path $OUTPUT_DIR/score.json

echo "=== Final Results ==="
echo "Overall accuracy report saved to: $OUTPUT_DIR/score.json"

# Display overall accuracy if score file exists
if [ -f "$OUTPUT_DIR/score.json" ]; then
    ACCURACY=$(python3 -c "
import sys, json;
try:
    data = json.load(open('$OUTPUT_DIR/score.json'));
    print(f'{data[\"scores\"][\"overall\"][\"accuracy\"]:.4f}')
except Exception as e:
    print('Error reading score:', e)
")
    echo "Overall accuracy: $ACCURACY"
fi

# Display detailed breakdown
echo ""
echo "=== Detailed Breakdown ==="
if [ -f "$OUTPUT_DIR/score.json" ]; then
    echo "Type-specific accuracies:"
    python3 -c "
import json
try:
    with open('$OUTPUT_DIR/score.json', 'r') as f:
        data = json.load(f)
    scores = data.get('scores', {})
    for key, value in scores.items():
        if key not in ['overall', 'try_counts', 'length_groups']:
            print(f'  {key}: {value[\"accuracy\"]:.4f} ({value[\"correct\"]}/{value[\"total\"]})')
except Exception as e:
    print('Error parsing score breakdown:', e)
"

    echo ""
    echo "Try count accuracies:"
    python3 -c "
import json
try:
    with open('$OUTPUT_DIR/score.json', 'r') as f:
        data = json.load(f)
    try_counts = data.get('scores', {}).get('try_counts', {})
    for try_key, try_data in try_counts.items():
        print(f'  {try_key}: {try_data[\"accuracy\"]:.4f} ({try_data[\"correct\"]}/{try_data[\"total\"]})')
except Exception as e:
    print('Error parsing try count breakdown:', e)
"

    echo ""
    echo "Length-based accuracies:"
    python3 -c "
import json
try:
    with open('$OUTPUT_DIR/score.json', 'r') as f:
        data = json.load(f)
    length_groups = data.get('scores', {}).get('length_groups', {})
    for length_key, length_data in length_groups.items():
        print(f'  {length_key}: {length_data[\"accuracy\"]:.4f} ({length_data[\"correct\"]}/{length_data[\"total\"]})')
except Exception as e:
    print('Error parsing length group breakdown:', e)
"
fi

echo ""
echo "=== VNBench Evaluation Complete ==="