import re
from typing import List, Tuple, Dict, Optional, NamedTuple
from dataclasses import dataclass

try:
    import nltk
    from nltk.tokenize import sent_tokenize
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt', quiet=True)
    NLTK_AVAILABLE = True
except ImportError:
    NLTK_AVAILABLE = False

@dataclass
class Span:

    start_char: int
    end_char: int
    start_token: Optional[int] = None
    end_token: Optional[int] = None
    text: str = ""

    @property
    def char_length(self) -> int:
        return self.end_char - self.start_char

    @property
    def token_length(self) -> Optional[int]:
        if self.start_token is not None and self.end_token is not None:
            return self.end_token - self.start_token
        return None

def sentence_split(text: str, fallback_block_size: int = 256) -> List[Span]:

    if not text.strip():
        return []

    spans = []

    if NLTK_AVAILABLE:
        try:
            sentences = sent_tokenize(text)
            current_pos = 0

            for sentence in sentences:
                start = text.find(sentence, current_pos)
                if start == -1:
                    start = text.find(sentence)
                    if start == -1:
                        continue

                end = start + len(sentence)
                spans.append(Span(
                    start_char=start,
                    end_char=end,
                    text=sentence
                ))
                current_pos = end

            if spans:
                return spans
        except Exception:
            pass

    sentence_pattern = r'(?<=[.!?])\s+'
    parts = re.split(sentence_pattern, text)

    if len(parts) > 1:
        current_pos = 0
        for part in parts:
            if not part.strip():
                continue
            start = text.find(part, current_pos)
            if start == -1:
                continue
            end = start + len(part)
            spans.append(Span(
                start_char=start,
                end_char=end,
                text=part
            ))
            current_pos = end

        if spans:
            return spans

    for i in range(0, len(text), fallback_block_size):
        end = min(i + fallback_block_size, len(text))
        spans.append(Span(
            start_char=i,
            end_char=end,
            text=text[i:end]
        ))

    return spans

def get_token_char_mapping(
    text: str,
    tokenizer
) -> Tuple[List[Tuple[int, int]], List[int]]:

    encoding = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    offsets = encoding.get('offset_mapping', [])
    token_ids = encoding.get('input_ids', [])

    char_positions = []
    for offset in offsets:
        if offset is None:
            char_positions.append((0, 0))
        else:
            char_positions.append(tuple(offset))

    return char_positions, token_ids

def token_to_span_mapping(
    token_char_positions: List[Tuple[int, int]],
    spans: List[Span]
) -> Dict[int, int]:

    mapping = {}

    for token_idx, (start, end) in enumerate(token_char_positions):
        midpoint = (start + end) / 2

        for span_idx, span in enumerate(spans):
            if span.start_char <= midpoint < span.end_char:
                mapping[token_idx] = span_idx
                break
        else:
            if spans:
                min_dist = float('inf')
                closest_span = 0
                for span_idx, span in enumerate(spans):
                    span_mid = (span.start_char + span.end_char) / 2
                    dist = abs(midpoint - span_mid)
                    if dist < min_dist:
                        min_dist = dist
                        closest_span = span_idx
                mapping[token_idx] = closest_span

    return mapping

def assign_token_positions_to_spans(
    spans: List[Span],
    token_char_positions: List[Tuple[int, int]]
) -> List[Span]:

    token_to_span = token_to_span_mapping(token_char_positions, spans)

    for span in spans:
        span.start_token = None
        span.end_token = None

    for token_idx, span_idx in token_to_span.items():
        span = spans[span_idx]
        if span.start_token is None or token_idx < span.start_token:
            span.start_token = token_idx
        if span.end_token is None or token_idx >= span.end_token:
            span.end_token = token_idx + 1

    return spans

def aggregate_kappa_to_spans(
    kappa_dict: Dict[int, float],
    spans: List[Span],
    token_to_span: Dict[int, int],
    w_minus: float = 1.0,
    w_plus: float = 0.25,
    use_abs_kappa: bool = False
) -> Dict[int, float]:

    span_curvatures: Dict[int, List[float]] = {i: [] for i in range(len(spans))}

    for token_idx, kappa in kappa_dict.items():
        if token_idx in token_to_span:
            span_idx = token_to_span[token_idx]
            span_curvatures[span_idx].append(kappa)

    span_scores = {}
    for span_idx, curvatures in span_curvatures.items():
        if not curvatures:
            span_scores[span_idx] = 0.0
            continue

        if use_abs_kappa:
            scores = [abs(kappa) for kappa in curvatures]
        else:
            scores = []
            for kappa in curvatures:
                neg_part = w_minus * max(-kappa, 0.0)
                pos_part = w_plus * max(kappa, 0.0)
                scores.append(neg_part + pos_part)

        span_scores[span_idx] = sum(scores) / len(scores)

    return span_scores

def find_high_fanout_positions(
    kappa_dict: Dict[int, float],
    threshold_percentile: float = 80.0
) -> List[int]:

    if not kappa_dict:
        return []

    neg_values = [-k for k in kappa_dict.values() if k < 0]

    if not neg_values:
        return []

    import numpy as np
    threshold = np.percentile(neg_values, threshold_percentile)

    high_fanout = [
        pos for pos, kappa in kappa_dict.items()
        if -kappa >= threshold and kappa < 0
    ]

    return sorted(high_fanout)

def compute_fanout_statistic(
    kappa_dict: Dict[int, float]
) -> float:

    return sum(max(-kappa, 0.0) for kappa in kappa_dict.values())

def find_curvature_pivots(
    positions: List[int],
    kappa_values: List[float],
    min_chunk_tokens: int = 128,
    max_chunk_tokens: int = 512,
    pivot_threshold: float = 0.5
) -> List[int]:

    if len(positions) < 2:
        return positions

    import numpy as np

    positions = np.array(positions)
    kappa = np.array(kappa_values)

    signs = np.sign(kappa)
    sign_changes = np.where(np.diff(signs) != 0)[0] + 1

    abs_kappa = np.abs(kappa)
    local_max = []
    for i in range(1, len(kappa) - 1):
        if (abs_kappa[i] > abs_kappa[i-1] and
            abs_kappa[i] > abs_kappa[i+1] and
            abs_kappa[i] >= pivot_threshold):
            local_max.append(i)

    candidates = sorted(set(sign_changes.tolist() + local_max))

    pivots = [0]

    for idx in candidates:
        pos = positions[idx]
        last_pivot_pos = positions[pivots[-1]] if pivots else 0

        if pos - last_pivot_pos < min_chunk_tokens:
            continue

        pivots.append(idx)

    final_pivots = [0]
    for i in range(1, len(positions)):
        pos = positions[i]
        last_pos = positions[final_pivots[-1]]

        if pos - last_pos >= max_chunk_tokens:
            best_idx = None
            best_score = -float('inf')

            for j in range(final_pivots[-1] + 1, i + 1):
                if positions[j] - last_pos >= min_chunk_tokens:
                    score = abs(kappa[j])
                    if score > best_score:
                        best_score = score
                        best_idx = j

            if best_idx is not None:
                final_pivots.append(best_idx)
        elif i in pivots and i not in final_pivots:
            final_pivots.append(i)

    pivot_positions = [int(positions[i]) for i in sorted(set(final_pivots))]

    return pivot_positions

def extract_anchor_spans(
    text: str,
    high_fanout_positions: List[int],
    token_char_positions: List[Tuple[int, int]],
    window_tokens: int = 5
) -> List[str]:

    if not high_fanout_positions or not token_char_positions:
        return []

    n_tokens = len(token_char_positions)
    anchors = []

    for pos in high_fanout_positions:
        start_tok = max(0, pos - window_tokens)
        end_tok = min(n_tokens, pos + window_tokens + 1)

        start_char = token_char_positions[start_tok][0]
        end_char = token_char_positions[end_tok - 1][1]

        anchor = text[start_char:end_char].strip()
        if anchor:
            anchors.append(anchor)

    return anchors
