import re
from dataclasses import dataclass
from typing import cast

from torch import Tensor
from transformers import PreTrainedTokenizerBase

# Credit: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot-llama.yaml
FINAL_ANSWER_REGEX = r"(?i)The final answer is[:\s]*([*]*\s*-?[$0-9.,]*[0-9]+[$0-9.,]*\s*[*]*)"
IGNORE_ANSWER_PATTERNS = [
    r",",
    r"\$",
    r"(?s).*#### ",
    r"\.$",
    r"\*",
]


@dataclass
class AnswerTokenInfo:
    sample_idx: int
    raw_answer: str | None
    cleaned_answer: str | None
    token_indices: list[int]
    token_texts: list[str]
    match_success: bool


def clean_numeric_answer(raw_answer: str) -> str:
    cleaned = raw_answer.strip()
    for pattern in IGNORE_ANSWER_PATTERNS:
        cleaned = re.sub(pattern, "", cleaned)
    return cleaned.strip()


def build_char_to_token_map(
    token_ids: Tensor,
    tokenizer: PreTrainedTokenizerBase,
) -> list[tuple[int, int, str]]:
    """Build (start_char, end_char, token_text) for each token.

    Uses incremental decoding to handle tokenizer-specific spacing behavior.
    """
    positions: list[tuple[int, int, str]] = []
    prev_decoded = ""

    for i in range(len(token_ids)):
        decoded_so_far = cast(
            str,
            tokenizer.decode(token_ids[: i + 1].tolist(), skip_special_tokens=True),
        )
        token_start = len(prev_decoded)
        token_end = len(decoded_so_far)
        token_text = decoded_so_far[token_start:]
        positions.append((token_start, token_end, token_text))
        prev_decoded = decoded_so_far

    return positions


def find_answer_token_indices(
    prediction: str,
    full_sequence_ids: Tensor,
    input_length: int,
    tokenizer: PreTrainedTokenizerBase,
    pad_token_id: int | None = None,
) -> AnswerTokenInfo:
    """Find token indices corresponding to the numeric answer in a GSM8K prediction.

    Uses the LAST occurrence of "The final answer is X" pattern.

    Args:
        prediction: The decoded prediction text.
        full_sequence_ids: Full token sequence from shard (prompt + generated).
        input_length: Where generation starts in full_sequence_ids.
        tokenizer: Tokenizer used for decoding.
        pad_token_id: Padding token ID to filter out. If None, uses tokenizer's pad_token_id.

    Returns:
        AnswerTokenInfo with token indices in full sequence (includes input_length offset).
    """
    matches = list(re.finditer(FINAL_ANSWER_REGEX, prediction))
    if not matches:
        return AnswerTokenInfo(
            sample_idx=-1,
            raw_answer=None,
            cleaned_answer=None,
            token_indices=[],
            token_texts=[],
            match_success=False,
        )

    last_match = matches[-1]
    raw_answer = last_match.group(1).strip()
    answer_start = last_match.start(1)
    answer_end = last_match.end(1)

    # Strip trailing whitespace from span
    while answer_end > answer_start and prediction[answer_end - 1].isspace():
        answer_end -= 1

    # Exclude trailing period from token span (kept in raw_answer for reference)
    if answer_end > answer_start and prediction[answer_end - 1] == ".":
        answer_end -= 1

    gen_token_ids = full_sequence_ids[input_length:]

    if pad_token_id is None:
        pad_token_id = tokenizer.pad_token_id

    if pad_token_id is not None:
        non_pad_mask = gen_token_ids != pad_token_id
        gen_token_ids = gen_token_ids[non_pad_mask]
        non_pad_indices = non_pad_mask.nonzero(as_tuple=True)[0].tolist()
    else:
        non_pad_indices = list(range(len(gen_token_ids)))

    positions = build_char_to_token_map(gen_token_ids, tokenizer)

    answer_indices: list[int] = []
    answer_tokens: list[str] = []

    for i, (start, end, text) in enumerate(positions):
        if end > answer_start and start < answer_end:
            original_idx = input_length + non_pad_indices[i]
            answer_indices.append(original_idx)
            answer_tokens.append(text)

    return AnswerTokenInfo(
        sample_idx=-1,
        raw_answer=raw_answer,
        cleaned_answer=clean_numeric_answer(raw_answer),
        token_indices=answer_indices,
        token_texts=answer_tokens,
        match_success=len(answer_indices) > 0,
    )
