"""Common helpers for number-edit experiments."""

from __future__ import annotations

import hashlib
import json
import random
import re
from decimal import Decimal, InvalidOperation
from pathlib import Path
from typing import Iterable, List, Optional


def load_jsonl(path: str | Path) -> List[dict]:
    path = Path(path)
    with path.open("r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]


def fix_offset_by_context(
    text: str,
    value: str,
    context: str,
    old_start: int = -1,
    old_end: int = -1,
) -> Optional[tuple]:
    """Relocate ``value`` inside ``text`` using ``context`` as a semantic anchor.

    Returns ``(new_start, new_end)`` such that ``text[new_start:new_end] == value``,
    or ``None`` if no reliable match can be found.

    Strategy (in order):
      1. If the provided offset is already correct, keep it.
      2. If ``value`` occurs exactly once in ``text``, use that occurrence.
      3. Try to locate ``context`` in ``text`` (exact substring, then with
         whitespace collapsed, then with double-backslash normalization).
         Within the matched context window, find ``value``.
      4. Use word-overlap: find the window in ``text`` that shares the most
         distinctive words with ``context``, then locate ``value`` there.
    """
    if not value or not text:
        return None

    # 1. Existing offset already correct
    if 0 <= old_start < old_end <= len(text) and text[old_start:old_end] == value:
        return old_start, old_end

    # 2. Unique occurrence in text
    all_positions = []
    i = 0
    while True:
        i = text.find(value, i)
        if i < 0:
            break
        all_positions.append(i)
        i += 1
    if len(all_positions) == 1:
        p = all_positions[0]
        return p, p + len(value)
    if not all_positions:
        return None

    # 3. Context-based substring match, with progressive normalization
    ctx = (context or "").strip()
    if ctx:
        # 3a. Exact substring
        ctx_pos = text.find(ctx)
        if ctx_pos >= 0:
            rel = ctx.find(value)
            if rel >= 0:
                return ctx_pos + rel, ctx_pos + rel + len(value)

        # 3b. Undo Gemini's JSON double-backslash escaping
        ctx_unesc = ctx.replace("\\\\", "\\")
        if ctx_unesc != ctx:
            ctx_pos = text.find(ctx_unesc)
            if ctx_pos >= 0:
                rel = ctx_unesc.find(value)
                if rel >= 0:
                    return ctx_pos + rel, ctx_pos + rel + len(value)

        # 3c. Whitespace-collapsed match (handles \n vs space discrepancy)
        def _collapse(s):
            return re.sub(r"\s+", " ", s)

        text_c = _collapse(text)
        ctx_c = _collapse(ctx_unesc)
        if ctx_c in text_c:
            # Find the collapsed context position, then map back to original
            c_pos = text_c.find(ctx_c)
            # Walk original text, counting collapsed chars to find true start
            orig_pos, walked = 0, 0
            while orig_pos < len(text) and walked < c_pos:
                ch = text[orig_pos]
                orig_pos += 1
                if ch.isspace():
                    # Skip consecutive whitespace
                    while orig_pos < len(text) and text[orig_pos].isspace():
                        orig_pos += 1
                    walked += 1
                else:
                    walked += 1
            # Now find value in the original text starting at orig_pos + len(ctx)
            window_end = min(len(text), orig_pos + len(ctx_unesc) + 20)
            rel = text.find(value, orig_pos, window_end)
            if rel >= 0:
                return rel, rel + len(value)

    # 4. Word-overlap fallback: find text window with max distinctive-word match
    ctx_words = set(re.findall(r"[A-Za-z][A-Za-z_]{2,}|\d+", ctx))
    ctx_words.discard(value)
    if not ctx_words:
        # No anchors beyond value itself; give up rather than guess
        return None

    best_pos, best_score = -1, 0
    for pos in all_positions:
        window_start = max(0, pos - 60)
        window_end = min(len(text), pos + len(value) + 60)
        window = text[window_start:window_end]
        window_words = set(re.findall(r"[A-Za-z][A-Za-z_]{2,}|\d+", window))
        score = len(ctx_words & window_words)
        if score > best_score:
            best_score = score
            best_pos = pos
    if best_pos >= 0 and best_score >= 2:
        return best_pos, best_pos + len(value)

    return None


def write_jsonl(path: str | Path, rows: Iterable[dict]) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")


def ensure_dir(path: str | Path) -> Path:
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    return path


def is_int_string(text: str) -> bool:
    return bool(re.fullmatch(r"-?\d+", text))


def is_decimal_string(text: str) -> bool:
    return bool(re.fullmatch(r"-?\d+\.\d+", text))


def is_fraction_string(text: str) -> bool:
    return bool(re.fullmatch(r"-?\d+/-?\d+", text))


def is_percent_string(text: str) -> bool:
    return bool(re.fullmatch(r"-?\d+(?:\.\d+)?%", text))


def _stable_rng(value: str, context: str = "") -> random.Random:
    """Create a deterministic RNG from value + context using sha256."""
    seed_str = f"{context}:{value}"
    digest = hashlib.sha256(seed_str.encode("utf-8")).hexdigest()
    return random.Random(int(digest[:16], 16))


def _sign_preserving_direction(rng: random.Random, original, delta) -> int:
    """Pick +1 / -1 so the perturbation never crosses zero.

    - If original > 0 and delta would take it to 0 or below, force +1.
    - If original < 0 and delta would take it to 0 or above, force -1.
    - If original == 0, always force +1 (never introduce negatives from zero;
      0 is almost always a domain lower bound or origin).
    Why: commonsense counts, domain bounds, and cardinalities must stay
    non-negative — a 2→-1 "How many solutions" makes the problem absurd.
    """
    if original == 0:
        return 1
    if original > 0:
        if delta >= original:
            return 1
        return rng.choice([1, -1])
    # original < 0
    if delta >= -original:
        return -1
    return rng.choice([1, -1])


def perturb_numeric_string(
    value: str,
    problem_name: str = "",
    source: str = "",
    role: str = "",
) -> Optional[str]:
    """Return a small deterministic perturbation of a numeric string.

    Uses sha256-based deterministic randomness seeded by problem_name + source + role + value.
    Perturbation scales with magnitude: small numbers change by 1-3, large numbers by ~5-10%.
    Sign is preserved — positives stay positive, negatives stay negative, zero
    becomes a small positive. Result is never equal to the original and never 0
    (for nonzero inputs).
    """
    value = value.strip()
    context = f"{problem_name}:{source}:{role}"
    rng = _stable_rng(value, context)

    if is_percent_string(value):
        base = value[:-1]
        edited = perturb_numeric_string(base, problem_name, source, role)
        return None if edited is None else f"{edited}%"

    if is_int_string(value):
        n = int(value)
        abs_n = abs(n)
        if abs_n <= 10:
            delta = rng.randint(1, 3)
        elif abs_n <= 100:
            delta = rng.randint(1, 5)
        elif abs_n <= 1000:
            delta = rng.randint(2, 15)
        else:
            delta = rng.randint(1, max(1, abs_n // 20))
        direction = _sign_preserving_direction(rng, n, delta)
        return str(n + direction * delta)

    if is_decimal_string(value):
        try:
            dec = Decimal(value)
        except InvalidOperation:
            return None
        decimals = len(value.split(".")[1])
        ulp = Decimal(1).scaleb(-decimals)
        abs_dec = abs(dec)
        if abs_dec <= Decimal("1"):
            delta_units = rng.randint(1, 3)
        elif abs_dec <= Decimal("10"):
            delta_units = rng.randint(1, 5)
        else:
            delta_units = rng.randint(1, max(1, int(abs_dec / Decimal("5"))))
        # Work in ulp units: compare delta_units vs abs_dec/ulp to decide sign safety.
        abs_dec_units = int(abs_dec / ulp)
        direction = _sign_preserving_direction(
            rng,
            int(dec / ulp) if dec != 0 else 0,
            delta_units,
        )
        new_dec = dec + direction * delta_units * ulp
        return format(new_dec, f".{decimals}f")

    if is_fraction_string(value):
        num, den = value.split("/")
        try:
            num_i = int(num)
            den_i = int(den)
        except ValueError:
            return None
        if den_i == 0:
            return None
        abs_num = abs(num_i)
        if abs_num <= 10:
            delta = rng.randint(1, 3)
        else:
            delta = rng.randint(1, max(1, abs_num // 5))
        direction = _sign_preserving_direction(rng, num_i, delta)
        new_num = num_i + direction * delta
        return f"{new_num}/{den_i}"

    return None


def latex_fraction_patterns(value: str) -> List[str]:
    """Generate regex patterns for LaTeX variants of a fraction like '1/3'."""
    if not is_fraction_string(value):
        return []
    parts = value.split("/")
    if len(parts) != 2:
        return []
    num, den = parts[0].strip(), parts[1].strip()
    patterns = [
        rf"\\frac\{{{re.escape(num)}\}}\{{{re.escape(den)}\}}",  # \frac{1}{3}
        rf"\\frac{re.escape(num)}{re.escape(den)}",               # \frac13
        rf"\\dfrac\{{{re.escape(num)}\}}\{{{re.escape(den)}\}}",  # \dfrac{1}{3}
        rf"\\tfrac\{{{re.escape(num)}\}}\{{{re.escape(den)}\}}",  # \tfrac{1}{3}
    ]
    return patterns


def replace_span(
    text: str,
    start: int,
    end: int,
    old_value: str,
    new_value: str,
) -> Optional[str]:
    """Replace a span; fall back to nearest exact match if offsets are noisy."""
    if 0 <= start < end <= len(text) and text[start:end] == old_value:
        return text[:start] + new_value + text[end:]

    # Try exact string match
    matches = list(re.finditer(re.escape(old_value), text))
    if matches:
        best = min(matches, key=lambda m: abs(m.start() - start))
        return text[: best.start()] + new_value + text[best.end() :]

    # Try LaTeX fraction variants (e.g., "1/3" matches "\frac{1}{3}")
    # Use the new_value passed in by the caller — never re-perturb, because
    # re-perturbation with default seed would drift from the outer metadata.
    if is_fraction_string(new_value):
        new_num, new_den = new_value.split("/")
        for pattern in latex_fraction_patterns(old_value):
            m = re.search(pattern, text)
            if m:
                matched_text = m.group()
                if "\\dfrac" in matched_text:
                    replacement = f"\\dfrac{{{new_num}}}{{{new_den}}}"
                elif "\\tfrac" in matched_text:
                    replacement = f"\\tfrac{{{new_num}}}{{{new_den}}}"
                elif "{" in matched_text:
                    replacement = f"\\frac{{{new_num}}}{{{new_den}}}"
                else:
                    replacement = f"\\frac{new_num}{new_den}"
                return text[: m.start()] + replacement + text[m.end() :]

    return None


