import json

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig


def score(predictions, references):
        try:
            from latex2sympy2_extended import NormalizationConfig
            from math_verify import (ExprExtractionConfig,
                                        LatexExtractionConfig, parse, verify)
        except ImportError:
            raise ImportError('Failed to import required modules. Please '
                                'install the necessary packages: '
                                'pip install math_verify latex2sympy2_extended')

        correct = 0
        count = 0
        details = []
        for i, j in zip(predictions, references):
            count += 1
            j_with_env = f'${j}$'
            gold_parsed = parse(
                j_with_env,
                extraction_mode='first_match',
                extraction_config=[
                    LatexExtractionConfig(),
                    ExprExtractionConfig(),
                ],
            )

            if len(gold_parsed) != 0:
                # We require the answer to be provided in correct
                # latex (no malformed operators)
                answer_parsed = parse(
                    i,
                    extraction_config=[
                        LatexExtractionConfig(
                            normalization_config=NormalizationConfig(
                                nits=False,
                                malformed_operators=False,
                                basic_latex=True,
                                equations=True,
                                boxed='all',
                                units=True,
                            ),
                            # Ensures that boxed is tried first
                            boxed_match_priority=0,
                            try_extract_without_anchor=False,
                        )
                    ],
                    extraction_mode='first_match',
                )

                answer_correct = float(verify(answer_parsed, gold_parsed))
                correct += answer_correct
                detail = {
                    'pred': str(answer_parsed),
                    'answer': str(gold_parsed),
                    'correct': True if answer_correct else False,
                }
                details.append(detail)
        return details


if __name__ == "__main__":

    OUTPUT_PATH = "src/test_time_scaling/cache/qwq32b_aime24_no_thinking_rollouts.json" 
    MODEL_PATH = "Qwen/QwQ-32B"

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    dataset = load_dataset("HuggingFaceH4/aime_2024", split="train")
    prompt_template = "{question}\n\nPlease reason step by step, and put your final answer within \\boxed{{}}"
    inputs = []
    answers = []
    for example in dataset:
        messages = [
            {"role": "user", "content": prompt_template.format(question=example["problem"])}
        ]
        inputs.append(
            tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            ) + "Therefore, after all this, I believe the answer is"
        )
        answers.append(example["answer"])

    gen_cfg = GenerationConfig(
        n=1,
        max_new_tokens=32768,
        do_sample=True,
        temperature=0.6,
        top_p=0.95,
        min_p=0.0,
        top_k=40,
        skip_special_tokens=False
    )
    turbomind_cfg = TurbomindEngineConfig(tp=4, enable_prefix_caching=True)
    with pipeline(MODEL_PATH, turbomind_cfg) as pipe:
        outputs = pipe(
            prompts=None,
            input_ids=[
                tokenizer.encode(text, add_special_tokens=False)
                for text in inputs
                for _ in range(16)
            ],
            gen_config=gen_cfg,
            use_tqdm=True
        )
    predictions = []
    references = []
    for output, answer in zip(outputs, [a for a in answers for _ in range(16)]):
        predictions.append(output.text)
        references.append(answer)
    
    # sampling_params = SamplingParams(
    #     max_tokens=32768,  
    #     temperature=0.6,  
    #     top_p=0.95,
    #     min_p=0.0,
    #     top_k=40,
    #     n=16,
    #     skip_special_tokens=False,
    # )
    # model = LLM(MODEL_PATH, tensor_parallel_size=4, gpu_memory_utilization=0.8, enable_prefix_caching=True)
    # outputs = model.generate(inputs, sampling_params, use_tqdm=True)

    # predictions = []
    # references = []
    # for output, answer in zip(outputs, answers):
    #     prompt = output.prompt
    #     for r in output.outputs:
    #         predictions.append(r.text)
    #         references.append(answer)

    outputs = {}
    scores = score(predictions, references)
    for i, (score, prediction, reference) in enumerate(zip(scores, predictions, references)):
        example_idx = i // 16
        if example_idx not in outputs:
            outputs[example_idx] = {
                "question": dataset[example_idx]["problem"],
                "answer": dataset[example_idx]["answer"],
                "rollouts": [],
            }
        outputs[example_idx]["rollouts"].append({
            "prediction": prediction,
            "correct": score["correct"],
        })

    with open(OUTPUT_PATH, "w") as f:
        json.dump(outputs, f, indent=4)
            