import json

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset


def score(predictions, references):
    try:
        from latex2sympy2_extended import NormalizationConfig
        from math_verify import (ExprExtractionConfig,
                                    LatexExtractionConfig, parse, verify)
    except ImportError:
        raise ImportError('Failed to import required modules. Please '
                            'install the necessary packages: '
                            'pip install math_verify latex2sympy2_extended')

    correct = 0
    count = 0
    details = []
    for i, j in zip(predictions, references):
        count += 1
        j_with_env = f'${j}$'
        gold_parsed = parse(
            j_with_env,
            extraction_mode='first_match',
            extraction_config=[
                LatexExtractionConfig(),
                ExprExtractionConfig(),
            ],
        )

        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct
            # latex (no malformed operators)
            answer_parsed = parse(
                i,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed='all',
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode='first_match',
            )

            answer_correct = float(verify(answer_parsed, gold_parsed))
            correct += answer_correct
            detail = {
                'pred': str(answer_parsed),
                'answer': str(gold_parsed),
                'correct': True if answer_correct else False,
            }
            details.append(detail)
    return details


if __name__ == "__main__":

    OUTPUT_PATH = "qwen3_32b/cache/qwen3_32b_math500_rollouts.json" 
    MODEL_PATH = "Qwen/Qwen3-32B"

    model = LLM(MODEL_PATH, tensor_parallel_size=4, gpu_memory_utilization=0.8)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    # "HuggingFaceH4/aime_2024", train "math-ai/aime25", test opencompass/LiveMathBench, v202412_hard_en, test
    dataset = load_dataset("HuggingFaceH4/aime_2024", split="train")
    prompt_template = "{question}\n\nPlease reason step by step, and put your final answer within \\boxed{{}}"
    inputs = []
    answers = []
    questions = []
    for example in dataset:
        if example.get("level", 5) != 5:
            continue
        questions.append(example["problem"])
        messages = [
            {"role": "user", "content": prompt_template.format(question=questions[-1])}
        ]
        inputs.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True))
        answers.append(example["answer"])
    sampling_params = SamplingParams(
        max_tokens=32768,  
        temperature=0.6,  
        top_p=0.95,
        min_p=0.0,
        top_k=40,
        n=16,
        skip_special_tokens=False,
    )
    outputs = model.generate(inputs, sampling_params, use_tqdm=True)

    predictions = []
    references = []
    for output, answer in zip(outputs, answers):
        prompt = output.prompt
        for r in output.outputs:
            predictions.append(r.text)
            references.append(answer)

    outputs = {}
    scores = score(predictions, references)
    for i, (score, prediction, reference) in enumerate(zip(scores, predictions, references)):
        example_idx = i // 16
        if example_idx not in outputs:
            outputs[example_idx] = {
                "question": questions[example_idx],
                "answer": answers[example_idx],
                "rollouts": [],
            }
        outputs[example_idx]["rollouts"].append({
            "prediction": prediction,
            "correct": score["correct"],
        })

    with open(OUTPUT_PATH, "w") as f:
        json.dump(outputs, f, indent=4)
            