from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor
from concurrent.futures import ThreadPoolExecutor # 引入多线程
import argparse
import re, json, torch
from tqdm import tqdm
import time
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch._dynamo

torch._dynamo.config.suppress_errors = True

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark", default='mathverse', type=str)
parser.add_argument("--model_path", default='Qwen/Qwen2.5-VL-7B-Instruct', type=str)
parser.add_argument("--model", default='qwen2.5vl-7b', type=str)
parser.add_argument("--out_dir", default='model_answer', type=str)
parser.add_argument("--temperature", default=0.7, type=float)
parser.add_argument("--gpus", default=1, type=int)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--batch_size", default=1024, type=int, help="每批处理的样本数，根据内存调整")
args = parser.parse_args()

def preprocess_item(item, processor, benchmark):
    """单条数据预处理函数，用于多线程调用"""
    try:
        if benchmark == 'p-bench-crop':
            img_path = item['crop_images'][0]
            # print(img_path)
        else:
            img_path = item['images'][0]
        if not os.path.exists(img_path):
            img_path = os.path.join('/mnt/nas/yanlong/datasets/MM-Eval/MM-Eval/', img_path)
        
        inst = item['query'].replace('<image>', '')
        messages = [
            # {"role": "system", "content": "You are a helpful assistant.\nThe assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The answer are enclosed within <answer> </answer> tags, respectively, i.e., reasoning process here <answer> answer here </answer>"},
            {"role": "system", "content": "You are a helpful assistant.\n The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The answer are enclosed within <answer> </answer> tags, respectively, i.e., reasoning process here <answer> answer here </answer>"},
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img_path, "min_pixels": 65536, "max_pixels": 16777216},
                    {"type": "text", "text": inst},
                ],
            },
        ]
        
        prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        
        mm_data = {}
        if image_inputs is not None: mm_data["image"] = image_inputs
        if video_inputs is not None: mm_data["video"] = video_inputs
        
        return {"prompt": prompt, "multi_modal_data": mm_data, "raw_item": item}
    except Exception as e:
        print(f"预处理错误: {e}")
        return None




if __name__ == '__main__':
    
    MODEL_PATH = args.model_path
    if args.benchmark == 'countqa':
        BENCHMARK_PATH = 'test.json'
    if args.benchmark == 'hrbench-4k':
        BENCHMARK_PATH = 'hr_bench_4k.json'
    if args.benchmark == 'hrbench-8k':
        BENCHMARK_PATH = 'hr_bench_8k.json'
    if args.benchmark == 'vstar':
        BENCHMARK_PATH = 'vstar.json'
    if args.benchmark == 'cvbench-2d':
        BENCHMARK_PATH = 'test_2d.json'
    if args.benchmark == 'cvbench-3d':
        BENCHMARK_PATH = 'test_3d.json'
    if args.benchmark == 'mmstar':
        BENCHMARK_PATH = 'MMStar.json'
    if args.benchmark == 'babyvision':
        BENCHMARK_PATH = 'babyvision.json'
    if args.benchmark == 'mme-realworld':
        BENCHMARK_PATH = 'MME_RealWorld.json'
    if args.benchmark == 'mme-realworld-cn':
        BENCHMARK_PATH = 'MME_RealWorld_CN.json'
    if args.benchmark == 'perception_bench_1':
        BENCHMARK_PATH = 'hallusion.json'
    if args.benchmark == 'colorbench':
        BENCHMARK_PATH = 'test.json'
    if args.benchmark == 'p-bench':
        BENCHMARK_PATH = 'pbench.json'
    if args.benchmark == 'p-bench-crop':
        BENCHMARK_PATH = 'pbench.json'
    if args.benchmark == "fakeclue":
        BENCHMARK_PATH = 'fakeclue.json'
    if args.benchmark == "forensicsbench":
        BENCHMARK_PATH = 'forensicsbench.json'
    if args.benchmark == "loki":
        BENCHMARK_PATH = 'loki_image.json'



    tensor_parallel_size = args.gpus
    gpu_memory_utilization = 0.8
    # 1. 设置输出路径 (使用 .jsonl 格式以便追加)
    out_dir = os.path.join(args.out_dir, args.benchmark)
    os.makedirs(out_dir, exist_ok=True)
    OUT_PATH = f'{out_dir}/{args.model}_seed{args.seed}_answer.json'
    
    # 2. 断点续传逻辑：统计已处理的样本
    done_sample_ids = set()
    if os.path.exists(OUT_PATH):
        with open(OUT_PATH, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    # 优先使用唯一标识符，如果没有则需确保数据顺序一致
                    # 这里假设数据中有 'sample_index' 或 'question_id'
                    # 如果没有唯一ID，可以使用 item 中的 query + images 路径作为特征
                    identifier = data.get('images')[0] + data.get('query')
                    # print(identifier)
                    done_sample_ids.add(identifier)
                except:
                    continue
        print(f"检测到断点，已完成 {len(done_sample_ids)} 个样本。")

    with open(BENCHMARK_PATH, 'r', encoding='utf-8') as f:
        total_data = json.load(f)
    
    # 过滤掉已经跑过的样本
    todo_data = [
        item for item in total_data 
        if (item.get('images')[0] + item.get('query')) not in done_sample_ids
    ]
    
    if not todo_data:
        print("所有样本已处理完毕。")
        exit()

    print(f"剩余待处理样本: {len(todo_data)}")

    # 4. 初始化模型
    llm = LLM(
        model=MODEL_PATH,
        limit_mm_per_prompt={"image": 1, "video": 1},
        dtype=torch.bfloat16, 
        gpu_memory_utilization=0.8, 
        tensor_parallel_size=args.gpus,
        trust_remote_code=True
    )

    sampling_params = SamplingParams(
        max_tokens=8192,
        temperature=args.temperature,
        top_p=0.8,
        top_k=20,
        presence_penalty=1.5,
        repetition_penalty=1.0,
        seed=args.seed,
    )

    processor = AutoProcessor.from_pretrained(MODEL_PATH, max_pixels=16777216)

    total_latency = 0
    total_output_tokens = 0
    processed_count = 0

    # 5. 分批循环推理
    batch_size = args.batch_size
    for i in range(0, len(todo_data), batch_size):
        chunk = todo_data[i : i + batch_size]
        prompt_list = []
        valid_chunk_items = []

        print(f"正在并行预处理图片... (Batch {i//batch_size})")
        with ThreadPoolExecutor(max_workers=32) as executor:
            # 这里的 max_workers 根据你的 CPU 核心数调整
            results = list(executor.map(lambda x: preprocess_item(x, processor, args.benchmark), chunk))
        
        # 过滤掉失败的样本
        valid_results = [r for r in results if r is not None]
        prompt_list = [{"prompt": r["prompt"], "multi_modal_data": r["multi_modal_data"]} for r in valid_results]
        valid_chunk_items = [r["raw_item"] for r in valid_results]

        if not prompt_list:
            continue
            
        # --- 计时：模型推理 ---
        gen_start = time.time()
        outputs = llm.generate(prompt_list, sampling_params=sampling_params)
        gen_end = time.time()
        
        batch_latency = gen_end - gen_start
        total_latency += batch_latency
        
        # 写入结果并统计 Token
        with open(OUT_PATH, 'a', encoding='utf-8') as f_out:
            for j, output in enumerate(outputs):
                response = output.outputs[0].text
                # 获取生成 token 的数量
                out_tokens = len(output.outputs[0].token_ids)
                total_output_tokens += out_tokens
                
                result_item = valid_chunk_items[j].copy()
                result_item['model_answer'] = response
                
                # 计算该 batch 的平均 token 时延 (近似值)
                # vllm 是并行的，所以单条时延通常用 batch_time / batch_size 描述
                result_item['gen_time_sec'] = round(batch_latency, 3) 
                result_item['output_tokens'] = out_tokens
                result_item['tokens_per_sec'] = round(out_tokens / (batch_latency / len(outputs)), 2) if batch_latency > 0 else 0
                
                f_out.write(json.dumps(result_item, ensure_ascii=False) + '\n')
        
        processed_count += len(prompt_list)
        print(f"已完成: {processed_count} / {len(todo_data)} | "
              f"Batch耗时: {batch_latency:.2f}s | "
              f"平均速度: {total_output_tokens / total_latency:.2f} tokens/s")

    # 7. 打印最终汇总报告
    print("\n" + "="*50)
    print("推理任务完成统计:")
    print(f"处理总样本数: {processed_count}")
    print(f"总推理耗时: {total_latency:.2f} 秒")
    print(f"生成总 Token 数: {total_output_tokens}")
    if total_latency > 0:
        print(f"系统吞吐量 (Throughput): {total_output_tokens / total_latency:.2f} tokens/s")
        print(f"单样本平均推理耗时: {total_latency / processed_count:.3f} 秒")
    print(f"结果保存至: {OUT_PATH}")
    print("="*50)