import contextlib
from recipe.multilingual_grpo.src.dapo_multi_naive import compute_score as dapo_compute_score

try:
    from math_verify.metric import math_metric
    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
    print("To use Math-Verify, please install it first by running `pip install math-verify`.")


def verify_compute_score(model_output: str, ground_truth: str) -> bool:
    verify_func = math_metric(
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
    )
    ret_score = 0.0

    # Wrap the ground truth in \boxed{} format for verification
    ground_truth_boxed = "\\boxed{" + ground_truth + "}"
    with contextlib.suppress(Exception):
        ret_score, _ = verify_func([ground_truth_boxed], [model_output])

    return ret_score




def reward_func(data_source, solution_str, ground_truth, extra_info=None):
    if data_source == 'openai/gsm8k':
        NotImplementedError(f"Data source {data_source} is not supported.")
    elif data_source.lower().startswith("math") or data_source.lower().startswith("aime") or data_source.lower().startswith("valid/gsm8k"):
        res = verify_compute_score(solution_str, ground_truth)
    elif data_source.startswith('dapo_reward'):
        res = dapo_compute_score(solution_str, ground_truth, extra_info)
    elif data_source.startswith("valid/dapo_reward"):
        from verl.utils.reward_score import dapo_math
        res = dapo_math.compute_score(solution_str, ground_truth)
    else:
        NotImplementedError(f"Data source {data_source} is not supported.")
    
    
    if isinstance(res, dict):
        return res
    elif isinstance(res, (int, float, bool)):
        return float(res)
    else:
        return float(res[0])





def reward_func_naive(data_source, solution_str, ground_truth, extra_info=None):
    if data_source == 'openai/gsm8k':
        NotImplementedError(f"Data source {data_source} is not supported.")
    elif data_source.lower().startswith("math") or data_source.lower().startswith("aime") or data_source.lower().startswith("valid/gsm8k") or data_source.lower().startswith("valid/sr"):
        res = verify_compute_score(solution_str, ground_truth)
    elif data_source.startswith('dapo_reward'):
        from verl.utils.reward_score import dapo_math
        res = dapo_math.compute_score(solution_str, ground_truth)
    elif data_source.startswith("valid/dapo_reward"):
        from verl.utils.reward_score import dapo_math
        res = dapo_math.compute_score(solution_str, ground_truth)
    else:
        NotImplementedError(f"Data source {data_source} is not supported.")
    
    
    if isinstance(res, dict):
        return res
    elif isinstance(res, (int, float, bool)):
        return float(res)
    else:
        return float(res[0])


def reward_func_math_ratio(data_source, solution_str, ground_truth, extra_info=None):
    if data_source == 'openai/gsm8k':
        NotImplementedError(f"Data source {data_source} is not supported.")
    elif data_source.lower().startswith("math") or data_source.lower().startswith("aime") or data_source.lower().startswith("valid/gsm8k") or data_source.lower().startswith("valid/sr"):
        res = verify_compute_score(solution_str, ground_truth)
    elif data_source.startswith('dapo_reward'):
        from verl.utils.reward_score import dapo_math_ratio
        res = dapo_math_ratio.compute_score(solution_str, ground_truth,extra_info)
    elif data_source.startswith("valid/dapo_reward"):
        from verl.utils.reward_score import dapo_math
        res = dapo_math.compute_score(solution_str, ground_truth)
    else:
        NotImplementedError(f"Data source {data_source} is not supported.")
    
    
    if isinstance(res, dict):
        return res
    elif isinstance(res, (int, float, bool)):
        return float(res)
    else:
        return float(res[0])
