from argparse import ArgumentParser

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import numpy as np


def score(predictions, references):
    try:
        from latex2sympy2_extended import NormalizationConfig
        from math_verify import (ExprExtractionConfig,
                                    LatexExtractionConfig, parse, verify)
    except ImportError:
        raise ImportError('Failed to import required modules. Please '
                            'install the necessary packages: '
                            'pip install math_verify latex2sympy2_extended')

    correct = 0
    count = 0
    details = []
    for i, j in zip(predictions, references):
        count += 1
        j_with_env = f'${j}$'
        gold_parsed = parse(
            j_with_env,
            extraction_mode='first_match',
            extraction_config=[
                LatexExtractionConfig(),
                ExprExtractionConfig(),
            ],
        )

        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct
            # latex (no malformed operators)
            answer_parsed = parse(
                i,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed='all',
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode='first_match',
            )

            answer_correct = float(verify(answer_parsed, gold_parsed))
            correct += answer_correct
            detail = {
                'pred': str(answer_parsed),
                'answer': str(gold_parsed),
                'correct': True if answer_correct else False,
            }
            details.append(detail)
    return details


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--data", type=str, default="Qwen/Qwen3-32B")
    args = parser.parse_args()

    MODEL_PATH = "Qwen/Qwen3-32B"

    model = LLM(MODEL_PATH, tensor_parallel_size=2, gpu_memory_utilization=0.8)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    if args.data == 'math500':
        dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
        question_key = "problem"
    elif args.data == 'aime24':
        dataset = load_dataset("HuggingFaceH4/aime_2024", split="train")
        question_key = "problem"
    else:
        dataset = load_dataset("opencompass/LiveMathBench", "v202412_hard_en", split="test")
        question_key = "question"

    prompt_template = "{question}\n\nPlease reason step by step, and put your final answer within \\boxed{{}}"
    inputs = []
    answers = []
    questions = []
    for example in dataset:
        if example.get("level", 5) != 5:
            continue
        questions.append(example[question_key])
        messages = [
            {"role": "user", "content": prompt_template.format(question=questions[-1])}
        ]
        inputs.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True))
        answers.append(example["answer"])

    for max_token in [1024, 2048, 4096, 8192, 16384, 32768]:
        sampling_params = SamplingParams(
            max_tokens=max_token,  
            temperature=0.6,  
            top_p=0.95,
            min_p=0.0,
            top_k=40,
            n=16,
            skip_special_tokens=False,
        )

        outputs = model.generate(inputs, sampling_params, use_tqdm=True)
        responses = [None for _ in range(len(inputs)) for __ in range(sampling_params.n)]
        references = []
        unanswered_responses = []
        for i, request_output in enumerate(outputs):
            for j, completion_output in enumerate(request_output.outputs):
                # need force to generate answers
                if completion_output.finish_reason == "length":
                    unanswered_responses.append((i * sampling_params.n + j, inputs[i] + completion_output.text + "\n\n**Final Answer**\n\\boxed{", completion_output.text + "\n\n**Final Answer**\n\\boxed{"))
                    # responses[i * sampling_params.n + j] = completion_output.text + model.generate(inputs[i] + completion_output.text, answer_sampling_params, use_tqdm=False)[0].outputs[0].text
                    # "Considering the limited time by the user, I have to give the solution based on the thinking directly now.\n</think>.\n\n"
                else:
                    responses[i * sampling_params.n + j] = completion_output.text
                references.append(answers[i])

        answer_sampling_params = SamplingParams(
            max_tokens=max_token,  
            temperature=0.0,  
            top_p=0.95,
            min_p=0.0,
            top_k=40,
            n=1,
            seed=42,
            skip_special_tokens=False,
            stop=["}"]
        )
        unanswered_outputs = model.generate([x[1] for x in unanswered_responses], answer_sampling_params, use_tqdm=True)
        for x, request_output in zip(unanswered_responses, unanswered_outputs):
            responses[x[0]] = x[2] + request_output.outputs[0].text + "}"
            print(x[2] + request_output.outputs[0].text + "}")
            print("==" * 80)
            print("==" * 80)

        correct_lst = [item["correct"] for item in score(responses, references)]
        correct_lst = np.array(correct_lst)
        correct_lst = np.reshape(correct_lst, (-1, 16))
        print(f"{max_token} tokens: Avg Avg@16: {correct_lst.mean(0).mean(0)}")
        print(f"{max_token} tokens: Avg Pass@16: {(correct_lst.sum(1) > 0).mean(0)}")