# !/usr/bin/env python
# -*- coding:utf-8 -*-
from typing import Any, Dict, Union, List
import re
import json


def extract_answer_mc(model_response: str):
    """
    从大模型的响应中解析出多选题答案：
    - 返回 'A'/'B'/'C'/'D' 之一，或 'uncertainty'
    - 解析失败返回 None
    """
    if not isinstance(model_response, str):
        return None
        
    text = model_response.strip()

    # 0.0 先处理极简形式：A / A. / A) / 只有一个选项字母
    m_simple = re.fullmatch(r'([A-D])[\.\)]?\s*', text, flags=re.IGNORECASE)
    if m_simple:
        return m_simple.group(1).upper()

    # 0. 预处理：\text{...} 展开
    text_clean = re.sub(r'\\text\{\s*([^{}]+?)\s*\}', r'\1', text)

    # 0.1 额外预处理：把 \boxed{\text{A. ...}} 变成 \boxed{\text{A}}
    # 这样后面的 boxed 兜底规则仍然能复用
    # 同时先单独识别 \boxed{\text{uncertainty}} / \boxed{uncertainty}
    # 若出现则直接作为候选
    boxed_uncertainty_patterns = [
        r'\\boxed\{\s*\\text\{\s*uncertainty\s*\}\s*\}',
        r'\\boxed\{\s*uncertainty\s*\}',
    ]
    boxed_letter_with_prefix_pattern = (
        r'\\boxed\{\s*\\text\{\s*([A-D])\s*[\.\)]?[^}]*\}\s*\}'
    )

    boxed_candidates = []

    # 先收集 boxed uncertainty
    for pat in boxed_uncertainty_patterns:
        if re.search(pat, text_clean, flags=re.IGNORECASE):
            boxed_candidates.append('uncertainty')

    # 把 \boxed{\text{A. ...}} 归一化为 \boxed{\text{A}}
    def _norm_boxed_letter(m):
        letter = m.group(1).upper()
        if letter in ['A', 'B', 'C', 'D']:
            return r'\boxed{\text{' + letter + r'}}'
        return m.group(0)

    text_clean = re.sub(
        boxed_letter_with_prefix_pattern,
        _norm_boxed_letter,
        text_clean,
        flags=re.IGNORECASE
    )

    # 1. 在含有 answer 的局部范围内解析
    answer_token_pattern = r'[*"_\']*\banswer\b[*"_\']*'
    answer_spans = list(re.finditer(answer_token_pattern, text_clean, flags=re.IGNORECASE))

    candidates = []

    for m in answer_spans:
        start = m.start()
        end = min(len(text_clean), m.end() + 300)
        segment = text_clean[start:end]

        # 1.1 uncertainty（包含反引号）
        unc_patterns = [
            r'[*"_\']*\banswer\b[*"_\']*\s*[:：]\s*["`\']?uncertainty["`\']?\b',
            r'\banswer\b[^\n]*?\\boxed\{\s*\\text\{\s*uncertainty\s*\}\s*\}',
            r'\banswer\b[^\n]*?\\boxed\{\s*uncertainty\s*\}',
        ]
        found_uncertainty = False
        for pat in unc_patterns:
            if re.search(pat, segment, flags=re.IGNORECASE):
                candidates.append('uncertainty')
                found_uncertainty = True
                break

        if found_uncertainty:
            continue

        # 1.2 A-D（包含反引号）
        choice_patterns = [
            # answer: C / "answer": "C" / **answer**: `C`
            r'[*"_\']*\banswer\b[*"_\']*\s*[:：]\s*["`\']?([A-D])["`\']?\b',

            # **Answer:** C. / Answer: C) 等
            r'\banswer\b[^\n]*?[:：]\s*["`\']?([A-D])["`\']?\b',

            # answer \boxed{\text{C}}
            r'\banswer\b[^\n]*?\\boxed\{\s*\\text\{\s*([A-D])\s*\}\s*\}',
            # answer \boxed{C}
            r'\banswer\b[^\n]*?\\boxed\{\s*([A-D])\s*\}',

            # Answer C.（无冒号）
            r'\banswer\b[^\n]*?\b([A-D])\b',
        ]

        for pat in choice_patterns:
            m_choice = re.search(pat, segment, flags=re.IGNORECASE)
            if m_choice:
                choice = m_choice.group(1).upper()
                if choice in ['A', 'B', 'C', 'D']:
                    candidates.append(choice)
                    break

    # 若 answer 片段中有答案，按“最后一个”规则
    if candidates:
        return candidates[-1]

    # 2. 如果没有显式 answer，但存在 boxed uncertainty / letter 前缀情况
    if boxed_candidates:
        # 若同时有 letter 兜底，后面会覆盖，这里先不 return

        # 注意：这里不立即返回，而先继续看有没有 boxed A-D，
        # 最终以“最后出现的”规则处理

        pass

    # 2. 兜底：最后一个 boxed A-D 或 uncertainty
    last_choice = None

    # 先扫描 boxed uncertainty / boxed letter
    boxed_patterns = [
        r'\\boxed\{\s*\\text\{\s*uncertainty\s*\}\s*\}',
        r'\\boxed\{\s*uncertainty\s*\}',
        r'\\boxed\{\s*\\text\{\s*([A-D])\s*\}\s*\}',
        r'\\boxed\{\s*([A-D])\s*[\.\)]\s*[^}]*\}',
        r'\\boxed\{\s*([A-D])\s*\}',
    ]

    for pat in boxed_patterns:
        for m in re.finditer(pat, text_clean, flags=re.IGNORECASE):
            if m.lastindex is None:
                # 匹配到的是 uncertainty（无捕获组）
                last_choice = 'uncertainty'
            else:
                ch = m.group(m.lastindex).upper()
                if ch in ['A', 'B', 'C', 'D']:
                    last_choice = ch

    if last_choice is not None:
        return last_choice

    return None



def think_boxed_reward_fn_mc(model_response, gt, fast=False):
    if "</think>" in model_response:
        model_answer = model_response.split("</think>")[-1]
    
    model_answer = extract_answer_mc(model_response)
    if model_answer is None:
        return {"formatted": False,'ans':None}, False  # Cannot even parse anything.

    if model_answer=='uncertainty':
        return {"formatted": True,'ans':'uncertainty'}, None
    
    elif len(model_answer)==1:

        is_correct = (model_answer==gt)
            
        if is_correct:
            return {"formatted": True,'ans':model_answer}, 1.0  # Correctness reward.
        else:
            return {
                "formatted": True,'ans':model_answer
            }, 0.0  # Formatted but wrong answer; no format reward to avoid hacking.
    else:
        return {"formatted": False,'ans':None}, 0.0



def _evaluate_single(solution: str, ground_truth: str) -> Dict[str, object]:
    """
    Evaluate a single (solution, ground_truth) pair.

    Scoring rules:
    - If output format is invalid: score = -1, acc = False
    - If format is valid:
        - acc is None (uncertain): score = 0.8, isuc = True
        - acc is True: score = 1
        - acc is False: score = 0
    """
    format_info, acc = think_boxed_reward_fn_mc(solution, ground_truth)
    isuc = False

    if not format_info.get("formatted", False):
        return {"score": -1, "acc": False, "isuc": False}

    if acc is None:
        # Will be re-assigned during advantage computation
        return {"score": 0.8, "acc": None, "isuc": True}

    return {"score": 1 if acc else 0, "acc": acc, "isuc": isuc}


def compute_score(
    data_source: Union[str, List[str]],
    solution_str: Union[str, List[str]],
    ground_truth: Union[str, List[str]],
    extra_info: Union[Dict, List[Dict]],
) -> Union[Dict[str, object], List[Dict[str, object]]]:
    """
    Compute scores for either a single example or a batch.

    Note:
    - `data_source` and `extra_info` are kept for API compatibility but are not used.
    """
    if isinstance(data_source, list):
        return [
            _evaluate_single(solution_str[i], ground_truth[i])
            for i in range(len(data_source))
        ]

    return _evaluate_single(solution_str, ground_truth)
