"""Paragraph + length-normalization segmenter (no discourse markers).

Drops the discourse-marker step from analysis.exploration.segmentation that
caused tautological alignment with the heuristic classifier's trigger
phrases. Splits on `\\n\\n`, merges short paragraphs, splits long ones on
sentence boundaries — same length contract (80–250 tokens by default) so
comparisons against the v1 segmenter are apples-to-apples.

Intentionally lives under llm_validation/ rather than replacing
analysis/exploration/segmentation.py, so we can run side-by-side without
invalidating the existing parquet.
"""
from __future__ import annotations

import re

from analysis.exploration.segmentation import (
    Span,
    _estimate_tokens,
    _split_sentences,
    extract_reasoning,
)


def _split_on_paragraphs(text: str) -> list[tuple[int, str]]:
    fragments = []
    pos = 0
    for part in re.split(r"\n\s*\n", text):
        part = part.strip()
        if part:
            start = text.find(part, pos)
            if start == -1:
                start = pos
            fragments.append((start, part))
            pos = start + len(part)
    return fragments


def segment_trace_v2(
    text: str,
    min_tokens: int = 80,
    max_tokens: int = 250,
) -> list[Span]:
    """Segment a reasoning trace using paragraph + length normalization only.

    Algorithm:
        1. Split on paragraph boundaries (\\n\\n).
        2. Greedy merge short paragraphs until min_tokens is reached.
        3. Split paragraphs longer than max_tokens on sentence boundaries.
        4. Same trailing-buffer merge-back rule as v1.
    """
    if not text or not text.strip():
        return []

    raw_fragments = _split_on_paragraphs(text)
    if not raw_fragments:
        return []

    frag_tokens = [
        (start, frag, _estimate_tokens(frag))
        for start, frag in raw_fragments
    ]

    spans: list[Span] = []
    span_id = 0

    buf_text = ""
    buf_start = frag_tokens[0][0]
    buf_tokens = 0

    def _emit_span(text_: str, start_: int, n_tok: int):
        nonlocal span_id
        spans.append(Span(
            span_id=span_id,
            text=text_.strip(),
            start_char=start_,
            end_char=start_ + len(text_),
            n_tokens=n_tok,
        ))
        span_id += 1

    def _split_long_and_emit(text_block: str, start: int, total_tokens: int):
        sentences = _split_sentences(text_block)
        if len(sentences) <= 1:
            _emit_span(text_block, start, total_tokens)
            return
        cur = ""
        cur_start = start
        cur_tok = 0
        for sent in sentences:
            sent_tok = _estimate_tokens(sent)
            if cur_tok + sent_tok > max_tokens and cur:
                _emit_span(cur, cur_start, cur_tok)
                cur_start = cur_start + len(cur)
                cur = sent + " "
                cur_tok = sent_tok
            else:
                cur += sent + " "
                cur_tok += sent_tok
        if cur.strip():
            _emit_span(cur, cur_start, cur_tok)

    def _flush_buffer():
        nonlocal buf_text, buf_start, buf_tokens
        if not buf_text.strip():
            buf_text = ""
            buf_tokens = 0
            return
        if buf_tokens > max_tokens:
            _split_long_and_emit(buf_text, buf_start, buf_tokens)
        else:
            _emit_span(buf_text, buf_start, buf_tokens)
        buf_text = ""
        buf_tokens = 0

    for start, frag, n_tok in frag_tokens:
        if buf_tokens + n_tok >= min_tokens:
            if buf_tokens >= min_tokens:
                _flush_buffer()
                buf_text = frag + "\n"
                buf_start = start
                buf_tokens = n_tok
            else:
                buf_text += frag + "\n"
                buf_tokens += n_tok
                _flush_buffer()
                buf_start = start + len(frag)
        else:
            if not buf_text:
                buf_start = start
            buf_text += frag + "\n"
            buf_tokens += n_tok

    if buf_text.strip():
        if spans and buf_tokens < min_tokens // 2:
            prev = spans[-1]
            merged_text = prev.text + "\n" + buf_text.strip()
            merged_tokens = prev.n_tokens + buf_tokens
            spans[-1] = Span(
                span_id=prev.span_id,
                text=merged_text,
                start_char=prev.start_char,
                end_char=buf_start + len(buf_text),
                n_tokens=merged_tokens,
            )
        else:
            _flush_buffer()

    return spans


def segment_response_v2(
    response: str,
    min_tokens: int = 80,
    max_tokens: int = 250,
) -> list[Span]:
    """Segment a full model response, extracting <reasoning> if present."""
    text = extract_reasoning(response)
    if text is None:
        text = response
    return segment_trace_v2(text, min_tokens=min_tokens, max_tokens=max_tokens)
