from symeval import EvaluatorMathBatch
from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed

evaluator = EvaluatorMathBatch()



def compute_score(solution_str, ground_truth) -> float:
    retval = 0.0
    try:
        string_in_last_boxed = last_boxed_only_string(solution_str)
        if string_in_last_boxed is not None:
            answer = remove_boxed(string_in_last_boxed)
            if eval_math_symeval(ground_truth, answer):
                retval = 1.0
    except Exception as e:
        print(e)
        
    return retval



def eval_math_symeval(gt_answer, pred_answer):
    score = evaluator.batch_eq(ref_answers=[gt_answer], pred_answers=[pred_answer])[0]
    return score

import unittest

class TestComputeScore(unittest.TestCase):
    def test_identical_options(self):
        # Test case 1: Both options are "A"
        score = compute_score("\\boxed{A}", "A")
        self.assertEqual(score, 1.0, "Should return 1.0 for identical options A and A")
        
        # Test case 2: Both options are "B"
        score = compute_score("\\boxed{B}", "B")
        self.assertEqual(score, 1.0, "Should return 1.0 for identical options B and B")
    
    def test_different_options(self):
        # Test case 3: Options are "A" and "B"
        score = compute_score("\\boxed{A}", "B")
        self.assertEqual(score, 0.0, "Should return 0.0 for different options A and B")
        
        # Test case 4: Options are "1" and "2"
        score = compute_score("\\boxed{1}", "2")
        self.assertEqual(score, 0.0, "Should return 0.0 for different options 1 and 2")
    
    def test_empty_input(self):
        # Test case 5: Empty input
        score = compute_score("", "")
        self.assertEqual(score, 0.0, "Should return 0.0 for empty inputs")
        
    def test_invalid_input(self):
        # Test case 6: Invalid input format
        score = compute_score("A", "A")  # No boxed format
        self.assertEqual(score, 0.0, "Should return 0.0 for invalid input format")

if __name__ == "__main__":
    unittest.main()