import os
import json
import shutil
from tqdm import tqdm

# 读取原始 JSON 文件
files = [
    "/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/close-ended/mimic_cxr_close_pairs.json",
    "/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/open-ended/mimic_cxr_open_pairs.json",
    "/root/project/benchmark_data/Knowledge_Deficiency_Hallucination/close-ended/mimic_cxr_close_pairs.json",
    "/root/project/benchmark_data/Context_Misalignment_Hallucination/MIMIC-CXR_pairs.json"
]
images = set()
src = "/root/project/datasets/mimic_cxr_jpg/files"
dst = "/root/project/datasets/mimic_cxr_jpg_used/files"

for input_file in files:
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    for item in data:
        img_path = item.get("img_name")
        images.add(img_path)

# for image in images:
#     # 迁移报告
#     src_img = os.path.join(src, image)
#     dst_img = os.path.join(dst, image)
#     src_report = "/".join(src_img.split("/")[:-1])+".txt"
#     dst_report = "/".join(dst_img.split("/")[:-1])+".txt"
#     target_dir = os.path.dirname(dst_report)
#     if not os.path.exists(target_dir):
#         os.makedirs(target_dir)
#     shutil.copy(src_report, dst_report)
    
#     # 迁移图片
#     # src_img = os.path.join(src, image)
#     # dst_img = os.path.join(dst, image)
#     # target_dir = os.path.dirname(dst_img)
#     # if not os.path.exists(target_dir):
#     #     os.makedirs(target_dir)
#     # shutil.copy(src_img, dst_img)

files2 = [
    "/root/project/benchmark_data/1.json",
    "/root/project/benchmark_data/2.json",
    "/root/project/benchmark_data/3.json",
    "/root/project/benchmark_data/4.json",
]
dst2 = "/root/project/datasets/mimic_cxr_jpg_used2/files"
images2 = set()
for input_file in files2:
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    for item in data:
        if "image_path" in item.keys():
            img_paths = item.get("image_path")
            for img_path in img_paths:
                if not "CXR" in img_path:
                    images2.add(img_path)
        elif "image" in item.keys():
            img_path = item.get("image")
            if not "CXR" in img_path:
                images2.add(img_path)
        else:
            print("Error!")

for image in tqdm(images2):
    # # 迁移报告
    # src_img = os.path.join(src, image)
    # dst_img = os.path.join(dst, image)
    # src_report = "/".join(src_img.split("/")[:-1])+".txt"
    # dst_report = "/".join(dst_img.split("/")[:-1])+".txt"
    # target_dir = os.path.dirname(dst_report)
    # if not os.path.exists(target_dir):
    #     os.makedirs(target_dir)
    # shutil.copy(src_report, dst_report)
    
    # 迁移图片
    if image in images: 
        print("already exist")
    src_img = os.path.join(src, image)
    dst_img = os.path.join(dst2, image)
    target_dir = os.path.dirname(dst_img)
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
    shutil.copy(src_img, dst_img)