import json
import sys
from collections import defaultdict
import os

def normalize_conclusion(conclusion):
    if isinstance(conclusion, list):
        if len(conclusion) == 1 and isinstance(conclusion[0], str) and "," in conclusion[0]:
            conclusion = conclusion[0]
        else:
            return [c.lower() for c in conclusion]
    if isinstance(conclusion, str):
        conclusion = conclusion.replace("～", ",").replace("，", ",")
        import re
        return [letter.lower() for letter in re.findall(r"[A-Za-z]", conclusion)]
    return []

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

def group_by_id_and_template(data):
    grouped = defaultdict(dict)
    for entry in data:
        grouped[entry["id"]][entry["template_name"]] = entry
    return grouped

def clean_entry(entry):
    """清理entry，去掉id, template_name, content_type, input等冗余字段"""
    cleaned = {}
    for key, value in entry.items():
        if key not in ["id", "template_name", "content_type", "input"]:
            cleaned[key] = value
    return cleaned

def generate_output_path(input_path):
    """根据输入路径自动生成输出路径"""
    # 将 data/eval/ 替换为 data/diff/
    output_path = input_path.replace("data/eval/", "data/diff/")
    
    # 在文件名的 .jsonl 前面加上 _diff
    if output_path.endswith(".jsonl"):
        output_path = output_path[:-6] + "_diff.jsonl"
    
    return output_path

def compare_and_collect(grouped):
    results = []
    for id_, templates in grouped.items():
        if "baseline" not in templates:
            continue
        
        baseline_entry = templates["baseline"]
        gt = set([g.lower() for g in baseline_entry.get("gt", [])])
        baseline_pred = set(normalize_conclusion(baseline_entry.get("pred", [])))
        
        # 收集template做对baseline做错的情况
        partial_correct_templates = {}
        baseline_wrong = len(baseline_pred & gt) == 0  # baseline做错
        
        if baseline_wrong:
            for tname, tentry in templates.items():
                if tname == "baseline":
                    continue
                t_pred = set(normalize_conclusion(tentry.get("pred", [])))
                if len(t_pred & gt) > 0:  # template做对
                    partial_correct_templates[tname] = clean_entry(tentry)
        
        # 收集template做错baseline做对的情况
        partial_wrong_templates = {}
        baseline_correct = len(baseline_pred & gt) > 0  # baseline做对
        
        if baseline_correct:
            for tname, tentry in templates.items():
                if tname == "baseline":
                    continue
                t_pred = set(normalize_conclusion(tentry.get("pred", [])))
                if len(t_pred & gt) == 0:  # template做错
                    partial_wrong_templates[tname] = clean_entry(tentry)
        
        # 如果有任何一种情况，就输出
        if partial_correct_templates or partial_wrong_templates:
            result = {
                "id": id_,
                "input": baseline_entry.get("input", ""),
                "baseline": clean_entry(baseline_entry),
                "partial_correct_templates": partial_correct_templates,
                "partial_wrong_templates": partial_wrong_templates
            }
            results.append(result)
    
    return results

def main():
    # ==========================================================================
    # 配置参数设置 - 在这里修改所有需要调整的参数
    # ==========================================================================

    # 控制输出条数的变量
    MAX_OUTPUT_RECORDS = 50
    # 默认输入路径
    DEFAULT_INPUT_PATH = "data/eval/jianfei/qwen2.5-7b-instruct/template_jianfei_text_None_seed-30_qwen2.5-7b-instruct_20250720_0154.jsonl"
    # 默认输出路径
    # 这里默认为自动生成(原来文件基础上增加_diff后缀，输出在data/diff目录下)，如果需要手动设置，可以修改DEFAULT_OUTPUT_PATH变量
    # DEFAULT_OUTPUT_PATH = "data/diff/suoyin/qwen2.5-7b-instruct/template_suoyin_text_200_seed-30_qwen2.5-7b-instruct_20250713_0355_diff.jsonl"
    
    # ==========================================================================
    # 以下代码自动运行，一般不需要修改
    # ==========================================================================

    # 支持命令行参数和代码中直接设置两种方式
    if len(sys.argv) >= 2:
        # 使用命令行参数
        input_path = sys.argv[1]
        output_path = sys.argv[2] if len(sys.argv) >= 3 else generate_output_path(input_path)
        print(f"使用命令行参数: 输入文件={input_path}, 输出文件={output_path}")
    elif len(sys.argv) == 1:
        # 使用代码中设置的默认路径
        input_path = DEFAULT_INPUT_PATH
        output_path = generate_output_path(input_path)
        print(f"使用代码中设置的路径: 输入文件={input_path}, 输出文件={output_path}")
    else:
        print("用法: python compare_templates_with_baseline.py input.jsonl [output.jsonl]")
        print("或者直接在代码中修改DEFAULT_INPUT_PATH变量")
        return
    data = load_jsonl(input_path)
    grouped = group_by_id_and_template(data)
    diff_results = compare_and_collect(grouped)
    
    # 限制输出条数
    if len(diff_results) > MAX_OUTPUT_RECORDS:
        diff_results = diff_results[:MAX_OUTPUT_RECORDS]
        print(f"警告: 结果数量超过限制({MAX_OUTPUT_RECORDS})，已截取前{MAX_OUTPUT_RECORDS}条")
    
    # 自动创建输出文件夹
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    with open(output_path, "w", encoding="utf-8") as f:
        for obj in diff_results:
            f.write(json.dumps(obj, ensure_ascii=False, separators=(",", ":")) + "\n")
    
    # 统计信息
    correct_count = sum(1 for obj in diff_results if "partial_correct_templates" in obj)
    wrong_count = sum(1 for obj in diff_results if "partial_wrong_templates" in obj)
    
    print(f"已输出 {len(diff_results)} 条差异样本到 {output_path}")
    print(f"其中 {correct_count} 条包含template做对baseline做错的情况")
    print(f"其中 {wrong_count} 条包含template做错baseline做对的情况")

if __name__ == "__main__":
    main()      
