#!/usr/bin/env python3
"""
Score calculation script for LongVideoBench evaluation.
Computes accuracy metrics from model predictions.
"""
import argparse
import json
import os
import glob
from collections import defaultdict
import numpy as np
def load_predictions(jsonl_file):
    """Load predictions from JSONL file."""
    predictions = []
    with open(jsonl_file, 'r') as f:
        for line in f:
            predictions.append(json.loads(line.strip()))
    return predictions
def normalize_answer(answer):
    """Normalize answer for comparison, handling both numeric and letter formats."""
    if not answer:
        return ""
    # Convert to string and clean
    answer = str(answer).strip().upper()
    # Handle numeric answers (convert to letter: 0->A, 1->B, 2->C, 3->D, 4->E, etc.)
    if answer.isdigit():
        try:
            # Convert to integer and map to letter (0->A, 1->B, ...)
            return chr(65 + int(answer))  # 65 is ASCII for 'A'
        except:
            return answer
    # Extract single letter if it's a multiple choice answer
    if len(answer) > 0 and answer[0] in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']:
        return answer[0]
    # Handle cases like "A)" or "(B)"
    for char in answer:
        if char in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']:
            return char
    return answer
def calculate_accuracy(predictions):
    """Calculate accuracy metrics."""
    total = len(predictions)
    correct = 0
    # Detailed tracking
    by_question_category = defaultdict(lambda: {'correct': 0, 'total': 0})
    by_topic_category = defaultdict(lambda: {'correct': 0, 'total': 0})
    by_level = defaultdict(lambda: {'correct': 0, 'total': 0})
    by_duration = defaultdict(lambda: {'correct': 0, 'total': 0})
    for pred in predictions:
        model_answer = normalize_answer(pred.get('model_answer', ''))
        correct_answer = normalize_answer(pred.get('correct_choice', ''))
        is_correct = model_answer == correct_answer
        if is_correct:
            correct += 1
        # Track by question category
        question_category = pred.get('question_category', 'unknown')
        by_question_category[question_category]['total'] += 1
        by_question_category[question_category]['correct'] += 1 if is_correct else 0
        # Track by topic category
        topic_category = pred.get('topic_category', 'unknown')
        by_topic_category[topic_category]['total'] += 1
        by_topic_category[topic_category]['correct'] += 1 if is_correct else 0
        # Track by level
        level = pred.get('level', 'unknown')
        by_level[level]['total'] += 1
        by_level[level]['correct'] += 1 if is_correct else 0
        # Track by duration (group by ranges)
        duration = pred.get('duration', 0)
        if duration < 60:
            duration_range = "0-60s"
        elif duration < 300:
            duration_range = "60-300s"
        elif duration < 600:
            duration_range = "300-600s"
        else:
            duration_range = "600s+"
        by_duration[duration_range]['total'] += 1
        by_duration[duration_range]['correct'] += 1 if is_correct else 0
    # Calculate overall accuracy
    overall_accuracy = correct / total if total > 0 else 0
    # Calculate detailed accuracies
    detailed_results = {
        'overall': {
            'accuracy': overall_accuracy,
            'correct': correct,
            'total': total
        },
        'by_question_category': {},
        'by_topic_category': {},
        'by_level': {},
        'by_duration': {}
    }
    # Question category breakdown
    for q_category, stats in by_question_category.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        detailed_results['by_question_category'][q_category] = {
            'accuracy': accuracy,
            'correct': stats['correct'],
            'total': stats['total']
        }
    # Topic category breakdown
    for t_category, stats in by_topic_category.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        detailed_results['by_topic_category'][t_category] = {
            'accuracy': accuracy,
            'correct': stats['correct'],
            'total': stats['total']
        }
    # Level breakdown
    for level, stats in by_level.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        detailed_results['by_level'][level] = {
            'accuracy': accuracy,
            'correct': stats['correct'],
            'total': stats['total']
        }
    # Duration breakdown
    for duration_range, stats in by_duration.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        detailed_results['by_duration'][duration_range] = {
            'accuracy': accuracy,
            'correct': stats['correct'],
            'total': stats['total']
        }
    return detailed_results
def main():
    parser = argparse.ArgumentParser(description="Calculate LongVideoBench evaluation scores")
    parser.add_argument('--output_path', required=True, help="Directory containing prediction files")
    parser.add_argument('--score_path', required=True, help="Path to save score JSON")
    parser.add_argument('--pattern', default="*.jsonl", help="Pattern to match prediction files")
    args = parser.parse_args()
    # Find all prediction files
    prediction_files = glob.glob(os.path.join(args.output_path, args.pattern))
    if not prediction_files:
        print(f"No prediction files found in {args.output_path} with pattern {args.pattern}")
        return
    print(f"Found {len(prediction_files)} prediction files")
    # Load all predictions
    all_predictions = []
    for file_path in prediction_files:
        print(f"Loading predictions from {file_path}")
        predictions = load_predictions(file_path)
        all_predictions.extend(predictions)
    print(f"Total predictions loaded: {len(all_predictions)}")
    # Calculate scores
    scores = calculate_accuracy(all_predictions)
    # Add metadata
    result = {
        'scores': scores,
        'metadata': {
            'total_predictions': len(all_predictions),
            'prediction_files': prediction_files,
            'evaluation_type': 'LongVideoBench'
        }
    }
    # Save scores
    os.makedirs(os.path.dirname(args.score_path), exist_ok=True)
    with open(args.score_path, 'w') as f:
        json.dump(result, f, indent=2)
    print(f"\n=== LongVideoBench Evaluation Results ===")
    print(f"Overall Accuracy: {scores['overall']['accuracy']:.4f} ({scores['overall']['correct']}/{scores['overall']['total']})")
    print(f"\n--- By Question Category ---")
    for q_category, stats in scores['by_question_category'].items():
        print(f"{q_category}: {stats['accuracy']:.4f} ({stats['correct']}/{stats['total']})")
    print(f"\n--- By Topic Category ---")
    for t_category, stats in scores['by_topic_category'].items():
        print(f"{t_category}: {stats['accuracy']:.4f} ({stats['correct']}/{stats['total']})")
    print(f"\n--- By Level ---")
    for level, stats in scores['by_level'].items():
        print(f"{level}: {stats['accuracy']:.4f} ({stats['correct']}/{stats['total']})")
    print(f"\n--- By Duration Range ---")
    for duration_range, stats in scores['by_duration'].items():
        print(f"{duration_range}: {stats['accuracy']:.4f} ({stats['correct']}/{stats['total']})")
    print(f"\nDetailed scores saved to: {args.score_path}")
if __name__ == "__main__":
    main()