#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Dict, List, Union
from recipe.ucpo.reward_fn.math_grader import think_boxed_reward_fn


def _evaluate_single(solution: str, ground_truth: str) -> Dict[str, object]:
    """
    Evaluate a single (solution, ground_truth) pair.

    Scoring rules:
    - If output format is invalid: score = -1, acc = False
    - If format is valid:
        - acc is None (uncertain): score = 0.8, isuc = True
        - acc is True: score = 1
        - acc is False: score = 0
    """
    format_info, acc = think_boxed_reward_fn(solution, ground_truth)
    isuc = False

    if not format_info.get("formatted", False):
        return {"score": -1, "acc": False, "isuc": False}

    if acc is None:
        # Will be re-assigned during advantage computation
        return {"score": 0.8, "acc": None, "isuc": True}

    return {"score": 1 if acc else 0, "acc": acc, "isuc": isuc}


def compute_score(
    data_source: Union[str, List[str]],
    solution_str: Union[str, List[str]],
    ground_truth: Union[str, List[str]],
    extra_info: Union[Dict, List[Dict]],
) -> Union[Dict[str, object], List[Dict[str, object]]]:
    """
    Compute scores for either a single example or a batch.

    Note:
    - `data_source` and `extra_info` are kept for API compatibility but are not used.
    """
    if isinstance(data_source, list):
        return [
            _evaluate_single(solution_str[i], ground_truth[i])
            for i in range(len(data_source))
        ]

    return _evaluate_single(solution_str, ground_truth)
