from typing import Any

import torch

from hallucinations.features.answer_tokens import (
    FINAL_ANSWER_REGEX,
    build_char_to_token_map,
    clean_numeric_answer,
    find_answer_token_indices,
)


class MockTokenizer:
    """Mock tokenizer for testing."""

    def __init__(self, token_map: dict[int, str], pad_token_id: int = 0):
        self._token_map = token_map
        self.pad_token_id = pad_token_id

    def decode(self, token_ids: list[int], skip_special_tokens: bool = True) -> str:  # noqa: ARG002
        return "".join(self._token_map.get(tid, "") for tid in token_ids)


def mock_tokenizer(token_map: dict[int, str], pad_token_id: int = 0) -> Any:
    """Create mock tokenizer with Any type to satisfy type checker."""
    return MockTokenizer(token_map, pad_token_id)


class TestCleanNumericAnswer:
    def test_removes_commas(self) -> None:
        assert clean_numeric_answer("1,234") == "1234"

    def test_removes_dollar_sign(self) -> None:
        assert clean_numeric_answer("$100") == "100"

    def test_removes_asterisks(self) -> None:
        assert clean_numeric_answer("**42**") == "42"

    def test_removes_trailing_period(self) -> None:
        assert clean_numeric_answer("42.") == "42"

    def test_preserves_decimal(self) -> None:
        assert clean_numeric_answer("3.14") == "3.14"

    def test_combined_cleanup(self) -> None:
        assert clean_numeric_answer("$1,234.00") == "1234.00"


class TestBuildCharToTokenMap:
    def test_simple_tokens(self) -> None:
        token_map = {1: "The", 2: " answer", 3: " is", 4: " 42"}
        tokenizer = mock_tokenizer(token_map)
        tokens = torch.tensor([1, 2, 3, 4])

        positions = build_char_to_token_map(tokens, tokenizer)

        assert len(positions) == 4
        assert positions[0] == (0, 3, "The")
        assert positions[1] == (3, 10, " answer")
        assert positions[2] == (10, 13, " is")
        assert positions[3] == (13, 16, " 42")

    def test_number_split_tokens(self) -> None:
        token_map = {1: "10", 2: ".", 3: "00"}
        tokenizer = mock_tokenizer(token_map)
        tokens = torch.tensor([1, 2, 3])

        positions = build_char_to_token_map(tokens, tokenizer)

        assert len(positions) == 3
        assert positions[0] == (0, 2, "10")
        assert positions[1] == (2, 3, ".")
        assert positions[2] == (3, 5, "00")


class TestFindAnswerTokenIndices:
    def test_simple_integer(self) -> None:
        token_map = {
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: " 42",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([100, 101, 1, 2, 3, 4, 5])  # 2 input tokens
        prediction = "The final answer is 42"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=2,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "42"
        assert result.cleaned_answer == "42"
        assert result.token_indices == [6]  # 5th token after input_length=2
        assert result.token_texts == [" 42"]

    def test_excludes_trailing_period(self) -> None:
        """Trailing period should be excluded from token indices."""
        token_map = {
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: " $",
            6: "18",
            7: ".",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2, 3, 4, 5, 6, 7])
        prediction = "The final answer is $18."

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "$18."
        assert result.cleaned_answer == "18"
        # Period token should NOT be included
        assert result.token_indices == [4, 5]
        assert result.token_texts == [" $", "18"]

    def test_decimal_split(self) -> None:
        token_map = {
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: " 10",
            6: ".",
            7: "00",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([100, 1, 2, 3, 4, 5, 6, 7])  # 1 input token
        prediction = "The final answer is 10.00"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=1,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "10.00"
        assert result.cleaned_answer == "10.00"
        assert result.token_indices == [5, 6, 7]  # Tokens at positions 5, 6, 7 in full seq
        assert result.token_texts == [" 10", ".", "00"]

    def test_currency_answer(self) -> None:
        token_map = {
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: " $",
            6: "100",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2, 3, 4, 5, 6])  # No input tokens
        prediction = "The final answer is $100"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "$100"
        assert result.cleaned_answer == "100"
        assert result.token_indices == [4, 5]
        assert result.token_texts == [" $", "100"]

    def test_negative_number(self) -> None:
        token_map = {
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: " -",
            6: "5",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2, 3, 4, 5, 6])
        prediction = "The final answer is -5"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "-5"
        assert result.token_indices == [4, 5]

    def test_no_match(self) -> None:
        token_map = {1: "Hello", 2: " world"}
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2])
        prediction = "Hello world"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is False
        assert result.raw_answer is None
        assert result.cleaned_answer is None
        assert result.token_indices == []
        assert result.token_texts == []

    def test_with_padding(self) -> None:
        token_map = {
            0: "",  # pad token
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: " 42",
        }
        tokenizer = mock_tokenizer(token_map, pad_token_id=0)
        # Sequence with padding at the end
        full_sequence = torch.tensor([100, 1, 2, 3, 4, 5, 0, 0])
        prediction = "The final answer is 42"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=1,
            tokenizer=tokenizer,
            pad_token_id=0,
        )

        assert result.match_success is True
        assert result.token_indices == [5]  # Should still find correct index
        assert result.token_texts == [" 42"]

    def test_comma_thousands(self) -> None:
        token_map = {
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: " 1",
            6: ",",
            7: "000",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2, 3, 4, 5, 6, 7])
        prediction = "The final answer is 1,000"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "1,000"
        assert result.cleaned_answer == "1000"
        assert result.token_indices == [4, 5, 6]

    def test_case_insensitive(self) -> None:
        token_map = {
            1: "THE",
            2: " FINAL",
            3: " ANSWER",
            4: " IS",
            5: " 42",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2, 3, 4, 5])
        prediction = "THE FINAL ANSWER IS 42"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.token_indices == [4]

    def test_with_colon(self) -> None:
        token_map = {
            1: "The",
            2: " final",
            3: " answer",
            4: " is",
            5: ":",
            6: " 42",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2, 3, 4, 5, 6])
        prediction = "The final answer is: 42"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "42"

    def test_multiple_occurrences_uses_last(self) -> None:
        """When the answer pattern appears multiple times, use the last one."""
        token_map = {
            1: "First",
            2: " the",
            3: " final",
            4: " answer",
            5: " is",
            6: " 10",
            7: ".",
            8: " But",
            9: " wait",
            10: ",",
            11: " the",
            12: " final",
            13: " answer",
            14: " is",
            15: " 42",
        }
        tokenizer = mock_tokenizer(token_map)
        full_sequence = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
        prediction = "First the final answer is 10. But wait, the final answer is 42"

        result = find_answer_token_indices(
            prediction=prediction,
            full_sequence_ids=full_sequence,
            input_length=0,
            tokenizer=tokenizer,
        )

        assert result.match_success is True
        assert result.raw_answer == "42"
        assert result.token_indices == [14]
        assert result.token_texts == [" 42"]


class TestFinalAnswerRegex:
    """Tests for the FINAL_ANSWER_REGEX pattern."""

    def test_regex_matches_simple(self) -> None:
        import re

        text = "The final answer is 42."
        match = re.search(FINAL_ANSWER_REGEX, text)
        assert match is not None
        # Regex captures trailing period; clean_numeric_answer removes it
        assert "42" in match.group(1)

    def test_regex_matches_decimal(self) -> None:
        import re

        text = "The final answer is 3.14."
        match = re.search(FINAL_ANSWER_REGEX, text)
        assert match is not None
        # Regex captures trailing period; clean_numeric_answer removes it
        assert "3.14" in match.group(1)

    def test_regex_matches_currency(self) -> None:
        import re

        text = "The final answer is $100."
        match = re.search(FINAL_ANSWER_REGEX, text)
        assert match is not None
        assert "$100" in match.group(1)

    def test_regex_matches_negative(self) -> None:
        import re

        text = "The final answer is -42."
        match = re.search(FINAL_ANSWER_REGEX, text)
        assert match is not None
        assert "-42" in match.group(1)
