import re
from typing import Any
import random
import os
from openai import OpenAI
from typing import Optional


openai_api_key = "EMPTY"
api_base = "http://localhost:18901/v1"
model = "judge"

client = OpenAI(
    api_key=openai_api_key,
    base_url=api_base,
)


def get_prompt(_question, _answer, _pred):
    with open('/ossfs/workspace/EasyR1/examples/reward_function/verify_prompt.md', 'r', encoding='utf-8') as file:
        judge_system_prompt = file.read()
    judge_user_prompt = """
    [问题]:{question}
    [参考答案]:{answer}
    [模型回答]:{prediction}
    """

    full_prompt = judge_user_prompt.format(
            question=_question,
            answer=_answer,
            prediction=_pred
        )
    return judge_system_prompt, full_prompt


# Metadata
REWARD_NAME = "perceptual"
REWARD_TYPE = "batch"


# def check_format_match(response: str) -> bool:
#     pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
#     format_match = re.fullmatch(pattern, response)
#     return True if format_match else False
def check_format_match(response: str) -> bool:
    pattern = re.compile(r"<answer>.*?</answer>", re.DOTALL)
    # 使用 search() 方法来查找模式，而不是 fullmatch()，因为 fullmatch() 要求整个字符串匹配模式
    format_match = re.search(pattern, response)
    return True if format_match else False

def compute_score(reward_inputs: list[dict[str, Any]], format_weight: float = 0.2) -> list[dict[str, float]]:
    scores = []
    for reward_input in reward_inputs:
        format_match = check_format_match(reward_input["response"])
        content_match = re.search(r"<answer>(.*?)</answer>", reward_input["response"])
        answer_text = content_match.group(1).strip() if content_match else reward_input["response"].strip()
        question = reward_input["extra_info"]['question']
        if not answer_text:
            format_match = False

        if answer_text and len(answer_text) >= 300:
            format_match = False
            accuracy_score = 0.0
        else:
            format_score = 1.0 if format_match else 0.0
            if reward_input["data_source"] == "reason":
                accuracy_score = compute_score_math(answer_text, reward_input["ground_truth"], question)
            else:
                accuracy_score = compute_score_general(answer_text, reward_input["ground_truth"], question)
        scores.append(
            {
                "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
                "format": format_score,
                "accuracy": accuracy_score,
            }
        )

    return scores


def compute_score_general(predict_str: str, ground_truth: str, question_text) -> float:
    system_prompt, full_prompt = get_prompt(question_text, ground_truth, predict_str)

    chat_response = client.chat.completions.create(
        model="judge",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": full_prompt},
        ],
        seed = random.randint(0, 1000000),
        temperature=0.3,
        max_tokens=8192,
    )
    response = chat_response.choices[0].message.content.strip()
    _score = 0
    f_response = response
    if '<最终结果>' in f_response:
        f_response = f_response.split('<最终结果>')[-1].strip().split('<\最终结果>')[0].strip()
    if 'boxed' in f_response:
        f_response = f_response.split('boxed{')[-1].strip().split('}')[0].strip()
    if 'Yes' in f_response:
        _score = 1
    else:
        _score = 0
    if _score:
        acc_reward = 1.0
    else:
        acc_reward = 0.0
    print(f'DEBUG JUDGE {f_response=} {_score=}')
    return acc_reward

MATH_VERIFY_PROMPT = """# CONTEXT #
I am a teacher, and I have some high-level math problems. I am tasked with evaluating the correctness of a student's answer. 
Below, I am provided with a problem and a reference answer. Additionally, a student's answer is provided. My job is to assess whether the student's answer captures the same meaning as the reference answer, even when expressed with different wording or format.

# OBJECTIVE #
I need you to judge whether the student's answer is correct given the ground truth answer.

Your tasks include:
1. Identify Mathematical or Notational Equivalence: Pay special attention to any LaTeX expressions in both answers. Confirm that the mathematical relationships, variables, and operations conveyed are equivalent.

# TONE #
Professional, scientific.

# RESPONSE: MARKDOWN REPORT #
## Equivalence Judgement
[Whether the student's answer share the same meaning with the reference answer. (TRUE or FALSE)]

# ATTENTION #
 - The reference answer is ALWAYS correct. You should carefully judge whether the student gives the same answer as reference answer.
 - The Equivalence Judgement is only TRUE or FALSE. The answer is FALSE even if the student's final answer almost correct with a minor mistakes.
 - Don't give extra explanation.

**Question**:
{query}

**Reference Answer**
{gold_ans}

## Student Final Answer
{pred_ans}"""

def compute_score_math(query, ground_truth, model_answer):
    full_prompt = MATH_VERIFY_PROMPT.format(
        query=query,
        gold_ans=ground_truth,
        pred_ans=model_answer,
    )

    response = ""
    for it in range(8):
        try:
            chat_response = client.chat.completions.create(
                model="Qwen3",
                messages=[
                    {"role": "user", "content": full_prompt},
                ],
                seed = random.randint(0, 1000000),
                temperature=0.5,
            )
            response = chat_response.choices[0].message.content.strip()
            break
        except Exception as e:
            print(f' [ERROR math] generative_verify error: {e}')
            continue

    judgement = response.split('## Equivalence Judgement')[-1].lower()
    if 'true' in judgement and 'false' not in judgement:
        return 1.0
    elif 'false' in judgement and 'true' not in judgement:
        return 0.0
    else:
        print(f' [ERROR math] verify bug output: ')
        return 0.0


if __name__ == "__main__":
    predict_str = '''<think>Here's a breakdown of the 8 cars in the image:
Far left: A dark-colored car, partially visible, parked on the left edge.
Second from left: A light-colored (possibly silver or white) sedan, fully visible, connected to a charging station.
Third from left: A dark-colored sedan, fully visible, connected to a charging station.
Fourth from left (center-left): A dark-colored sedan, fully visible, connected to a charging station.
Center-right (further back): A dark-colored SUV or sedan, partially visible in the background, further down the charging lane.
Third from right: A light-colored (possibly silver or white) sedan, fully visible, connected to a charging station.
Second from right: A dark-colored car, partially visible, parked on the right edge.
Far back (center): A dark-colored sedan or SUV, fully visible in the background, parked further into the facility.</think><answer>8</answer>'''
    ground_truth = "6"
    extra_info = {
        'answer': '6', 
        'id': 0, 
        'image': '',
        'question': 'how many cars in the image?'
    }

    score = compute_score(predict_str, ground_truth, extra_info)
    print(f"Score: {score}")