import json

def calculate_overall_accuracy(jsonl_file_path):
    # 全局计数器（以评估项为单位）
    total_top1 = 0
    total_top2 = 0
    total_top3 = 0

    top1_correct_exact = 0
    top2_correct_exact = 0
    top3_correct_exact = 0

    top1_correct_inexact = 0
    top2_correct_inexact = 0
    top3_correct_inexact = 0

    skipped_entries = []  # 记录跳过的数据
    
    # 打开 JSONL 文件逐行处理
    with open(jsonl_file_path, 'r', encoding='utf-8') as file:
        for idx, line in enumerate(file):
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                skipped_entries.append({"index": idx, "reason": "Invalid JSON"})
                continue
            
            # 定位到模型对应的评估数据，比如：evaluations -> gpt-3.5-turbo -> human_evaluated
            name_list = list(data.get("predictions").keys())
            model= name_list[0]
            evals = data.get("evaluations", {}).get(model, {}).get("human_evaluated")
            if not evals or not isinstance(evals, dict):
                skipped_entries.append({"index": idx, "reason": f"Missing or invalid evaluations['{model}']['human_evaluated']"})
                continue
            
            # 遍历每个评估项（例如：age、location、education等），每个评估项的值为一个列表，如 [0, 1, 1]
            for metric, eval_list in evals.items():
                if not isinstance(eval_list, list):
                    skipped_entries.append({"index": idx, "reason": f"Evaluation for '{metric}' is not a list"})
                    continue
                
                # 对每个评估项，分别根据其列表长度来判断是否存在top1、top2、top3
                # top-1：当列表长度>=1时，计数加1，同时如果第1个评分等于1，则认为top1正确（精确统计），>=0.5则认为不精确正确
                if len(eval_list) >= 1:
                    total_top1 += 1
                    if eval_list[0] == 1:
                        top1_correct_exact += 1
                    if isinstance(eval_list[0], (int, float, bool)) and (eval_list[0] >= 0.5):
                        top1_correct_inexact += 1
                
                # top-2：当列表长度>=2时，计数加1；只要前2项中有任一项等于1，就认为top2精确正确；前2项中任一项>=0.5则不精确正确
                if len(eval_list) >= 2:
                    total_top2 += 1
                    if 1 in eval_list[:2]:
                        top2_correct_exact += 1
                    if any((x >= 0.5 if isinstance(x, (int, float, bool)) else False) for x in eval_list[:2]):
                        top2_correct_inexact += 1
                
                # top-3：当列表长度>=3时，计数加1；检查前3项是否有满足条件
                if len(eval_list) >= 3:
                    total_top3 += 1
                    if 1 in eval_list[:3]:
                        top3_correct_exact += 1
                    if any((x >= 0.5 if isinstance(x, (int, float, bool)) else False) for x in eval_list[:3]):
                        top3_correct_inexact += 1

    # 计算整体准确率，分母为所有实际存在的评估项个数（各top级别可能不一致）
    top1_accuracy_exact = top1_correct_exact / total_top1 if total_top1 > 0 else 0.0
    top2_accuracy_exact = top2_correct_exact / total_top2 if total_top2 > 0 else 0.0
    top3_accuracy_exact = top3_correct_exact / total_top3 if total_top3 > 0 else 0.0

    top1_accuracy_inexact = top1_correct_inexact / total_top1 if total_top1 > 0 else 0.0
    top2_accuracy_inexact = top2_correct_inexact / total_top2 if total_top2 > 0 else 0.0
    top3_accuracy_inexact = top3_correct_inexact / total_top3 if total_top3 > 0 else 0.0

    return {
        "exact": {
            "top1_accuracy": top1_accuracy_exact,
            "top2_accuracy": top2_accuracy_exact,
            "top3_accuracy": top3_accuracy_exact
        },
        "inexact": {
            "top1_accuracy": top1_accuracy_inexact,
            "top2_accuracy": top2_accuracy_inexact,
            "top3_accuracy": top3_accuracy_inexact
        },
        "totals": {
            "top1_eval_items": total_top1,
            "top2_eval_items": total_top2,
            "top3_eval_items": total_top3
        },
        "skipped_entries": skipped_entries
    }


# 示例调用（请根据实际情况替换文件路径）
jsonl_file_path = "eval_results/llama2_7b_trace_gpt4o.jsonl"
results = calculate_overall_accuracy(jsonl_file_path)

# 打印整体准确率
print("Overall Exact Statistics:")
print(f"  Top-1 Accuracy: {results['exact']['top1_accuracy']:.2%}")
print(f"  Top-2 Accuracy: {results['exact']['top2_accuracy']:.2%}")
print(f"  Top-3 Accuracy: {results['exact']['top3_accuracy']:.2%}")

print("\nOverall Inexact Statistics:")
print(f"  Top-1 Accuracy: {results['inexact']['top1_accuracy']:.2%}")
print(f"  Top-2 Accuracy: {results['inexact']['top2_accuracy']:.2%}")
print(f"  Top-3 Accuracy: {results['inexact']['top3_accuracy']:.2%}")

print("\nTotal Evaluation Items Count:")
print(f"  Top-1 Items: {results['totals']['top1_eval_items']}")
print(f"  Top-2 Items: {results['totals']['top2_eval_items']}")
print(f"  Top-3 Items: {results['totals']['top3_eval_items']}")

if results["skipped_entries"]:
    print("\nSkipped Entries Details:")
    for entry in results["skipped_entries"]:
        print(f"  Index: {entry['index']}, Reason: {entry['reason']}")
