import re
import argparse
import json
import os
from collections import defaultdict
def extract_and_choose_answer(pattern, model_answer):
    matches = re.findall(pattern, model_answer)
    option_count = {}
    for match in matches:
        option_count[match.upper()] = option_count.get(match.upper(), 0) + 1
    if not option_count:
        loose_pattern = r'[A-Fa-f]'
        if pattern == loose_pattern:
            return None
        else:
            return extract_and_choose_answer(loose_pattern, model_answer)
    max_count = max(option_count.values())
    max_options = [option for option, count in option_count.items() if count == max_count]
    return max_options[0]
def generate_score(result_path, score_path):
    json_objects = []
    # 查找所有.jsonl文件
    for filename in os.listdir(result_path):
        print(f"Processing file: {filename}")
        if filename.endswith('.jsonl'):
            file_path = os.path.join(result_path, filename)
            with open(file_path, 'r', encoding='utf-8') as file:
                # 逐行读取JSONL文件
                for line_num, line in enumerate(file, 1):
                    line = line.strip()
                    if line:  # 跳过空行
                        try:
                            data = json.loads(line)
                            json_objects.append(data)
                        except json.JSONDecodeError as e:
                            print(f"Error parsing line {line_num} in {filename}: {e}")
                            continue
    # 如果没有找到.jsonl文件，尝试查找.json文件
    if not json_objects:
        print("No .jsonl files found, trying .json files...")
        for filename in os.listdir(result_path):
            if filename.endswith('.json'):
                file_path = os.path.join(result_path, filename)
                with open(file_path, 'r', encoding='utf-8') as file:
                    # 逐行读取JSON文件（假设是JSONL格式）
                    for line_num, line in enumerate(file, 1):
                        line = line.strip()
                        if line:  # 跳过空行
                            try:
                                data = json.loads(line)
                                json_objects.append(data)
                            except json.JSONDecodeError as e:
                                print(f"Error parsing line {line_num} in {filename}: {e}")
                                continue
    if not json_objects:
        print("No valid JSON/JSONL data found!")
        return
    # 统计变量
    all = defaultdict(int)
    right = defaultdict(int)
    accuracy_dict = defaultdict(float)
    incorrect_items = []
    # 按时长分类统计
    duration_stats = {
        'short': {'total': 0, 'correct': 0},
        'medium': {'total': 0, 'correct': 0},
        'long': {'total': 0, 'correct': 0}
    }
    print(f'****Total items: {len(json_objects)}****')
    for item in json_objects:
        # 获取视频时长分类
        duration = item.get("duration", "unknown")
        # 获取模型回答和正确答案
        model_answer = item.get("model_answer", "").strip()
        correct_answer = item.get("answer", "").strip().upper()
        # 统计总数
        all[duration] += 1
        if duration in duration_stats:
            duration_stats[duration]['total'] += 1
        # 尝试从模型回答中提取选项
        pattern = r'[A-F]'
        extracted_answer = extract_and_choose_answer(pattern, model_answer)
        # 比较答案
        is_correct = False
        if extracted_answer:
            if extracted_answer == correct_answer:
                is_correct = True
        else:
            # 如果无法提取选项，检查模型回答中是否包含正确答案
            if correct_answer in model_answer.upper():
                is_correct = True
        if is_correct:
            right[duration] += 1
            if duration in duration_stats:
                duration_stats[duration]['correct'] += 1
        else:
            # 记录错误项
            incorrect_items.append({
                "video_id": item.get("video_id"),
                "question": item.get("question"),
                "model_answer": model_answer,
                "correct_answer": correct_answer,
                "duration": duration
            })
    print(f'Total counts: {dict(all)}')
    print(f'Correct counts: {dict(right)}')
    # 计算准确率
    total_accuracy = 0.0
    valid_categories = 0
    for key in all:
        if all[key] > 0:
            accuracy_dict[key] = right[key] / all[key]
            total_accuracy += accuracy_dict[key]
            valid_categories += 1
    overall_average = total_accuracy / valid_categories if valid_categories > 0 else 0.0
    # 计算时长分类准确率
    duration_accuracies = {}
    for duration, stats in duration_stats.items():
        if stats['total'] > 0:
            duration_accuracies[duration] = stats['correct'] / stats['total']
        else:
            duration_accuracies[duration] = 0.0
    # 准备结果
    results = {
        "individual_accuracies": accuracy_dict,
        "overall_average": overall_average,
        "duration_breakdown": duration_accuracies,
        "duration_stats": duration_stats
    }
    # 保存结果
    with open(score_path, "w", encoding="utf8") as f:
        json.dump(results, f, indent=4)
    print(f'***********Score results saved to {score_path}*************')
    # 打印时长分类结果
    print("=====================================")
    print("Duration Breakdown")
    print("=====================================")
    for duration, accuracy in duration_accuracies.items():
        stats = duration_stats[duration]
        print(f"{duration.capitalize()}: {accuracy*100:.1f}% ({stats['correct']}/{stats['total']})")
    print(f"Overall: {overall_average*100:.1f}%")
    # 保存错误项
    incorrect_path = os.path.join(os.path.dirname(score_path), "incorrect_items.json")
    with open(incorrect_path, "w", encoding="utf8") as f:
        json.dump(incorrect_items, f, indent=4)
    print(f'***********Incorrect items saved to {incorrect_path}*************')
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_path", '-a', type=str, help="path to the output data")
    parser.add_argument("--score_path", '-o', type=str, help="path to the score")
    args = parser.parse_args()
    generate_score(args.output_path, args.score_path)