# import torch
# import torchvision.transforms as T
# from PIL import Image
# from torchvision.transforms.functional import InterpolationMode
# from transformers import AutoModel, AutoTokenizer

# # --- 图像预处理部分 ---

# # ImageNet 默认的均值和标准差
# IMAGENET_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_STD = (0.229, 0.224, 0.225)

# def build_transform(input_size):
#     """构建图像预处理流程"""
#     transform = T.Compose([
#         T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
#         T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
#         T.ToTensor(),
#         T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
#     ])
#     return transform

# def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
#     """找到最接近的目标长宽比"""
#     best_ratio_diff = float('inf')
#     best_ratio = (1, 1)
#     area = width * height
#     for ratio in target_ratios:
#         target_aspect_ratio = ratio[0] / ratio[1]
#         ratio_diff = abs(aspect_ratio - target_aspect_ratio)
#         if ratio_diff < best_ratio_diff:
#             best_ratio_diff = ratio_diff
#             best_ratio = ratio
#         elif ratio_diff == best_ratio_diff:
#             if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
#                 best_ratio = ratio
#     return best_ratio

# def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
#     """
#     动态预处理，将高分辨率图像切分成多个小块以适应模型。
#     这是模型处理高清图的自带能力，并非模型切分到多卡。
#     """
#     orig_width, orig_height = image.size
#     aspect_ratio = orig_width / orig_height

#     target_ratios = set(
#         (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
#         i * j <= max_num and i * j >= min_num)
#     target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

#     target_aspect_ratio = find_closest_aspect_ratio(
#         aspect_ratio, target_ratios, orig_width, orig_height, image_size)

#     target_width = image_size * target_aspect_ratio[0]
#     target_height = image_size * target_aspect_ratio[1]
#     blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

#     resized_img = image.resize((target_width, target_height))
#     processed_images = []
#     for i in range(blocks):
#         box = (
#             (i % (target_width // image_size)) * image_size,
#             (i // (target_width // image_size)) * image_size,
#             ((i % (target_width // image_size)) + 1) * image_size,
#             ((i // (target_width // image_size)) + 1) * image_size
#         )
#         split_img = resized_img.crop(box)
#         processed_images.append(split_img)
#     assert len(processed_images) == blocks
#     if use_thumbnail and len(processed_images) != 1:
#         thumbnail_img = image.resize((image_size, image_size))
#         processed_images.append(thumbnail_img)
#     return processed_images

# def load_image(image_file, input_size=448, max_num=12):
#     """加载并处理单张图片，返回模型所需的 pixel_values"""
#     image = Image.open(image_file).convert('RGB')
#     transform = build_transform(input_size=input_size)
#     images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
#     pixel_values = [transform(image) for image in images]
#     pixel_values = torch.stack(pixel_values)
#     return pixel_values

# # --- 主程序：单轮 VQA ---

# # 1. 设置模型路径和设备
# path = 'OpenGVLab/InternVL-Chat-V1-5'
# # 如果没有CUDA设备，可以改为 'cpu'，但推理速度会非常慢
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

# # 2. 加载模型和 Tokenizer
# # 该模型需要在单张GPU上运行，请确保有足够的显存（例如 40GB+）
# model = AutoModel.from_pretrained(
#     path,
#     torch_dtype=torch.bfloat16,
#     low_cpu_mem_usage=True,
#     trust_remote_code=True).eval().to(device)
# tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)

# # 3. 准备图像和问题
# # 请将 './examples/image1.jpg' 替换为你的图片路径
# image_path = './examples/image1.jpg'
# # 你的问题，必须包含 <image> 占位符
# question = '<image>\nPlease describe the image shortly.'

# # 4. 加载并预处理图像
# pixel_values = load_image(image_path, max_num=12).to(torch.bfloat16).to(device)

# # 5. 设置生成参数
# generation_config = dict(
#     max_new_tokens=1024,
#     do_sample=False, # 设置为False以获得更具确定性的回答
#     num_beams=3,
# )

# # 6. 执行单轮VQA对话
# # 对于单轮对话，我们不需要 `history` 参数
# response = model.chat(
#     tokenizer=tokenizer,
#     pixel_values=pixel_values,
#     question=question,
#     generation_config=generation_config
# )

# # 7. 打印结果
# print(f"User: {question}\n")
# print(f"Assistant: {response}")

import os
import json
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

# --- 1. InternVL 特有的图像预处理函数 (从您的示例代码中完整复制) ---
# ImageNet 默认的均值和标准差
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    """构建图像预处理流程"""
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    """找到最接近的目标长宽比"""
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    """动态预处理，将高分辨率图像切分成多个小块以适应模型。"""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image_for_internvl(image_file, input_size=448, max_num=12):
    """加载并处理单张图片，返回模型所需的 pixel_values"""
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

# --- 2. 模型和处理器加载 ---
model_id = 'OpenGVLab/InternVL-Chat-V1-5'
# 如果没有CUDA设备，可以改为 'cpu'，但推理速度会非常慢
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"正在加载模型: {model_id}...")
# InternVL 需要在单GPU上加载，因此不使用 device_map='auto'
model = AutoModel.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("模型加载完成。")

# 定义模型推理的生成参数
generation_config = dict(
    num_beams=3,
    max_new_tokens=1024,
    do_sample=False, # 设置为False以获得更具确定性的回答
)

# --- 3. 路径和配置信息 (与您的Qwen代码保持一致) ---
base_data_dir = "."
folders_to_process = ["Indonesian", "Korea", "Mongolia", "Vitnamese", "Singapore", "China"]
output_dir = "Output_Json_InternVL_VQA"
os.makedirs(output_dir, exist_ok=True)

# --- 4. 动态VQA提示词模板 (与您的Qwen代码保持一致) ---
# 注意：InternVL的prompt通常以<image>开头，我们将在循环中动态添加
VQA_PROMPT_TEMPLATES = {
    # ... (此处省略，内容与您提供的完全相同) ...
    "China": {
        "Chinese": "请基于图像回答以下与中国文化相关的问题：\n{question}\n{options}\n这是一个多选题，请先返回所有可能的选项字母，再用中文解释你的选择。",
        "English": "Based on the image, please answer the following question related to Chinese Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Indonesian": {
        "Indonesian": "Berdasarkan gambar, silakan jawab pertanyaan berikut terkait Budaya Indonesia.\n{question}\n{options}\nIni adalah pertanyaan pilihan ganda. Harap kembalikan semua kemungkinan huruf opsi terlebih dahulu, lalu jelaskan pilihan Anda dalam Bahasa Indonesia.",
        "English": "Based on the image, please answer the following question related to Indonesian Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Korea": {
        "Korean": "이미지를 바탕으로 다음 한국 문화와 관련된 질문에 답변해 주세요.\n{question}\n{options}\n이것은 객관식 문제입니다. 먼저 가능한 모든 옵션 문자를 반환한 다음, 한국어로 당신의 선택을 설명해 주세요.",
        "English": "Based on the image, please answer the following question related to Korean Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Mongolia": {
        "Mongolian": "Зурагт үндэслэн Монголын соёлтой холбоотой дараах асуултад хариулна уу.\n{question}\n{options}\nЭнэ бол олон сонголттой асуулт юм. Эхлээд боломжит бүх сонголтын үсгийг буцааж, дараа нь сонголтоо монгол хэлээр тайлбарлана уу.",
        "English": "Based on the image, please answer the following question related to Mongolian Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Singapore": {
        "English": "Based on the image, please answer the following question related to Singaporean Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English.",
        "Malay": "Berdasarkan imej, sila jawab soalan berikut yang berkaitan dengan Budaya Singapura.\n{question}\n{options}\nIni adalah soalan pilihan berganda. Sila kembalikan semua huruf pilihan yang mungkin terlebih dahulu, kemudian jelaskan pilihan anda dalam Bahasa Inggeris.",
        "Chinese": "请基于图像回答以下与新加坡文化相关的问题。\n{question}\n{options}\n这是一个多选题，请先返回所有可能的选项字母，再用英文解释你的选择。",
    },
    "Vitnamese": {
        "Vietnamese": "Dựa vào hình ảnh, vui lòng trả lời câu hỏi sau đây liên quan đến Văn hóa Việt Nam.\n{question}\n{options}\nĐây là một câu hỏi trắc nghiệm. Vui lòng trả về tất cả các chữ cái tùy chọn có thể có trước, sau đó giải thích lựa chọn của bạn bằng tiếng Việt.",
        "English": "Based on the image, please answer the following question related to Vietnamese Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    }
}
NATIVE_LANGUAGE_MAP = {
    "Indonesian": "Indonesian",
    "Korea": "Korean",
    "Mongolia": "Mongolian",
    "Singapore": "Malay,Chinese",
    "Vitnamese": "Vietnamese",
    "China": "Chinese"
}


# --- 5. 遍历文件夹和文件进行处理 (核心逻辑修改) ---
for folder_name in folders_to_process:
    current_folder_path = os.path.join(base_data_dir, folder_name)

    if not os.path.isdir(current_folder_path):
        print(f"⚠️  警告: 文件夹 '{current_folder_path}' 不存在，已跳过。")
        continue

    print(f"\n📁 开始处理文件夹: {current_folder_path}")

    for filename in os.listdir(current_folder_path):
        if "Text_Only" not in filename and filename.endswith(".json"):
            input_path = os.path.join(current_folder_path, filename)
            
            print(f"  ➡️  正在处理VQA文件: {filename}")

            try:
                with open(input_path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            except Exception as e:
                print(f"    ❌ 读取文件失败: {input_path}, 错误: {e}")
                continue

            # --- 确定语言和Prompt模板 (逻辑不变) ---
            if "English" in filename:
                language = "English"
            else:
                language = NATIVE_LANGUAGE_MAP.get(folder_name, "English")
            
            prompt_template = None
            if len(language.split(",")) == 1:
                prompt_template = VQA_PROMPT_TEMPLATES[folder_name][language]
            else:
                for l in language.split(","):
                    if l in filename:
                        prompt_template = VQA_PROMPT_TEMPLATES[folder_name][l]
                        break
            if prompt_template is None:
                # 备用逻辑，如果多语言文件名匹配失败，则默认使用第一个
                default_lang = language.split(",")[0]
                prompt_template = VQA_PROMPT_TEMPLATES[folder_name][default_lang]


            for item in tqdm(data, desc=f"  Processing items in {filename}", leave=False):
                try:
                    # 检查并获取图像路径 (逻辑不变)
                    if "Image_path" not in item or not item["Image_path"]:
                        item["internvl_1_5_answer"] = "Error: Image_path is missing or empty."
                        continue

                    full_image_path = os.path.join(current_folder_path, item["Image_path"])
                    if not os.path.exists(full_image_path):
                        item["internvl_1_5_answer"] = f"Error: Image file not found at {full_image_path}"
                        continue
                    
                    # *** 关键修改: 使用InternVL的图像加载和预处理 ***
                    pixel_values = load_image_for_internvl(full_image_path, max_num=12).to(torch.bfloat16).to(device)

                    # 获取问题和选项 (逻辑不变)
                    question_text = item.get("Question", item.get("Question", "")).strip()
                    options = [
                        f"A. {str(item.get('Option1', '')).strip()}",
                        f"B. {str(item.get('Option2', '')).strip()}",
                        f"C. {str(item.get('Option3', '')).strip()}",
                        f"D. {str(item.get('Option4', '')).strip()}",
                    ]
                    
                    # 使用VQA模板生成 text_prompt (逻辑不变)
                    formatted_prompt = prompt_template.format(
                        question=question_text, 
                        options="\n".join(options)
                    )

                    # *** 关键修改: 构造InternVL的输入格式 ***
                    # InternVL 需要 <image> 占位符
                    final_question_prompt = f"<image>\n{formatted_prompt}"
                    
                    # *** 关键修改: 模型推理 ***
                    with torch.no_grad():
                        response = model.chat(
                            tokenizer=tokenizer,
                            pixel_values=pixel_values,
                            question=final_question_prompt,
                            generation_config=generation_config
                        )

                    # InternVL的`chat`方法直接返回解码后的文本字符串
                    output_text = response.strip()

                    # 在原数据项中添加新字段 (修改字段名)
                    item["internvl_1_5_answer"] = output_text

                except Exception as e:
                    item["internvl_1_5_answer"] = f"Error: {str(e)}"

            # --- 保存更新后的数据 (修改输出文件名) ---
            base_filename = os.path.splitext(filename)[0]
            output_filename = f"{base_filename}_internvl_1.5_answered.json"
            output_path = os.path.join(output_dir, output_filename)

            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2, ensure_ascii=False)

            print(f"    ✅ VQA处理完成，结果已保存至: {output_path}")

print(f"\n🎉 所有文件夹的VQA任务处理完毕！所有输出文件已保存到 '{output_dir}' 文件夹中。")