import os
import json
import torch
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

# --- 1. 模型和分词器加载 ---
# 模型ID (请确保您使用的是支持 .chat() 方法的 InternVL Chat 模型)
model_id = "OpenGVLab/InternVL-Chat-V1-5" 

print(f"正在加载模型: {model_id}...")
# 使用 device_map="auto" 简化多GPU加载
model = AutoModel.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True, # 必须为 True 才能加载自定义的 .chat() 方法
    device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("模型加载完成。")

# --- 2. 路径和配置信息 ---
base_data_dir = "." 
folders_to_process = ["Indonesian", "Korea", "Mongolia", "Vitnamese", "Singapore"]
output_dir = "Output_Json_InternVL"
os.makedirs(output_dir, exist_ok=True)

# --- 3. 简单的英语提示词模板 ---
PROMPT_TEMPLATE = "Please answer the following question related to {country} Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."

# --- 4. 遍历文件夹和文件进行处理 ---
for folder_name in folders_to_process:
    current_folder_path = os.path.join(base_data_dir, folder_name)

    if not os.path.isdir(current_folder_path):
        print(f"⚠️  警告: 文件夹 '{current_folder_path}' 不存在，已跳过。")
        continue

    print(f"\n📁 开始处理文件夹: {current_folder_path}")

    for filename in os.listdir(current_folder_path):
        if "Text_Only" not in filename and "English" in filename and filename.endswith(".json"):
            input_path = os.path.join(current_folder_path, filename)
            
            print(f"  ➡️  正在处理文件: {filename}")

            try:
                with open(input_path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            except Exception as e:
                print(f"    ❌ 读取文件失败: {input_path}, 错误: {e}")
                continue
            
            for item in tqdm(data, desc=f"  Processing items in {filename}", leave=False):
                try:
                    question = item.get("Rephrased_Question", "").strip()
                    options = [
                        f"A. {str(item.get('Option1', '')).strip()}",
                        f"B. {str(item.get('Option2', '')).strip()}",
                        f"C. {str(item.get('Option3', '')).strip()}",
                        f"D. {str(item.get('Option4', '')).strip()}",
                    ]
                    
                    text_prompt = PROMPT_TEMPLATE.format(
                        country=folder_name,
                        question=question, 
                        options="\n".join(options)
                    )
                    
                    # =========================================================================
                    # --- 模型推理 (正确方式：使用 Intern-VL 的 .chat() 方法) ---
                    # 这是修正错误的关键部分！
                    # =========================================================================
                    
                    # 1. 定义生成参数
                    generation_config = dict(
                        max_new_tokens=1024,
                        do_sample=False, # 使用False以获得更确定性的回答
                        temperature=0.0
                    )
                    
                    # 2. 调用 .chat() 方法，而不是 .generate()
                    response, history = model.chat(
                        tokenizer=tokenizer,
                        pixel_values=None, # 纯文本任务，无需图像
                        question=text_prompt,
                        generation_config=generation_config,
                        history=None, # 每个问题都是独立的，不使用历史对话
                        return_history=True
                    )
                    
                    # 3. .chat() 方法直接返回解码后的文本字符串
                    output_text = response.strip()

                    # --- 在原数据项中添加新字段 ---
                    item["internvl_answer"] = output_text

                except Exception as e:
                    item["internvl_answer"] = f"Error: {str(e)}"

            # --- 保存更新后的数据 ---
            base_filename = os.path.splitext(filename)[0]
            output_filename = f"{base_filename}_internvl_answered.json"
            output_path = os.path.join(output_dir, output_filename)

            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2, ensure_ascii=False)

            print(f"    ✅ 处理完成，结果已保存至: {output_path}")

print(f"\n🎉 所有文件夹处理完毕！所有输出文件已保存到 '{output_dir}' 文件夹中。")