import re
from typing import Any
import random
import os
import time
from openai import OpenAI
from typing import Optional
from mathruler.grader import grade_answer
from concurrent.futures import ThreadPoolExecutor, as_completed # 导入并行处理模块

openai_api_key = "uKkw9zzvcRbHq0cOtoO8F1dVGQrj8kGK" 
# api_base = "https://antchat.alipay.com/v1"
# model = "Qwen3-32B"
api_base = "http://localhost:18901/v1"
# api_base = "http://172.23.114.3:18901/v1"
model = "judge"
client = OpenAI(
    api_key=openai_api_key,
    base_url=api_base,
)



def get_prompt(_question, _answer, _pred):
    # 假设这个路径是正确的，并且文件存在
    with open('./examples/reward_function/verify_prompt.md', 'r', encoding='utf-8') as file:
        judge_system_prompt = file.read()
    judge_user_prompt = """
    [问题]:{question}
    [参考答案]:{answer}
    [模型回答]:{prediction}
    """

    full_prompt = judge_user_prompt.format(
            question=_question,
            answer=_answer,
            prediction=_pred
        )
    return judge_system_prompt, full_prompt


# Metadata
REWARD_NAME = "perceptual"
REWARD_TYPE = "batch"

SUBSTITUTIONS = [
    ("an ", ""),
    ("a ", ""),
    (".$", "$"),
    ("\\$", ""),
    (r"\ ", ""),
    (" ", ""),
    ("mbox", "text"),
    (",\\text{and}", ","),
    ("\\text{and}", ","),
    ("\\text{m}", "\\text{}"),
]

REMOVED_EXPRESSIONS = [
    "square",
    "ways",
    "integers",
    "dollars",
    "mph",
    "inches",
    "hours",
    "km",
    "units",
    "\\ldots",
    "sue",
    "points",
    "feet",
    "minutes",
    "digits",
    "cents",
    "degrees",
    "cm",
    "gm",
    "pounds",
    "meters",
    "meals",
    "edges",
    "students",
    "childrentickets",
    "multiples",
    "\\text{s}",
    "\\text{.}",
    "\\text{\ns}",
    "\\text{}^2",
    "\\text{}^3",
    "\\text{\n}",
    "\\text{}",
    r"\mathrm{th}",
    r"^\circ",
    r"^{\circ}",
    r"\;",
    r",\!",
    "{,}",
    '"',
    "\\dots",
]


def normalize_final_answer(final_answer: str) -> str:
    """Normalize a final answer to a quantitative reasoning question.

    Args:
        final_answer: The answer string to normalize

    Returns:
        Normalized answer string
    """
    final_answer = final_answer.split("=")[-1]

    # Apply substitutions and removals
    for before, after in SUBSTITUTIONS:
        final_answer = final_answer.replace(before, after)
    for expr in REMOVED_EXPRESSIONS:
        final_answer = final_answer.replace(expr, "")

    # Extract and normalize LaTeX math
    final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
    final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)

    # Normalize shorthand TeX:
    #  \fracab -> \frac{a}{b}
    #  \frac{abc}{bef} -> \frac{abc}{bef}
    #  \fracabc -> \frac{a}{b}c
    #  \sqrta -> \sqrt{a}
    #  \sqrtab -> sqrt{a}b
    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
    final_answer = final_answer.replace("$", "")

    # Normalize numbers
    if final_answer.replace(",", "").isdigit():
        final_answer = final_answer.replace(",", "")

    return final_answer.strip()


def check_format_match(response: str) -> bool:
    pattern = re.compile(r"<answer>.*?</answer>", re.DOTALL)
    format_match = re.search(pattern, response)
    return True if format_match else False

def string_match(response: str, ground_truth: str) -> float:
    match = re.findall(r"(?i)Answer\s*:\s*([^\n]+)", response)
    answer = match[-1] if match else "[INVALID]"
    if normalize_final_answer(answer) == normalize_final_answer(ground_truth):
        return True
    else:
        return False

def compute_score(reward_inputs: list[dict[str, Any]], format_weight: float = 0.0) -> list[dict[str, float]]:
    # 初始化一个与 reward_inputs 相同长度的列表，用于存储最终分数
    # 这样可以确保结果的顺序与输入一致
    final_scores_list = [None] * len(reward_inputs) 
    
    # 用于存储需要并发执行的任务
    tasks_to_run = [] # 存储 (函数, 参数元组, 原始索引, format_score)
    
    for idx, reward_input in enumerate(reward_inputs):
        format_match = check_format_match(reward_input["response"])
        content_match = re.search(r"<answer>(.*?)</answer>", reward_input["response"])
        answer_text = content_match.group(1).strip() if content_match else reward_input["response"].strip()
        question = reward_input["extra_info"]['question']
        format_score = 1.0 if format_match else 0.0
        # 预先处理可以立即确定的分数
        if not answer_text:
            final_scores_list[idx] = {
                "overall": 0.0,
                "format": 0.0,
                "accuracy": 0.0,
            }
            continue 

        if answer_text and len(answer_text) >= 300:
            accuracy_score = 0.0
            final_scores_list[idx] = {
                "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
                "format": format_score,
                "accuracy": accuracy_score,
            }
            continue 

        if string_match(answer_text, reward_input["ground_truth"]):
            accuracy_score = 1.0
            final_scores_list[idx] = {
                "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
                "format": format_score,
                "accuracy": accuracy_score,
            }
            continue 

        if grade_answer(answer_text, reward_input["ground_truth"]):
            accuracy_score = 1.0
            final_scores_list[idx] = {
                "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
                "format": format_score,
                "accuracy": accuracy_score,
            }
            continue 
        # 将需要调用 judge model 的任务添加到列表中
        if reward_input["data_source"] == "reason":
            tasks_to_run.append((compute_score_math, (question, reward_input["ground_truth"], answer_text), idx, format_score))
        else: # "general"
            tasks_to_run.append((compute_score_general, (answer_text, reward_input["ground_truth"], question), idx, format_score))

    # 使用 ThreadPoolExecutor 并发执行任务
    # max_workers 可以根据你的服务器性能和API限制进行调整
    with ThreadPoolExecutor(max_workers=32) as executor: # 适当增加 max_workers
        future_to_task_info = {}
        for func, args, original_idx, format_s in tasks_to_run:
            future = executor.submit(func, *args)
            future_to_task_info[future] = (original_idx, format_s)

        for future in as_completed(future_to_task_info):
            original_idx, format_s = future_to_task_info[future]
            try:
                accuracy_score = future.result()
            except Exception as exc:
                print(f'Generated an exception for task {original_idx}: {exc}')
                accuracy_score = 0.0 # 发生错误时，将准确率设为0

            # 将结果存储到正确的位置
            final_scores_list[original_idx] = {
                "overall": (1 - format_weight) * accuracy_score + format_weight * format_s,
                "format": format_s,
                "accuracy": accuracy_score,
            }
            print(f"original_idx:{original_idx}. accuracy:{accuracy_score}. format:{format_s}")

    return final_scores_list


def compute_score_general(predict_str: str, ground_truth: str, question_text) -> float:
    system_prompt, full_prompt = get_prompt(question_text, ground_truth, predict_str)


    for it in range(5): 
        try:
            chat_response = client.chat.completions.create(
                model=model, # 确保这是正确的模型名称
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": full_prompt},
                ],
                seed = random.randint(0, 1000000),
                temperature=0.3,
                max_tokens=8192,
            )
            response = chat_response.choices[0].message.content.strip()
            break
        except Exception as e:
            print(f' [ERROR general] generative_verify error: {e}')
            # 可以在这里添加等待时间，避免频繁重试导致API限制
            # import time
            time.sleep(5
            ) 
            continue
    _score = 0
    f_response = response
    # 假设这里的 `<最终结果>` 和 `boxed{}` 是模型返回的特定格式
    if '<最终结果>' in f_response:
        # 使用正则表达式更健壮地匹配 <最终结果>...</最终结果>
        match = re.search(r'<最终结果>(.*?)</最终结果>', f_response, re.DOTALL)
        if match:
            f_response = match.group(1).strip()
        else:
            # 如果没有匹配到闭合标签，则从 <最终结果> 之后开始取
            f_response = f_response.split('<最终结果>')[-1].strip()

    if 'boxed' in f_response:
        # 使用正则表达式匹配 boxed{...}
        match = re.search(r'boxed\{(.*?)\}', f_response, re.DOTALL)
        if match:
            f_response = match.group(1).strip()
        else:
            # 如果没有匹配到闭合括号，则从 boxed{ 之后开始取
            f_response = f_response.split('boxed{')[-1].strip()

    if 'Yes' in f_response:
        _score = 1
    else:
        _score = 0
    
    acc_reward = 1.0 if _score else 0.0
    print(f'DEBUG JUDGE GENERAL {f_response=} {_score=} response: {f_response}')
    return acc_reward

MATH_VERIFY_PROMPT = """# CONTEXT #
I am a teacher, and I have some high-level math problems. I am tasked with evaluating the correctness of a student's answer.
Below, I am provided with a problem and a reference answer. Additionally, a student's answer is provided. My job is to assess whether the student's answer captures the same meaning as the reference answer, even when expressed with different wording or format.

# OBJECTIVE #
I need you to judge whether the student's answer is correct given the ground truth answer.

Your tasks include:
1. Identify Mathematical or Notational Equivalence: Pay special attention to any LaTeX expressions in both answers. Confirm that the mathematical relationships, variables, and operations conveyed are equivalent.

# TONE #
Professional, scientific.

# RESPONSE: MARKDOWN REPORT #
## Equivalence Judgement
[Whether the student's answer share the same meaning with the reference answer. (TRUE or FALSE)]

# ATTENTION #
 - The reference answer is ALWAYS correct. You should carefully judge whether the student gives the same answer as reference answer.
 - The Equivalence Judgement is only TRUE or FALSE. The answer is FALSE even if the student's final answer almost correct with a minor mistakes.
 - Don't give extra explanation.

**Question**:
{query}

**Reference Answer**
{gold_ans}

## Student Final Answer
{pred_ans}"""

def compute_score_math(query: str, ground_truth: str, model_answer: str) -> float:
    full_prompt = MATH_VERIFY_PROMPT.format(
        query=query,
        gold_ans=ground_truth,
        pred_ans=model_answer,
    )

    response = ""
    # 这里的重试逻辑也可以考虑提取出来或者统一管理
    for it in range(5): 
        try:
            chat_response = client.chat.completions.create(
                model=model, # 确保这是正确的模型名称
                messages=[
                    {"role": "user", "content": full_prompt},
                ],
                seed = random.randint(0, 1000000),
                temperature=0.5,
            )
            response = chat_response.choices[0].message.content.strip()
            break
        except Exception as e:
            print(f' [ERROR math] generative_verify error: {e}')
            # 可以在这里添加等待时间，避免频繁重试导致API限制
            # import time
            time.sleep(5) 
            continue

    judgement = response.split('## Equivalence Judgement')[-1].lower()
    if 'true' in judgement and 'false' not in judgement:
        print(f'DEBUG JUDGE MATH True: {response}')
        return 1.0
    elif 'false' in judgement and 'true' not in judgement:
        print(f'DEBUG JUDGE MATH False: {response}')
        return 0.0
    else:
        print(f' [ERROR math] verify bug output: {response}')
        return 0.0
