'''
Author: swchen
Date: 2025-09-19 00:56:37
LastEditors: swchen
LastEditTime: 2025-09-25 17:30:00
FilePath: /SupervisorAgent/smolagents/examples/smolagents_benchmark_gsmhard/supervisor_compare.py
Description: 

Copyright (c) 2025 by Shaowen Chen, All Rights Reserved. 
'''
import json

def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records


import argparse
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file-path-1", type=str, default="", help="Path to the input JSONL file.")
    parser.add_argument("--file-path-2", type=str, default="", help="Path to the second JSONL file.")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    file_path_1 = args.file_path_1
    file_path_2 = args.file_path_2

    res1 = load_jsonl_multiline(file_path_1)
    res2 = load_jsonl_multiline(file_path_2)

    print(len(res1), len(res2))
    cor1 = 0
    cor2 = 0
    sum1 = 0
    sum2 = 0
    r1_total_tokens = 0
    r2_total_tokens = 0
    for r1 in res1:
        for r2 in res2:
            if r1["task_id"] == r2["task_id"]:
                r1_total_tokens += r1["token_counts"]["total_token_count"]
                r2_total_tokens += r2["token_counts"]["total_token_count"]
                if r1["prediction"] is None or r2["prediction"] is None:
                    continue
                sum1 += 1
                sum2 += 1
                break
    print(f"File1 Total Tokens: {r1_total_tokens}, (Avg): {r1_total_tokens/sum1}")
    print(f"File2 Total Tokens: {r2_total_tokens}, (Avg): {r2_total_tokens/sum2}")