from hallucinations.features.math_eval import (
    clean_latex,
    extract_answer,
    extract_boxed_answer,
    extract_final_answer,
    extract_final_answer_numeric,
    is_equiv,
    is_fraction_equiv,
    is_numeric_equiv,
    normalize_answer,
    parse_fraction,
    parse_number,
)


class TestExtractBoxedAnswer:
    def test_simple_boxed(self) -> None:
        assert extract_boxed_answer(r"The answer is \boxed{42}") == "42"

    def test_boxed_with_spaces(self) -> None:
        assert extract_boxed_answer(r"\boxed { 42 }") == "42"

    def test_boxed_in_dollar(self) -> None:
        assert extract_boxed_answer(r"$\boxed{42}$") == "42"

    def test_nested_braces(self) -> None:
        assert extract_boxed_answer(r"\boxed{\frac{1}{2}}") == r"\frac{1}{2}"

    def test_multiple_boxed_returns_last(self) -> None:
        assert extract_boxed_answer(r"\boxed{1} and \boxed{2}") == "2"

    def test_fbox(self) -> None:
        assert extract_boxed_answer(r"\fbox{42}") == "42"

    def test_no_boxed_returns_none(self) -> None:
        assert extract_boxed_answer("no boxed answer here") is None

    def test_none_input(self) -> None:
        assert extract_boxed_answer(None) is None  # type: ignore[arg-type]


class TestExtractFinalAnswer:
    def test_simple(self) -> None:
        assert extract_final_answer("The final answer is 42.") == "42"

    def test_with_colon(self) -> None:
        assert extract_final_answer("The final answer is: 42.") == "42"

    def test_case_insensitive(self) -> None:
        assert extract_final_answer("THE FINAL ANSWER IS 42") == "42"

    def test_without_the(self) -> None:
        assert extract_final_answer("final answer is 42") == "42"

    def test_no_match(self) -> None:
        assert extract_final_answer("the answer is 42") is None


class TestExtractFinalAnswerNumeric:
    def test_simple_number(self) -> None:
        assert extract_final_answer_numeric("The final answer is 42.") == "42"

    def test_with_commas(self) -> None:
        assert extract_final_answer_numeric("The final answer is 1,234.") == "1234"

    def test_with_dollar(self) -> None:
        assert extract_final_answer_numeric("The final answer is $50.") == "50"

    def test_negative(self) -> None:
        assert extract_final_answer_numeric("The final answer is -42.") == "-42"

    def test_decimal(self) -> None:
        assert extract_final_answer_numeric("The final answer is 3.14.") == "3.14"


class TestExtractAnswer:
    def test_prefer_boxed_true(self) -> None:
        text = r"The final answer is 1. Therefore \boxed{2}"
        assert extract_answer(text, prefer_boxed=True) == "2"

    def test_prefer_boxed_false(self) -> None:
        text = r"The final answer is 1. Therefore \boxed{2}"
        assert extract_answer(text, prefer_boxed=False) == "1"

    def test_fallback_to_final_answer(self) -> None:
        text = "The final answer is 42."
        assert extract_answer(text, prefer_boxed=True) == "42"

    def test_fallback_to_boxed(self) -> None:
        text = r"\boxed{42}"
        assert extract_answer(text, prefer_boxed=False) == "42"


class TestCleanLatex:
    def test_text_command(self) -> None:
        assert clean_latex(r"\text{hello}") == "hello"

    def test_textbf_command(self) -> None:
        assert clean_latex(r"\textbf{bold}") == "bold"

    def test_left_right(self) -> None:
        assert clean_latex(r"\left(\right)") == "()"

    def test_spacing(self) -> None:
        assert clean_latex(r"a\,b\;c") == "abc"


class TestIsEquiv:
    def test_exact_match(self) -> None:
        assert is_equiv("42", "42") is True

    def test_numeric_equiv(self) -> None:
        assert is_equiv("42.0", "42") is True

    def test_fraction_equiv(self) -> None:
        assert is_equiv("1/2", "0.5") is True

    def test_normalized_match(self) -> None:
        assert is_equiv("  42  ", "42") is True

    def test_not_equiv(self) -> None:
        assert is_equiv("42", "43") is False

    def test_none_pred(self) -> None:
        assert is_equiv(None, "42") is False

    def test_none_gold(self) -> None:
        assert is_equiv("42", None) is False


class TestIsNumericEquiv:
    def test_integers(self) -> None:
        assert is_numeric_equiv("42", "42") is True

    def test_floats(self) -> None:
        assert is_numeric_equiv("3.14159", "3.14159") is True

    def test_int_float(self) -> None:
        assert is_numeric_equiv("42", "42.0") is True

    def test_with_commas(self) -> None:
        assert is_numeric_equiv("1,000", "1000") is True

    def test_with_dollar(self) -> None:
        assert is_numeric_equiv("$100", "100") is True

    def test_percent(self) -> None:
        assert is_numeric_equiv("50%", "0.5") is True

    def test_not_equiv(self) -> None:
        assert is_numeric_equiv("42", "43") is False

    def test_non_numeric(self) -> None:
        assert is_numeric_equiv("abc", "42") is False


class TestIsFractionEquiv:
    def test_same_fraction(self) -> None:
        assert is_fraction_equiv("1/2", "1/2") is True

    def test_equivalent_fractions(self) -> None:
        assert is_fraction_equiv("2/4", "1/2") is True

    def test_fraction_and_decimal(self) -> None:
        assert is_fraction_equiv("1/2", "0.5") is True

    def test_latex_fraction(self) -> None:
        assert is_fraction_equiv(r"\frac{1}{2}", "0.5") is True

    def test_not_equiv(self) -> None:
        assert is_fraction_equiv("1/2", "1/3") is False


class TestParseNumber:
    def test_integer(self) -> None:
        assert parse_number("42") == 42.0

    def test_float(self) -> None:
        assert parse_number("3.14") == 3.14

    def test_negative(self) -> None:
        assert parse_number("-42") == -42.0

    def test_with_commas(self) -> None:
        assert parse_number("1,234,567") == 1234567.0

    def test_with_dollar(self) -> None:
        assert parse_number("$100") == 100.0

    def test_percent(self) -> None:
        assert parse_number("50%") == 0.5

    def test_invalid(self) -> None:
        assert parse_number("abc") is None


class TestParseFraction:
    def test_simple_fraction(self) -> None:
        from fractions import Fraction

        assert parse_fraction("1/2") == Fraction(1, 2)

    def test_negative_fraction(self) -> None:
        from fractions import Fraction

        assert parse_fraction("-1/2") == Fraction(-1, 2)

    def test_latex_frac(self) -> None:
        from fractions import Fraction

        assert parse_fraction(r"\frac{3}{4}") == Fraction(3, 4)

    def test_decimal_to_fraction(self) -> None:
        from fractions import Fraction

        assert parse_fraction("0.5") == Fraction(1, 2)

    def test_invalid(self) -> None:
        assert parse_fraction("abc") is None

    def test_zero_denominator(self) -> None:
        assert parse_fraction("1/0") is None


class TestNormalizeAnswer:
    def test_lowercase(self) -> None:
        assert normalize_answer("ABC") == "abc"

    def test_remove_whitespace(self) -> None:
        assert normalize_answer("a b c") == "abc"

    def test_remove_punctuation(self) -> None:
        assert normalize_answer("42.") == "42"

    def test_remove_dollar(self) -> None:
        assert normalize_answer("$42$") == "42"


class TestGSM8KCompatibility:
    """Tests demonstrating GSM8K-style answer extraction and comparison."""

    def test_gsm8k_final_answer_pattern(self) -> None:
        text = "So we have 5 + 3 = 8. The final answer is 8."
        assert extract_final_answer_numeric(text) == "8"

    def test_gsm8k_with_dollar_sign(self) -> None:
        text = "The total cost is $150. The final answer is $150."
        assert extract_final_answer_numeric(text) == "150"

    def test_gsm8k_with_commas(self) -> None:
        text = "The population is 1,234,567. The final answer is 1,234,567."
        assert extract_final_answer_numeric(text) == "1234567"

    def test_gsm8k_equivalence(self) -> None:
        # GSM8K gold answers are after ####
        gold = "72"
        pred = "72"
        assert is_equiv(pred, gold) is True

    def test_gsm8k_equiv_with_formatting(self) -> None:
        gold = "72"
        pred = "72.0"
        assert is_equiv(pred, gold) is True

    def test_gsm8k_extract_and_compare(self) -> None:
        text = "Step 1: 5 * 10 = 50. Step 2: 50 + 22 = 72. The final answer is 72."
        gold = "72"
        pred = extract_answer(text, prefer_boxed=False)
        assert is_equiv(pred, gold) is True


class TestOMEGA500Compatibility:
    """Tests demonstrating OMEGA-500-style answer extraction and comparison."""

    def test_omega_boxed_integer(self) -> None:
        text = r"Therefore, the answer is \boxed{28}"
        assert extract_boxed_answer(text) == "28"

    def test_omega_boxed_negative(self) -> None:
        text = r"The result is \boxed{-777}"
        assert extract_boxed_answer(text) == "-777"

    def test_omega_boxed_fraction(self) -> None:
        text = r"The probability is \boxed{-93/223}"
        assert extract_boxed_answer(text) == "-93/223"

    def test_omega_equivalence_fraction(self) -> None:
        gold = "-93/223"
        pred = "-93/223"
        assert is_equiv(pred, gold) is True

    def test_omega_extract_and_compare(self) -> None:
        text = r"After calculation, \boxed{48859}"
        gold = "48859"
        pred = extract_answer(text, prefer_boxed=True)
        assert is_equiv(pred, gold) is True
