'''
Author: swchen
Date: 2025-09-19 00:56:37
LastEditors: swchen
LastEditTime: 2025-09-25 17:06:34
FilePath: /SupervisorAgent/smolagents/examples/smolagents_benchmark_drop/supervisor_compare.py
Description: 

Copyright (c) 2025 by Shaowen Chen, All Rights Reserved. 
'''
import json

def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records

def remove_repeat_examples(records):
    task_id_set = set()
    examples = []
    for r in records:
        if r["prediction"] is None:
            continue
        if r["task_id"] in task_id_set:
            continue
        else:
            task_id_set.add(r["task_id"])
            examples.append(r)
    return examples

import datasets
def load_drop_dataset(run_set, level=None, dataset_path="./datasets/drop"):
    eval_ds = datasets.load_dataset(
        "json",  # 明确指定格式
        data_files={
            "test": f"{dataset_path}/*test.jsonl"
        },
        split=run_set
    )
    eval_ds = eval_ds.rename_column("context", "question")
    eval_ds = eval_ds.rename_column("completion", "true_answer")
    eval_ds = eval_ds.rename_column("id", "task_id")
    # if level is not None:
    #     eval_ds = eval_ds.filter(lambda x: x["level"] == level)
        
    # print(len(eval_ds.filter(lambda x: x["level"] == 5)))
    return eval_ds

import argparse
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file-path-1", type=str, default="", help="Path to the input JSONL file.")
    parser.add_argument("--file-path-2", type=str, default="", help="Path to the second JSONL file.")
    
    return parser.parse_args()


if __name__ == "__main__":
    raw_data = load_drop_dataset("test")
    args = parse_args()
    file_path_1 = args.file_path_1
    file_path_2 = args.file_path_2
    res1 = load_jsonl_multiline(file_path_1)
    res2 = load_jsonl_multiline(file_path_2)
    res1 = remove_repeat_examples(res1)
    res2 = remove_repeat_examples(res2)
    print(len(res1), len(res2))
    cor1 = 0
    cor2 = 0
    sum1 = 0
    sum2 = 0
    r1_total_tokens = 0
    r2_total_tokens = 0
    for r1 in res1:
        for r2 in res2:
            if r1["task_id"] == r2["task_id"]:
                r1_total_tokens += r1["token_counts"]["total_token_count"]
                r2_total_tokens += r2["token_counts"]["total_token_count"]
                
                cor1 += 1 if r1["is_correct"] else 0
                cor2 += 1 if r2["is_correct"] else 0
                sum1 += 1
                sum2 += 1
                break
    print(f"File1 Correct: {cor1}/{sum1}, Accuracy: {cor1/sum1}")
    print(f"File2 Correct: {cor2}/{sum2}, Accuracy: {cor2/sum2}")
    print(f"File1 Total Tokens(Avg): {r1_total_tokens/sum1}")
    print(f"File2 Total Tokens(Avg): {r2_total_tokens/sum2}")