import json
import argparse

def calculate_f1_score(prediction, reference):
    """
    计算两段文本之间基于词汇重叠的F1得分
    常用于问答系统评估
    """
    # 预处理：转小写，分词
    pred_tokens = prediction.lower().split()
    ref_tokens = reference.lower().split()
    
    # 转换为集合，去除重复词
    pred_set = set(pred_tokens)
    ref_set = set(ref_tokens)
    
    # 计算交集
    common_tokens = pred_set.intersection(ref_set)
    
    # 计算precision, recall, f1
    if len(pred_set) == 0:
        precision = 0.0
    else:
        precision = len(common_tokens) / len(pred_set)
    
    if len(ref_set) == 0:
        recall = 0.0
    else:
        recall = len(common_tokens) / len(ref_set)
    
    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1
    }


def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records


def remove_repeat_examples(records):
    task_id_set = set()
    examples = []
    for r in records:
        if r["prediction"] is None:
            continue
        if r["task_id"] in task_id_set:
            continue
        else:
            task_id_set.add(r["task_id"])
            examples.append(r)
    return examples


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file-path-1", type=str, default="")
    parser.add_argument("--file-path-2", type=str, default="")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    file_path_1 = args.file_path_1
    file_path_2 = args.file_path_2
    res1 = load_jsonl_multiline(file_path_1)
    res2 = load_jsonl_multiline(file_path_2)
    cor1_F1 = 0
    cor2_F1 = 0
    sum1 = 0
    sum2 = 0
    r1_total_tokens = 0
    r2_total_tokens = 0
    for r1 in res1:
        for r2 in res2:
            if r1["task_id"] == r2["task_id"]:
                if r1["prediction"] is None or r2["prediction"] is None:
                    continue
                cor1_F1 += calculate_f1_score(r1["prediction"], r1["true_answer"])['f1']
                cor2_F1 += calculate_f1_score(r2["prediction"], r2["true_answer"])['f1']
                r1_total_tokens += r1["token_counts"]["total_token_count"]
                r2_total_tokens += r2["token_counts"]["total_token_count"]
                sum1 += 1
                sum2 += 1
                break
    print(f"File1 F1: {cor1_F1}/{sum1}, Accuracy: {cor1_F1/sum1}")
    print(f"File2 F1: {cor2_F1}/{sum2}, Accuracy: {cor2_F1/sum2}")
    print(f"File1 Total Tokens: {r1_total_tokens}, (Avg): {r1_total_tokens/sum1}")
    print(f"File2 Total Tokens: {r2_total_tokens}, (Avg): {r2_total_tokens/sum2}")