from qwen_agent.agents import Assistant
from qwen_agent.utils.output_beautify import multimodal_typewriter_print
import argparse
import json
import os
import time
from typing import Dict, Any
from tqdm import tqdm
import statistics

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark", default='p-bench', type=str)
parser.add_argument("--model", default='qwen3-4b', type=str)
parser.add_argument("--out_dir", default='model_answer', type=str)
parser.add_argument("--temperature", default=0.7, type=float)
parser.add_argument("--api_base", default='http://localhost:18901/v1', type=str, help="vLLM服务地址")
parser.add_argument("--api_key", default='EMPTY', type=str, help="API密钥")
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--max_samples", default=None, type=int, help="最多处理样本数，None表示全部")
parser.add_argument("--use_tools", default=True, type=bool, help="是否使用agent tools")
args = parser.parse_args()


def get_benchmark_path(benchmark_name: str) -> str:
    """根据benchmark名称返回数据路径"""
    benchmark_paths = {
        'countqa': 'test.json',
        'hrbench-4k': 'hr_bench_4k.json',
        'hrbench-8k': 'hr_bench_8k.json',
        'vstar': 'vstar.json',
        'cvbench-2d': 'test_2d.json',
        'cvbench-3d': 'test_3d.json',
        'mmstar': 'MMStar.json',
        'babyvision': 'babyvision.json',
        'mme-realworld': 'MME_RealWorld.json',
        'mme-realworld-cn': 'MME_RealWorld_CN.json',
        'perception_bench_1': 'hallusion.json',
        'colorbench': 'test.json',
        'p-bench': 'pbench.json',
        'p-bench-crop': 'pbench.json',
        'fakeclue': 'fakeclue.json',
        'forensicsbench': 'forensicsbench.json',
        'loki': 'loki_image.json'
    }
    return benchmark_paths.get(benchmark_name, '')


def initialize_agent(
    api_base: str,
    api_key: str,
    temperature: float,
    model,
    use_tools: bool = False
) -> Assistant:
    """初始化Agent，与脚本B接口保持一致"""
    llm_cfg = {
        'model_type': 'qwenvl_oai',
        'model': model,
        'model_server': api_base,
        'api_key': api_key,
        'generate_cfg': {
            "top_p": 0.8,
            "top_k": 20,
            "temperature": temperature,
            "repetition_penalty": 1.0,
            "presence_penalty": 1.5,
            "max_tokens": 8192,
        }
    }

    analysis_prompt = """Your role is that of a research assistant specializing in visual information. Answer questions about images by looking at them closely and then using research tools. Please follow this structured thinking process and show your work.

Start an iterative loop for each question:

- **First, look closely:** Begin with a detailed description of the image, paying attention to the user's question. List what you can tell just by looking, and what you'll need to look up.
- **Next, find information:** Use a tool to research the things you need to find out.
- **Then, review the findings:** Carefully analyze what the tool tells you and decide on your next action.

Continue this loop until your research is complete.

To finish, bring everything together in a clear, synthesized answer that fully responds to the user's question."""

    tools = ['image_zoom_in_tool'] if use_tools else []
    
    agent = Assistant(
        llm=llm_cfg,
        function_list=tools,
        system_message=analysis_prompt,
    )
    return agent


if __name__ == '__main__':
    
    BENCHMARK_PATH = get_benchmark_path(args.benchmark)
    if not BENCHMARK_PATH:
        print(f"Benchmark '{args.benchmark}' not found.")
        exit()
    
    # 设置输出路径 (保持与脚本A一致的格式)
    out_dir = os.path.join(args.out_dir, args.benchmark)
    os.makedirs(out_dir, exist_ok=True)
    OUT_PATH = f'{out_dir}/{args.model}_seed{args.seed}_answer.json'
    
    # 断点续传逻辑：统计已处理的样本
    done_sample_ids = set()
    if os.path.exists(OUT_PATH):
        with open(OUT_PATH, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    identifier = data.get('images')[0] + data.get('query')
                    done_sample_ids.add(identifier)
                except:
                    continue
        print(f"检测到断点，已完成 {len(done_sample_ids)} 个样本。")

    # 加载benchmark数据
    with open(BENCHMARK_PATH, 'r', encoding='utf-8') as f:
        total_data = json.load(f)
    
    # 限制样本数
    if args.max_samples and len(total_data) > args.max_samples:
        total_data = total_data[:args.max_samples]
    
    # 过滤掉已经跑过的样本
    todo_data = [
        item for item in total_data 
        if (item.get('images')[0] + item.get('query')) not in done_sample_ids
    ]
    
    if not todo_data:
        print("所有样本已处理完毕。")
        exit()

    print(f"剩余待处理样本: {len(todo_data)}")

    # 初始化Agent
    agent = initialize_agent(
        api_base=args.api_base,
        api_key=args.api_key,
        temperature=args.temperature,
        model = args.model,
        use_tools=args.use_tools
    )

    # 推理循环
    inference_times = []
    
    for idx, item in enumerate(tqdm(todo_data, desc=f"Processing {args.benchmark}")):
        try:
            # 获取图片路径
            if args.benchmark == 'p-bench-crop':
                img_path = item['crop_images'][0]
            else:
                img_path = item['images'][0]
            
            if not os.path.exists(img_path):
                img_path = os.path.join('/mnt/nas/yanlong/datasets/MM-Eval/MM-Eval/', img_path)
            
            # 获取查询文本
            inst = item['query'].replace('<image>', '').strip()
            
            if not os.path.exists(img_path):
                print(f"跳过: 图片不存在 {img_path}")
                continue
            
            # 构建消息
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"image": img_path},
                        {"text": inst},
                    ]
                }
            ]
            
            # 开始计时
            start_time = time.time()
            
            # 调用Agent推理
            response_plain_text = ''
            for ret_messages in agent.run(messages):
                response_plain_text = multimodal_typewriter_print(ret_messages, response_plain_text)
            
            # 结束计时
            end_time = time.time()
            elapsed_time = end_time - start_time
            inference_times.append(elapsed_time)
            
            # 保存结果 (保持与脚本A一致的格式)
            result_item = item.copy()
            result_item['model_answer'] = response_plain_text
            
            with open(OUT_PATH, 'a', encoding='utf-8') as f_out:
                f_out.write(json.dumps(result_item, ensure_ascii=False) + '\n')
            
            # 每处理10个样本打印一次统计信息
            if (idx + 1) % 10 == 0:
                current_avg = statistics.mean(inference_times)
                print(f"\n已完成: {idx + 1} / {len(todo_data)}, 当前平均时延: {current_avg:.2f}s")
        
        except Exception as e:
            print(f"处理失败 (idx: {idx}): {e}")
            continue

    # 最终统计
    print(f"\n{'='*60}")
    print(f"推理完成 - Benchmark: {args.benchmark}")
    print(f"{'='*60}")
    print(f"总样本数: {len(total_data)}")
    print(f"成功处理: {len(inference_times)}")
    print(f"结果保存至: {OUT_PATH}")
    
    if inference_times:
        avg_latency = statistics.mean(inference_times)
        min_latency = min(inference_times)
        max_latency = max(inference_times)
        std_latency = statistics.stdev(inference_times) if len(inference_times) > 1 else 0
        
        print(f"\n时延统计:")
        print(f"  平均时延 (Mean):    {avg_latency:.2f}s")
        print(f"  最小时延 (Min):     {min_latency:.2f}s")
        print(f"  最大时延 (Max):     {max_latency:.2f}s")
        print(f"  标准差 (Std):       {std_latency:.2f}s")
        
        # 保存统计信息
        stats_path = f'{out_dir}/{args.model}_seed{args.seed}_stats.json'
        stats = {
            'benchmark': args.benchmark,
            'model': args.model,
            'total_samples': len(total_data),
            'processed_samples': len(inference_times),
            'avg_latency': avg_latency,
            'min_latency': min_latency,
            'max_latency': max_latency,
            'std_latency': std_latency,
            'inference_times': inference_times,
        }
        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)
        print(f"  统计信息保存至: {stats_path}")
    else:
        print("没有成功处理任何样本。")