import json

def calculate_top_k_accuracy(jsonl_file_path):
    # 初始化统计变量（精确统计）
    top1_correct_exact = 0
    top2_correct_exact = 0
    top3_correct_exact = 0

    # 初始化统计变量（不精确统计）
    top1_correct_inexact = 0
    top2_correct_inexact = 0
    top3_correct_inexact = 0

    total_entries = 0
    skipped_entries = []  # 记录被跳过的数据

    # 打开 JSONL 文件并逐行处理
    with open(jsonl_file_path, 'r', encoding='utf-8') as file:
        for idx, line in enumerate(file):
            #if idx >= 128:
                #break
            # 解析 JSON 数据
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                skipped_entries.append({"index": idx, "reason": "Invalid JSON"})
                continue

            # 获取 reviews 中 synth 的键名
            review_key = list(data.get("reviews", {}).get("synth", {}).keys())
            if not review_key:
                skipped_entries.append({"index": idx, "reason": "Missing 'reviews.synth' key"})
                continue
            key = review_key[0]

            # 获取 evaluations 中对应键的数据
            name_list = list(data.get("predictions").keys())
            model_name = name_list[0]
            evaluations = data.get("evaluations", {}).get(model_name, {}).get("synth", {}).get(key)
            if not evaluations or not isinstance(evaluations, list):
                skipped_entries.append({"index": idx, "reason": "Missing or invalid 'evaluations' data"})
                continue

            # 根据 evaluations 的实际长度计算 top-k 正确率
            eval_length = len(evaluations)

            # 精确统计
            if eval_length >= 1 and evaluations[0] == 1:
                top1_correct_exact += 1
            if eval_length >= 2 and 1 in evaluations[:2]:
                top2_correct_exact += 1
            if eval_length >= 3 and 1 in evaluations[:3]:
                top3_correct_exact += 1

            # 不精确统计（0.5 或 1 都算正确）
            if eval_length >= 1 and evaluations[0] >= 0.5:
                top1_correct_inexact += 1
            if eval_length >= 2 and any(x >= 0.5 for x in evaluations[:2]):
                top2_correct_inexact += 1
            if eval_length >= 3 and any(x >= 0.5 for x in evaluations[:3]):
                top3_correct_inexact += 1

            # 总数据条数加 1
            total_entries += 1

    # 计算精确统计的正确率
    top1_accuracy_exact = top1_correct_exact / total_entries if total_entries > 0 else 0.0
    top2_accuracy_exact = top2_correct_exact / total_entries if total_entries > 0 else 0.0
    top3_accuracy_exact = top3_correct_exact / total_entries if total_entries > 0 else 0.0

    # 计算不精确统计的正确率
    top1_accuracy_inexact = top1_correct_inexact / total_entries if total_entries > 0 else 0.0
    top2_accuracy_inexact = top2_correct_inexact / total_entries if total_entries > 0 else 0.0
    top3_accuracy_inexact = top3_correct_inexact / total_entries if total_entries > 0 else 0.0

    # 返回结果
    return {
        "exact": {
            "top1_accuracy": top1_accuracy_exact,
            "top2_accuracy": top2_accuracy_exact,
            "top3_accuracy": top3_accuracy_exact
        },
        "inexact": {
            "top1_accuracy": top1_accuracy_inexact,
            "top2_accuracy": top2_accuracy_inexact,
            "top3_accuracy": top3_accuracy_inexact
        },
        "total_entries": total_entries,
        "skipped_entries": skipped_entries
    }

# 示例调用
jsonl_file_path = "eval_results/llama2_7b_nodefense.jsonl"  # 替换为你的 JSONL 文件路径
results = calculate_top_k_accuracy(jsonl_file_path)

# 打印结果
print("Exact Statistics:")
print(f"Top-1 Accuracy: {results['exact']['top1_accuracy']:.2%}")
print(f"Top-2 Accuracy: {results['exact']['top2_accuracy']:.2%}")
print(f"Top-3 Accuracy: {results['exact']['top3_accuracy']:.2%}")

print("\nInexact Statistics:")
print(f"Top-1 Accuracy: {results['inexact']['top1_accuracy']:.2%}")
print(f"Top-2 Accuracy: {results['inexact']['top2_accuracy']:.2%}")
print(f"Top-3 Accuracy: {results['inexact']['top3_accuracy']:.2%}")

print(f"\nTotal Entries Processed: {results['total_entries']}")
print(f"Skipped Entries: {len(results['skipped_entries'])}")

# 打印被跳过的数据详情
if results['skipped_entries']:
    print("\nDetails of Skipped Entries:")
    for entry in results['skipped_entries']:
        print(f"Index: {entry['index']}, Reason: {entry['reason']}")