"""Split reasoning traces into spans of 80-250 tokens.

Respects paragraph boundaries and discourse markers to produce spans
suitable for primitive classification.
"""
from __future__ import annotations

import re
from dataclasses import dataclass

# ---------------------------------------------------------------------------
# Reasoning extraction
# ---------------------------------------------------------------------------

_REASONING_RE = re.compile(r"<reasoning>(.*?)</reasoning>", re.DOTALL)


def extract_reasoning(response: str) -> str | None:
    """Extract <reasoning>...</reasoning> section from a response."""
    m = _REASONING_RE.search(response)
    return m.group(1).strip() if m else None


# ---------------------------------------------------------------------------
# Discourse markers that trigger span boundaries
# ---------------------------------------------------------------------------

_DISCOURSE_MARKERS = re.compile(
    r"(?:^|\n)"  # start of line
    r"(?:"
    r"(?:Case|Step|Option)\s+\d"  # Case 1, Step 2, Option 3
    r"|(?:First|Second|Third|Next|Finally|Now|Then)\s*[,:]"
    r"|(?:Suppose|Assume|Let(?:'s| us)|Consider)\s"
    r"|(?:Check|Verify|Confirm|Let me check|Does this)\s"
    r"|(?:Wait|Actually|Hmm|But wait|Hold on)\s"
    r"|(?:Instead|Alternatively|Let me try|Going back|Start(?:ing)? over)\s"
    r"|(?:Therefore|Hence|Thus|So the answer|In summary|We conclude)\s"
    r"|(?:Contradiction|This (?:doesn't|does not) work|That's wrong)\s"
    r")",
    re.IGNORECASE | re.MULTILINE,
)


# ---------------------------------------------------------------------------
# Span dataclass
# ---------------------------------------------------------------------------

@dataclass
class Span:
    """A segment of a reasoning trace."""
    span_id: int
    text: str
    start_char: int
    end_char: int
    n_tokens: int


# ---------------------------------------------------------------------------
# Tokenizer helper
# ---------------------------------------------------------------------------

_tokenizer_cache = {}


def _get_tokenizer(tokenizer_name: str):
    """Lazy-load a HuggingFace tokenizer."""
    if tokenizer_name not in _tokenizer_cache:
        from transformers import AutoTokenizer
        _tokenizer_cache[tokenizer_name] = AutoTokenizer.from_pretrained(
            tokenizer_name, trust_remote_code=True
        )
    return _tokenizer_cache[tokenizer_name]


def _estimate_tokens(text: str) -> int:
    """Fast character-based token estimate (~4 chars per token for English).

    Used for merge/split decisions during segmentation.
    Much faster than calling the tokenizer.
    """
    return max(1, len(text) // 4)


def _count_tokens(text: str, tokenizer) -> int:
    """Count tokens in text using a HuggingFace tokenizer.

    Only used for final span token counts, not for merge/split decisions.
    """
    if tokenizer is None:
        return _estimate_tokens(text)
    return len(tokenizer.encode(text, add_special_tokens=False))


# ---------------------------------------------------------------------------
# Core segmentation
# ---------------------------------------------------------------------------

def _split_on_paragraphs(text: str) -> list[tuple[int, str]]:
    """Split text on double-newlines, returning (start_char, fragment) pairs."""
    fragments = []
    pos = 0
    for part in re.split(r"\n\s*\n", text):
        part = part.strip()
        if part:
            start = text.find(part, pos)
            if start == -1:
                start = pos
            fragments.append((start, part))
            pos = start + len(part)
    return fragments


def _split_on_discourse_markers(text: str, start_offset: int) -> list[tuple[int, str]]:
    """Further split a paragraph fragment on discourse marker boundaries."""
    matches = list(_DISCOURSE_MARKERS.finditer(text))
    if not matches:
        return [(start_offset, text)]

    fragments = []
    prev_end = 0
    for m in matches:
        # Don't split at position 0 (the fragment already starts there)
        if m.start() <= prev_end:
            continue
        before = text[prev_end:m.start()].strip()
        if before:
            fragments.append((start_offset + prev_end, before))
        prev_end = m.start()

    # Last fragment
    rest = text[prev_end:].strip()
    if rest:
        fragments.append((start_offset + prev_end, rest))

    return fragments if fragments else [(start_offset, text)]


def _split_sentences(text: str) -> list[str]:
    """Split text into sentences on sentence-terminal punctuation."""
    parts = re.split(r"(?<=[.!?])\s+", text)
    return [p.strip() for p in parts if p.strip()]


def segment_trace(
    text: str,
    tokenizer_name: str = "allenai/OLMo-3-7B-Instruct-SFT",
    min_tokens: int = 80,
    max_tokens: int = 250,
) -> list[Span]:
    """Segment a reasoning trace into spans of min_tokens to max_tokens.

    Algorithm:
    1. Split on paragraph boundaries (double newline)
    2. Further split on discourse markers
    3. Merge short fragments until min_tokens reached
    4. Split long fragments on sentence boundaries
    """
    if not text or not text.strip():
        return []

    # Use fast estimate for merge/split decisions, real tokenizer only for final counts
    use_fast = tokenizer_name is not None
    tokenizer = None  # Lazy-load only if needed for final counts

    # Step 1+2: Split into raw fragments
    raw_fragments: list[tuple[int, str]] = []
    for para_start, para_text in _split_on_paragraphs(text):
        raw_fragments.extend(_split_on_discourse_markers(para_text, para_start))

    if not raw_fragments:
        return []

    # Step 3+4: Merge short fragments, split long ones
    # Use fast char-based estimate for merge/split decisions
    frag_tokens = [
        (start, frag, _estimate_tokens(frag))
        for start, frag in raw_fragments
    ]

    # Greedy merge/split
    spans: list[Span] = []
    span_id = 0

    buf_text = ""
    buf_start = frag_tokens[0][0] if frag_tokens else 0
    buf_tokens = 0

    def _flush_buffer():
        nonlocal span_id, buf_text, buf_start, buf_tokens
        if not buf_text.strip():
            buf_text = ""
            buf_tokens = 0
            return
        # If buffer exceeds max_tokens, split on sentences
        if buf_tokens > max_tokens:
            _split_and_emit(buf_text, buf_start, buf_tokens)
        else:
            spans.append(Span(
                span_id=span_id,
                text=buf_text.strip(),
                start_char=buf_start,
                end_char=buf_start + len(buf_text),
                n_tokens=buf_tokens,
            ))
            span_id += 1
        buf_text = ""
        buf_tokens = 0

    def _split_and_emit(text_block: str, start: int, total_tokens: int):
        nonlocal span_id
        sentences = _split_sentences(text_block)
        if len(sentences) <= 1:
            # Can't split further — emit as-is
            spans.append(Span(
                span_id=span_id,
                text=text_block.strip(),
                start_char=start,
                end_char=start + len(text_block),
                n_tokens=total_tokens,
            ))
            span_id += 1
            return

        cur_text = ""
        cur_start = start
        cur_tokens = 0
        for sent in sentences:
            sent_tokens = _estimate_tokens(sent)
            if cur_tokens + sent_tokens > max_tokens and cur_text:
                spans.append(Span(
                    span_id=span_id,
                    text=cur_text.strip(),
                    start_char=cur_start,
                    end_char=cur_start + len(cur_text),
                    n_tokens=cur_tokens,
                ))
                span_id += 1
                cur_start = cur_start + len(cur_text)
                cur_text = sent + " "
                cur_tokens = sent_tokens
            else:
                cur_text += sent + " "
                cur_tokens += sent_tokens

        if cur_text.strip():
            spans.append(Span(
                span_id=span_id,
                text=cur_text.strip(),
                start_char=cur_start,
                end_char=cur_start + len(cur_text),
                n_tokens=cur_tokens,
            ))
            span_id += 1

    for start, frag, n_tok in frag_tokens:
        if buf_tokens + n_tok >= min_tokens:
            if buf_tokens >= min_tokens:
                # Buffer already big enough — flush it, start new buffer
                _flush_buffer()
                buf_text = frag + "\n"
                buf_start = start
                buf_tokens = n_tok
            else:
                # Adding this fragment completes the buffer
                buf_text += frag + "\n"
                buf_tokens += n_tok
                _flush_buffer()
                buf_start = start + len(frag)
        else:
            # Fragment too short — accumulate
            if not buf_text:
                buf_start = start
            buf_text += frag + "\n"
            buf_tokens += n_tok

    # Flush remaining buffer
    if buf_text.strip():
        # If the last buffer is very short, merge with the previous span
        if spans and buf_tokens < min_tokens // 2:
            prev = spans[-1]
            merged_text = prev.text + "\n" + buf_text.strip()
            merged_tokens = prev.n_tokens + buf_tokens
            spans[-1] = Span(
                span_id=prev.span_id,
                text=merged_text,
                start_char=prev.start_char,
                end_char=buf_start + len(buf_text),
                n_tokens=merged_tokens,
            )
        else:
            _flush_buffer()

    return spans


def segment_response(
    response: str,
    tokenizer_name: str = "allenai/OLMo-3-7B-Instruct-SFT",
    min_tokens: int = 80,
    max_tokens: int = 250,
) -> list[Span]:
    """Segment a full model response, extracting <reasoning> if present."""
    text = extract_reasoning(response)
    if text is None:
        # No <reasoning> tags — use the full response
        text = response
    return segment_trace(text, tokenizer_name, min_tokens, max_tokens)
