import os
import json
import torch
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info # <--- 关键：需要 qwen_omni_utils.py 文件
from tqdm import tqdm
import traceback # <-- 引入 traceback 模块以获取详细的错误堆栈

# --- 1. 模型和处理器加载 ---
# 建议加上 trust_remote_code=True，对多模态模型更稳妥
model_id = "Qwen/Qwen2.5-Omni-7B" 
print(f"正在加载模型: {model_id}...")
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
    trust_remote_code=True
)
processor = Qwen2_5OmniProcessor.from_pretrained(model_id, trust_remote_code=True)
print("模型加载完成。")

# --- 2. 路径和配置信息 ---
base_data_dir = "." 
folders_to_process = ["Indonesian", "Korea", "Mongolia", "Vitnamese", "Singapore"]
output_dir = "Output_Json_Qwen_Omni_VQA"
os.makedirs(output_dir, exist_ok=True)

# --- 3. 动态VQA提示词模板 (与您的代码相同，此处省略以节省空间) ---
VQA_PROMPT_TEMPLATES = {
    # ... 您的模板定义 ...
    "China": {
        "Chinese": "请基于图像回答以下与中国文化相关的问题：\n{question}\n{options}\n这是一个多选题，请先返回所有可能的选项字母，再用中文解释你的选择。",
        "English": "Based on the image, please answer the following question related to Chinese Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Indonesian": {
        "Indonesian": "Berdasarkan gambar, silakan jawab pertanyaan berikut terkait Budaya Indonesia.\n{question}\n{options}\nIni adalah pertanyaan pilihan ganda. Harap kembalikan semua kemungkinan huruf opsi terlebih dahulu, lalu jelaskan pilihan Anda dalam Bahasa Indonesia.",
        "English": "Based on the image, please answer the following question related to Indonesian Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Korea": {
        "Korean": "이미지를 바탕으로 다음 한국 문화와 관련된 질문에 답변해 주세요.\n{question}\n{options}\n이것은 객관식 문제입니다. 먼저 가능한 모든 옵션 문자를 반환한 다음, 한국어로 당신의 선택을 설명해 주세요.",
        "English": "Based on the image, please answer the following question related to Korean Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Mongolia": {
        "Mongolian": "Зурагт үндэслэн Монголын соёлтой холбоотой дараах асуултад хариулна уу.\n{question}\n{options}\nЭнэ бол олон сонголттой асуулт юм. Эхлээд боломжит бүх сонголтын үсгийг буцааж, дараа нь сонголтоо монгол хэлээр тайлбарлана уу.",
        "English": "Based on the image, please answer the following question related to Mongolian Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Singapore": {
        "English": "Based on the image, please answer the following question related to Singaporean Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English.",
        "Malay": "Berdasarkan imej, sila jawab soalan berikut yang berkaitan dengan Budaya Singapura.\n{question}\n{options}\nIni adalah soalan pilihan berganda. Sila kembalikan semua huruf pilihan yang mungkin terlebih dahulu, kemudian jelaskan pilihan anda dalam Bahasa Inggeris.",
        "Chinese": "请基于图像回答以下与新加坡文化相关的问题。\n{question}\n{options}\n这是一个多选题，请先返回所有可能的选项字母，再用英文解释你的选择。",
    },
    "Vitnamese": {
        "Vietnamese": "Dựa vào hình ảnh, vui lòng trả lời câu hỏi sau đây liên quan đến Văn hóa Việt Nam.\n{question}\n{options}\nĐây là một câu hỏi trắc nghiệm. Vui lòng trả về tất cả các chữ cái tùy chọn có thể có trước, sau đó giải thích lựa chọn của bạn bằng tiếng Việt.",
        "English": "Based on the image, please answer the following question related to Vietnamese Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    }
}
NATIVE_LANGUAGE_MAP = {
    # ... 您的映射定义 ...
    "Indonesian": "Indonesian",
    "Korea": "Korean",
    "Mongolia": "Mongolian",
    "Singapore": "Malay,Chinese",
    "Vitnamese": "Vietnamese"
}

# --- 新增：用于记录所有错误信息的列表 ---
error_log = []

# --- 4. 遍历文件夹和文件进行处理 ---
for folder_name in folders_to_process:
    current_folder_path = os.path.join(base_data_dir, folder_name)

    if not os.path.isdir(current_folder_path):
        print(f"⚠️  警告: 文件夹 '{current_folder_path}' 不存在，已跳过。")
        continue

    print(f"\n📁 开始处理文件夹: {current_folder_path}")

    for filename in os.listdir(current_folder_path):
        if "Text_Only" not in filename and filename.endswith(".json"):
            input_path = os.path.join(current_folder_path, filename)
            
            print(f"  ➡️  正在处理VQA文件: {filename}")

            try:
                with open(input_path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            except Exception as e:
                print(f"    ❌ 读取文件失败: {input_path}, 错误: {e}")
                error_log.append({
                    "file": input_path,
                    "error_type": "FileReadError",
                    "error_message": str(e)
                })
                continue

            # --- 确定语言和Prompt模板 (逻辑保持不变) ---
            # ... (这部分代码和原来一样，为简洁省略)
            if "English" in filename:
                language = "English"
            else:
                language = NATIVE_LANGUAGE_MAP.get(folder_name, "English")
            
            prompt_template = None
            if len(language.split(",")) == 1:
                prompt_template = VQA_PROMPT_TEMPLATES[folder_name][language]
            else:
                for l in language.split(","):
                    if l in filename:
                        prompt_template = VQA_PROMPT_TEMPLATES[folder_name][l]
                        break

            # --- 逐条处理数据，并加入详细的错误捕获 ---
            for index, item in enumerate(tqdm(data, desc=f"  Processing items in {filename}", leave=False)):
                try:
                    # 检查并获取图像路径
                    if "Image_path" not in item or not item["Image_path"]:
                        raise ValueError("Image_path is missing or empty.")

                    full_image_path = os.path.join(current_folder_path, item["Image_path"])
                    if not os.path.exists(full_image_path):
                        raise FileNotFoundError(f"Image file not found at {full_image_path}")
                    
                    # 获取问题和选项
                    question = item.get("Question", "").strip()
                    options = "\n".join([
                        f"A. {str(item.get('Option1', '')).strip()}",
                        f"B. {str(item.get('Option2', '')).strip()}",
                        f"C. {str(item.get('Option3', '')).strip()}",
                        f"D. {str(item.get('Option4', '')).strip()}",
                    ])
                    
                    if not question:
                        raise ValueError("Question is empty.")

                    text_prompt = prompt_template.format(question=question, options=options)
                    
                    # 构造 conversation
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": full_image_path},
                                {"type": "text", "text": text_prompt},
                            ],
                        },
                    ]
                    
                    # 应用聊天模板
                    text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)

                    # 处理多模态信息
                    audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)

                    # 打包成模型输入
                    inputs = processor(text=text, images=images, return_tensors="pt").to(model.device).to(model.dtype)

                    # 模型推理
                    with torch.no_grad():
                        text_ids, audio = model.generate(**inputs, use_audio_in_video=False)
                    
                    # 解码
                    generated_text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0].strip()

                    item["qwen_omni_answer"] = generated_text

                except Exception as e:
                    # --- 关键的错误处理逻辑 ---
                    error_type = type(e).__name__
                    error_message = str(e)
                    
                    # 获取详细的堆栈跟踪信息
                    error_traceback = traceback.format_exc()

                    # 在屏幕上打印简明扼要的错误信息
                    print(f"\n    ❌ 错误! 文件: {filename}, 数据索引: {index}, 图片: {item.get('Image_path', 'N/A')}")
                    print(f"    错误类型: {error_type}")
                    print(f"    错误信息: {error_message}")

                    # 将详细信息记录到日志列表中
                    error_log.append({
                        "file": input_path,
                        "item_index": index,
                        "problem_item": item,
                        "error_type": error_type,
                        "error_message": error_message,
                        "traceback": error_traceback  # 记录完整的堆栈信息
                    })
                    
                    # 在原始数据中也记录下错误，以便输出文件能体现
                    item["qwen_omni_answer"] = f"Error: {error_type} - {error_message}"

            # --- 保存更新后的数据 ---
            base_filename = os.path.splitext(filename)[0]
            output_filename = f"{base_filename}_qwen_omni_7B_answered.json"
            output_path = os.path.join(output_dir, output_filename)

            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2, ensure_ascii=False)

            print(f"    ✅ VQA处理完成，结果已保存至: {output_path}")

            # 清理CUDA缓存，有助于防止因状态累积导致的错误
            torch.cuda.empty_cache()


# --- 运行结束后，保存错误日志 ---
if error_log:
    error_log_path = os.path.join(output_dir, "error_log.json")
    print(f"\n⚠️  处理过程中遇到错误，详细信息已保存至: {error_log_path}")
    with open(error_log_path, "w", encoding="utf-8") as f:
        json.dump(error_log, f, indent=2, ensure_ascii=False)
else:
    print("\n✅ 处理过程中未发生致命错误。")

print(f"\n🎉 所有文件夹的Qwen-Omni VQA任务处理完毕！所有输出文件已保存到 '{output_dir}' 文件夹中。")
