"""Common helpers for symbol-edit experiments."""

from __future__ import annotations

import re
import sys
from pathlib import Path
from typing import Dict, Optional

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from number_edit.common import ensure_dir, load_jsonl, write_jsonl  # re-export


# ── Target Symbols (unified, no = sign) ──────────────────────────────────────

# All 7 target symbols and their variants
SYMBOL_SWAP: Dict[str, str] = {
    # Relations: direction swap
    ">": "<",
    "<": ">",
    "\\geq": "\\leq",
    "\\ge": "\\le",
    "≥": "≤",
    "\\leq": "\\geq",
    "\\le": "\\ge",
    "≤": "≥",
    "\\gt": "\\lt",
    "\\lt": "\\gt",
    # Operators: inverse swap
    "+": "-",
    "-": "+",
    "\\times": "\\div",
    "×": "÷",
    "\\cdot": "\\div",
    "·": "÷",
    "\\div": "\\times",
    "÷": "×",
}


def perturb_symbol(symbol: str) -> Optional[str]:
    """Return the unsound replacement for a symbol. Deterministic, no randomness.

    For LaTeX commands (\\leq, etc.), adds trailing space to prevent
    merging with the next token (e.g., \\leqf would be invalid LaTeX).
    """
    s = symbol.strip()
    new = SYMBOL_SWAP.get(s)
    if new is None:
        return None
    # LaTeX commands need trailing space if original was a single char
    if new.startswith("\\") and not s.startswith("\\"):
        new = new + " "
    return new


# ── Precise Symbol Replacement ──────────────────────────────────────────────

def replace_symbol(
    text: str,
    offset_start: int,
    offset_end: int,
    old_symbol: str,
    new_symbol: str,
    context: str = "",
) -> Optional[str]:
    """Replace a symbol. Exact offset first, then context-guided fallback."""
    # Try exact offset first
    if 0 <= offset_start < len(text) and offset_end <= len(text):
        if text[offset_start:offset_end] == old_symbol:
            return text[:offset_start] + new_symbol + text[offset_end:]

    # Fallback: find occurrences in text
    matches = list(re.finditer(re.escape(old_symbol), text))
    if not matches:
        return None

    # Filter out matches inside comment markers (/-- or -/)
    def is_in_comment_marker(m):
        pos = m.start()
        if old_symbol == "-":
            if pos > 0 and text[pos - 1] == "/":
                return True
            if pos > 0 and text[pos - 1] == "-" and pos > 1 and text[pos - 2] == "/":
                return True
            if pos + 1 < len(text) and text[pos + 1] == "/":
                return True
            if pos + 1 < len(text) and text[pos + 1] == "-":
                return True
        return False

    matches = [m for m in matches if not is_in_comment_marker(m)]
    if not matches:
        return None

    if len(matches) == 1:
        m = matches[0]
        return text[:m.start()] + new_symbol + text[m.end():]

    # Multiple occurrences — use context to pick the right one
    if context:
        ctx_clean = context[:40].strip()
        for m in matches:
            window_start = max(0, m.start() - 30)
            window_end = min(len(text), m.end() + 30)
            window = text[window_start:window_end]
            ctx_words = [w for w in ctx_clean.split() if len(w) > 3]
            matches_ctx = sum(1 for w in ctx_words if w in window)
            if matches_ctx >= 2:
                return text[:m.start()] + new_symbol + text[m.end():]

    # Last resort: pick nearest to offset_start
    if offset_start >= 0:
        best = min(matches, key=lambda m: abs(m.start() - offset_start))
        return text[:best.start()] + new_symbol + text[best.end():]

    return None


# ── Edit Types (simplified: just statement vs proof) ─────────────────────────

EDIT_TYPES = ("statement_edit", "proof_edit")
