import json
import ast  # 用于安全解析字符串形式的列表
from collections import defaultdict

def extract_data_source_from_path(path: str) -> str:
    if "SEED-Bench-2" in path:
        return "SeedBench2"
    elif "GQA" in path:
        return "GQA"
    elif "MME" in path:
        return "MME"
    elif "MMBENCH" in path:
        return "MMBench"
    elif "MMBench" in path:
        return "MMBench"
    elif "OK-VQA" in path:
        return "OK-VQA"
    elif "OKVQA" in path:
        return "OK-VQA"
    elif "SEEDBENCH" in path:
        return "SeedBench2"
    elif "TallyQA" in path:
        return "TallyQA"
    else:
        return "Unknown"

def process_json(json_path, output_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    for item in data.get("data", []):
        origin_data = item.get("origin_data", {})
        image_str = origin_data.get("image", "")

        try:
            image_list = ast.literal_eval(image_str)  # 把字符串变成真正的 list
        except Exception as e:
            print(f"[WARNING] 无法解析 image 字段: {image_str}")
            continue

        if image_list:
            data_source = extract_data_source_from_path(image_list[0])
            origin_data["data_source"] = data_source  # 更新 data_source 字段

    with open(output_path, 'w', encoding='utf-8') as fout:
        json.dump(data, fout, indent=2, ensure_ascii=False)

    print(f"✅ 数据源识别完成，输出保存至: {output_path}")

def merge_json_files(input_files, output_file):
    merged_data = []
    # 遍历每个 JSON 文件
    for input_file in input_files:
        with open(input_file, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)  # 读取 JSON 文件
                if "data" in data:
                    merged_data.extend(data["data"])  # 合并 data 列表
                else:
                    print(f"[WARNING] 文件 {input_file} 不包含 'data' 字段")
                    print(f"[WARNING] 文件 {input_file} 的内容格式: {type(data)}")
                    merged_data.extend(data)  # 尝试直接合并数据
            except json.decoder.JSONDecodeError as e:
                print(f"JSON 解码错误: {e} 文件: {input_file}")
                continue

    merged_data2 = []
    for item in merged_data:
        extend_data = item.get("extend_data_2", {})
        answer = extend_data.get("program_answer", "")
        # print(image_str)
        if answer:
            # 如果 answer 不为空，则添加到 merged_data2
            merged_data2.append(item)
    # 创建合并后的 JSON
    merged_json = {"data": merged_data2}
    print(f"🔍 读取数据量{len(merged_data)}，有效数据量: {len(merged_data2)}")
    # 写入合并后的结果
    with open(output_file, 'w', encoding='utf-8') as fout:
        json.dump(merged_json, fout, indent=2, ensure_ascii=False)

    print(f"✅ 合并完成，输出保存至: {output_file}")

def compare_and_merge_json(input_file_1, input_file_2, output_file):
    # 读取第一个 JSON 文件
    with open(input_file_1, 'r', encoding='utf-8') as f1:
        data1 = json.load(f1)

    # 读取第二个 JSON 文件
    with open(input_file_2, 'r', encoding='utf-8') as f2:
        data2 = json.load(f2)

    # 获取 data 数量
    data1_length = len(data1.get("data", []))
    data2_length = len(data2.get("data", []))

    # 比较两个文件的 data 数量
    if data1_length == data2_length:
        print(f"数据数量相同: {data1_length} 条")

        # 遍历两个文件的数据，检查 source_id 和 question 是否匹配
        for item1, item2 in zip(data1["data"], data2["data"]):
            origin_data_1 = item1.get("origin_data", {})
            origin_data_2 = item2.get("origin_data", {})

            # 检查 source_id 是否匹配
            if origin_data_1.get("source_id") != origin_data_2.get("source_id"):
                print(f"[ERROR] source_id 不匹配: {origin_data_2.get('source_id')} != {origin_data_1.get('source_id')}")
                return  # 如果 source_id 不匹配，停止执行

            # 检查 question 是否匹配
            if origin_data_2.get("question") != item1.get("extend_data_1", {}).get("question"):
                print(f"[ERROR] source_id 不匹配: {origin_data_2.get('source_id')} != {origin_data_1.get('source_id')}")
                print(f"[ERROR] question 不匹配: {origin_data_2.get('question')} != {item1.get('extend_data_1', {}).get('question')}")
                return  # 如果 question 不匹配，停止执行

            # 如果匹配，合并 extend_data_1 到 extend_data_2
            item1["extend_data_2"] = item2.get("extend_data_1", {})

        # 创建合并后的 JSON
        merged_json = {"data": data1["data"]}

        # 写入合并后的结果
        with open(output_file, 'w', encoding='utf-8') as fout:
            json.dump(merged_json, fout, indent=2, ensure_ascii=False)

        print(f"✅ 合并完成，输出保存至: {output_file}")
    else:
        print(f"数据数量不同：文件1数据数量={data1_length}，文件2数据数量={data2_length}")

# process_json(output_file, output_file)
def modify_duplicate_source_ids(input_files, output_file):
    merged_data = []
    seen_source_ids = set()  # 用来记录已出现过的 source_id
    source_id_count = defaultdict(int)  # 统计每个 source_id 出现的次数
    mode_num = 0  # 用来记录修改的次数

    # 遍历每个 JSON 文件
    for input_file in input_files:
        with open(input_file, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)  # 读取 JSON 文件
                if "data" in data:
                    # 遍历文件中的每一条数据
                    for item in data["data"]:
                        origin_data = item.get("origin_data", {})
                        source_id = origin_data.get("source_id")

                        # 统计 source_id 出现的次数
                        source_id_count[source_id] += 1

                        # 确保每个 source_id 唯一
                        if source_id_count[source_id] > 1:
                            # 如果 source_id 已经存在，则进行编号修改
                            new_source_id = f"{source_id}_{source_id_count[source_id]}"  # 格式：source_id_2
                            # 检查是否有重复的修改后的 source_id
                            while new_source_id in seen_source_ids:
                                source_id_count[source_id] += 1  # 增加计数
                                new_source_id = f"{source_id}_{source_id_count[source_id]}"  # 生成新的修改后 ID
                            origin_data["source_id"] = new_source_id
                            seen_source_ids.add(new_source_id)
                            mode_num += 1
                            print(f"修改 source_id: {source_id} -> {new_source_id}")
                        else:
                            # 如果 source_id 是新的，则加入到已见集合中
                            seen_source_ids.add(source_id)

                        merged_data.append(item)  # 将数据添加到合并列表
                else:
                    print(f"[WARNING] 文件 {input_file} 不包含 'data' 字段")
            except json.decoder.JSONDecodeError as e:
                print(f"JSON 解码错误: {e} 文件: {input_file}")
                continue

    # 打印重复的 source_id 及其出现次数
    print("\n重复的 source_id 及其出现次数:")
    for source_id, count in source_id_count.items():
        if count > 1:
            print(f"{source_id} 出现了 {count} 次")

    print(f"\n修改了 {mode_num} 个重复的 source_id")

    # 创建合并后的 JSON
    merged_json = {"data": merged_data}

    # 写入合并后的结果
    with open(output_file, 'w', encoding='utf-8') as fout:
        json.dump(merged_json, fout, indent=2, ensure_ascii=False)

    print(f"✅ 修改完成，输出保存至: {output_file}")
def check_duplicate_source_ids(input_file):
    # 读取 JSON 文件
    print(f"🔍 检查重复的 source_id: {input_file}")
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # 使用 defaultdict 来统计每个 source_id 的出现次数
    source_id_count = defaultdict(int)
    if "data" in data:
        data = data["data"]
    # 遍历数据并统计 source_id
    print(len(data))
    for item in data:
        origin_data = item.get("origin_data", {})
        source_id = origin_data.get("source_id")
        if source_id:
            source_id_count[source_id] += 1

    # 查找重复的 source_id
    duplicates = {source_id: count for source_id, count in source_id_count.items() if count > 1}

    # 打印重复的 source_id 及其出现次数
    if duplicates:
        print("发现重复的 source_id:")
        for source_id, count in duplicates.items():
            print(f"source_id: {source_id} 出现了 {count} 次")
    else:
        print("没有发现重复的 source_id")


target_paths = ['results_1.json', 'results_2.json']  # 主数据
output_file = "merged_result.json"  # 输出的合并后的 JSON 文件路径
merge_json_files(target_paths, output_file)  # 合并 JSON 文件

input_file_1 = "results_1.json"  # 第一个 JSON 文件路径
input_file_2 = "results_2.json"  # 第二个 JSON 文件路径
output_file = "merged_result.json"  # 输出合并后的 JSON 文件
compare_and_merge_json(input_file_1, input_file_2, output_file)

check_duplicate_source_ids(output_file)
modify_duplicate_source_ids([output_file], output_file)
# 用法示例
check_duplicate_source_ids(output_file)




