from __future__ import annotations

"""
Label/docstring policy helpers.

Current policy (pipelines rely on docstrings, not marker blocks):
- A declaration is considered "label-associated" if its docstring starts with the textbook label.
- Statement stage enforces: exactly one *main* label-associated declaration for the current item.
- Any newly introduced top-level declaration must have a docstring immediately above it.
- Proof stage may enforce that new helper declarations include the label in their docstrings.
"""

import difflib
import re

_DECL_KEYWORDS = (
    "def",
    "abbrev",
    "lemma",
    "theorem",
    "instance",
    "structure",
    "class",
    "inductive",
    "opaque",
)

_TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_']*")


def _find_docstring_start(lines: list[str], decl_line_idx: int) -> int | None:
    """
    Find the docstring `/-- ... -/` that should be associated with a declaration near `decl_line_idx`.
    Allow attribute lines (e.g. `@[simp]`) between the docstring and the declaration.
    """
    j = decl_line_idx - 1
    while j >= 0:
        s = lines[j].strip()
        if not s:
            j -= 1
            continue
        if s.startswith("@[") or s.startswith("attribute "):
            j -= 1
            continue
        break
    if j < 0:
        return None
    sj = lines[j].lstrip()
    if sj.startswith("/--"):
        return j

    # Multi-line docstring: the nearest non-empty line before the declaration can be
    # the closing `-/`. Walk upward to find a matching `/--` start.
    if "-/" in lines[j]:
        k = j
        while k >= 0:
            sk = lines[k].lstrip()
            if sk.startswith("/--"):
                return k
            # Hit another non-doc block-comment start before a docstring start.
            if sk.startswith("/-") and not sk.startswith("/--"):
                return None
            k -= 1
    return None


def _extract_docstring_text(lines: list[str], doc_start: int) -> str:
    parts: list[str] = []
    i = doc_start
    while i < len(lines):
        parts.append(lines[i])
        if "-/" in lines[i]:
            break
        i += 1
    return "\n".join(parts)


def _is_decl_line(line: str) -> bool:
    s = line.lstrip()
    if not s:
        return False
    if s.startswith(("section ", "end ", "namespace ", "open ", "variable ", "universe ")):
        return False
    for mod in ("private ", "protected ", "noncomputable "):
        if s.startswith(mod):
            s = s[len(mod) :].lstrip()
    return any(s.startswith(k + " ") for k in _DECL_KEYWORDS)


def _docstring_starts_with_label(doc_text: str, label: str) -> bool:
    if not doc_text or not label:
        return False
    lines = doc_text.splitlines()
    if not lines:
        return False

    first_line = lines[0].lstrip()
    candidate = first_line
    if candidate.startswith("/--"):
        candidate = candidate[len("/--") :].lstrip()
        if not candidate:
            # Allow multiline docstrings where the label begins on the first non-empty line
            # after the opening `/--`.
            for ln in lines[1:]:
                s = ln.strip()
                if not s or s == "-/":
                    continue
                candidate = s
                break
    if not candidate.startswith(label):
        return False
    if len(candidate) == len(label):
        return True
    nxt = candidate[len(label)]
    return not nxt.isalnum()


def _decl_name_from_line(line: str) -> str | None:
    s = line.lstrip()
    for mod in ("private ", "protected ", "noncomputable "):
        if s.startswith(mod):
            s = s[len(mod) :].lstrip()
    for kw in _DECL_KEYWORDS:
        if s.startswith(kw + " "):
            rest = s[len(kw) :].lstrip()
            if not rest:
                return None
            name = rest.split()[0]
            name = name.split("(", 1)[0].split("{", 1)[0].split(":", 1)[0].strip()
            return name or None
    return None


def _decl_kind_from_line(line: str) -> str | None:
    s = line.lstrip()
    for mod in ("private ", "protected ", "noncomputable ", "unsafe ", "partial "):
        if s.startswith(mod):
            s = s[len(mod) :].lstrip()
    for kw in _DECL_KEYWORDS:
        if s.startswith(kw + " "):
            return kw
    return None


def _strip_doc_comments_for_token_scan(snippet: str) -> str:
    """
    Remove doc/comment lines from a snippet before lightweight token scanning.
    """
    out: list[str] = []
    in_doc = False
    for ln in snippet.splitlines():
        s = ln.strip()
        if in_doc:
            if "-/" in s:
                in_doc = False
            continue
        if s.startswith("/--"):
            if "-/" not in s:
                in_doc = True
            continue
        if s.startswith("--"):
            continue
        out.append(ln)
    return "\n".join(out)


def find_label_associated_declarations(*, text: str, label: str) -> list[dict[str, object]]:
    """
    Return declarations whose docstring is considered "associated with" `label`
    (docstring starts with the label).
    """
    lines = text.splitlines()
    res: list[dict[str, object]] = []
    for i, ln in enumerate(lines):
        if not _is_decl_line(ln):
            continue
        doc_start = _find_docstring_start(lines, i)
        if doc_start is None:
            continue
        doc_text = _extract_docstring_text(lines, doc_start)
        if not _docstring_starts_with_label(doc_text, label):
            continue
        res.append(
            {
                "decl_line": i + 1,
                "doc_line": doc_start + 1,
                "decl_name": _decl_name_from_line(ln),
                "decl_head": ln.strip(),
            }
        )
    return res


def find_sorry_def_abbrev_declarations(*, text: str) -> list[dict[str, object]]:
    """
    Find top-level `def` / `abbrev` declarations whose body is a direct placeholder `:= sorry`.
    """
    lines = text.splitlines()
    decl_lines = [i for i, ln in enumerate(lines) if _is_decl_line(ln)]
    if not decl_lines:
        return []

    out: list[dict[str, object]] = []
    for pos, i in enumerate(decl_lines):
        line = lines[i]
        kind = _decl_kind_from_line(line)
        if kind not in {"def", "abbrev"}:
            continue
        name = _decl_name_from_line(line)
        if not name:
            continue
        j = decl_lines[pos + 1] if pos + 1 < len(decl_lines) else len(lines)
        block = "\n".join(lines[i:j])
        if re.search(r":=\s*sorry\b", block):
            out.append(
                {
                    "name": name,
                    "kind": kind,
                    "decl_line": i + 1,
                    "decl_head": line.strip(),
                }
            )
    return out


def find_label_main_decl_key_def_placeholder_violations(
    *,
    text: str,
    label: str,
    max_lines: int = 220,
) -> list[str]:
    """
    For the main declaration associated with `label`, report referenced local `def`/`abbrev`
    placeholders (`:= sorry`) that make the statement semantically vacuous.
    """
    snippet, _decl_info = extract_label_main_declaration_snippet(text=text, label=label, max_lines=max_lines)
    if not snippet:
        return []

    code_snippet = _strip_doc_comments_for_token_scan(snippet)
    snippet_tokens = set(_TOKEN_RE.findall(code_snippet))
    if not snippet_tokens:
        return []

    placeholders = find_sorry_def_abbrev_declarations(text=text)
    violations: list[str] = []
    for d in placeholders:
        name = str(d.get("name") or "")
        if name and name in snippet_tokens:
            line = int(d.get("decl_line") or 0)
            kind = str(d.get("kind") or "def")
            violations.append(f"line {line}: `{kind} {name}` is `:= sorry` but is used by the main statement")
    return violations


def find_label_main_declaration_violations(*, text: str, label: str) -> list[str]:
    """
    Enforce: exactly one declaration in the file is associated with `label`.
    """
    decls = find_label_associated_declarations(text=text, label=label)
    if len(decls) == 1:
        return []
    if len(decls) == 0:
        return [f"missing main declaration for label={label!r} (no docstring starts with the label)"]
    msgs: list[str] = [f"found {len(decls)} declarations associated with label={label!r} (expected exactly 1)"]
    for d in decls:
        name = d.get("decl_name") or "<unknown>"
        msgs.append(f"line {d['decl_line']}: {name} ({d['decl_head']})")
    return msgs


def find_new_decl_docstring_violations(
    *,
    before_text: str,
    after_text: str,
) -> list[str]:
    """
    Enforce: any *newly introduced* top-level declaration must have a docstring `/-- ... -/` above it.
    """
    before_lines = before_text.splitlines()
    after_lines = after_text.splitlines()
    sm = difflib.SequenceMatcher(a=before_lines, b=after_lines)

    added_decl_lines: set[int] = set()
    for tag, _i1, _i2, j1, j2 in sm.get_opcodes():
        if tag in {"insert", "replace"}:
            for j in range(j1, j2):
                if _is_decl_line(after_lines[j]):
                    added_decl_lines.add(j)

    violations: list[str] = []
    for j in sorted(added_decl_lines):
        doc_start = _find_docstring_start(after_lines, j)
        if doc_start is None:
            violations.append(f"line {j+1}: new declaration is missing a docstring `/-- ... -/` above it")
    return violations


def find_new_decl_docstring_label_violations(
    *,
    before_text: str,
    after_text: str,
    label: str,
) -> list[str]:
    """
    Enforce: any *newly introduced* top-level declaration must have a docstring `/-- ... -/` above it,
    and that docstring must mention `label`.
    """
    before_lines = before_text.splitlines()
    after_lines = after_text.splitlines()
    sm = difflib.SequenceMatcher(a=before_lines, b=after_lines)

    added_decl_lines: set[int] = set()
    for tag, _i1, _i2, j1, j2 in sm.get_opcodes():
        if tag in {"insert", "replace"}:
            for j in range(j1, j2):
                if _is_decl_line(after_lines[j]):
                    added_decl_lines.add(j)

    violations: list[str] = []
    for j in sorted(added_decl_lines):
        doc_start = _find_docstring_start(after_lines, j)
        if doc_start is None:
            violations.append(f"line {j+1}: new declaration is missing a docstring `/-- ... -/` above it")
            continue
        doc_text = _extract_docstring_text(after_lines, doc_start)
        if label not in doc_text:
            violations.append(f"line {j+1}: new declaration's docstring does not include label={label!r}")
    return violations


def extract_label_main_declaration_snippet(
    *,
    text: str,
    label: str,
    context_before: int = 0,
    max_lines: int = 160,
) -> tuple[str | None, dict[str, object] | None]:
    """
    Best-effort extraction of a short snippet around the main declaration for `label`
    (docstring starts with the label).

    Returns (snippet | None, decl_info | None).
    """
    decls = find_label_associated_declarations(text=text, label=label)
    if not decls:
        return None, None

    d0 = decls[0]
    doc_line = int(d0.get("doc_line") or 1)
    decl_line = int(d0.get("decl_line") or doc_line)

    lines = text.splitlines()
    start = max(0, (doc_line - 1) - max(0, int(context_before)))
    end = min(len(lines), start + max(1, int(max_lines)))
    snippet = "\n".join(lines[start:end]).strip()
    info = dict(d0)
    info["associated_decl_count"] = len(decls)
    info["snippet_start_line"] = start + 1
    info["snippet_line_count"] = end - start
    return (snippet or None), info
