# export HF_ENDPOINT=https://hf-mirror.com
import os
import json
import re
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import torch
from transformers import AutoModelForCausalLM
from types import SimpleNamespace
from collections import defaultdict

from math_verify import parse, verify
from oat_math_grader import boxed_reward_fn as oat_evaluate

THOUGHT_DELIMITER_START = "<think>"
THOUGHT_DELIMITER_END = "</think>"

# ---- 选择题答案抽取 ----
# 按优先级从高到低排列的抽取模式
MULTICHOICE_PATTERNS = [
    # \boxed{A}
    r"\\boxed\{\s*\(?([A-Ea-e])\)?\s*\}",
    # **ANSWER: D** / ANSWER: D / answer: A
    r"(?i)\b(?:ANSWER)\s*[:：]\s*\*{0,2}\s*\(?([A-Ea-e])\)?\s*\*{0,2}",
    # **Final Answer:** A / **Final Answer:** **A)** / Final Answer: B
    r"(?i)(?:Final\s+Answer)\s*[:：]\s*\*{0,2}\s*\(?([A-Ea-e])\)?\s*\*{0,2}",
    # The answer is A / the answer is (B) / the correct answer is C
    r"(?i)(?:the\s+)?(?:correct\s+)?answer\s+is\s*[:：]?\s*\(?([A-Ea-e])\)?",
    # option A is correct / option B
    r"(?i)option\s+\(?([A-Ea-e])\)?\s+is\s+correct",
    # choose A / select B
    r"(?i)(?:choose|select)\s+\(?([A-Ea-e])\)?",
    # (A) 或 A) 或 A. 出现在最后几行，作为最终答案
    r"(?:^|\n)\s*\*{0,2}\(?([A-Ea-e])\)[\.\s\*]",
]


def extract_multichoice_answer(text: str):
    """
    从模型输出中抽取选择题答案（A-E）。
    按优先级依次尝试多种格式，返回大写字母或 None。
    """
    if not text:
        return None
    for pattern in MULTICHOICE_PATTERNS:
        matches = re.findall(pattern, text)
        if matches:
            # 取最后一个匹配（模型可能在推理过程中提到多个选项，最终答案通常在最后）
            return matches[-1].upper()
    return None


def labeling_multichoice(response: str, golden_answer: str) -> bool:
    """
    判断选择题回答是否正确。
    golden_answer 应为单个字母（如 "A"）。
    """
    extracted = extract_multichoice_answer(response)
    if extracted is None:
        return False
    return extracted == golden_answer.strip().upper()


def timeout(timeout_seconds: int = 10):
    if os.name == "posix":
        import signal

        def decorator(func):
            def handler(signum, frame):
                raise TimeoutError("verify timed out!")

            def wrapper(*args, **kwargs):
                old_handler = signal.getsignal(signal.SIGALRM)
                signal.signal(signal.SIGALRM, handler)
                signal.alarm(timeout_seconds)
                try:
                    return func(*args, **kwargs)
                finally:
                    signal.alarm(0)
                    signal.signal(signal.SIGALRM, old_handler)
            return wrapper
        return decorator


@timeout(timeout_seconds=10)
def labeling_responses(responses: list[str], golden_answer: str):
    predict_answers = list(map(parse, responses))
    golden_answers = list(
        map(parse, ["$" + golden_answer + "$"] * len(responses)))
    labels = list(map(verify, golden_answers, predict_answers))
    return labels


def make_conv_zero(question):
    question = question + \
        "\n\nPresent the answer in LaTex format: \\boxed{Your answer}"
    content = f"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {question}. Assistant:"
    return content


def make_conv_zero_code(question):
    question = question + "\n\nWrite Python code to solve the problem. Present the code in \n```python\nYour code\n```\nat the end."
    content = f"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {question}. Assistant:"
    return content


def make_conv_prime_sft(question, tokenizer):
    # for math problem
    content = question + \
        "\n\nPresent the answer in LaTex format: \\boxed{Your answer}"
    # for code problem
    # content = question + "\n\nWrite Python code to solve the problem. Present the code in \n```python\nYour code\n```\nat the end."
    msg = [
        {"role": "user", "content": content}
    ]
    chat = tokenizer.apply_chat_template(
        msg, tokenize=False, add_generation_prompt=True)
    return chat


def apply_qwen_math_template(question: str):
    return (
        "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>user\n"
        + question
        + "<|im_end|>\n<|im_start|>assistant\n"
    )


def qwen3_template(question: str, tokenizer):
    prompt = question + \
        "  Let's think step by step and output the final answer within \\boxed{}."
    msg = [
        {"role": "user", "content": prompt}
    ]
    prompt = tokenizer.apply_chat_template(
        msg, add_generation_prompt=True, tokenize=False)

    return prompt


def simplerl_template(question: str):
    return (
        '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n'
        + question
        + '\nPlease reason step by step, and put your final answer within\\boxed{{}}.<|im_end|>\n<|im_start|>assistant\n'
    )


def l1_template(question: str, budget, tokenizer):
    prompt = question + " Let's think step by step and output the final answer within \\boxed{}." + \
        f" Think for {budget} tokens."
    msg = [
        {"role": "user", "content": prompt}
    ]
    prompt = tokenizer.apply_chat_template(
        msg, add_generation_prompt=True, tokenize=False)

    return prompt


def l1_ours_template(question: str, budget, tokenizer):

    system_prompt = 'Your task is to follow a systematic, thorough reasoning process before providing the final solution. This involves analyzing, summarizing, exploring, reassessing, and refining your thought process through multiple iterations. Structure your response into two sections: Thought and Solution. In the Thought section, present your reasoning using the format: "<think>\n {thoughts} </think>\n". Each thought should include detailed analysis, brainstorming, verification, and refinement of ideas. After "</think>\n," in the Solution section, provide the final, logical, and accurate answer, clearly derived from the exploration in the Thought section. If applicable, include the answer in \\boxed{} for closed-form results like multiple choices or mathematical solutions.' + f" Think for {budget} tokens."

    msg = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    prompt = tokenizer.apply_chat_template(
        msg, add_generation_prompt=True, tokenize=False)

    return prompt


def luffy_template(question: str, tokenizer):

    system_prompt = 'Your task is to follow a systematic, thorough reasoning process before providing the final solution. This involves analyzing, summarizing, exploring, reassessing, and refining your thought process through multiple iterations. Structure your response into two sections: Thought and Solution. In the Thought section, present your reasoning using the format: "<think>\n {thoughts} </think>\n". Each thought should include detailed analysis, brainstorming, verification, and refinement of ideas. After "</think>\n," in the Solution section, provide the final, logical, and accurate answer, clearly derived from the exploration in the Thought section. If applicable, include the answer in \\boxed{} for closed-form results like multiple choices or mathematical solutions.'

    msg = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    prompt = tokenizer.apply_chat_template(
        msg, add_generation_prompt=True, tokenize=False)

    return prompt


def _collect_stop_token_ids(tokenizer):
    stop_token_ids = set()
    eos_token_id = tokenizer.eos_token_id
    if isinstance(eos_token_id, int):
        stop_token_ids.add(eos_token_id)
    elif isinstance(eos_token_id, (list, tuple)):
        stop_token_ids.update(int(tok_id)
                              for tok_id in eos_token_id if tok_id is not None)

    unk_token_id = getattr(tokenizer, "unk_token_id", None)
    for token in ("<|im_end|>", "<|eot_id|>", "</s>"):
        token_id = tokenizer.convert_tokens_to_ids(token)
        if isinstance(token_id, int) and token_id >= 0 and token_id != unk_token_id:
            stop_token_ids.add(token_id)
        try:
            token_ids = tokenizer(
                token, add_special_tokens=False)["input_ids"]
        except TypeError:
            token_ids = tokenizer(token)["input_ids"]
        if len(token_ids) == 1:
            stop_token_ids.add(int(token_ids[0]))

    return sorted(stop_token_ids)


def _wrap_single_output(prompt, text, token_ids):
    return SimpleNamespace(
        prompt=prompt,
        outputs=[SimpleNamespace(text=text, token_ids=token_ids)]
    )


def _infer_llm_max_model_len(llm, fallback_len):
    candidate_paths = [
        ("llm_engine", "model_config", "max_model_len"),
        ("engine", "model_config", "max_model_len"),
        ("llm_engine", "scheduler_config", "max_model_len"),
    ]
    for path in candidate_paths:
        cur_obj = llm
        ok = True
        for attr in path:
            if not hasattr(cur_obj, attr):
                ok = False
                break
            cur_obj = getattr(cur_obj, attr)
        if ok and isinstance(cur_obj, int) and cur_obj > 0:
            return cur_obj
    return fallback_len


def _budget_force_single_generation(
    llm,
    tokenizer,
    prompt,
    stop_token_ids,
    model_max_len,
    temperature,
    top_p,
    top_k,
    thinking_budget,
    num_ignore,
    ignore_str,
    final_answer_prefix,
    final_max_tokens,
):
    current_prompt = prompt
    output_text_parts = []
    output_token_ids = []
    remaining_budget = thinking_budget
    ignore_token_ids = tokenizer(
        ignore_str, add_special_tokens=False)["input_ids"] if ignore_str else []
    final_answer_prefix_token_ids = tokenizer(
        final_answer_prefix, add_special_tokens=False)["input_ids"] if final_answer_prefix else []

    for ignore_idx in range(num_ignore + 1):
        if remaining_budget <= 0:
            break

        thinking_sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=remaining_budget,
            min_tokens=0 if ignore_idx == 0 else 1,
            stop_token_ids=stop_token_ids,
            skip_special_tokens=False,
        )

        response = llm.generate(
            current_prompt,
            sampling_params=thinking_sampling_params,
        )
        response = response[0].outputs[0]

        current_prompt += response.text
        output_text_parts.append(response.text)
        output_token_ids.extend(response.token_ids)
        remaining_budget -= len(response.token_ids)

        if ignore_idx < num_ignore and remaining_budget > 0:
            current_prompt += ignore_str
            output_text_parts.append(ignore_str)
            output_token_ids.extend(ignore_token_ids)

    if final_answer_prefix:
        current_prompt += final_answer_prefix
        output_text_parts.append(final_answer_prefix)
        output_token_ids.extend(final_answer_prefix_token_ids)

    prompt_token_ids = tokenizer(
        current_prompt, add_special_tokens=False)["input_ids"]
    remaining_context = model_max_len - len(prompt_token_ids)
    if final_max_tokens is not None:
        remaining_context = min(remaining_context, final_max_tokens)
    if remaining_context <= 0:
        print(
            f"Warning: no context left for final answer. "
            f"prompt_tokens={len(prompt_token_ids)}, max_model_len={model_max_len}"
        )
        return "".join(output_text_parts), output_token_ids

    final_sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        max_tokens=remaining_context,
        min_tokens=0,
        stop_token_ids=stop_token_ids,
        skip_special_tokens=False,
    )

    final_response = llm.generate(
        current_prompt,
        sampling_params=final_sampling_params,
    )
    final_response = final_response[0].outputs[0]
    output_text_parts.append(final_response.text)
    output_token_ids.extend(final_response.token_ids)

    return "".join(output_text_parts), output_token_ids


def _build_sampling_params(
    temperature,
    top_p,
    top_k,
    max_tokens,
    stop_token_ids,
    min_tokens=0,
):
    return SamplingParams(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        max_tokens=max_tokens,
        min_tokens=min_tokens,
        stop_token_ids=stop_token_ids,
        skip_special_tokens=False,
    )


def _budget_force_batch_generation(
    llm,
    tokenizer,
    prompts,
    stop_token_ids,
    model_max_len,
    temperature,
    top_p,
    top_k,
    requested_thinking_budget,
    num_ignore,
    ignore_str,
    final_answer_prefix,
    final_max_tokens,
):
    ignore_token_ids = tokenizer(
        ignore_str, add_special_tokens=False)["input_ids"] if ignore_str else []
    final_answer_prefix_token_ids = tokenizer(
        final_answer_prefix, add_special_tokens=False)["input_ids"] if final_answer_prefix else []
    states = []
    for prompt in prompts:
        prompt_token_ids = tokenizer(
            prompt, add_special_tokens=False)["input_ids"]
        available_thinking_budget = model_max_len - len(prompt_token_ids) - 1
        thinking_budget = max(0, min(
            requested_thinking_budget, available_thinking_budget))
        if thinking_budget < requested_thinking_budget:
            print(
                f"Warning: requested thinking budget {requested_thinking_budget} "
                f"was clipped to {thinking_budget} because "
                f"prompt_tokens={len(prompt_token_ids)}, max_model_len={model_max_len}."
            )
        states.append({
            "prompt": prompt,
            "current_prompt": prompt,
            "output_text_parts": [],
            "output_token_ids": [],
            "remaining_budget": thinking_budget,
        })

    for ignore_idx in range(num_ignore + 1):
        budget_groups = defaultdict(list)
        for idx, state in enumerate(states):
            if state["remaining_budget"] > 0:
                budget_groups[state["remaining_budget"]].append(idx)

        if not budget_groups:
            break

        for current_budget, indices in sorted(budget_groups.items()):
            sampling_params = _build_sampling_params(
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                max_tokens=current_budget,
                stop_token_ids=stop_token_ids,
                min_tokens=0 if ignore_idx == 0 else 1,
            )
            batch_prompts = [states[idx]["current_prompt"] for idx in indices]
            responses = llm.generate(batch_prompts, sampling_params)

            for state_idx, response in zip(indices, responses):
                output = response.outputs[0]
                state = states[state_idx]
                state["current_prompt"] += output.text
                state["output_text_parts"].append(output.text)
                state["output_token_ids"].extend(output.token_ids)
                state["remaining_budget"] -= len(output.token_ids)

        if ignore_idx < num_ignore:
            for state in states:
                if state["remaining_budget"] > 0:
                    state["current_prompt"] += ignore_str
                    state["output_text_parts"].append(ignore_str)
                    state["output_token_ids"].extend(ignore_token_ids)

    final_groups = defaultdict(list)
    for idx, state in enumerate(states):
        if final_answer_prefix:
            state["current_prompt"] += final_answer_prefix
            state["output_text_parts"].append(final_answer_prefix)
            state["output_token_ids"].extend(final_answer_prefix_token_ids)

        prompt_token_ids = tokenizer(
            state["current_prompt"], add_special_tokens=False)["input_ids"]
        remaining_context = model_max_len - len(prompt_token_ids)
        if final_max_tokens is not None:
            remaining_context = min(remaining_context, final_max_tokens)
        if remaining_context <= 0:
            print(
                f"Warning: no context left for final answer. "
                f"prompt_tokens={len(prompt_token_ids)}, max_model_len={model_max_len}"
            )
            continue
        final_groups[remaining_context].append(idx)

    for current_budget, indices in sorted(final_groups.items()):
        sampling_params = _build_sampling_params(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=current_budget,
            stop_token_ids=stop_token_ids,
            min_tokens=0,
        )
        batch_prompts = [states[idx]["current_prompt"] for idx in indices]
        responses = llm.generate(batch_prompts, sampling_params)

        for state_idx, response in zip(indices, responses):
            output = response.outputs[0]
            state = states[state_idx]
            state["output_text_parts"].append(output.text)
            state["output_token_ids"].extend(output.token_ids)

    return [
        ("".join(state["output_text_parts"]), state["output_token_ids"])
        for state in states
    ]


def main(input_file, output_file, model_path, debug=False, remove_system=True, template='own', temperature=0.6, top_p=1.0, max_tokens=8192, n=1, force_generate=True, add_think_before_answer=False, add_oat_evaluate=False, any_true=False, skip_scoring=False, output_eval=None, no_split_think=False, budget=1024, enigne="vllm", enable_thinking=None, generation_mode='original', length_budget=None, num_ignore=1, ignore_str="Wait", final_answer_prefix="", final_max_tokens=None):

    df = pd.read_parquet(input_file)
    dec_output_path = output_file.replace('.jsonl', '') + '.decoded.jsonl'
    if force_generate or (not os.path.exists(dec_output_path)):
        # 数据处理
        messages = df['prompt'].tolist()

        if messages[0][0]['role'] == 'system':
            # assert remove_system is True
            if remove_system:
                print('remove system')
                assert messages[0][0]['role'] == 'system'
                messages = [message[1:] for message in messages]

            else:
                # assert remove_system is False
                print('not remove system')
        else:
            print('not remove system!!!')

        answers = df['reward_model'].tolist()
        # data_sources = df['data_source'].tolist()
        answers = [answer['ground_truth'] for answer in answers]

        # if debug:
        # answers = answers[:10]
        assert len(messages) == len(answers)

        print(messages[0])
        print(
            f"temperature: {temperature}, top_p: {top_p}, max_tokens: {max_tokens}, n: {n}")
        if enigne == 'hf':
            if generation_mode != 'original':
                raise ValueError(
                    "generation_mode != 'original' is only supported for vllm now.")
            outputs = generate_hf(messages, model_path, template=template, temperature=temperature,
                                  top_p=top_p, max_tokens=max_tokens, n=n, budget=budget)
        elif enigne == 'vllm':
            outputs = generate_vllm(messages, model_path, template=template, temperature=temperature,
                                    top_p=top_p, max_tokens=max_tokens, n=n, budget=budget, enable_thinking=enable_thinking, generation_mode=generation_mode, length_budget=length_budget, num_ignore=num_ignore, ignore_str=ignore_str, final_answer_prefix=final_answer_prefix, final_max_tokens=final_max_tokens)
        # rets = {}

        # save the outputs first
        data_sources = [df['data_source'].iloc[i]
                        for i in range(len(df)) for j in range(n)]
        with open(dec_output_path, 'w') as fo:
            for i, output in enumerate(outputs):
                prompt = output.prompt
                for j in range(n):
                    generated_text = output.outputs[j].text
                    item = {
                        'prompt': prompt,
                        'generated_text': generated_text,
                        'answer': answers[i],
                        'data_source': df['data_source'].iloc[i]
                    }
                    fo.write(json.dumps(item) + '\n')

        # format sort prompts, outputs, answers
        assert len(outputs[0].outputs) == n
        prompts = [out.prompt for out in outputs for j in range(n)]
        answers = [answers[i] for i in range(len(outputs)) for j in range(n)]
        lengths = [len(out.outputs[j].token_ids)
                   for out in outputs for j in range(n)]
        outputs = [out.outputs[j].text for out in outputs for j in range(n)]

    else:
        print('Found already decoded file, skip decoding...')
        jss = []
        with open(dec_output_path, 'r') as f:
            for line in f:
                jss.append(json.loads(line))

        outputs = [item['generated_text'] for item in jss]
        prompts = [item['prompt'] for item in jss]
        answers = [item['answer'] for item in jss]
        # data_source: 优先从 decoded jsonl 读取，兼容旧格式时从 parquet 回退
        if 'data_source' in jss[0]:
            data_sources = [item['data_source'] for item in jss]
        else:
            print('警告: decoded jsonl 中无 data_source 字段，从 parquet 回退读取')
            data_sources = [df['data_source'].iloc[i // n]
                            for i in range(len(jss))]
        print('Recomputing output lengths from decoded outputs...')
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        lengths = [len(tokenizer(text, add_special_tokens=False)["input_ids"])
                   for text in outputs]

    from collections import defaultdict
    rets = defaultdict(list)
    save_data = []
    avg = 0
    from tqdm import tqdm

    print('Scoring...')
    if skip_scoring:
        return

    # for i, output in tqdm(enumerate(outputs)):
    diff_cnt = 0
    for i in tqdm(range(len(outputs)), total=len(outputs)):
        # print(i)
        generated_text = outputs[i]
        prompt = prompts[i]
        answer = answers[i]
        think_format = False
        if prompt.endswith(THOUGHT_DELIMITER_START+'\n') or add_think_before_answer is True:
            generated_text = THOUGHT_DELIMITER_START + '\n' + generated_text
            think_format = True
        if no_split_think:
            think_format = False
        labels = None
        if think_format:
            try:
                generated_text = generated_text.split(THOUGHT_DELIMITER_END)[1]
            except Exception as e:
                labels = [False]

        if labels is None:
            try:
                labels = labeling_responses([generated_text,], answer)
            except Exception as e:
                labels = [False]

        # 如果 math_verify 判定为 False，再尝试选择题匹配（兼容 arc_c / gpqa / mmlu 等）
        if labels[0] is False or labels[0] == False:
            mc_correct = labeling_multichoice(generated_text, answer)
            if mc_correct:
                labels = [True]

        if add_oat_evaluate:
            new_label = oat_evaluate(generated_text, answer, fast=False)
            new_label = new_label[1] == 1.0
            if any_true is True:
                if labels[0] is False and new_label is True:
                    diff_cnt += 1
                    # breakpoint()
                labels = [labels[0] or new_label]
            else:
                labels = [new_label]

        rets[data_sources[i]].append(labels[0])

        save_data.append({
            'prompt': prompt,
            'generated_text': generated_text,
            'answer': answer,
            'correctness': labels[0],
            'data_source': data_sources[i]
        })
        if labels[0]:
            avg += 1

    print('accuracy: ', avg / len(outputs))
    print('diff_cnt: ', diff_cnt)

    accs = []
    for data_source, labels in rets.items():
        # print(data_source, len(labels))
        acc = np.array(labels).mean()
        print(f'{data_source}: {acc}')
        accs.append(acc)

    print('avg acc: ', np.array(accs).mean())
    if lengths is not None:
        print("average length: ", np.array(lengths).mean())
        print('max length:', max(np.array(lengths)))

    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in save_data:
                f.write(json.dumps(item) + '\n')
    except Exception as e:
        print(f'Error: {e}')
        print(f'Output file: {output_file}')


def generate_hf(messages, model_path, template='own', temperature=0.6, top_p=0.95, max_tokens=8192, n=1, budget=1024, batch_size=32):

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.padding_side = "left"
    # You may want to use bfloat16 and/or move to GPU here
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    gen_prompts = []
    for i in range(len(messages)):
        cur_message = messages[i]
        prompt = tokenizer.apply_chat_template(
            cur_message,
            tokenize=False,
            add_generation_prompt=True,
            thinking_budget=budget,
            # return_tensors="pt"
        )
        gen_prompts.append(prompt)
        if i == 0:
            print('Example input: ', prompt)

    input_ids = tokenizer(gen_prompts, return_tensors="pt",
                          padding=True).input_ids
    print(input_ids.shape)

    outputs = []
    for i in range(0, len(input_ids), batch_size):
        batch_messages = input_ids[i: i+batch_size]
        outputs = model.generate(batch_messages.to(model.device), max_new_tokens=max_tokens,
                                 temperature=temperature,
                                 top_p=top_p,
                                 num_return_sequences=n
                                 )
        print(outputs)
        output_text = tokenizer.batch_decode(outputs)
        print(outputs)
        outputs.append(output_text)

    return outputs


def generate_vllm(messages, model_path, template='own', temperature=0.6, top_p=0.95, max_tokens=8192, n=1, budget=1024, top_k=-1, type='original', enable_thinking=None, generation_mode=None, length_budget=None, num_ignore=1, ignore_str="Wait", final_answer_prefix="", final_max_tokens=None):
    # vllm模型加载
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # max_tokens is for the maximum length for generation.
    if generation_mode is None:
        generation_mode = type
    if generation_mode not in ('original', 'budget_forcing'):
        raise ValueError(
            f"Invalid generation_mode: {generation_mode}. Expected 'original' or 'budget_forcing'.")

    if generation_mode == 'original':
        sampling_params = SamplingParams(
            temperature=temperature, top_p=top_p, max_tokens=max_tokens, n=n, top_k=top_k)
    else:
        if n != 1:
            raise ValueError("budget_forcing mode only supports n=1 for now.")
        stop_token_ids = _collect_stop_token_ids(tokenizer)
        if not stop_token_ids:
            raise ValueError(
                "Failed to infer stop_token_ids for budget_forcing mode.")

    print(torch.cuda.device_count())
    llm = LLM(model=model_path, tensor_parallel_size=torch.cuda.device_count(
    ), gpu_memory_utilization=0.85)  # 让 vLLM 按模型真实配置推断上下文长度
    llm_max_model_len = _infer_llm_max_model_len(
        llm, max(max_tokens, length_budget or 0, budget))

    gen_prompts = []

    for i in range(len(messages)):
        cur_message = messages[i]
        if template == 'own':
            if enable_thinking != None:
                gen_prompt = tokenizer.apply_chat_template(
                    cur_message,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=enable_thinking
                )
            else:
                gen_prompt = tokenizer.apply_chat_template(
                    cur_message,
                    tokenize=False,
                    add_generation_prompt=True
                )
        elif template == 'seed':
            gen_prompt = tokenizer.apply_chat_template(
                cur_message,
                tokenize=False,
                add_generation_prompt=True,
                thinking_budget=budget
            )
        elif template == 'luffy':
            gen_prompt = luffy_template(cur_message[-1]['content'], tokenizer)
        elif template == 'simplerl':
            gen_prompt = simplerl_template(cur_message[-1]['content'])
        elif template == 'qwen':
            gen_prompt = apply_qwen_math_template(cur_message[-1]['content'])
        elif template == 'qwen3':
            gen_prompt = qwen3_template(cur_message[-1]['content'], tokenizer)
        elif template == 'prime':
            gen_prompt = make_conv_zero(cur_message[-1]['content'])
        elif template == 'prime_sft':
            gen_prompt = make_conv_prime_sft(
                cur_message[-1]['content'], tokenizer)
        elif template == 'prime_code':
            gen_prompt = make_conv_zero_code(cur_message[-1]['content'])
        elif template == 'l1':
            gen_prompt = l1_template(
                cur_message[-1]['content'], budget=budget, tokenizer=tokenizer)
        elif template == 'l1-ours':
            gen_prompt = l1_ours_template(
                cur_message[-1]['content'], budget=budget, tokenizer=tokenizer)
        elif template == 'no':
            gen_prompt = cur_message[-1]['content']
        else:
            raise ValueError(f'Invalid template: {template}')
        gen_prompts.append(gen_prompt)
        if i == 0:
            print('Example input: ', gen_prompt)
    if generation_mode == 'original':
        print("original!!", sampling_params)

        outputs = llm.generate(gen_prompts, sampling_params)
    else:
        # 按固定预算先生成一段，再追加最终答案，便于比较相同预算下的结果。
        outputs = []
        requested_thinking_budget = length_budget if length_budget is not None else max_tokens
        print(
            f"budget_forcing!! requested_thinking_budget={requested_thinking_budget}, "
            f"num_ignore={num_ignore}, ignore_str={ignore_str!r}, "
            f"temperature={temperature}, max_model_len={llm_max_model_len}, "
            f"final_max_tokens={final_max_tokens}"
        )
        batch_outputs = _budget_force_batch_generation(
            llm=llm,
            tokenizer=tokenizer,
            prompts=gen_prompts,
            stop_token_ids=stop_token_ids,
            model_max_len=llm_max_model_len,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            requested_thinking_budget=requested_thinking_budget,
            num_ignore=num_ignore,
            ignore_str=ignore_str,
            final_answer_prefix=final_answer_prefix,
            final_max_tokens=final_max_tokens,
        )
        for prompt, (output_text, output_token_ids) in zip(gen_prompts, batch_outputs):
            print("With budget forcing:")
            print(output_text)
            outputs.append(_wrap_single_output(
                prompt, output_text, output_token_ids))

    return outputs


if __name__ == "__main__":
    import fire
    fire.Fire(main)
