import os
import json
import torch
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
# from qwen_omni_utils import process_mm_info # <--- 关键修改1: 不再需要多模态处理工具
from qwen_omni_utils import process_mm_info
from tqdm import tqdm

# --- 1. 模型和处理器加载 (保持不变) ---
# Qwen-Omni 模型本身就可以很好地处理纯文本任务
model_id = "Qwen/Qwen2.5-Omni-7B" 

print(f"正在加载模型: {model_id}...")
# 建议启用 flash_attention_2 以获得更好的加速和内存节省
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
    trust_remote_code=True
)
# model.disable_talker() # 对于纯文本任务，这个通常不是必需的，可以保留或注释掉

processor = Qwen2_5OmniProcessor.from_pretrained(model_id,trust_remote_code=True)
print("模型加载完成。")

# --- 2. 路径和配置信息 (保持不变) ---
base_data_dir = "." 
folders_to_process = ["Indonesian", "Korea", "Mongolia", "Vitnamese", "Singapore"]
output_dir = "Output_Json_Qwen_Omni_Text_QA" # 建议换一个输出目录名以作区分
BATCH_SIZE = 8

os.makedirs(output_dir, exist_ok=True)

# --- 3. 动态QA提示词模板 (关键修改2: 移除所有"基于图像"的描述) ---
TEXT_QA_PROMPT_TEMPLATES = {
    "China": {
        "Chinese": "请回答以下与中国文化相关的问题：\n{question}\n{options}\n这是一个多选题，请先返回所有可能的选项字母，再用中文解释你的选择。",
        "English": "Please answer the following question related to Chinese Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Indonesian": {
        "Indonesian": "Silakan jawab pertanyaan berikut terkait Budaya Indonesia.\n{question}\n{options}\nIni adalah pertanyaan pilihan ganda. Harap kembalikan semua kemungkinan huruf opsi terlebih dahulu, lalu jelaskan pilihan Anda dalam Bahasa Indonesia.",
        "English": "Please answer the following question related to Indonesian Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Korea": {
        "Korean": "다음 한국 문화와 관련된 질문에 답변해 주세요.\n{question}\n{options}\n이것은 객관식 문제입니다. 먼저 가능한 모든 옵션 문자를 반환한 다음, 한국어로 당신의 선택을 설명해 주세요.",
        "English": "Please answer the following question related to Korean Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Mongolia": {
        "Mongolian": "Монголын соёлтой холбоотой дараах асуултад хариулна уу.\n{question}\n{options}\nЭнэ бол олон сонголттой асуулт юм. Эхлээд боломжит бүх сонголтын үсгийг буцааж, дараа нь сонголтоо монгол хэлээр тайлбарлана уу.",
        "English": "Please answer the following question related to Mongolian Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    },
    "Singapore": {
        "English": "Please answer the following question related to Singaporean Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English.",
        "Malay": "Sila jawab soalan berikut yang berkaitan dengan Budaya Singapura.\n{question}\n{options}\nIni adalah soalan pilihan berganda. Sila kembalikan semua huruf pilihan yang mungkin terlebih dahulu, kemudian jelaskan pilihan anda dalam Bahasa Inggeris.",
        "Chinese": "请回答以下与新加坡文化相关的问题。\n{question}\n{options}\n这是一个多选题，请先返回所有可能的选项字母，再用英文解释你的选择。",
    },
    "Vitnamese": {
        "Vietnamese": "Vui lòng trả lời câu hỏi sau đây liên quan đến Văn hóa Việt Nam.\n{question}\n{options}\nĐây là một câu hỏi trắc nghiệm. Vui lòng trả về tất cả các chữ cái tùy chọn có thể có trước, sau đó giải thích lựa chọn của bạn bằng tiếng Việt.",
        "English": "Please answer the following question related to Vietnamese Culture.\n{question}\n{options}\nThis is a multiple-choice question. Please first return all possible option letters, then explain your choice in English."
    }
}

# 文件夹名到本地语言的映射 (保持不变)
NATIVE_LANGUAGE_MAP = {
    "Indonesian": "Indonesian",
    "Korea": "Korean",
    "Mongolia": "Mongolian",
    "Singapore": "Malay,Chinese",
    "Vitnamese": "Vietnamese"
}

# --- 4. 遍历文件夹和文件进行处理 (修改为纯文本模式) ---
for folder_name in folders_to_process:
    current_folder_path = os.path.join(base_data_dir, folder_name)

    if not os.path.isdir(current_folder_path):
        print(f"⚠️  警告: 文件夹 '{current_folder_path}' 不存在，已跳过。")
        continue

    print(f"\n📁 开始处理文件夹: {current_folder_path}")

    # 这里我们假设纯文本问题也在原来的JSON文件中，所以文件遍历逻辑保持不变
    # 如果你的纯文本问题在不同的文件中 (例如 "Text_Only" 文件)，请调整下面的 if 条件
    for filename in os.listdir(current_folder_path):
        # 例如，如果你想处理 Text_Only 文件，可以改成: if "Text_Only" in filename and filename.endswith(".json"):
        if "Text_Only" not in filename and filename.endswith(".json"): # 保持原逻辑，处理非Text_Only文件
            input_path = os.path.join(current_folder_path, filename)
            
            print(f"  ➡️  正在处理QA文件: {filename}")

            try:
                with open(input_path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            except Exception as e:
                print(f"    ❌ 读取文件失败: {input_path}, 错误: {e}")
                continue

            # --- 确定语言和Prompt模板 ---
            if "English" in filename:
                language = "English"
            else:
                language = NATIVE_LANGUAGE_MAP.get(folder_name, "English")
            
            prompt_template = None
            if len(language.split(",")) == 1:
                # 使用新的纯文本模板
                prompt_template = TEXT_QA_PROMPT_TEMPLATES[folder_name][language]
            else:
                for l in language.split(","):
                    if l in filename:
                        # 使用新的纯文本模板
                        prompt_template = TEXT_QA_PROMPT_TEMPLATES[folder_name][l]
                        break
            
            # --- 按批次处理数据 ---
            for i in tqdm(range(0, len(data), BATCH_SIZE), desc=f"  Processing batches in {filename}", leave=False):
                batch_data = data[i:i+BATCH_SIZE]
                
                batch_conversations = []
                valid_items_in_batch = [] # 只存储预处理成功的item

                # --- 1. 准备批次数据 (纯文本) ---
                for item in batch_data:
                    # 关键修改3: 移除所有与图像相关的检查和路径处理
                    # if "Image_path" not in item ... (这部分代码已删除)
                    
                    question = item.get("Rephrased_Question", item.get("Question", "")).strip()
                    options = [
                        f"A. {str(item.get('Option1', '')).strip()}",
                        f"B. {str(item.get('Option2', '')).strip()}",
                        f"C. {str(item.get('Option3', '')).strip()}",
                        f"D. {str(item.get('Option4', '')).strip()}",
                    ]
                    
                    text_prompt = prompt_template.format(
                        question=question, 
                        options="\n".join(options)
                    )
                    
                    # 关键修改4: 构建纯文本的对话格式
                    conversation = [
                        {
                            "role": "system",
                            "content": [
                                {"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."} # 可以简化 system prompt
                            ],
                        },
                        {
                            "role": "user",
                            "content": [
                                # 不再有 {"type": "image", ...}
                                {"type": "text", "text": text_prompt},
                            ],
                        },
                    ]
                    
                    batch_conversations.append(conversation)
                    valid_items_in_batch.append(item)

                if not batch_conversations:
                    continue

                # --- 2. 批次推理 (纯文本) ---
                try:
                    text = processor.apply_chat_template(
                        batch_conversations, 
                        add_generation_prompt=True, 
                        tokenize=False
                    )

                    # 关键修改5: 模型输入不再包含 images, audios, videos
                    # audios, images, videos = process_mm_info(...) # <--- 已删除
                    audios, images, videos = process_mm_info(batch_conversations, use_audio_in_video=False)
                    
                    # 只处理文本
                    inputs = processor(
                        text=text, 
                        images=images,
                        videos=videos,
                        return_tensors="pt",
                        padding=True
                    ).to(model.device).to(model.dtype)

                    with torch.no_grad():
                        # 关键修改6: generate 函数不再需要多模态相关参数
                        text_ids = model.generate(**inputs, max_new_tokens=1024, use_audio_in_video=False, return_audio=False)

                    generated_texts = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

                    # --- 3. 将结果写回数据项 ---
                    # 注意：因为没有了图像检查，batch_data 和 generated_texts 应该是一一对应的
                    for item, answer in zip(valid_items_in_batch, generated_texts):
                        item["qwen_omni_answer"] = answer[0] if isinstance(answer, list) and len(answer) > 0 else answer

                except Exception as e:
                    error_message = f"Error during batch processing: {str(e)}"
                    for item in batch_data:
                        item["qwen_omni_answer"] = error_message

            # --- 保存更新后的数据 ---
            base_filename = os.path.splitext(filename)[0]
            output_filename = f"{base_filename}_qwen_omni_7B_text_answered.json"
            output_path = os.path.join(output_dir, output_filename)

            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2, ensure_ascii=False)

            print(f"    ✅ QA处理完成，结果已保存至: {output_path}")

print(f"\n🎉 所有文件夹的Qwen-Omni Text QA任务处理完毕！所有输出文件已保存到 '{output_dir}' 文件夹中。")