import json
import base64
import os
from openai import OpenAI
from tqdm import tqdm
import concurrent.futures # 导入多线程库

# --- 1. 配置区域 ---

# a) [必须修改] 填入你的 API 配置
API_BASE_URL = "http://35.220.164.252:3888/v1/"
API_KEY = "sk-s1gWkIlgSgpHXSgVDs8GP8eLgpCHOrGzeouBWpOSKJa9SaB9"

# b) [必须修改] 设置输入和输出文件路径
INPUT_JSON_PATH = r"D:\ecnu\2025\nvi\json\base_train_task2.json"
OUTPUT_JSON_PATH = "dataset_with_predictions_fast_train_gpt4o.json"

# c) [可选修改] 你想要处理的数据条数
MAX_ITEMS_TO_PROCESS = 100

# d) [可选修改] 选择要使用的模型
# MODEL_NAME = "gemini-2.5-flash-lite-thinking"
MODEL_NAME = "gpt-4o"
# e) [重要-性能] 设置并行处理的线程数 (可以根据你的网络和API服务器的承受能力调整)
# 建议从 5-10 开始尝试
MAX_WORKERS = 8


# --- 2. 核心功能代码 ---

# 初始化OpenAI客户端 (在主函数外部，以便所有线程共享)
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

def encode_image_to_base64(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        return f"ERROR: File not found at {image_path}" # 返回错误信息而非None
    except Exception as e:
        return f"ERROR: Could not read file {image_path}, Reason: {e}"

def build_prompt_text(core_instruction):
    input_text = "You are following a multi-step navigation instruction. Based on the current images, think step by step to determine which step you are currently at. Then, choose the best next-action that matches the next instruction step."
    problem_text = "The action options are: A = move forward, B = turn left, C = turn right, D = stop. Output only: <A/B/C/D>\n The current-view is : <image>"
    return f"{input_text}\n The instruction is : {core_instruction}\n{problem_text}"

def process_single_item(item):
    """
    这个函数负责处理单个数据项：编码图片、调用API、返回结果。
    它将被多线程并行调用。
    """
    try:
        core_instruction = item["instruction"]
        image_path = item["images"]["img_hist0"]
    except KeyError as e:
        return item, f"DATA_ERROR: Missing key {e}"

    full_prompt_text = build_prompt_text(core_instruction)
    img_base64 = encode_image_to_base64(image_path)
    
    if img_base64.startswith("ERROR:"):
        return item, img_base64 # 如果图片编码失败，直接返回错误

    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": full_prompt_text},
                        {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}}
                    ]
                }
            ]
        )
        return item, response.choices[0].message.content
    except Exception as e:
        return item, f"API_ERROR: {e}"

def main():
    print(f"正在从文件加载JSON数据: {INPUT_JSON_PATH}")
    with open(INPUT_JSON_PATH, "r", encoding="utf-8") as f:
        full_data = json.load(f)

    # data_to_process = full_data[:MAX_ITEMS_TO_PROCESS]
    data_to_process = full_data
    print(f"将使用最多 {MAX_WORKERS} 个线程并行处理 {len(data_to_process)} 条数据。")

    results_with_predictions = []
    
    # 使用ThreadPoolExecutor进行并行处理
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # 提交所有任务
        future_to_item = {executor.submit(process_single_item, item): item for item in data_to_process}
        
        # 使用tqdm来跟踪已完成的任务进度
        for future in tqdm(concurrent.futures.as_completed(future_to_item), total=len(data_to_process), desc="并行处理中"):
            original_item, prediction = future.result()
            new_item = original_item.copy()
            new_item['model_prediction'] = prediction
            results_with_predictions.append(new_item)

    print("\n所有条目处理完毕，正在将结果写入新文件...")
    with open(OUTPUT_JSON_PATH, "w", encoding="utf-8") as f:
        json.dump(results_with_predictions, f, indent=2, ensure_ascii=False)
    print(f"✨ 成功！结果已保存到: {os.path.abspath(OUTPUT_JSON_PATH)}")

if __name__ == "__main__":
    main()