from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
MODEL_PATH = ""
from datasets import load_dataset
import re
from mathruler.grader import extract_boxed_content, grade_answer
import json
import os
import shutil
from PIL import Image
import requests
from io import BytesIO



prompt_suffix = """\n\nYou first think through your reasoning process as an internal monologue, enclosed within <think> </think> tags. Then, provide your final answer enclosed within \\boxed{}. If you believe the answer can be further enhanced, generate <self-correction> </self-correction> tags enclosed with no content, and regenerate a new reasoning process and a new answer from scratch after that. The new response should first think through your reasoning process as an internal monologue, enclosed within <think> </think> tags. Then, provide your final answer enclosed within \\boxed{}. All reasoning, answer steps must be included without omission."""

def judge_answer(response: str, ground_truth: str) -> bool:

    response = re.sub(r"\s*(<|>|/)\s*", r"\1", response)

    if "</self-correction>" in response:
        second_part = response.split("</self-correction>")[-1]
        if re.search(r"\\boxed\{.*?\}", second_part):
            response = second_part 
    answer = extract_boxed_content(response)
    return grade_answer(answer, ground_truth)

def main():
    llm = LLM(
        model=MODEL_PATH,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.9,
        seed=1,
        max_model_len=8192 * 8,
        trust_remote_code=True
    )
    processor = AutoProcessor.from_pretrained(
        MODEL_PATH,
        max_pixels=1280*28*28,
        min_pixels=256*28*28
    )

    dataset_names = [

    ]

    sampling_params = SamplingParams(
        temperature=1.0,
        top_p=0.95,
        top_k=-1,
        # repetition_penalty=1.0,
        max_tokens=8192*2
    )

    for ds_name in dataset_names:
        print(f"\n===== Evaluating on {ds_name} =====")
        dataset = load_dataset(ds_name, split="test")

        batch_size = 1024
        correct_count = 0
        count = 0

        batch_prompts = []
        batch_correct_answers = []
        batch_items = []

        for idx, item in enumerate(dataset):
            image_path = None
            if item.get("images") and len(item["images"]) > 0:
                try:
                    image_path = save_image_locally(item["images"][0], idx)
                except Exception as e:
                    print(f"Warning: Failed to save image {idx}: {e}")
                    image_path = item["images"][0]
            
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": item["images"][0]},
                        {"type": "text", "text": item["problem"].replace("<image>", "") + prompt_suffix}
                    ]
                }
            ]

            text_prompt = processor.apply_chat_template(
                messages,
                tokenize=False,  # vLLM自己会tokenize
                add_generation_prompt=True
            )

            image_inputs, video_inputs = process_vision_info(messages)

            mm_data = {}
            if image_inputs is not None:
                mm_data['image'] = image_inputs
            if video_inputs is not None:
                mm_data['video'] = video_inputs

            batch_prompts.append({
                'prompt': text_prompt,
                'multi_modal_data': mm_data,
            })
            batch_correct_answers.append(item["answer"])
            batch_items.append({
                "image_path": image_path,
                "correct_answer": item["answer"]
            })

            if len(batch_prompts) == batch_size or idx == len(dataset) - 1:
                outputs = llm.generate(
                    batch_prompts,
                    sampling_params=sampling_params
                )

                for i, output in enumerate(outputs):
                    text = output.outputs[0].text.strip()
                    if "</self-correction>" in text:
                        if "boxed" in text.split("</self-correction>")[-1]:
                            second_text = text.split("</self-correction>")[-1]
                        else:
                            second_text = text.split("</self-correction>")[0]
                        first_text = text.split("</self-correction>")[0]
                    else:
                        first_text = text
                        second_text = text
                    correct_answer = batch_correct_answers[i]
                    
                    first_correct = judge_answer(first_text, correct_answer)
                    second_correct = judge_answer(second_text, correct_answer)

                    if second_correct:
                        correct_count += 1
                    count += 1

                    with open("octopus_8b.jsonl", "a") as f:
                        f.write(json.dumps({
                            "image_path": batch_items[i]["image_path"],
                            "problem":  batch_prompts[i]["prompt"],
                            "response": text,
                            "first_correct": first_correct,
                            "second_correct": second_correct,
                        }, ensure_ascii=False) + "\n")


                batch_prompts = []
                batch_correct_answers = []
                batch_items = []



if __name__ == '__main__':
    main()
