#!/usr/bin/env python3
"""
Score calculation script for LVBench evaluation.
Computes accuracy metrics from model predictions.
"""

import argparse
import json
import os
import glob
from collections import defaultdict
import numpy as np


def load_predictions(jsonl_file):
    """Load predictions from JSONL file."""
    predictions = []
    with open(jsonl_file, 'r') as f:
        for line in f:
            predictions.append(json.loads(line.strip()))
    return predictions


def normalize_answer(answer):
    """Normalize answer for comparison."""
    if not answer:
        return ""
    
    # Convert to string and clean
    answer = str(answer).strip().upper()
    
    # Extract single letter if it's a multiple choice answer
    if len(answer) > 0 and answer[0] in ['A', 'B', 'C', 'D']:
        return answer[0]
    
    return answer


def calculate_accuracy(predictions):
    """Calculate accuracy metrics."""
    total = len(predictions)
    correct = 0
    
    # Detailed tracking
    by_video_type = defaultdict(lambda: {'correct': 0, 'total': 0})
    by_question_type = defaultdict(lambda: {'correct': 0, 'total': 0})
    
    for pred in predictions:
        model_answer = normalize_answer(pred.get('model_answer', ''))
        correct_answer = normalize_answer(pred.get('answer', ''))
        
        is_correct = model_answer == correct_answer
        if is_correct:
            correct += 1
        
        # Track by video type
        video_type = pred.get('video_type', 'unknown')
        by_video_type[video_type]['total'] += 1
        by_video_type[video_type]['correct'] += 1 if is_correct else 0
        
        # Track by question type
        question_types = pred.get('question_type', [])
        if isinstance(question_types, list):
            for q_type in question_types:
                by_question_type[q_type]['total'] += 1
                by_question_type[q_type]['correct'] += 1 if is_correct else 0
        elif question_types:
            by_question_type[str(question_types)]['total'] += 1
            by_question_type[str(question_types)]['correct'] += 1 if is_correct else 0
    
    # Calculate overall accuracy
    overall_accuracy = correct / total if total > 0 else 0
    
    # Calculate detailed accuracies
    detailed_results = {
        'overall': {
            'accuracy': overall_accuracy,
            'correct': correct,
            'total': total
        },
        'by_video_type': {},
        'by_question_type': {}
    }
    
    # Video type breakdown
    for v_type, stats in by_video_type.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        detailed_results['by_video_type'][v_type] = {
            'accuracy': accuracy,
            'correct': stats['correct'],
            'total': stats['total']
        }
    
    # Question type breakdown
    for q_type, stats in by_question_type.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        detailed_results['by_question_type'][q_type] = {
            'accuracy': accuracy,
            'correct': stats['correct'],
            'total': stats['total']
        }
    
    return detailed_results


def main():
    parser = argparse.ArgumentParser(description="Calculate LVBench evaluation scores")
    parser.add_argument('--output_path', required=True, help="Directory containing prediction files")
    parser.add_argument('--score_path', required=True, help="Path to save score JSON")
    parser.add_argument('--pattern', default="*.jsonl", help="Pattern to match prediction files")
    
    args = parser.parse_args()
    
    # Find all prediction files
    prediction_files = glob.glob(os.path.join(args.output_path, args.pattern))
    
    if not prediction_files:
        print(f"No prediction files found in {args.output_path} with pattern {args.pattern}")
        return
    
    print(f"Found {len(prediction_files)} prediction files")
    
    # Load all predictions
    all_predictions = []
    for file_path in prediction_files:
        print(f"Loading predictions from {file_path}")
        predictions = load_predictions(file_path)
        all_predictions.extend(predictions)
    
    print(f"Total predictions loaded: {len(all_predictions)}")
    
    # Calculate scores
    scores = calculate_accuracy(all_predictions)
    
    # Add metadata
    result = {
        'scores': scores,
        'metadata': {
            'total_predictions': len(all_predictions),
            'prediction_files': prediction_files,
            'evaluation_type': 'LVBench'
        }
    }
    
    # Save scores
    os.makedirs(os.path.dirname(args.score_path), exist_ok=True)
    with open(args.score_path, 'w') as f:
        json.dump(result, f, indent=2)
    
    print(f"\n=== LVBench Evaluation Results ===")
    print(f"Overall Accuracy: {scores['overall']['accuracy']:.4f} ({scores['overall']['correct']}/{scores['overall']['total']})")
    
    print(f"\n--- By Video Type ---")
    for v_type, stats in scores['by_video_type'].items():
        print(f"{v_type}: {stats['accuracy']:.4f} ({stats['correct']}/{stats['total']})")
    
    print(f"\n--- By Question Type ---")
    for q_type, stats in scores['by_question_type'].items():
        print(f"{q_type}: {stats['accuracy']:.4f} ({stats['correct']}/{stats['total']})")
    
    print(f"\nDetailed scores saved to: {args.score_path}")


if __name__ == "__main__":
    main()