import os
import json
import random
from concurrent.futures import ProcessPoolExecutor

from tqdm import tqdm
from vllm import SamplingParams, LLM
from transformers import AutoTokenizer


random.seed(42)
model = "Qwen/Qwen3-32B"

def score(predictions, references):
        try:
            from latex2sympy2_extended import NormalizationConfig
            from math_verify import (ExprExtractionConfig,
                                        LatexExtractionConfig, parse, verify)
        except ImportError:
            raise ImportError('Failed to import required modules. Please '
                                'install the necessary packages: '
                                'pip install math_verify latex2sympy2_extended')

        correct = 0
        count = 0
        details = []
        for i, j in zip(predictions, references):
            count += 1
            j_with_env = f'${j}$'
            gold_parsed = parse(
                j_with_env,
                extraction_mode='first_match',
                extraction_config=[
                    LatexExtractionConfig(),
                    ExprExtractionConfig(),
                ],
            )

            if len(gold_parsed) != 0:
                # We require the answer to be provided in correct
                # latex (no malformed operators)
                answer_parsed = parse(
                    i,
                    extraction_config=[
                        LatexExtractionConfig(
                            normalization_config=NormalizationConfig(
                                nits=False,
                                malformed_operators=False,
                                basic_latex=True,
                                equations=True,
                                boxed='all',
                                units=True,
                            ),
                            # Ensures that boxed is tried first
                            boxed_match_priority=0,
                            try_extract_without_anchor=False,
                        )
                    ],
                    extraction_mode='first_match',
                )

                answer_correct = float(verify(answer_parsed, gold_parsed))
                correct += answer_correct
                detail = {
                    'pred': str(answer_parsed),
                    'answer': str(gold_parsed),
                    'correct': True if answer_correct else False,
                }
                details.append(detail)
        return details


if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained(model)
    model = LLM(model, trust_remote_code=True, tensor_parallel_size=2, gpu_memory_utilization=0.7)
    sampling_params = SamplingParams(n=1, max_tokens=20, stop=["}"], seed=42, temperature=0.0)

    prompt_template = "{question}\n\nPlease reason step by step, and put your final answer within \\boxed{{}}"
    with open(f"qwen3_32b/cache/aime24_summarized_trajectories/summarized_trajectories_x4.json", "r") as rfile:
        data = json.load(rfile)
    for k in tqdm(data):
        item = data[k]

        messages = [
            {"role": "user", "content": prompt_template.format(question=item["question"])},
        ]
        prompts = []
        for summarized_trajectory in item["summarized_trajectory"]:
            prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + summarized_trajectory + "\n\n**Final Answer**\n\\boxed{"
            prompts.append(prompt)

        request_outputs = model.generate(prompts, sampling_params)
        responses = [output.text for resquest_output in request_outputs for output in resquest_output.outputs]

        details = score(["\\boxed{" + respose + "}" for respose in responses], [item["answer"]] * len(responses))
        
        data[k].update({
            "responses": [
                {
                    "response": response,
                    "num_token": len(tokenizer(summarized_trajectory, add_special_tokens=False)["input_ids"]),
                    "correct": detail["correct"],
                } for response, detail, summarized_trajectory in zip(responses, details, item["summarized_trajectory"])
            ]
        })

    with open(f"qwen3_32b/cache/aime24_summarized_trajectories/summarized_trajectories_x4_infer.json", "w") as wfile:
        json.dump(data, wfile, indent=4, ensure_ascii=False)