#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import json
import math
import re
import statistics
import tarfile
import tempfile
from collections import Counter, defaultdict, deque
from dataclasses import dataclass
from pathlib import Path


DECL_RE = re.compile(r"^\s*(theorem|lemma)\s+([A-Za-z_][A-Za-z0-9_']*)\b")
TOP_RE = re.compile(
    r"^\s*(?:private\s+|noncomputable\s+|protected\s+)?"
    r"(theorem|lemma|def|abbrev|instance|structure|class|inductive|axiom|constant|example)\b"
)
IMPORT_RE = re.compile(r"^\s*import\s+(.+?)\s*$")

CHAIN_RULES: list[tuple[str, tuple[str, ...]]] = [
    (
        "Risk/ERM spine",
        (
            "FormalSLT.lean",
            "FormalSLT/Risk.lean",
            "FormalSLT/ERM.lean",
            "FormalSLT/GhostSample.lean",
            "FormalSLT/UniformConvergence.lean",
        ),
    ),
    ("Probability", ("FormalSLT/Probability/",)),
    ("Rademacher", ("FormalSLT/Rademacher/",)),
    ("PAC-VC", ("FormalSLT/VC/",)),
    ("PAC-Bayes", ("FormalSLT/PACBayes",)),
    ("Stability", ("FormalSLT/AlgorithmicStability.lean", "FormalSLT/Stability/")),
    ("Azuma", ("FormalSLT/Azuma/",)),
    ("Covering", ("FormalSLT/Covering/",)),
]

EXCLUDED_REFERENCE_NAMES = {
    "add",
    "left",
    "mono",
    "mul",
    "one",
    "pow",
    "right",
    "sub",
    "symm",
    "zero",
}


@dataclass(frozen=True)
class Declaration:
    id: str
    kind: str
    name: str
    file: str
    chain: str
    start_line: int
    end_line: int
    span_lines: int
    body: str


def strip_block_comments(lines: list[str]) -> list[str]:
    cleaned: list[str] = []
    depth = 0
    for line in lines:
        i = 0
        out: list[str] = []
        while i < len(line):
            if depth == 0 and line.startswith("/-", i):
                depth += 1
                i += 2
            elif depth > 0 and line.startswith("-/", i):
                depth -= 1
                i += 2
            elif depth == 0 and line.startswith("--", i):
                break
            elif depth == 0:
                out.append(line[i])
                i += 1
            else:
                i += 1
        cleaned.append("".join(out))
    return cleaned


def chain_for(formalslt_root: Path, path: Path) -> str:
    rel = path.relative_to(formalslt_root).as_posix()
    for chain, prefixes in CHAIN_RULES:
        if any(rel == prefix or rel.startswith(prefix) for prefix in prefixes):
            return chain
    return "Other"


def lean_files(formalslt_root: Path) -> list[Path]:
    files = [formalslt_root / "FormalSLT.lean"]
    files.extend(sorted((formalslt_root / "FormalSLT").rglob("*.lean")))
    return [p for p in files if p.exists()]


def parse_declarations(formalslt_root: Path, path: Path) -> list[Declaration]:
    raw = path.read_text().splitlines()
    cleaned = strip_block_comments(raw)
    starts: list[tuple[int, str, str]] = []
    top_starts: list[int] = []
    for idx, line in enumerate(cleaned):
        if TOP_RE.match(line):
            top_starts.append(idx)
        match = DECL_RE.match(line)
        if match:
            starts.append((idx, match.group(1), match.group(2)))

    declarations: list[Declaration] = []
    rel = path.relative_to(formalslt_root).as_posix()
    for idx, kind, name in starts:
        following = [top for top in top_starts if top > idx]
        end = min(following) if following else len(raw)
        body = "\n".join(cleaned[idx:end])
        declarations.append(
            Declaration(
                id=f"{rel}:{idx + 1}:{name}",
                kind=kind,
                name=name,
                file=rel,
                chain=chain_for(formalslt_root, path),
                start_line=idx + 1,
                end_line=end,
                span_lines=max(1, end - idx),
                body=body,
            )
        )
    return declarations


def parse_imports(path: Path) -> tuple[int, int, int]:
    mathlib = 0
    internal = 0
    other = 0
    for line in strip_block_comments(path.read_text().splitlines()):
        match = IMPORT_RE.match(line)
        if not match:
            continue
        target = match.group(1)
        if target.startswith("Mathlib"):
            mathlib += 1
        elif target.startswith("FormalSLT"):
            internal += 1
        else:
            other += 1
    return mathlib, internal, other


def extract_tarball(tarball: Path, destination: Path) -> Path:
    with tarfile.open(tarball, "r:gz") as tf:
        tf.extractall(destination, filter="data")
    candidates = [destination / "formalslt", destination / "formalslt_anonymized"]
    for candidate in candidates:
        if (candidate / "FormalSLT.lean").exists():
            return candidate
    for candidate in destination.iterdir():
        if candidate.is_dir() and (candidate / "FormalSLT.lean").exists():
            return candidate
    raise FileNotFoundError("could not find extracted FormalSLT root")


def build_reference_edges(declarations: list[Declaration]) -> tuple[set[tuple[str, str]], dict[str, object]]:
    name_counts = Counter(decl.name for decl in declarations)
    unique_targets = {
        decl.name: decl
        for decl in declarations
        if name_counts[decl.name] == 1
        and len(decl.name) >= 5
        and decl.name not in EXCLUDED_REFERENCE_NAMES
    }
    if not unique_targets:
        return set(), {"unique_target_names": 0, "ambiguous_names": 0}

    target_pattern = re.compile(r"\b(" + "|".join(re.escape(name) for name in sorted(unique_targets, key=len, reverse=True)) + r")\b")
    edges: set[tuple[str, str]] = set()
    for source in declarations:
        for match in target_pattern.finditer(source.body):
            target = unique_targets[match.group(1)]
            if target.id != source.id:
                edges.add((source.id, target.id))

    metadata = {
        "unique_target_names": len(unique_targets),
        "ambiguous_names": sum(1 for _, count in name_counts.items() if count > 1),
        "excluded_reference_names": sorted(EXCLUDED_REFERENCE_NAMES),
    }
    return edges, metadata


def weak_components(node_ids: list[str], edges: set[tuple[str, str]]) -> list[list[str]]:
    adjacency: dict[str, set[str]] = {node_id: set() for node_id in node_ids}
    for source, target in edges:
        adjacency[source].add(target)
        adjacency[target].add(source)

    seen: set[str] = set()
    components: list[list[str]] = []
    for node_id in node_ids:
        if node_id in seen:
            continue
        queue = deque([node_id])
        seen.add(node_id)
        component: list[str] = []
        while queue:
            current = queue.popleft()
            component.append(current)
            for nxt in adjacency[current]:
                if nxt not in seen:
                    seen.add(nxt)
                    queue.append(nxt)
        components.append(component)
    components.sort(key=len, reverse=True)
    return components


def strongly_connected_components(node_ids: list[str], edges: set[tuple[str, str]]) -> list[list[str]]:
    adjacency: dict[str, list[str]] = {node_id: [] for node_id in node_ids}
    for source, target in edges:
        adjacency[source].append(target)

    index = 0
    stack: list[str] = []
    on_stack: set[str] = set()
    indices: dict[str, int] = {}
    lowlinks: dict[str, int] = {}
    components: list[list[str]] = []

    def visit(node_id: str) -> None:
        nonlocal index
        indices[node_id] = index
        lowlinks[node_id] = index
        index += 1
        stack.append(node_id)
        on_stack.add(node_id)

        for target in adjacency[node_id]:
            if target not in indices:
                visit(target)
                lowlinks[node_id] = min(lowlinks[node_id], lowlinks[target])
            elif target in on_stack:
                lowlinks[node_id] = min(lowlinks[node_id], indices[target])

        if lowlinks[node_id] == indices[node_id]:
            component: list[str] = []
            while True:
                current = stack.pop()
                on_stack.remove(current)
                component.append(current)
                if current == node_id:
                    break
            components.append(component)

    for node_id in node_ids:
        if node_id not in indices:
            visit(node_id)
    return components


def local_reference_depths(node_ids: list[str], edges: set[tuple[str, str]]) -> tuple[dict[str, int], int]:
    components = strongly_connected_components(node_ids, edges)
    component_for: dict[str, int] = {}
    for idx, component in enumerate(components):
        for node_id in component:
            component_for[node_id] = idx

    dag_edges: dict[int, set[int]] = defaultdict(set)
    for source, target in edges:
        source_component = component_for[source]
        target_component = component_for[target]
        if source_component != target_component:
            dag_edges[source_component].add(target_component)

    memo: dict[int, int] = {}

    def depth(component_id: int) -> int:
        if component_id in memo:
            return memo[component_id]
        deps = dag_edges.get(component_id, set())
        memo[component_id] = 1 + max((depth(dep) for dep in deps), default=0)
        return memo[component_id]

    component_depths = {idx: depth(idx) for idx in range(len(components))}
    node_depths = {node_id: component_depths[component_for[node_id]] for node_id in node_ids}
    cyclic_components = sum(1 for component in components if len(component) > 1)
    return node_depths, cyclic_components


def percentile(values: list[int], q: float) -> float:
    if not values:
        return 0.0
    ordered = sorted(values)
    idx = (len(ordered) - 1) * q
    lower = math.floor(idx)
    upper = math.ceil(idx)
    if lower == upper:
        return float(ordered[lower])
    return ordered[lower] * (upper - idx) + ordered[upper] * (idx - lower)


def analyze(formalslt_root: Path) -> dict[str, object]:
    files = lean_files(formalslt_root)
    declarations = [decl for path in files for decl in parse_declarations(formalslt_root, path)]
    declarations_by_id = {decl.id: decl for decl in declarations}
    node_ids = [decl.id for decl in declarations]
    edges, metadata = build_reference_edges(declarations)
    depths, cyclic_components = local_reference_depths(node_ids, edges)
    components = weak_components(node_ids, edges)
    non_isolated = sum(1 for component in components if len(component) > 1 for _ in component)

    inbound: Counter[str] = Counter()
    outbound: Counter[str] = Counter()
    chain_matrix: dict[str, Counter[str]] = defaultdict(Counter)
    for source_id, target_id in edges:
        source = declarations_by_id[source_id]
        target = declarations_by_id[target_id]
        inbound[target_id] += 1
        outbound[source_id] += 1
        chain_matrix[source.chain][target.chain] += 1

    chain_rows: dict[str, dict[str, object]] = defaultdict(
        lambda: {
            "declarations": 0,
            "outgoing_edges": 0,
            "incoming_edges": 0,
            "within_chain_edges": 0,
            "cross_chain_edges": 0,
            "decls_with_dependencies": 0,
            "decls_textually_referenced_by_others": 0,
            "depth_max": 0,
            "depth_median": 0.0,
            "depth_p90": 0.0,
        }
    )

    depths_by_chain: dict[str, list[int]] = defaultdict(list)
    for decl in declarations:
        row = chain_rows[decl.chain]
        row["declarations"] = int(row["declarations"]) + 1
        row["outgoing_edges"] = int(row["outgoing_edges"]) + outbound[decl.id]
        row["incoming_edges"] = int(row["incoming_edges"]) + inbound[decl.id]
        row["decls_with_dependencies"] = int(row["decls_with_dependencies"]) + int(outbound[decl.id] > 0)
        row["decls_textually_referenced_by_others"] = int(row["decls_textually_referenced_by_others"]) + int(inbound[decl.id] > 0)
        depths_by_chain[decl.chain].append(depths[decl.id])

    for source_chain, targets in chain_matrix.items():
        for target_chain, count in targets.items():
            if source_chain == target_chain:
                chain_rows[source_chain]["within_chain_edges"] = int(chain_rows[source_chain]["within_chain_edges"]) + count
            else:
                chain_rows[source_chain]["cross_chain_edges"] = int(chain_rows[source_chain]["cross_chain_edges"]) + count

    for chain, values in depths_by_chain.items():
        row = chain_rows[chain]
        row["depth_max"] = max(values)
        row["depth_median"] = statistics.median(values)
        row["depth_p90"] = percentile(values, 0.9)

    chain_order = [chain for chain, _ in CHAIN_RULES if chain in chain_rows]
    matrix_rows = [
        {
            "source_chain": source_chain,
            **{target_chain: chain_matrix[source_chain][target_chain] for target_chain in chain_order},
        }
        for source_chain in chain_order
    ]

    top_textually_referenced = [
        {
            "name": declarations_by_id[node_id].name,
            "kind": declarations_by_id[node_id].kind,
            "chain": declarations_by_id[node_id].chain,
            "file": declarations_by_id[node_id].file,
            "line": declarations_by_id[node_id].start_line,
            "incoming_reference_count": count,
        }
        for node_id, count in inbound.most_common(20)
    ]

    top_depth = sorted(declarations, key=lambda decl: (depths[decl.id], inbound[decl.id], decl.span_lines), reverse=True)[:20]

    return {
        "totals": {
            "declarations": len(declarations),
            "edges": len(edges),
            "declarations_with_textual_local_reference_edges": sum(1 for decl in declarations if outbound[decl.id] > 0),
            "declarations_textually_referenced_by_others": sum(1 for decl in declarations if inbound[decl.id] > 0),
            "non_isolated_declarations": non_isolated,
            "weak_components": len(components),
            "largest_weak_component": len(components[0]) if components else 0,
            "max_static_local_reference_depth": max(depths.values(), default=0),
            "median_static_local_reference_depth": statistics.median(depths.values()) if depths else 0,
            "p90_static_local_reference_depth": percentile(list(depths.values()), 0.9),
            "cyclic_sccs_from_static_scan": cyclic_components,
            **metadata,
        },
        "by_chain": {k: dict(v) for k, v in sorted(chain_rows.items())},
        "chain_order": chain_order,
        "chain_local_reference_matrix": matrix_rows,
        "top_textually_referenced_declarations": top_textually_referenced,
        "top_depth_declarations": [
            {
                "name": decl.name,
                "kind": decl.kind,
                "chain": decl.chain,
                "file": decl.file,
                "line": decl.start_line,
                "depth": depths[decl.id],
                "outgoing_reference_count": outbound[decl.id],
                "incoming_reference_count": inbound[decl.id],
            }
            for decl in top_depth
        ],
        "method_note": (
            "Static declaration local-reference graph. Nodes are comment-skipped line-leading theorem/lemma declarations. "
            "An edge A -> B is added when A's declaration span textually references B's unique local theorem/lemma name. "
            "Duplicate, short, and common names are ignored. The scan can miss references through duplicate names, notation, generated names, local aliases, defs, instances, or Mathlib declarations. "
            "It can also count names appearing in theorem statements rather than proof bodies. Raw chain-to-chain edge counts are not normalized by chain size. "
            "Depth is computed after collapsing static cycles. This is not Lean kernel dependency extraction and does not measure proof difficulty, mathematical novelty, or agent productivity."
        ),
    }


def write_matrix_csv(result: dict[str, object], output: Path) -> None:
    chain_order = list(result["chain_order"])
    rows = list(result["chain_local_reference_matrix"])
    with output.open("w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["source_chain", *chain_order])
        writer.writeheader()
        for row in rows:
            writer.writerow(row)


def write_markdown(result: dict[str, object], output: Path) -> None:
    totals = result["totals"]
    by_chain = result["by_chain"]
    top_textually_referenced = result["top_textually_referenced_declarations"][:10]
    top_depth = result["top_depth_declarations"][:10]

    lines = [
        "# Static Declaration Local-Reference Graph",
        "",
        "This report is generated from the submitted FormalSLT source snapshot. It is a static source proxy, not Lean kernel dependency extraction.",
        "",
        "Command:",
        "",
        "```bash",
        "python3 scripts/analyze_local_reference_graph.py \\",
        "  --tarball formalslt_anonymized.tar.gz \\",
        "  --json-out LOCAL_REFERENCE_GRAPH_ANALYSIS.json \\",
        "  --markdown-out LOCAL_REFERENCE_GRAPH_ANALYSIS.md \\",
        "  --matrix-csv LOCAL_REFERENCE_CHAIN_MATRIX.csv \\",
        "  --fig-dir figures",
        "```",
        "",
        "## Headline Counts",
        "",
        "| Quantity | Value |",
        "|---|---:|",
        f"| Theorem/lemma declarations | {totals['declarations']} |",
        f"| Static local-reference edges | {totals['edges']} |",
        f"| Declarations with at least one textual local-reference edge | {totals['declarations_with_textual_local_reference_edges']} |",
        f"| Declarations textually referenced by at least one other declaration | {totals['declarations_textually_referenced_by_others']} |",
        f"| Non-isolated declarations | {totals['non_isolated_declarations']} |",
        f"| Weak connected components | {totals['weak_components']} |",
        f"| Largest weak component | {totals['largest_weak_component']} |",
        f"| Maximum static local-reference depth | {totals['max_static_local_reference_depth']} |",
        f"| Median static local-reference depth | {totals['median_static_local_reference_depth']:.1f} |",
        f"| 90th-percentile static local-reference depth | {totals['p90_static_local_reference_depth']:.1f} |",
        "",
        "## By Chain",
        "",
        "| Chain | Decls | Edges out | Edges in | Within-chain edges | Cross-chain edges | Decls with refs | Decls referenced | Max depth | Median depth |",
        "|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|",
    ]
    for chain, row in by_chain.items():
        lines.append(
            "| "
            + " | ".join(
                [
                    chain,
                    str(row["declarations"]),
                    str(row["outgoing_edges"]),
                    str(row["incoming_edges"]),
                    str(row["within_chain_edges"]),
                    str(row["cross_chain_edges"]),
                    str(row["decls_with_dependencies"]),
                    str(row["decls_textually_referenced_by_others"]),
                    str(row["depth_max"]),
                    f"{row['depth_median']:.1f}",
                ]
            )
            + " |"
        )

    lines.extend(
        [
            "",
            "## Most Textually Referenced Local Declarations",
            "",
            "| Declaration | Chain | File:line | Incoming refs |",
            "|---|---|---|---:|",
        ]
    )
    for row in top_textually_referenced:
        lines.append(
            f"| `{row['name']}` | {row['chain']} | `{row['file']}:{row['line']}` | {row['incoming_reference_count']} |"
        )

    lines.extend(
        [
            "",
            "## Deepest Static Local-Reference Spans",
            "",
            "| Declaration | Chain | File:line | Depth | Outgoing refs | Incoming refs |",
            "|---|---|---|---:|---:|---:|",
        ]
    )
    for row in top_depth:
        lines.append(
            f"| `{row['name']}` | {row['chain']} | `{row['file']}:{row['line']}` | {row['depth']} | {row['outgoing_reference_count']} | {row['incoming_reference_count']} |"
        )

    lines.extend(
        [
            "",
            "## Method Boundary",
            "",
            str(result["method_note"]),
            "",
            "Use this report as evidence of source-level compositional structure. Do not use it as a theorem-difficulty metric, a novelty claim, an agent-productivity metric, or a substitute for the Lean build and axiom audit.",
        ]
    )
    output.write_text("\n".join(lines) + "\n")


def plot_figures(result: dict[str, object], fig_dir: Path) -> None:
    import matplotlib.pyplot as plt
    import numpy as np

    fig_dir.mkdir(parents=True, exist_ok=True)
    chain_order = list(result["chain_order"])
    by_chain = result["by_chain"]

    matrix = np.array(
        [[row[chain] for chain in chain_order] for row in result["chain_local_reference_matrix"]],
        dtype=float,
    )
    fig, ax = plt.subplots(figsize=(7.4, 5.1))
    im = ax.imshow(matrix, cmap="Blues")
    ax.set_xticks(range(len(chain_order)), labels=chain_order, rotation=35, ha="right")
    ax.set_yticks(range(len(chain_order)), labels=chain_order)
    ax.set_xlabel("Referenced chain")
    ax.set_ylabel("Source declaration chain")
    ax.set_title("Static local-reference edges by chain")
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            value = int(matrix[i, j])
            if value:
                ax.text(j, i, str(value), ha="center", va="center", fontsize=8)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="edge count")
    fig.tight_layout()
    fig.savefig(fig_dir / "local_reference_chain_matrix.pdf")
    fig.savefig(fig_dir / "local_reference_chain_matrix.png", dpi=220)
    plt.close(fig)

    depth_max = [float(by_chain[chain]["depth_max"]) for chain in chain_order]
    depth_median = [float(by_chain[chain]["depth_median"]) for chain in chain_order]
    y = np.arange(len(chain_order))
    fig, ax = plt.subplots(figsize=(7.4, 4.7))
    ax.barh(y, depth_max, color="#2b6cb0", alpha=0.76, label="max depth")
    ax.scatter(depth_median, y, color="#111827", s=26, label="median depth", zorder=3)
    ax.set_yticks(y, labels=chain_order)
    ax.invert_yaxis()
    ax.set_xlabel("static local-reference depth")
    ax.set_title("Depth of textual local theorem/lemma references by chain")
    ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.12), ncol=2, frameon=False)
    ax.grid(axis="x", color="#d1d5db", linewidth=0.6, alpha=0.8)
    fig.tight_layout(rect=[0, 0.05, 1, 1])
    fig.savefig(fig_dir / "local_reference_depth_by_chain.pdf")
    fig.savefig(fig_dir / "local_reference_depth_by_chain.png", dpi=220)
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description="Static theorem/lemma local-reference graph for FormalSLT.")
    parser.add_argument("--source", type=Path, help="Path to extracted formalslt root.")
    parser.add_argument("--tarball", type=Path, default=Path("formalslt_anonymized.tar.gz"))
    parser.add_argument("--json-out", type=Path, default=Path("LOCAL_REFERENCE_GRAPH_ANALYSIS.json"))
    parser.add_argument("--markdown-out", type=Path, default=Path("LOCAL_REFERENCE_GRAPH_ANALYSIS.md"))
    parser.add_argument("--matrix-csv", type=Path, default=Path("LOCAL_REFERENCE_CHAIN_MATRIX.csv"))
    parser.add_argument("--fig-dir", type=Path, default=Path("figures"))
    args = parser.parse_args()

    if args.source:
        source = args.source
        result = analyze(source)
    else:
        with tempfile.TemporaryDirectory() as tmp:
            source = extract_tarball(args.tarball, Path(tmp))
            result = analyze(source)

    args.json_out.write_text(json.dumps(result, indent=2, sort_keys=True) + "\n")
    write_markdown(result, args.markdown_out)
    write_matrix_csv(result, args.matrix_csv)
    plot_figures(result, args.fig_dir)

    totals = result["totals"]
    print(f"Declarations: {totals['declarations']}")
    print(f"Static local-reference edges: {totals['edges']}")
    print(f"Declarations with textual local-reference edges: {totals['declarations_with_textual_local_reference_edges']}")
    print(f"Declarations referenced by others: {totals['declarations_textually_referenced_by_others']}")
    print(f"Largest weak component: {totals['largest_weak_component']}")
    print(f"Max static local-reference depth: {totals['max_static_local_reference_depth']}")


if __name__ == "__main__":
    main()
