"""Reward functions for GRPO training."""

import math
import re
from typing import List, Union
import numpy as np

from math_verify import parse, verify

ACCURACY_REWARD_CORRECT=1.0
ACCURACY_REWARD_WRONG=0.0
# DEEPSEEKFORMAT_REWARD_CORRECT=0.25
# DEEPSEEKFORMAT_REWARD_WRONG=0.0


def accuracy_reward(solution, ground_truth, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""

    # We always take the final solution
    if "</think>" in solution:
        solution = solution.split("</think>")[1]
    
    # 0 in case parsing cannot be completed
    try:
        math_verify_parsed = parse(solution, parsing_timeout=5)
    except Exception:
        return ACCURACY_REWARD_WRONG
    
    # 0 if parsing is problematic
    if len(math_verify_parsed) < 2:
        return ACCURACY_REWARD_WRONG
    
    # We perform a quick string match first
    if math_verify_parsed[1] in ground_truth:
        return ACCURACY_REWARD_CORRECT
    
    # We now fallback to semantic verification
    for gt in ground_truth:
        try:
            if verify(
                parse(f"\\boxed{{{gt}}}", parsing_timeout=5),
                math_verify_parsed,
                timeout_seconds=5,
            ):
                return ACCURACY_REWARD_CORRECT
        except Exception:
            continue
    
    # Very unlikely to be correct after the above matches
    return ACCURACY_REWARD_WRONG




def rlcot_reward_fn(prompt_str: str, solution_str: str, ground_truth: Union[str, List[str]], enable_llm = False):

    if isinstance(ground_truth, (str, float, int)):
        ground_truth = [ground_truth]
    elif isinstance(ground_truth, list) and isinstance(ground_truth[0], np.ndarray):
        ground_truth = ground_truth[0].tolist()

    # is_correct = None
    reward_details = {}

    reward = accuracy_reward(solution=solution_str, ground_truth=ground_truth)
    reward_details["accuracy"] = reward
    
    return reward, reward_details

