from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor
from concurrent.futures import ThreadPoolExecutor # 引入多线程
import argparse
import re, json, torch
from tqdm import tqdm
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch._dynamo

torch._dynamo.config.suppress_errors = True

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark", default='mathverse', type=str)
parser.add_argument("--model_path", default='Qwen/Qwen2.5-VL-7B-Instruct', type=str)
parser.add_argument("--model", default='qwen2.5vl-7b', type=str)
parser.add_argument("--out_dir", default='model_answer', type=str)
parser.add_argument("--temperature", default=0.6, type=float)
parser.add_argument("--gpus", default=1, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--batch_size", default=1024, type=int, help="每批处理的样本数，根据内存调整")
args = parser.parse_args()

def preprocess_item(item, processor, benchmark):
    """单条数据预处理函数，用于多线程调用"""
    try:
        if benchmark == 'p-bench-crop':
            img_path = item['crop_images'][0]
            # print(img_path)
        else:
            img_path = item['images'][0]
        
        inst = item['query'].replace('<image>', '')
        messages = [
            # {"role": "system", "content": "You are a helpful assistant.\nThe assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The answer are enclosed within <answer> </answer> tags, respectively, i.e., reasoning process here <answer> answer here </answer>"},
            {"role": "system", "content": "You are a helpful assistant.\n The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The answer are enclosed within <answer> </answer> tags, respectively, i.e., reasoning process here <answer> answer here </answer>"},
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img_path, "min_pixels": 65536, "max_pixels": 16777216},
                    {"type": "text", "text": inst},
                ],
            },
        ]
        
        prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        processor.image_processor.size["longest_edge"] = 16777216
        image_inputs, video_inputs = process_vision_info(messages)
        
        mm_data = {}
        if image_inputs is not None: mm_data["image"] = image_inputs
        if video_inputs is not None: mm_data["video"] = video_inputs
        
        return {"prompt": prompt, "multi_modal_data": mm_data, "raw_item": item}
    except Exception as e:
        print(f"预处理错误: {e}")
        return None




if __name__ == '__main__':
    
    MODEL_PATH = args.model_path
    if args.benchmark == 'countqa':
        BENCHMARK_PATH = 'test.json'
    if args.benchmark == 'hrbench-4k':
        BENCHMARK_PATH = 'hr_bench_4k.json'
    if args.benchmark == 'hrbench-8k':
        BENCHMARK_PATH = 'hr_bench_8k.json'
    if args.benchmark == 'vstar':
        BENCHMARK_PATH = 'vstar.json'
    if args.benchmark == 'cvbench-2d':
        BENCHMARK_PATH = 'test_2d.json'
    if args.benchmark == 'cvbench-3d':
        BENCHMARK_PATH = 'test_3d.json'
    if args.benchmark == 'mmstar':
        BENCHMARK_PATH = 'MMStar.json'
    if args.benchmark == 'babyvision':
        BENCHMARK_PATH = 'babyvision.json'
    if args.benchmark == 'mme-realworld':
        BENCHMARK_PATH = 'MME_RealWorld.json'
    if args.benchmark == 'mme-realworld-cn':
        BENCHMARK_PATH = 'MME_RealWorld_CN.json'
    if args.benchmark == 'colorbench':
        BENCHMARK_PATH = 'test.json'
    if args.benchmark == 'p-bench':
        BENCHMARK_PATH = 'pbench.json'
    if args.benchmark == 'p-bench-crop':
        BENCHMARK_PATH = 'pbench.json'
    if args.benchmark == "fakeclue":
        BENCHMARK_PATH = 'fakeclue.json'
    if args.benchmark == "forensicsbench":
        BENCHMARK_PATH = 'forensicsbench.json'
    if args.benchmark == "loki":
        BENCHMARK_PATH = 'loki_image.json'



    tensor_parallel_size = args.gpus
    gpu_memory_utilization = 0.8
    # 1. 设置输出路径 (使用 .jsonl 格式以便追加)
    out_dir = os.path.join(args.out_dir, args.benchmark)
    os.makedirs(out_dir, exist_ok=True)
    OUT_PATH = f'{out_dir}/{args.model}_seed{args.seed}_answer.json'
    
    # 2. 断点续传逻辑：统计已处理的样本
    done_sample_ids = set()
    if os.path.exists(OUT_PATH):
        with open(OUT_PATH, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    # 优先使用唯一标识符，如果没有则需确保数据顺序一致
                    # 这里假设数据中有 'sample_index' 或 'question_id'
                    # 如果没有唯一ID，可以使用 item 中的 query + images 路径作为特征
                    identifier = data.get('images')[0] + data.get('query')
                    # print(identifier)
                    done_sample_ids.add(identifier)
                except:
                    continue
        print(f"检测到断点，已完成 {len(done_sample_ids)} 个样本。")

    with open(BENCHMARK_PATH, 'r', encoding='utf-8') as f:
        total_data = json.load(f)
    
    # 过滤掉已经跑过的样本
    todo_data = [
        item for item in total_data 
        if (item.get('images')[0] + item.get('query')) not in done_sample_ids
    ]
    
    if not todo_data:
        print("所有样本已处理完毕。")
        exit()

    print(f"剩余待处理样本: {len(todo_data)}")

    # 4. 初始化模型
    llm = LLM(
        model=MODEL_PATH,
        limit_mm_per_prompt={"image": 1, "video": 1},
        dtype=torch.bfloat16, 
        gpu_memory_utilization=0.8, 
        tensor_parallel_size=args.gpus,
        trust_remote_code=True,
        mm_processor_kwargs={"max_pixels": 4096*4096}
    )

    sampling_params = SamplingParams(
        max_tokens=8192,
        temperature=args.temperature,
        top_p=0.8,
        top_k=20,
        presence_penalty=1.5,
        repetition_penalty=1.0,
        seed=args.seed,
    )

    processor = AutoProcessor.from_pretrained(MODEL_PATH, max_pixels=16777216)
    processor.image_processor.size["longest_edge"] = 16777216

    # 5. 分批循环推理
    batch_size = args.batch_size
    for i in range(0, len(todo_data), batch_size):
        chunk = todo_data[i : i + batch_size]
        prompt_list = []
        valid_chunk_items = []

        print(f"正在并行预处理图片... (Batch {i//batch_size})")
        with ThreadPoolExecutor(max_workers=32) as executor:
            # 这里的 max_workers 根据你的 CPU 核心数调整
            results = list(executor.map(lambda x: preprocess_item(x, processor, args.benchmark), chunk))
        
        # 过滤掉失败的样本
        valid_results = [r for r in results if r is not None]
        prompt_list = [{"prompt": r["prompt"], "multi_modal_data": r["multi_modal_data"]} for r in valid_results]
        valid_chunk_items = [r["raw_item"] for r in valid_results]

        if not prompt_list:
            continue
            
        # 执行推理
        outputs = llm.generate(prompt_list, sampling_params=sampling_params)

        # 写入结果 (保持 Append 模式)
        with open(OUT_PATH, 'a', encoding='utf-8') as f_out:
            for j, output in enumerate(outputs):
                response = output.outputs[0].text
                result_item = valid_chunk_items[j].copy()
                result_item['model_answer'] = response
                f_out.write(json.dumps(result_item, ensure_ascii=False) + '\n')
        
        print(f"已完成: {i + len(prompt_list)} / {len(todo_data)}")

    print(f"全部推理完成，结果保存至: {OUT_PATH}")