from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import json
import re
import os
from tqdm import tqdm

if __name__ == "__main__":

    OUTPUT_PATH = "qwq32b_aime24_rollouts.json" 
    MODEL_PATH = "Qwen/QwQ-32B"

    model = LLM(MODEL_PATH, tensor_parallel_size=4, gpu_memory_utilization=0.8)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    dataset = load_dataset("HuggingFaceH4/aime_2024", split="train")
    prompt_template = "{question}\n\nPlease reason step by step, and put your final answer within \\boxed{{}}"
    inputs = []
    answers = []
    for example in dataset:
        messages = [
            {"role": "user", "content": prompt_template.format(question=example["problem"])}
        ]
        inputs.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))
        answers.append(example["answer"])

    sampling_params = SamplingParams(
        max_tokens=32768,  
        temperature=0.6,  
        top_p=0.95,
        min_p=0.0,
        top_k=40,
        n=16,
        skip_special_tokens=False,
    )
    outputs = model.generate(inputs, sampling_params, use_tqdm=True)

    def score(predictions, references):
        try:
            from latex2sympy2_extended import NormalizationConfig
            from math_verify import (ExprExtractionConfig,
                                        LatexExtractionConfig, parse, verify)
        except ImportError:
            raise ImportError('Failed to import required modules. Please '
                                'install the necessary packages: '
                                'pip install math_verify latex2sympy2_extended')

        correct = 0
        count = 0
        details = []
        for i, j in zip(predictions, references):
            count += 1
            j_with_env = f'${j}$'
            gold_parsed = parse(
                j_with_env,
                extraction_mode='first_match',
                extraction_config=[
                    LatexExtractionConfig(),
                    ExprExtractionConfig(),
                ],
            )

            if len(gold_parsed) != 0:
                # We require the answer to be provided in correct
                # latex (no malformed operators)
                answer_parsed = parse(
                    i,
                    extraction_config=[
                        LatexExtractionConfig(
                            normalization_config=NormalizationConfig(
                                nits=False,
                                malformed_operators=False,
                                basic_latex=True,
                                equations=True,
                                boxed='all',
                                units=True,
                            ),
                            # Ensures that boxed is tried first
                            boxed_match_priority=0,
                            try_extract_without_anchor=False,
                        )
                    ],
                    extraction_mode='first_match',
                )

                answer_correct = float(verify(answer_parsed, gold_parsed))
                correct += answer_correct
                detail = {
                    'pred': str(answer_parsed),
                    'answer': str(gold_parsed),
                    'correct': True if answer_correct else False,
                }
                details.append(detail)
        return details

    predictions = []
    references = []
    for output, answer in zip(outputs, answers):
        prompt = output.prompt
        for r in output.outputs:
            predictions.append(r.text)
            references.append(answer)

    with open("qwq32b_aime24_predictions.jsonl", "w") as f:
        json.dump(predictions, f, indent=4)

    outputs = {}
    scores = score(predictions, references)
    for i, (score, prediction, reference) in enumerate(zip(scores, predictions, references)):
        example_idx = i // 16
        if example_idx not in outputs:
            outputs[example_idx] = {
                "question": dataset[example_idx]["problem"],
                "answer": dataset[example_idx]["answer"],
                "rollouts": [],
            }
        outputs[example_idx]["rollouts"].append({
            "prediction": prediction,
            "correct": score["correct"],
        })

    with open(OUTPUT_PATH, "w") as f:
        json.dump(outputs, f, indent=4)
            