import argparse
import json
import os
import re
import time
import random
import requests
from datetime import datetime
from tqdm import tqdm
from transformers import AutoTokenizer
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Tuple
from collections import defaultdict
import torch

# 添加API相关导入
try:
    from openai import OpenAI
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False
    print("Warning: openai package not installed. API mode will not work.")

# vLLM相关导入
try:
    from vllm import LLM, SamplingParams
    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False
    print("Warning: vllm package not installed. Local mode will not work.")

# 模型名称映射 (从hotpotqa代码迁移)
MODEL_MAP = {
    "4.1mini": "gpt-4.1-mini-2025-04-14",
    "4omini": "gpt-4o-mini-2024-07-18",
}

# API配置 (从hotpotqa代码迁移)
API_CONFIG = {
    "77": {
        "api_key": os.getenv("OPENAI_API_KEY_77", "sk-flaJoywuvjLlJaDsGwpsOiyGFoRN9rVkaLe5NwAjykpoUmzs"),
        "base_url": "https://api.key77qiqi.com/v1",
    },
}

# ChatGPT系统提示 (适配MMLUPRO风格)
CHATGPT_SYSTEM_PROMPT = '''
You are an advanced AI assistant tasked with a multiple-choice question.
You excel at providing comprehensive, well-structured answers with multiple paragraphs.
Each paragraph you write contains multiple sentences that thoroughly explore the topic.
You always follow formatting instructions precisely.
'''


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--subjects", nargs='+', default=[], help="要评估的科目列表")
    parser.add_argument("--start_sample", type=int, default=0)
    parser.add_argument("--end_sample", type=int, default=100000)
    parser.add_argument("--max_samples", type=int, default=0)
    parser.add_argument("--subset_num", type=int, default=-1, help="每个科目的样本数限制")
    parser.add_argument("--src_file", type=str, default="/eval/dataset/mmlupro/mmlupro_test.jsonl")
    
    # 添加eval_type参数
    parser.add_argument("--eval_type", type=str, default="local", choices=["local", "api"], 
                        help="评估类型：local(本地模型) 或 api(API模型)")
    
    # 本地模型参数
    parser.add_argument("--model_path", type=str, help="本地模型路径")
    parser.add_argument("--gpu_id", type=str, default="0")
    parser.add_argument("--gpu_memory_rate", type=float, default=0.95)
    
    # API模型参数
    parser.add_argument("--model_name", type=str, default="4omini", choices=["4omini", "4.1mini"],
                        help="API模型名称")
    parser.add_argument("--api_key_name", type=str, default="77", help="API密钥名称")
    
    # 通用参数
    parser.add_argument("--port", type=str, default="8000")
    parser.add_argument("--temp", type=float, default=0.0)
    parser.add_argument("--top_p", type=float, default=0.5)
    parser.add_argument("--top_k", type=int, default=10)
    parser.add_argument("--topk", type=int, default=10)
    parser.add_argument("--max_revisions", type=int, default=3)
    parser.add_argument("--split_char", type=str, default="\n\n")
    parser.add_argument("--debug", action="store_true", help="显示详细调试信息")
    return parser.parse_args()


# ============ 工具函数 ============
def clean_draft(draft):
    """清洗draft中的格式问题"""
    if not draft:
        return ""
    
    # 替换各种可能的反引号分隔符组合
    draft = draft.replace('\n\n`\n\n', '\n\n')
    draft = draft.replace('\n`\n', '\n\n')
    draft = draft.replace('`\n\n', '\n\n')
    draft = draft.replace('\n\n`', '\n\n')
    draft = draft.replace('\n\n```\n\n', '\n\n')
    draft = draft.replace('```\n\n', '\n\n')
    draft = draft.replace('\n\n```', '\n\n')
    
    # 清理多余的换行符（超过2个连续换行符的情况）
    while '\n\n\n' in draft:
        draft = draft.replace('\n\n\n', '\n\n')
    
    # 清理开头和结尾的空白
    draft = draft.strip()
    
    return draft


def apply_chat_template(tokenizer, model_short_name, prompt):
    """应用聊天模板"""
    if "it" in model_short_name or "instruct" in model_short_name:
        messages_chat = [
            {"role": "user", "content": prompt}
        ]
        return tokenizer.apply_chat_template(messages_chat, tokenize=False, add_generation_prompt=True)
    return prompt


# ============ 本地模型生成函数 ============
def generate_local_batch(llm, tokenizer, model_short_name, prompts, max_tokens, temperature, top_p, top_k):
    """使用vLLM批量生成"""
    # 应用聊天模板
    formatted_prompts = [apply_chat_template(tokenizer, model_short_name, p) for p in prompts]
    
    # 设置采样参数
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        max_tokens=max_tokens,
    )
    
    # 批量生成
    outputs = llm.generate(formatted_prompts, sampling_params)
    
    # 提取生成的文本
    results = []
    for output in outputs:
        if output.outputs:
            results.append(output.outputs[0].text.strip())
        else:
            results.append("")
    
    return results


# ============ API模型生成函数 (从hotpotqa代码迁移) ============
def generate_api_single(client, model, prompt, max_tokens, temperature, top_p, stop=None):
    """单个 API 调用生成"""
    messages = [{"role": "user", "content": prompt}]
    
    for _ in range(3):
        try:
            response = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens,
                stop=stop,
                timeout=30.0
            )
            if response.choices and response.choices[0].message and response.choices[0].message.content is not None:
                return response.choices[0].message.content
            else:
                print(f"Warning: invalid response for prompt: {prompt[:50]}...")
        except Exception as e:
            print(f"API 调用失败: {e}")
            time.sleep(random.uniform(0.1, 1.0))
    return ""


def generate_api_batch(client, model, prompts, max_tokens, temperature, top_p, stop=None):
    """API 批量生成 - 使用128个并发workers"""
    def generate_single(prompt):
        return generate_api_single(client, model, prompt, max_tokens, temperature, top_p, stop)
    
    results = [None] * len(prompts)
    
    with ThreadPoolExecutor(max_workers=128) as executor:
        # 提交所有任务
        future_to_index = {executor.submit(generate_single, prompt): i 
                          for i, prompt in enumerate(prompts)}
        
        # 使用tqdm显示完成进度
        for future in tqdm(as_completed(future_to_index), total=len(prompts), desc="Processing"):
            index = future_to_index[future]
            try:
                results[index] = future.result()
            except Exception as e:
                print(f"\nError processing request {index}: {e}")
                results[index] = ""
    
    return results


# ============ 检索函数 ============
def batch_retrieve(queries, port, topk=10):
    """批量检索 - 一次性处理所有查询"""
    if not queries:
        return [], []
    
    url_wiki = f"http://localhost:{port}/queries"
    max_retries = 5
    
    print(f"Batch retrieving {len(queries)} queries...")
    
    for attempt in range(max_retries):
        try:
            response = requests.post(url_wiki, json={"queries": queries, "k": topk})
            if response.status_code == 200:
                results = response.json()["answers"]
                scores = []
                for result in results:
                    doc_scores = [1.0 / (i + 1) for i in range(len(result))]
                    scores.append(doc_scores)
                return results, scores
            time.sleep(random.uniform(0.1, 0.5))
        except Exception as e:
            print(f"Retrieval error attempt {attempt + 1}: {e}")
            time.sleep(random.uniform(0.5, 1.0))
    
    return [[] for _ in queries], [[] for _ in queries]


# ============ MMLUPRO任务特定函数 ============
def format_question_with_options(question, options):
    """格式化问题和选项 - 适配MMLUPRO格式"""
    option_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
    formatted_question = question
    if options:
        formatted_question += "\n\nOptions:"
        for i, option in enumerate(options):
            if i < len(option_letters) and option:  # 确保option不为空
                formatted_question += f"\n{option_letters[i]}. {option}"
    return formatted_question


def extract_choice_answer(text, options):
    """从文本中提取选择题答案 - 适配MMLUPRO格式"""
    if not text:
        return ""
    
    # 1. 首先尝试查找所有 \box{} 格式的内容
    box_patterns = [
        r'\\boxed\{([^}]+)\}',  # 匹配任何在 \box{} 中的内容
        r'\\boxed\s*\{([^}]+)\}',  # 允许 \box 和 { 之间有空格
    ]
    
    all_box_contents = []
    for pattern in box_patterns:
        matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
        all_box_contents.extend(matches)
    
    if all_box_contents:
        # 优先使用最后一个 box 内容（通常是最终答案）
        for box_content in reversed(all_box_contents):
            box_content = box_content.strip()
            
            # 2. 检查是否是单个字母（A-J）
            single_letter_match = re.match(r'^([A-J])\.?\s*', box_content, re.IGNORECASE)
            if single_letter_match:
                return single_letter_match.group(1).upper()
            
            # 3. 检查是否包含选项字母开头（如 "C. Coarctation of the aorta"）
            letter_prefix_match = re.match(r'^([A-J])\.?\s+(.+)$', box_content, re.IGNORECASE)
            if letter_prefix_match:
                return letter_prefix_match.group(1).upper()
            
            # 4. 如果只有答案内容，尝试匹配到选项
            if options:
                # 清理 box 内容，去除多余的标点
                cleaned_box = box_content.strip('.').strip()
                
                # 尝试精确匹配
                for i, option in enumerate(options):
                    if i < 10 and option:  # 只处理 A-J 且选项不为空
                        option_text = option.strip().strip('.').strip()
                        if cleaned_box.lower() == option_text.lower():
                            return chr(65 + i)  # 返回对应的字母 A, B, C...
                
                # 尝试部分匹配（box内容是选项的一部分）
                for i, option in enumerate(options):
                    if i < 10 and option:
                        option_text = option.strip().strip('.').strip()
                        if cleaned_box.lower() in option_text.lower() or option_text.lower() in cleaned_box.lower():
                            return chr(65 + i)
    
    # 5. 如果没有找到 \box{} 格式，尝试其他常见格式
    fallback_patterns = [
        r'the correct answer is:?\s*([A-J])',
        r'answer is:?\s*([A-J])',
        r'Answer:\s*([A-J])',
        r'\[Answer\][\s:]*([A-J])',
        r'Final answer:?\s*([A-J])',
        r'Therefore,?\s+(?:the\s+)?answer\s+is\s*([A-J])'
    ]
    
    for pattern in fallback_patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).upper()
    
    # 6. 最后尝试查找独立的字母（在文本末尾）
    last_letter_match = re.findall(r'\b([A-J])\b', text[-200:], re.IGNORECASE)
    if last_letter_match:
        return last_letter_match[-1].upper()
    
    return ""


def get_options_from_item(item):
    """从item中提取选项 - 适配MMLUPRO格式"""
    # 优先使用options字段
    if 'options' in item and item['options']:
        return item['options']
    
    # 否则尝试从A, B, C, D等键中提取
    options = []
    for key in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']:
        if key in item and item[key]:
            options.append(item[key])
    
    return options


# ============ 统一的批量处理函数 ============
def batch_get_drafts(llm_or_client, model_info, questions_data, temp, top_p, top_k=None, eval_type="local"):
    """批量获取初始草稿答案 - 支持本地和API模型"""
    draft_prompt = '''
IMPORTANT: Structure your response as follows:

1. Write a comprehensive answer with MULTIPLE PARAGRAPHS (3-6 paragraphs typically).

2. Each paragraph MUST contain AT LEAST 4 complete sentences. Single-sentence paragraphs are NOT acceptable.

3. Separate paragraphs with blank lines (press Enter twice).

4. At the very end, after all paragraphs, add your final answer in this format:
Only put the letter in the box, e.g. \\boxed{A}. There is only one correct answer.
'''
    
    prompts = []
    for item in questions_data:
        options = get_options_from_item(item)
        formatted_question = format_question_with_options(item["question"], options)
        prompt = f"{CHATGPT_SYSTEM_PROMPT}\n\n{formatted_question}\n{draft_prompt}"
        prompts.append(prompt)
    
    print(f"Generating {len(prompts)} drafts in batch...")
    
    if eval_type == "local":
        llm, tokenizer, model_short_name = model_info
        all_drafts = generate_local_batch(
            llm, tokenizer, model_short_name, prompts,
            max_tokens=1536,
            temperature=temp,
            top_p=top_p,
            top_k=top_k
        )
    else:  # api
        client, model_name = model_info
        all_drafts = generate_api_batch(
            client, model_name, prompts,
            max_tokens=1536,
            temperature=temp,
            top_p=top_p
        )
    
    # 清洗所有生成的drafts
    cleaned_drafts = [clean_draft(draft) for draft in all_drafts]
    
    return cleaned_drafts


def batch_get_queries(llm_or_client, model_info, questions, answers, temp, top_p, top_k=None, eval_type="local"):
    """批量生成检索查询 - 支持本地和API模型"""
    query_prompt = '''
Based on the question and the current answer content, generate a search query to verify or find additional information.

Please summarize the content with the corresponding question.
This summarization will be used as a query to search with Bing search engine.
The query should be short but need to be specific to promise Bing can find related knowledge or pages.
You can also use search syntax to make the query short and clear enough for the search engine to find relevant language data.
Try to make the query as relevant as possible to the last few sentences in the content.
**IMPORTANT**
Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
    
    prompts = []
    for q, a in zip(questions, answers):
        prompt = f"##Question: {q}\n\n##Current Answer: {a}\n\n##Instruction: {query_prompt}"
        prompts.append(prompt)
    
    print(f"Generating {len(prompts)} queries in batch...")
    
    if eval_type == "local":
        llm, tokenizer, model_short_name = model_info
        all_queries = generate_local_batch(
            llm, tokenizer, model_short_name, prompts,
            max_tokens=128,
            temperature=temp,
            top_p=top_p,
            top_k=top_k
        )
    else:  # api
        client, model_name = model_info
        all_queries = generate_api_batch(
            client, model_name, prompts,
            max_tokens=128,
            temperature=temp,
            top_p=top_p
        )
    
    return [q.strip() for q in all_queries]


def batch_revise_answers(llm_or_client, model_info, questions, answers, retrieved_docs_list, temp, top_p, top_k=None, eval_type="local"):
    """批量修改答案 - 支持本地和API模型"""
    revise_prompt = '''
I want to revise the answer according to retrieved related text of the question.
You need to check whether the answer is correct.
If you find some errors in the answer, revise the answer to make it better.
If you find some necessary details are ignored, add it to make the answer more plausible according to the related text.

**IMPORTANT**
1. Keep the structure with multiple substantial paragraphs.
2. Use blank lines to separate paragraphs (press Enter twice).
3. If the original answer has \\boxed{...} at the end, you MUST keep it and update it if needed.
4. Only put the letter in the box, e.g. \\boxed{A}. There is only one correct answer.

Just output the revised paragraphs directly, including the \\boxed{} if present.
'''
    
    def chunk_text(text, max_length=1500):
        if len(text) <= max_length:
            return text
        return text[:max_length]
    
    def format_retrieved_docs(docs):
        if not docs:
            return ""
        formatted_docs = []
        for i, doc in enumerate(docs):
            doc_clean = re.sub(r'^\d+\s+', '', doc)
            formatted_docs.append(f"Document {i+1}:\n{doc_clean}\n")
        return '\n'.join(formatted_docs)
    
    prompts = []
    for q, a, docs in zip(questions, answers, retrieved_docs_list):
        content_str = chunk_text(format_retrieved_docs(docs))
        prompt = f"{CHATGPT_SYSTEM_PROMPT}\n\n##Retrieved Text: {content_str}\n\n##Question: {q}\n\n##Answer: {a}\n\n##Instruction: {revise_prompt}"
        prompts.append(prompt)
    
    print(f"Revising {len(prompts)} answers in batch...")
    
    if eval_type == "local":
        llm, tokenizer, model_short_name = model_info
        all_revised = generate_local_batch(
            llm, tokenizer, model_short_name, prompts,
            max_tokens=2048,
            temperature=temp,
            top_p=top_p,
            top_k=top_k
        )
    else:  # api
        client, model_name = model_info
        all_revised = generate_api_batch(
            client, model_name, prompts,
            max_tokens=2048,
            temperature=temp,
            top_p=top_p
        )
    
    # 清洗所有修改后的答案
    cleaned_revised = [clean_draft(answer) for answer in all_revised]
    
    return cleaned_revised


def split_draft(draft, split_char='\n\n'):
    """将草稿切分为多个段落，特殊处理\\boxed{}答案"""
    # 先进行清洗（以防万一）
    draft = clean_draft(draft)
    
    # 1. 找到最后一个 \box{ 的位置并提取内容
    box_answer = ""
    last_box_pos = draft.rfind('\\boxed{')
    
    if last_box_pos != -1:
        # 从最后一个 \box{ 开始，找到对应的闭合括号
        pos = last_box_pos + 7  # len('\\box{')
        brace_count = 1
        while pos < len(draft) and brace_count > 0:
            if draft[pos] == '{':
                brace_count += 1
            elif draft[pos] == '}':
                brace_count -= 1
            pos += 1
        
        if brace_count == 0:
            # 提取完整的最后一个 \box{...}
            box_answer = draft[last_box_pos:pos]
        
        # 找到第一个 \box{ 的位置，截断之后的所有内容
        first_box_pos = draft.find('\\boxed{')
        if first_box_pos != -1:
            draft = draft[:first_box_pos].strip()
    
    # 2. 分割段落
    paragraphs = draft.split(split_char)
    # 过滤掉空段落
    paragraphs = [p.strip() for p in paragraphs if p.strip()]
    
    # 3. 段落合并逻辑：少于1000字符的段落自动合并下一个
    merged_paragraphs = []
    i = 0
    while i < len(paragraphs):
        current = paragraphs[i]
        
        # 如果当前段落少于1000字符，尝试合并后续段落
        while len(current) < 500 and i + 1 < len(paragraphs):
            # 用空格替换段落间的换行符
            current = current + " " + paragraphs[i + 1]
            i += 1
        
        merged_paragraphs.append(current)
        
        i += 1
    
    return merged_paragraphs, box_answer

def preserve_box_answer(old_text, new_text):
    """确保保留\\boxed{}答案（如果新文本中没有的话）"""
    # 查找所有 box 格式的内容
    old_box_pattern = r'\\boxed\s*\{[^}]+\}'
    old_boxes = re.findall(old_box_pattern, old_text, re.IGNORECASE | re.DOTALL)
    new_boxes = re.findall(old_box_pattern, new_text, re.IGNORECASE | re.DOTALL)
    
    # 如果新文本没有box答案，但旧文本有，则添加最后一个box到末尾
    if old_boxes and not new_boxes:
        last_box = old_boxes[-1]
        # 确保文本末尾有适当的换行
        if not new_text.endswith('\n'):
            new_text += '\n\n'
        new_text += last_box
    
    return new_text


# ============ 主批量RAT函数 ============
def rat_batch(llm_or_client, model_info, data_items, args):
    """批量执行RAT - 支持本地和API模型"""
    num_questions = len(data_items)
    
    print(f"\n{'='*60}")
    print(f"Starting RAT for {num_questions} questions")
    if args.eval_type == "local":
        print(f"Using local model with vLLM")
    else:
        print(f"Using API model with high concurrency mode (128 workers)")
    print(f"{'='*60}")
    
    # 添加检查
    if num_questions == 0:
        print("ERROR: No questions to process!")
        return []
    
    # 步骤1: 批量生成初始草稿
    print(f"\n[Step 1] Generating initial drafts for all {num_questions} questions...")
    drafts = batch_get_drafts(llm_or_client, model_info, data_items, args.temp, args.top_p, 
                              args.top_k if args.eval_type == "local" else None, args.eval_type)
    
    # 初始化结果
    results = []
    paragraph_counts = []  # 统计段落数
    
    for i, (item, draft) in enumerate(zip(data_items, drafts)):
        paragraphs, box_answer = split_draft(draft, args.split_char)
        paragraph_counts.append(len(paragraphs))
        
        # 获取正确答案
        correct_answer = item.get("answer", "")
        
        # 获取类别
        category = item.get("category") if item.get("category") else item.get("source", "unknown")
        
        results.append({
            "question": item["question"],
            "options": get_options_from_item(item),
            "answer": correct_answer,
            "category": category,  # 添加类别信息
            "draft": draft,
            "current_answer": draft,
            "original_paragraphs": paragraphs.copy(),  # 保存原始段落
            "box_answer": box_answer,  # 单独存储box答案
            "processed_content": "",  # 存储已处理的内容（不含当前段落）
            "retrieval_history": [],
            "pred_ans": ""
        })
        
        # 调试模式下显示前几个例子的段落情况
        if hasattr(args, 'debug') and args.debug and i < 3:
            print(f"\n[DEBUG] Question {i+1} (Category: {category}): {item['question'][:60]}...")
            print(f"Original draft length: {len(draft)} chars")
            print(f"Number of paragraphs: {len(paragraphs)}")
            for j, para in enumerate(paragraphs):
                print(f"Paragraph {j+1} ({len(para)} chars): {para[:80]}...")
            if box_answer:
                print(f"✓ Found {box_answer} in draft")
    
    # 检查是否有结果
    if not results:
        print("ERROR: No results generated!")
        return []
    
    # 显示段落数分布
    print(f"\nParagraph distribution:")
    from collections import Counter
    para_dist = Counter(paragraph_counts)
    for count, freq in sorted(para_dist.items()):
        print(f"  {count} paragraphs: {freq} questions ({freq/len(results)*100:.1f}%)")
    
    # 获取最大段落数
    max_paragraphs = max(paragraph_counts) if paragraph_counts else 0
    print(f"\nMaximum number of paragraphs: {max_paragraphs}")
    
    # 步骤2: 逐段批量处理
    for para_idx in range(max_paragraphs):
        print(f"\n{'='*40}")
        print(f"[Step {para_idx + 2}] Processing paragraph {para_idx + 1}/{max_paragraphs}")
        
        # 收集需要处理的问题和答案
        questions_to_process = []
        answers_to_process = []
        indices_to_process = []
        
        for i, result in enumerate(results):
            if para_idx < len(result["original_paragraphs"]):
                # 构建当前答案：已处理的内容 + 当前段落 + box答案
                if para_idx == 0:
                    # 第一个段落
                    current_answer = result["original_paragraphs"][0]
                else:
                    # 使用已处理的内容 + 当前段落
                    current_answer = result["processed_content"] + '\n\n' + result["original_paragraphs"][para_idx]
                
                # 添加box答案（如果有的话）
                if result["box_answer"]:
                    if not current_answer.endswith('\n'):
                        current_answer += '\n\n'
                    current_answer += result["box_answer"]
                
                result["current_answer"] = current_answer
                
                # 格式化问题（包含选项）
                formatted_question = format_question_with_options(result["question"], result.get("options", []))
                
                questions_to_process.append(formatted_question)
                answers_to_process.append(current_answer)
                indices_to_process.append(i)
        
        if not questions_to_process:
            print(f"No questions have paragraph {para_idx + 1}, skipping...")
            continue
        
        print(f"Processing {len(questions_to_process)} questions for paragraph {para_idx + 1}...")
        
        # 批量生成查询
        print(f"Generating {len(questions_to_process)} queries...")
        queries = batch_get_queries(llm_or_client, model_info, questions_to_process, answers_to_process, 
                                    args.temp, args.top_p, args.top_k if args.eval_type == "local" else None, 
                                    args.eval_type)
        
        # 批量检索
        print(f"Performing batch retrieval...")
        all_retrieved_docs, all_scores = batch_retrieve(queries, args.port, args.topk)
        
        # 记录检索历史
        for idx, orig_idx in enumerate(indices_to_process):
            results[orig_idx]["retrieval_history"].append({
                "paragraph_idx": para_idx,
                "query": queries[idx] if idx < len(queries) else "",
            })
        
        # 对每个修改轮次进行批量处理
        for revision_round in range(args.max_revisions):
            print(f"\nRevision round {revision_round + 1}/{args.max_revisions}")
            
            # 准备要修改的文档
            docs_for_revision = []
            for idx, orig_idx in enumerate(indices_to_process):
                if idx < len(all_retrieved_docs) and revision_round < len(all_retrieved_docs[idx]):
                    docs_for_revision.append([all_retrieved_docs[idx][revision_round]])
                else:
                    docs_for_revision.append([])
            
            # 批量修改答案
            current_answers = [results[idx]["current_answer"] for idx in indices_to_process]
            revised_answers = batch_revise_answers(
                llm_or_client, model_info,
                questions_to_process, 
                current_answers, 
                docs_for_revision,
                args.temp,
                args.top_p,
                args.top_k if args.eval_type == "local" else None,
                args.eval_type
            )
            
            # 更新答案
            for idx, orig_idx in enumerate(indices_to_process):
                if idx < len(revised_answers) and revised_answers[idx] and revised_answers[idx].strip():
                    # 清洗修改后的答案
                    cleaned_answer = clean_draft(revised_answers[idx])
                    
                    # 确保保留\\boxed{}答案
                    cleaned_answer = preserve_box_answer(results[orig_idx]["current_answer"], cleaned_answer)
                    
                    results[orig_idx]["current_answer"] = cleaned_answer
                    
                    # 更新box答案（如果有新的）
                    new_box_match = re.search(r'\\boxed\{([A-J])\}', cleaned_answer, re.IGNORECASE)
                    if new_box_match:
                        results[orig_idx]["box_answer"] = new_box_match.group(0)
        
        # 段落处理完成后，更新已处理的内容
        for idx in indices_to_process:
            # 从当前答案中提取内容（不包括box）
            current_full_answer = results[idx]["current_answer"]
            
            # 移除box答案，得到纯内容
            content_without_box = current_full_answer
            if results[idx]["box_answer"]:
                box_pattern = re.escape(results[idx]["box_answer"])
                content_without_box = re.sub(f'\\s*{box_pattern}\\s*$', '', content_without_box).strip()
            
            # 更新已处理的内容（用于下一段落的处理）
            results[idx]["processed_content"] = content_without_box
    
    # 步骤3: 构建最终答案并提取答案框内容
    print(f"\n[Final Step] Constructing final answers...")
    
    for result in results:
        # 获取当前答案（已处理的所有内容）
        final_content = result["current_answer"]
        
        # 设置最终答案
        result["final_answer"] = final_content
        result["full_output"] = final_content  # 添加full_output字段以兼容评估函数
        
        # 提取pred_ans（适配选择题）
        result["pred_ans"] = extract_choice_answer(final_content, result.get("options", []))
        
        # 调试模式下显示提取结果
        if hasattr(args, 'debug') and args.debug and result["pred_ans"]:
            print(f"\n[DEBUG] Category: {result['category']}, Extracted answer: {result['pred_ans']}")
    
    return results


def evaluate_merged_results(results, subjects):
    """评估合并后的结果 - MMLUPRO风格"""
    # 初始化统计变量
    metrics = {}
    all_results = {}
    total_correct = 0
    total_questions = 0
    categorized_results = defaultdict(list)
    categorized_correct = defaultdict(int)
    categorized_total = defaultdict(int)
    
    # 处理每个结果
    for result in results:
        category = result.get("category", "unknown")
        
        # 只处理指定的subjects
        if subjects and category not in subjects:
            continue
        
        # 判断是否正确
        pred = result.get("pred_ans", "")
        gt = result.get("answer", "")
        is_correct = pred.upper() == gt.upper() if pred and gt else False
        
        # 构建详细结果
        detailed_result = {
            "question": result["question"],
            "options": result.get("options", []),
            "subject": category,
            "gt_answer": gt,
            "pred_answer": pred,
            "is_correct": is_correct,
            "full_output": result.get("final_answer", "")
        }
        
        # 更新统计
        categorized_results[category].append(detailed_result)
        categorized_correct[category] += 1 if is_correct else 0
        categorized_total[category] += 1
        total_correct += 1 if is_correct else 0
        total_questions += 1
    
    # 计算每个类别的准确率
    for subject in subjects if subjects else categorized_results.keys():
        correct_count = categorized_correct[subject]
        total_count = categorized_total[subject]
        acc = correct_count / total_count if total_count > 0 else 0.0
        metrics[subject] = acc
        all_results[subject] = categorized_results[subject]
        print(f"[✓] {subject}: {acc:.2%} ({correct_count}/{total_count})")
    
    # 计算平均准确率（subjects的平均值）
    if metrics:
        avg_acc = round(sum(float(v) for v in metrics.values()) / len(metrics), 4)
    else:
        avg_acc = 0.0
    
    print(f"\n[✔] Average Accuracy: {avg_acc:.2%}")
    
    return metrics, all_results, avg_acc


def main():
    print("=" * 80)
    print("RAT (Retrieval-Augmented Thoughts) System - MMLUPRO Version")
    print("支持本地模型和API模型评估")
    print("=" * 80)
    
    args = parse_args()
    
    # 打印参数用于调试
    print(f"\nParameters:")
    print(f"- Evaluation type: {args.eval_type}")
    print(f"- Source file: {args.src_file}")
    print(f"- Subjects: {args.subjects if args.subjects else 'All'}")
    if args.eval_type == "local":
        print(f"- Model path: {args.model_path}")
        print(f"- GPU ID: {args.gpu_id}")
        print(f"- GPU memory utilization: {args.gpu_memory_rate}")
    else:
        print(f"- API Model: {args.model_name}")
        print(f"- API Key Name: {args.api_key_name}")
    print(f"- Samples: {args.start_sample} to {args.end_sample}")
    print(f"- Max samples: {args.max_samples}")
    print(f"- Subset num: {args.subset_num}")
    print(f"- Max revisions: {args.max_revisions}")
    
    # 检查文件是否存在
    if not os.path.exists(args.src_file):
        print(f"\nERROR: Source file not found: {args.src_file}")
        return
    
    # 根据评估类型初始化模型
    if args.eval_type == "local":
        # 检查vLLM是否可用
        if not VLLM_AVAILABLE:
            print("\nERROR: vLLM not installed. Cannot use local mode.")
            print("Please install vLLM: pip install vllm")
            return
            
        # 检查模型路径
        if not args.model_path:
            print("\nERROR: --model_path required for local mode")
            return
            
        # 设置GPU
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
        
        # 初始化本地模型
        num_gpus = torch.cuda.device_count()
        print(f"\nInitializing local model...")
        print(f"- Number of GPUs: {num_gpus}")
        print(f"- Model path: {args.model_path}")
        
        llm = LLM(
            model=args.model_path, 
            tensor_parallel_size=num_gpus,
            gpu_memory_utilization=args.gpu_memory_rate, 
            trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        
        model_short_name = args.model_path.split('/')[-1].lower()
        model_info = (llm, tokenizer, model_short_name)
        llm_or_client = llm
        
    else:  # api
        # 检查OpenAI是否可用
        if not OPENAI_AVAILABLE:
            print("\nERROR: openai package not installed. Cannot use API mode.")
            print("Please install openai: pip install openai")
            return
            
        # 初始化API客户端
        api_config = API_CONFIG[args.api_key_name]
        client = OpenAI(
            api_key=api_config["api_key"],
            base_url=api_config["base_url"]
        )
        
        model_name = MODEL_MAP[args.model_name]
        model_short_name = args.model_name
        model_info = (client, model_name)
        llm_or_client = client
        
        print(f"\nUsing API model: {model_name}")
        print(f"API endpoint: {api_config['base_url']}")
        print(f"Max workers: 128")
    
    print(f"Retrieval port: {args.port}")
    
    # 加载数据
    print(f"\nLoading data from: {args.src_file}")
    data_ori = []
    try:
        with open(args.src_file, "r", encoding='utf-8') as f:
            for i, line in enumerate(f):
                if args.start_sample <= i < args.end_sample:
                    try:
                        obj_ori = json.loads(line.strip())
                        data_ori.append(obj_ori)
                    except json.JSONDecodeError as e:
                        print(f"Warning: Failed to parse line {i+1}: {e}")
                if i >= args.end_sample - 1:
                    break
    except Exception as e:
        print(f"ERROR: Failed to read file: {e}")
        return
    
    # 根据subjects筛选数据
    if args.subjects:
        filtered_data = []
        for item in data_ori:
            cat = item.get("category") if item.get("category") else item.get("source")
            if cat in args.subjects:
                filtered_data.append(item)
        data_ori = filtered_data
        print(f"Filtered to {len(data_ori)} questions from specified subjects")
    
    # 应用subset_num限制（如果指定）
    if args.subset_num != -1 and args.subset_num > 0:
        # 按类别分组
        category_data = defaultdict(list)
        for item in data_ori:
            cat = item.get("category") if item.get("category") else item.get("source", "unknown")
            category_data[cat].append(item)
        
        # 每个类别取subset_num个样本
        data_ori = []
        for cat, items in category_data.items():
            data_ori.extend(items[:args.subset_num])
        
        print(f"Applied subset_num={args.subset_num}, total questions: {len(data_ori)}")
    
    # 应用max_samples限制
    if args.max_samples > 0:
        data_ori = data_ori[:args.max_samples]
    
    print(f"Successfully loaded {len(data_ori)} questions")
    
    if len(data_ori) == 0:
        print("\nERROR: No data loaded! Check your file and parameters.")
        return
    
    # 打印类别分布
    category_counts = defaultdict(int)
    for item in data_ori:
        cat = item.get("category") if item.get("category") else item.get("source", "unknown")
        category_counts[cat] += 1
    
    print("\nCategory distribution:")
    for cat, count in sorted(category_counts.items()):
        print(f"  {cat}: {count} questions")
    
    # 打印前几个问题作为确认
    print("\nSample questions:")
    for i, item in enumerate(data_ori[:3]):
        cat = item.get("category") if item.get("category") else item.get("source", "unknown")
        print(f"{i+1} [{cat}]: {item.get('question', 'NO QUESTION')[:80]}...")
        options = get_options_from_item(item)
        if options:
            print(f"   Options: {len(options)} choices")
    
    # 设置输出路径
    t = time.localtime()
    output_dir = f'/eval/outputs/mmlupro/{model_short_name}.rat_rag'
    
    result_json_name = f'{t.tm_mon}.{t.tm_mday},{t.tm_hour}:{t.tm_min}.jsonl'
    result_path = os.path.join(output_dir, result_json_name)
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\nOutput will be saved to: {result_path}")
    
    # 批量处理
    start_time = time.time()
    try:
        results = rat_batch(llm_or_client, model_info, data_ori, args)
        
        if not results:
            print("\nERROR: No results generated!")
            return
        
        # 评估结果
        subjects_to_eval = args.subjects if args.subjects else list(category_counts.keys())
        metrics, all_results, avg_acc = evaluate_merged_results(results, subjects_to_eval)
        
        # 保存结果
        print(f"\nSaving results to {result_path}")
        with open(result_path, "w", encoding='utf-8') as f:
            for result in results:
                # 只保存需要的字段
                save_result = {
                    "question": result["question"],
                    "options": result.get("options", []),
                    "answer": result["answer"],
                    "category": result.get("category", "unknown"),
                    "draft": result["draft"],
                    "final_answer": result["final_answer"],
                    "pred_ans": result["pred_ans"],
                    "is_correct": result["pred_ans"].upper() == result["answer"].upper() if result["pred_ans"] and result["answer"] else False,
                    "retrieval_history": result["retrieval_history"]
                }
                f.write(json.dumps(save_result, ensure_ascii=False) + "\n")
        
        # 统计信息
        elapsed_time = time.time() - start_time
        avg_paragraphs = sum(len(r['original_paragraphs']) for r in results) / len(results)
        
        # 统计有多少答案使用了\box{}格式
        answers_with_box = sum(1 for r in results if '\\boxed{' in r.get('final_answer', ''))
        
        print(f"\nStatistics:")
        print(f"- Total questions: {len(results)}")
        print(f"- Average paragraphs per answer: {avg_paragraphs:.2f}")
        print(f"- Answers with \\boxed{{}}: {answers_with_box}/{len(results)}")
        print(f"- Total retrievals: {sum(len(r['retrieval_history']) for r in results)}")
        print(f"- Processing time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")
        print(f"- Average time per question: {elapsed_time/len(results):.2f} seconds")
        
        # 保存评估结果
        metrics_path = os.path.join(output_dir, f"{t.tm_mon}.{t.tm_mday},{t.tm_hour}:{t.tm_min}.metrics.json")
        with open(metrics_path, "w") as f:
            save_metrics = {
                "average_accuracy": f"{avg_acc:.4f}",
                "per_subject_accuracy": {k: f"{v:.4f}" for k, v in metrics.items()},
                "total_questions": len(results),
                "processing_time_seconds": elapsed_time
            }
            json.dump(save_metrics, f, indent=2)
        
        print(f"\nMetrics saved to: {metrics_path}")
        
    except Exception as e:
        print(f"\nError during batch processing: {e}")
        import traceback
        traceback.print_exc()
        return
    
    print(f"\nProcessing complete. Results saved to: {result_path}")


if __name__ == "__main__":
    main()