import argparse
import json
import os
from collections import defaultdict
import numpy as np


def load_predictions(output_dir):
    """Load all prediction files from the output directory."""
    predictions = []
    
    # Load all .jsonl files
    for file in os.listdir(output_dir):
        if file.endswith('.jsonl'):
            file_path = os.path.join(output_dir, file)
            with open(file_path, 'r') as f:
                for line in f:
                    predictions.append(json.loads(line.strip()))
    
    return predictions


def calculate_accuracy(predictions):
    """Calculate accuracy for different question types and source JSON files."""
    # Group by question type and source JSON
    type_scores = defaultdict(list)
    source_scores = defaultdict(list)
    overall_scores = []
    
    for pred in predictions:
        model_answer = pred['model_answer'].strip().upper()
        correct_answer_letter = pred.get('correct_answer_letter', '').upper()
        
        # Convert model answer to single letter
        if len(model_answer) > 1:
            # Try to extract the letter
            if model_answer.startswith(('A', 'B', 'C', 'D', 'E')):
                model_answer = model_answer[0]
            elif 'A' in model_answer:
                model_answer = 'A'
            elif 'B' in model_answer:
                model_answer = 'B'
            elif 'C' in model_answer:
                model_answer = 'C'
            elif 'D' in model_answer:
                model_answer = 'D'
            elif 'E' in model_answer:
                model_answer = 'E'
        
        # Check if answer is correct (compare letters)
        is_correct = model_answer == correct_answer_letter and correct_answer_letter != ''
        
        # Add to overall scores
        overall_scores.append(is_correct)
        
        # Add to type-specific scores
        question_type = pred.get('question_type', 'unknown')
        type_scores[question_type].append(is_correct)
        
        # Add to source-specific scores
        source_json = pred.get('source_json', 'unknown')
        source_scores[source_json].append(is_correct)
    
    # Calculate accuracies
    results = {
        'overall': {
            'accuracy': np.mean(overall_scores) if overall_scores else 0,
            'total': len(overall_scores),
            'correct': sum(overall_scores)
        }
    }
    
    # Calculate type-specific accuracies
    for q_type, scores in type_scores.items():
        results[q_type] = {
            'accuracy': np.mean(scores) if scores else 0,
            'total': len(scores),
            'correct': sum(scores)
        }
    
    # Calculate source-specific accuracies
    for source, scores in source_scores.items():
        results[source] = {
            'accuracy': np.mean(scores) if scores else 0,
            'total': len(scores),
            'correct': sum(scores)
        }
    
    return results


def main():
    parser = argparse.ArgumentParser(description="Calculate MLVU evaluation scores")
    parser.add_argument('--output_path', type=str, required=True, help="Path to the output directory containing prediction files")
    parser.add_argument('--score_path', type=str, required=True, help="Path to save the score file")
    
    args = parser.parse_args()
    
    # Load predictions
    predictions = load_predictions(args.output_path)
    print(f"Loaded {len(predictions)} predictions")
    
    # Calculate scores
    scores = calculate_accuracy(predictions)
    
    # Prepare results
    results = {
        'scores': scores,
        'total_predictions': len(predictions),
        'summary': {
            'overall_accuracy': scores['overall']['accuracy'],
            'total_questions': scores['overall']['total'],
            'correct_answers': scores['overall']['correct']
        }
    }
    
    # Save results
    with open(args.score_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Scores saved to: {args.score_path}")
    print(f"Overall accuracy: {scores['overall']['accuracy']:.4f} ({scores['overall']['correct']}/{scores['overall']['total']})")
    
    # Print type-specific accuracies
    print("\nType-specific accuracies:")
    for q_type, score_data in scores.items():
        if q_type != 'overall' and not q_type.startswith('1_') and not q_type.startswith('2_') and not q_type.startswith('3_') and not q_type.startswith('4_') and not q_type.startswith('5_') and not q_type.startswith('6_') and not q_type.startswith('7_'):
            print(f"  {q_type}: {score_data['accuracy']:.4f} ({score_data['correct']}/{score_data['total']})")
    
    # Print source-specific accuracies (MLVU tasks)
    print("\nMLVU task accuracies:")
    for source, score_data in scores.items():
        if source != 'overall' and (source.startswith('1_') or source.startswith('2_') or source.startswith('3_') or source.startswith('4_') or source.startswith('5_') or source.startswith('6_') or source.startswith('7_')):
            print(f"  {source}: {score_data['accuracy']:.4f} ({score_data['correct']}/{score_data['total']})")


if __name__ == "__main__":
    main()