import argparse
import json
import os
from collections import defaultdict
import numpy as np


def load_predictions(output_dir):
    """Load all prediction files from the output directory."""
    predictions = []
    
    # Load all .jsonl files
    for file in os.listdir(output_dir):
        if file.endswith('.jsonl'):
            file_path = os.path.join(output_dir, file)
            with open(file_path, 'r') as f:
                for line in f:
                    predictions.append(json.loads(line.strip()))
    
    return predictions


def calculate_accuracy(predictions):
    """Calculate accuracy for different question types and other metrics."""
    # Group by question type
    type_scores = defaultdict(list)
    overall_scores = []
    
    # Additional metrics
    try_counts = defaultdict(list)
    length_groups = defaultdict(list)  # short, medium, long videos
    
    for pred in predictions:
        model_answer = pred['model_answer'].strip().upper()
        correct_answer = pred['gt_option'].strip().upper()
        
        # Convert model answer to single letter
        if len(model_answer) > 1:
            # Try to extract the letter
            if model_answer.startswith(('A', 'B', 'C', 'D')):
                model_answer = model_answer[0]
            elif 'A' in model_answer:
                model_answer = 'A'
            elif 'B' in model_answer:
                model_answer = 'B'
            elif 'C' in model_answer:
                model_answer = 'C'
            elif 'D' in model_answer:
                model_answer = 'D'
        
        # Check if answer is correct
        is_correct = model_answer == correct_answer
        
        # Add to overall scores
        overall_scores.append(is_correct)
        
        # Add to type-specific scores
        question_type = pred.get('type', 'unknown')
        type_scores[question_type].append(is_correct)
        
        # Add to try count analysis
        try_count = pred.get('try', 0)
        try_counts[try_count].append(is_correct)
        
        # Add to length-based analysis
        length = pred.get('length', 0)
        if length <= 10:
            length_groups['short'].append(is_correct)
        elif length <= 30:
            length_groups['medium'].append(is_correct)
        else:
            length_groups['long'].append(is_correct)
    
    # Calculate accuracies
    results = {
        'overall': {
            'accuracy': np.mean(overall_scores) if overall_scores else 0,
            'total': len(overall_scores),
            'correct': sum(overall_scores)
        }
    }
    
    # Calculate type-specific accuracies
    for q_type, scores in type_scores.items():
        results[q_type] = {
            'accuracy': np.mean(scores) if scores else 0,
            'total': len(scores),
            'correct': sum(scores)
        }
    
    # Calculate try count accuracies
    try_results = {}
    for try_count, scores in try_counts.items():
        try_results[f'try_{try_count}'] = {
            'accuracy': np.mean(scores) if scores else 0,
            'total': len(scores),
            'correct': sum(scores)
        }
    results['try_counts'] = try_results
    
    # Calculate length-based accuracies
    length_results = {}
    for length_group, scores in length_groups.items():
        length_results[length_group] = {
            'accuracy': np.mean(scores) if scores else 0,
            'total': len(scores),
            'correct': sum(scores)
        }
    results['length_groups'] = length_results
    
    return results


def main():
    parser = argparse.ArgumentParser(description="Calculate VNBench evaluation scores")
    parser.add_argument('--output_path', type=str, required=True, help="Path to the output directory containing prediction files")
    parser.add_argument('--score_path', type=str, required=True, help="Path to save the score file")
    
    args = parser.parse_args()
    
    # Load predictions
    predictions = load_predictions(args.output_path)
    print(f"Loaded {len(predictions)} predictions")
    
    # Calculate scores
    scores = calculate_accuracy(predictions)
    
    # Prepare results
    results = {
        'scores': scores,
        'total_predictions': len(predictions),
        'summary': {
            'overall_accuracy': scores['overall']['accuracy'],
            'total_questions': scores['overall']['total'],
            'correct_answers': scores['overall']['correct']
        }
    }
    
    # Save results
    with open(args.score_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Scores saved to: {args.score_path}")
    print(f"Overall accuracy: {scores['overall']['accuracy']:.4f} ({scores['overall']['correct']}/{scores['overall']['total']})")
    
    # Print type-specific accuracies
    print("\nType-specific accuracies:")
    for q_type, score_data in scores.items():
        if q_type not in ['overall', 'try_counts', 'length_groups']:
            print(f"  {q_type}: {score_data['accuracy']:.4f} ({score_data['correct']}/{score_data['total']})")
    
    # Print try count accuracies
    print("\nTry count accuracies:")
    for try_key, score_data in scores.get('try_counts', {}).items():
        print(f"  {try_key}: {score_data['accuracy']:.4f} ({score_data['correct']}/{score_data['total']})")
    
    # Print length-based accuracies
    print("\nLength-based accuracies:")
    for length_key, score_data in scores.get('length_groups', {}).items():
        print(f"  {length_key}: {score_data['accuracy']:.4f} ({score_data['correct']}/{score_data['total']})")


if __name__ == "__main__":
    main()