import argparse
import json
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import torch
import os
import re
import copy
import requests
import time
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI
import httpx
from searcho1_prompts import get_webpage_to_reasonchain_instruction

# 支持的 API 模型列表
API_MODELS = ["4.1mini", "4omini"]

# API 配置信息
API_CONFIGS = {
    "77": {
        "api_key": os.getenv("OPENAI_API_KEY_77",""),
        "base_url": "https://api.key77qiqi.com/v1",
    },
    "xty": {
        "api_key": os.getenv("OPENAI_API_KEY_XTY",""),
        "base_url": "https://svip.xty.app/v1",
    }
}

# 模型名称映射
MODEL_MAP = {
    "4.1mini": "gpt-4.1-mini-2025-04-14",
    "4omini": "gpt-4o-mini-2024-07-18",
}

def get_default_topk(model_path):
    """根据模型类型返回默认的topk值"""
    if model_path in API_MODELS:
        return 10  # API 模型使用 topk=10
    else:
        return 5 
        
def get_default_max_tokens(model_path):
    if model_path in API_MODELS:
        return 1536
    else:
        return 1536
    
def parse_args():
    """解析命令行参数"""
    parser = argparse.ArgumentParser(description="使用 search-o1 方法评估模型在 MedQA 数据集上的表现")
    parser.add_argument("--start_sample", type=int, default=-1, help="开始样本索引")
    parser.add_argument("--end_sample", type=int, default=100000, help="结束样本索引")
    parser.add_argument("--src_file", type=str, default="/eval/dataset/medqa/medqa_test_v0.jsonl", help="输入数据集文件路径")
    parser.add_argument("--gpu_id", type=str, default="0,1,2,3", help="GPU ID 列表")
    parser.add_argument("--model_path", type=str, required=True, help="本地模型路径或 API 模型名称（4.1mini, 4omini）")
    parser.add_argument("--gpu_memory_rate", type=float, default=0.95, help="GPU 内存利用率")
    parser.add_argument("--port", type=str, default="5005", help="检索服务端口")
    parser.add_argument("--temp", type=float, default=0.3, help="生成温度")
    parser.add_argument("--top_p", type=float, default=0.5, help="Top-p 采样值")
    parser.add_argument("--top_k", type=int, default=20, help="Top-k 采样值")
    parser.add_argument("--topk", type=int, default=None, help="检索返回的文档数")
    parser.add_argument("--max_tokens", type=int, default=None, help="最大生成 token 数")
    parser.add_argument("--max_rounds", type=int, default=10, help="最大检索轮数")
    parser.add_argument("--key_source", type=str, default="77", help="API key 来源（77 或 xty）")
    return parser.parse_args()

def get_cot_prompt(question, options, tokenizer, args, is_api=False):
    """生成 CoT（Chain-of-Thought）提示词"""
    option_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
    option_str = "\n".join([f"{l}. {o}" for l, o in zip(option_letters[:len(options)], options)])
    
    sys_prompt = (
        "You are a reasoning assistant with the ability to perform web searches to help "
        "you answer the user's question accurately. You have special tools:\n\n"
        "- To perform a search: write <|begin_search_query|> your query here <|end_search_query|>.\n"
        "Then, the system will search and analyze relevant web pages, then provide you with helpful information in the format <|begin_search_result|> ...search results... <|end_search_result|>.\n\n"
        f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to {args.max_rounds}.\n\n"
        "Once you have all the information you need, continue your reasoning.\n\n"
        "Example:\n"
        "Question: \"Which class of drug is most appropriate for managing diabetic nephropathy in hypertensive patients?\"\n"
        "Options:\n"
        "A. ACE inhibitors\n"
        "B. Beta blockers\n"
        "C. Diuretics\n"
        "D. Calcium channel blockers\n\n"
        "Assistant thinking steps:\n"
        "- I should check current guidelines or studies about drugs used for diabetic nephropathy in hypertensive patients.\n\n"
        "Assistant:\n"
        "<|begin_search_query|>first-line treatment for diabetic nephropathy with hypertension<|end_search_query|>\n\n"
        "(System returns processed information from relevant medical sources)\n\n"
        "Assistant continues reasoning with the new information...\n\n"
        "Remember:\n"
        "- Use <|begin_search_query|> to request a web search and end with <|end_search_query|>.\n"
        "- When done searching, continue your reasoning.\n\n"
    )
    
    user_prompt = (
        'Please answer the following question. You should think step by step to solve it.\n\n'
        'Provide your final answer in the format: the correct answer is: A, B, C, D, etc.\n\n'
        f'Question:\n{question}\n\nOptions:\n{option_str}\n\n'
    )
    
    if is_api:
        return sys_prompt + user_prompt
    
    # 本地模型模式下使用 tokenizer 格式化
    messages_chat = [{"role": "user", "content": sys_prompt + user_prompt}]
    prompt = tokenizer.apply_chat_template(messages_chat, tokenize=False, add_generation_prompt=True)
    return prompt

# 搜索格式定义（所有模型共用）
SEARCH_FORMATS = {
    "default": {
        "begin_token": "<|begin_search_query|>",
        "end_token": "<|end_search_query|>",
        "doc_begin": "<|begin_search_result|>",
        "doc_end": "<|end_search_result|>",
    }
}

def match_search_format(generated_text, stop_reason):
    """检测生成文本是否包含搜索格式"""
    format_info = SEARCH_FORMATS["default"]
    begin, end = format_info["begin_token"], format_info["end_token"]
    if begin in generated_text and stop_reason == end:
        return begin, end
    return None

def extract_answer_math(s):
    """提取生成文本中的最终答案"""
    patterns = [r'the correct answer is:?\s*([A-Ja-j])']
    matches = []
    for pattern in patterns:
        matches.extend(re.findall(pattern, s, re.DOTALL | re.IGNORECASE))
    return matches[-1].strip().upper() if matches else ''

def initialize_openai_client(key_source: str) -> OpenAI:
    """初始化 OpenAI 客户端"""
    if key_source not in API_CONFIGS:
        raise ValueError(f"无效的 key_source: {key_source}. 必须是 {list(API_CONFIGS.keys())} 之一")
    config = API_CONFIGS[key_source]
    return OpenAI(
        base_url=config["base_url"],
        api_key=config["api_key"],
        http_client=httpx.Client(base_url=config["base_url"], follow_redirects=True),
    )

def generate_local(llm, prompts, sampling_params):
    """本地模型生成"""
    outputs = llm.generate(prompts, sampling_params)
    return outputs

def get_stop_reason(text, stop_tokens):
    """检测文本的停止原因"""
    # 首先检查原有逻辑：文本是否以停止标记结尾
    for token in stop_tokens:
        if text.endswith(token):
            return token
    
    # 检查搜索查询标记是否成对出现
    begin_count = text.count("<|begin_search_query|>")
    end_count = text.count("<|end_search_query|>")
    
    # 如果 begin 数量 = end 数量 + 1，说明有未闭合的查询
    if begin_count == end_count + 1:
        # 返回 end_search_query 作为停止原因
        return "<|end_search_query|>"
    
    return ''

def get_doc_insertion_text(analysis):
    """构建文档插入文本，包含停止词和搜索结果"""
    format_info = SEARCH_FORMATS["default"]
    return (
        f"{format_info['end_token']}\n\n"  # 添加 <|end_search_query|>
        f"{format_info['doc_begin']}\n{analysis}\n{format_info['doc_end']}\n\n"
    )

def generate_api_single(client, model, prompt, max_tokens, temperature, top_p, stop=None, continuation_mode=False):
    """单个 API 调用生成（改进的续传实现）"""
    try:
        if continuation_mode:
            # 续传模式：使用更清晰的续写指令
            messages = [
                {
                    "role": "system", 
                    "content": "You are completing a search-o1 reasoning chain. Continue exactly from where the text ends, following the established pattern. Do not add any meta-commentary or repeat previous content."
                },
                {
                    "role": "user", 
                    "content": f"Continue writing from where this text ends:\n\n{prompt}\n\nContinue:"
                }
            ]
        else:
            # 普通模式
            messages = [{"role": "user", "content": prompt}]
        
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            stop=stop
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"API 调用失败: {e}")
        return ""

def generate_api(client, model, prompts, max_tokens, temperature, top_p, stop, 
                 continuation_modes=None):
    """
    API 模型生成（改进的续传支持）
    
    Args:
        continuation_modes: 布尔值列表，指示每个prompt是否需要续传模式
    """
    if continuation_modes is None:
        continuation_modes = [False] * len(prompts)
    
    def generate_single_wrapper(prompt, continuation_mode):
        return generate_api_single(client, model, prompt, max_tokens, temperature, top_p, stop, continuation_mode)
    
    with ThreadPoolExecutor(max_workers=128) as executor:
        futures = [executor.submit(generate_single_wrapper, prompt, cont_mode) 
                  for prompt, cont_mode in zip(prompts, continuation_modes)]
        results = [future.result() for future in futures]
    return results

def should_use_continuation_mode(prompt, gen_text_store):
    """判断是否应该使用续传模式"""
    # 如果已经有生成的内容，并且没有以搜索结果或最终答案结束，则需要续传
    if gen_text_store:
        last_content = gen_text_store.strip()
        # 检查是否在推理过程中断
        if not (last_content.endswith("<|end_search_result|>") or 
                "the correct answer is:" in last_content[-100:] or  # 检查最后部分是否有答案
                last_content.endswith("<|end_search_query|>")):
            return True
    return False

def main():
    print("=Begin="*10)
    args = parse_args()
    if args.max_tokens is None:
        args.max_tokens = get_default_max_tokens(args.model_path)
    if args.topk is None:
        args.topk = get_default_topk(args.model_path)
        print(f"根据模型类型自动设置 topk = {args.topk}")
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    # 根据 model_path 确定模型短名称，用于文件命名
    model_short_name = args.model_path.split('/')[-1].lower() if args.model_path not in API_MODELS else args.model_path

    # 加载数据
    data = []
    with open(args.src_file, "r") as f:
        for i, line in enumerate(f):
            if args.start_sample <= i < args.end_sample:
                item = json.loads(line)
                # 处理答案字段的兼容性
                if "answer_idx" in item:
                    item["answer"] = item["answer_idx"]
                data.append(item)
            if i >= args.end_sample - 1:
                break
    print("数据总数: ", len(data))

    # 根据 model_path 初始化模型和 tokenizer
    if args.model_path in API_MODELS:
        tokenizer = None
        client = initialize_openai_client(args.key_source)
        model = MODEL_MAP[args.model_path]
        llm = None
        is_api = True
        print(f"使用 API 模式评估模型: {model}")
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
        llm = LLM(
            model=args.model_path,
            tensor_parallel_size=torch.cuda.device_count(),
            gpu_memory_utilization=args.gpu_memory_rate,
            trust_remote_code=True
        )
        is_api = False

    # 准备初始数据
    continued_answer = []
    for item in data:
        question = item["question"]
        options = item["options"]
        answer = item["answer"]
        prompt = get_cot_prompt(question, options, tokenizer, args, is_api)
        continued_answer.append({
            "initial_prompt": prompt,  # 保存初始prompt
            "chat_prompt": prompt,
            "question": question,
            "options": options,
            "answer": answer,
            "gen_text_store": "",
            "round_count": 0  # 添加轮次计数
        })

    # 定义停止 token
    stop_tokens = ["<|end_search_query|>"]

    # 配置本地模型的采样参数
    sampling_params = SamplingParams(
        temperature=args.temp,
        top_p=args.top_p,
        top_k=args.top_k,
        max_tokens=args.max_tokens,
        stop=stop_tokens
    )

    # 多轮推理和检索
    finished_all_list = []
    for k in range(args.max_rounds):
        prompts = []
        continuation_modes = []
        
        # 准备prompts并判断是否需要续传模式
        for item in continued_answer:
            if is_api:
                # API模型：判断是否需要续传
                if should_use_continuation_mode(item["chat_prompt"], item["gen_text_store"]):
                    # 续传模式：使用完整的生成历史
                    prompts.append(item["initial_prompt"] + item["gen_text_store"])
                    continuation_modes.append(True)
                else:
                    # 正常模式
                    prompts.append(item["chat_prompt"])
                    continuation_modes.append(False)
            else:
                # 本地模型保持原有逻辑
                prompts.append(item["chat_prompt"])
        
        if is_api:
            # API 模型并行生成（使用改进的续传）
            generated_texts = generate_api(client, model, prompts, args.max_tokens, args.temp, args.top_p, stop_tokens, continuation_modes)
            outputs = [{'text': text, 'stop_reason': get_stop_reason(text, stop_tokens)} for text in generated_texts]
        else:
            # 本地模型批量生成
            outputs = generate_local(llm, prompts, sampling_params)
            outputs = [{'text': output.outputs[0].text, 'stop_reason': output.outputs[0].stop_reason} for output in outputs]

        finished_texts = []
        continued_texts = []
        query_list = []
        prev_reasonings = []

        for i, output in enumerate(outputs):
            prompt = continued_answer[i]["chat_prompt"]
            answer = continued_answer[i]["answer"]
            question = continued_answer[i]["question"]
            options = continued_answer[i]["options"]
            gen_text_store = continued_answer[i]["gen_text_store"]
            initial_prompt = continued_answer[i]["initial_prompt"]
            round_count = continued_answer[i]["round_count"]
            
            generated_text = output['text']
            stop_reason = output['stop_reason']

            # 更新轮次计数
            round_count += 1

            if k == args.max_rounds - 1:  # 达到最大轮数
                # 尝试提取答案
                pred_ans = extract_answer_math(gen_text_store + generated_text)
                if not pred_ans:
                    pred_ans = "I don't know."
                    
                original_data = {
                    "question": question,
                    "options": options,
                    "answer": answer,
                    "generated_text": gen_text_store + generated_text,  # 完整的推理链
                    "stop_reason_final": "max_rounds_reached",
                    "pred_ans": pred_ans,
                    "round_count": round_count
                }
                finished_texts.append(original_data)
                continue
     
            matched = match_search_format(generated_text, stop_reason)
            current_full_text = gen_text_store + generated_text  # 当前完整文本
            
            if extract_answer_math(current_full_text) != '':
                # 提取到最终答案
                original_data = {
                    "question": question,
                    "options": options,
                    "answer": answer,
                    "pred_ans": extract_answer_math(current_full_text),
                    "stop_reason_final": "finished",
                    "generated_text": current_full_text,  # 保存完整的推理链
                    "round_count": round_count
                }
                finished_texts.append(original_data)
            elif matched:
                # 需要检索
                begin_token, end_token = matched
                query = generated_text.split(begin_token)[-1].split(end_token)[0].strip()

                if query:
                    query_list.append(query)
                    prev_reasoning = current_full_text.split(begin_token)[0]
                    prev_reasonings.append(prev_reasoning)
                    
                    original_data = {
                        "initial_prompt": initial_prompt,  # 保持初始prompt
                        "chat_prompt": prompt + generated_text.strip(),  # 更新chat_prompt
                        "options": options,
                        "answer": answer,
                        "question": question,
                        "stop_reason": stop_reason,
                        "gen_text_store": current_full_text.strip(),  # 累积完整文本
                        "round_count": round_count
                    }
                    continued_texts.append(original_data)
                else:
                    original_data = {
                        "question": question,
                        "options": options,
                        "answer": answer,
                        "generated_text": current_full_text,  # 完整的推理链
                        "stop_reason_final": "empty_query",
                        "pred_ans": "I don't know.",
                        "round_count": round_count
                    }
                    finished_texts.append(original_data)
            else:
                # 未完成且无需检索 - 可能是中断的推理，需要续传
                if is_api and len(generated_text.strip()) > 10:  # 如果生成了一些内容但中断了
                    original_data = {
                        "initial_prompt": initial_prompt,
                        "chat_prompt": prompt + generated_text.strip(),  # 更新为包含新生成内容
                        "options": options,
                        "answer": answer,
                        "question": question,
                        "gen_text_store": current_full_text.strip(),
                        "round_count": round_count
                    }
                    continued_texts.append(original_data)
                else:
                    # 真的完成了或生成内容太少
                    pred_ans = extract_answer_math(current_full_text)
                    if not pred_ans:
                        pred_ans = "I don't know."
                        
                    original_data = {
                        "question": question,
                        "options": options,
                        "answer": answer,
                        "generated_text": current_full_text,  # 完整的推理链
                        "stop_reason_final": "incomplete",
                        "pred_ans": pred_ans,
                        "round_count": round_count
                    }
                    finished_texts.append(original_data)

        print("=="*80)
      
        if query_list:
            url_wiki = f"http://localhost:{args.port}/queries"
            queries = query_list

            data_req = {"queries": queries, "k": args.topk}
            status_code = 0
            while status_code != 200:
                try:
                    response = requests.post(url_wiki, json=data_req, timeout=10)
                    status_code = response.status_code
                    if status_code == 200:
                        result = response.json()
                      
                        retrieved_queries = result["queries"]
                        answers = result["answers"]
                        analysis_prompts = []
                        
                        for query, docs, prev_reasoning in zip(retrieved_queries, answers, prev_reasonings):
                            formatted_docs = "\n".join([f"文档 {j+1}: {doc}" for j, doc in enumerate(docs)])
                            analysis_prompt = get_webpage_to_reasonchain_instruction(
                                prev_reasoning=prev_reasoning,
                                search_query=query,
                                document=formatted_docs
                            )
                            if is_api:
                                analysis_prompts.append(analysis_prompt)
                            else:
                                messages = [{'role': 'user', 'content': analysis_prompt}]
                                tokenized = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                                analysis_prompts.append(tokenized)
                        
                        # 生成分析结果
                        if is_api:
                            analyses = generate_api(client, model, analysis_prompts, 512, 0.7, 0.8, stop=None)
                        else:
                            analysis_sampling_params = SamplingParams(
                                temperature=0.7,
                                top_p=0.8,
                                top_k=args.top_k,
                                max_tokens=512
                            )
                            analysis_outputs = llm.generate(analysis_prompts, analysis_sampling_params)
                            analyses = [output.outputs[0].text for output in analysis_outputs]
                        
                        # 更新 continued_texts
                        for i, analysis in enumerate(analyses):
                            continued_text_now = continued_texts[i]
                            # 使用新函数构建插入文本，会自动添加 <|end_search_query|>
                            doc_insertion = get_doc_insertion_text(analysis)
                            continued_text_now["chat_prompt"] += doc_insertion
                            continued_text_now["gen_text_store"] += doc_insertion
                            continued_texts[i] = continued_text_now
                except Exception as e:
                    print(f"检索错误: {e}")
                    time.sleep(1)
                    
            if status_code != 200:
                for i in range(len(continued_texts)):
                    finished_texts.append({
                        "question": continued_texts[i]["question"],
                        "options": continued_texts[i]["options"],
                        "answer": continued_texts[i]["answer"],
                        "generated_text": continued_texts[i]["gen_text_store"],  # 完整文本
                        "stop_reason_final": "retrieve_error",
                        "pred_ans": "I don't know.",
                        "round_count": continued_texts[i]["round_count"]
                    })
                continued_texts = []

        finished_all_list.extend(finished_texts)
        print("=="*80)
        print(f"轮次: {k}, 新完成: {len(finished_texts)}, 总完成: {len(finished_all_list)}, 继续: {len(continued_texts)}")
        print("开始写入轮次: ", k)
        print("=="*80)

        # 根据数据集和模型命名输出文件
        if k == 0:
            t = time.localtime()
            if 'medqa' in args.src_file and 'mmlupro' not in args.src_file:
                output_dir = f'/eval/outputs/medqa/searcho1/{model_short_name}_medqa.rag'
            elif 'mmlupro_med' in args.src_file:
                output_dir = f'/eval/outputs/medqa/searcho1/{model_short_name}_mmlupro-med.rag'
            else:
                # 默认输出目录
                output_dir = f'/eval/outputs/medqa/searcho1/{model_short_name}.rag'
            result_json_name = f'{t.tm_mon}.{t.tm_mday},{t.tm_hour}:{t.tm_min}.jsonl'
            result_path = os.path.join(output_dir, result_json_name)
            os.makedirs(output_dir, exist_ok=True)
            
        if len(finished_texts) > 0:
            with open(result_path, "a") as f:
                for text in finished_texts:
                    # 确保保存的是完整的推理链
                    output_data = {
                        "question": text["question"],
                        "options": text["options"],
                        "answer": text["answer"],
                        "generated_text": text["generated_text"],  # 这里已经是完整文本
                        "pred_ans": text["pred_ans"],
                        "stop_reason_final": text["stop_reason_final"],
                        "round_count": text.get("round_count", 0)
                    }
                    f.write(json.dumps(output_data, ensure_ascii=False) + "\n")
                    
        if len(continued_texts) != 0:
            continued_answer = copy.deepcopy(continued_texts)
        else:
            continued_answer = []
            break

    # 最终评估 - 计算准确率
    print("\n" + "="*80)
    print("开始计算准确率...")
    
    # 读取结果文件并计算准确率
    with open(result_path, "r") as f:
        results = [json.loads(line) for line in f]
    
    correct = 0
    total = len(results)
    
    for item in results:
        pred = item.get("pred_ans", "").strip().upper()
        answer = str(item.get("answer", "")).strip()
        
        # 处理不同的答案格式
        # 如果answer是数字索引（0,1,2,3），转换为字母（A,B,C,D）
        if answer.isdigit():
            answer_idx = int(answer)
            if 0 <= answer_idx < 26:
                answer = chr(ord('A') + answer_idx)
        else:
            answer = answer.upper()
        
        if pred == answer:
            correct += 1
    
    accuracy = correct / total if total > 0 else 0
    
    print(f"总题数: {total}")
    print(f"正确数: {correct}")
    print(f"准确率: {accuracy:.2%}")
    
    # 保存评估结果
    eval_result = {
        "total": total,
        "correct": correct,
        "accuracy": f"{accuracy:.2%}",
        "model": args.model_path,
        "dataset": args.src_file
    }
    
    eval_path = result_path.replace(".jsonl", "_eval.json")
    with open(eval_path, "w") as f:
        json.dump(eval_result, f, indent=2, ensure_ascii=False)
    
    print(f"评估结果已保存到: {eval_path}")

if __name__ == "__main__":
    main()