#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import re
import statistics
import tarfile
import tempfile
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path


DECL_RE = re.compile(r"^\s*(theorem|lemma)\s+([A-Za-z_][A-Za-z0-9_']*)\b")
TOP_RE = re.compile(
    r"^\s*(?:private\s+|noncomputable\s+|protected\s+)?"
    r"(theorem|lemma|def|abbrev|instance|structure|class|inductive|axiom|constant|example)\b"
)
IMPORT_RE = re.compile(r"^\s*import\s+(.+?)\s*$")

CHAIN_RULES: list[tuple[str, tuple[str, ...]]] = [
    ("Risk/ERM spine", ("FormalSLT.lean", "FormalSLT/Risk.lean", "FormalSLT/ERM.lean", "FormalSLT/GhostSample.lean", "FormalSLT/UniformConvergence.lean")),
    ("Probability", ("FormalSLT/Probability/",)),
    ("Rademacher", ("FormalSLT/Rademacher/",)),
    ("PAC-VC", ("FormalSLT/VC/",)),
    ("PAC-Bayes", ("FormalSLT/PACBayes",)),
    ("Stability", ("FormalSLT/AlgorithmicStability.lean", "FormalSLT/Stability/")),
    ("Azuma", ("FormalSLT/Azuma/",)),
    ("Covering", ("FormalSLT/Covering/",)),
]

EXCLUDED_REFERENCE_NAMES = {"symm", "mono", "left", "right", "zero", "one", "add", "mul", "sub", "pow"}


@dataclass(frozen=True)
class Declaration:
    kind: str
    name: str
    file: str
    chain: str
    start_line: int
    end_line: int
    span_lines: int
    body: str


def strip_block_comments(lines: list[str]) -> list[str]:
    cleaned: list[str] = []
    depth = 0
    for line in lines:
        i = 0
        out: list[str] = []
        while i < len(line):
            if depth == 0 and line.startswith("/-", i):
                depth += 1
                i += 2
            elif depth > 0 and line.startswith("-/", i):
                depth -= 1
                i += 2
            elif depth == 0 and line.startswith("--", i):
                break
            elif depth == 0:
                out.append(line[i])
                i += 1
            else:
                i += 1
        cleaned.append("".join(out))
    return cleaned


def chain_for(formalslt_root: Path, path: Path) -> str:
    rel = path.relative_to(formalslt_root).as_posix()
    for chain, prefixes in CHAIN_RULES:
        if any(rel == prefix or rel.startswith(prefix) for prefix in prefixes):
            return chain
    return "Other"


def lean_files(formalslt_root: Path) -> list[Path]:
    files = [formalslt_root / "FormalSLT.lean"]
    files.extend(sorted((formalslt_root / "FormalSLT").rglob("*.lean")))
    return [p for p in files if p.exists()]


def parse_declarations(formalslt_root: Path, path: Path) -> list[Declaration]:
    raw = path.read_text().splitlines()
    cleaned = strip_block_comments(raw)
    starts: list[tuple[int, str, str]] = []
    top_starts: list[int] = []
    for idx, line in enumerate(cleaned):
        if TOP_RE.match(line):
            top_starts.append(idx)
        match = DECL_RE.match(line)
        if match:
            starts.append((idx, match.group(1), match.group(2)))

    declarations: list[Declaration] = []
    for idx, kind, name in starts:
        following = [top for top in top_starts if top > idx]
        end = min(following) if following else len(raw)
        body = "\n".join(cleaned[idx:end])
        rel = path.relative_to(formalslt_root).as_posix()
        declarations.append(
            Declaration(
                kind=kind,
                name=name,
                file=rel,
                chain=chain_for(formalslt_root, path),
                start_line=idx + 1,
                end_line=end,
                span_lines=max(1, end - idx),
                body=body,
            )
        )
    return declarations


def parse_imports(path: Path) -> tuple[int, int, int]:
    mathlib = 0
    internal = 0
    other = 0
    for line in strip_block_comments(path.read_text().splitlines()):
        match = IMPORT_RE.match(line)
        if not match:
            continue
        target = match.group(1)
        if target.startswith("Mathlib"):
            mathlib += 1
        elif target.startswith("FormalSLT"):
            internal += 1
        else:
            other += 1
    return mathlib, internal, other


def analyze(formalslt_root: Path) -> dict[str, object]:
    files = lean_files(formalslt_root)
    declarations = [decl for path in files for decl in parse_declarations(formalslt_root, path)]
    local_names = {
        decl.name
        for decl in declarations
        if len(decl.name) >= 5 and decl.name not in EXCLUDED_REFERENCE_NAMES
    }
    local_ref_counter: Counter[str] = Counter()
    chain_rows: dict[str, dict[str, object]] = defaultdict(
        lambda: {
            "files": 0,
            "source_lines": 0,
            "theorems": 0,
            "lemmas": 0,
            "mathlib_imports": 0,
            "internal_imports": 0,
            "other_imports": 0,
            "decl_span_lines": 0,
            "decls_with_local_refs": 0,
            "local_ref_occurrences": 0,
        }
    )

    for path in files:
        chain = chain_for(formalslt_root, path)
        raw_lines = path.read_text().splitlines()
        mathlib_imports, internal_imports, other_imports = parse_imports(path)
        row = chain_rows[chain]
        row["files"] = int(row["files"]) + 1
        row["source_lines"] = int(row["source_lines"]) + len(raw_lines)
        row["mathlib_imports"] = int(row["mathlib_imports"]) + mathlib_imports
        row["internal_imports"] = int(row["internal_imports"]) + internal_imports
        row["other_imports"] = int(row["other_imports"]) + other_imports

    for decl in declarations:
        row = chain_rows[decl.chain]
        row[f"{decl.kind}s"] = int(row[f"{decl.kind}s"]) + 1
        row["decl_span_lines"] = int(row["decl_span_lines"]) + decl.span_lines
        local_hits = 0
        for name in local_names:
            if name == decl.name:
                continue
            count = len(re.findall(rf"\b{re.escape(name)}\b", decl.body))
            if count:
                local_ref_counter[name] += count
                local_hits += count
        if local_hits:
            row["decls_with_local_refs"] = int(row["decls_with_local_refs"]) + 1
            row["local_ref_occurrences"] = int(row["local_ref_occurrences"]) + local_hits

    by_chain = {k: dict(v) for k, v in sorted(chain_rows.items())}
    totals = {
        "files": len(files),
        "source_lines": sum(int(row["source_lines"]) for row in by_chain.values()),
        "theorems": sum(int(row["theorems"]) for row in by_chain.values()),
        "lemmas": sum(int(row["lemmas"]) for row in by_chain.values()),
        "mathlib_imports": sum(int(row["mathlib_imports"]) for row in by_chain.values()),
        "internal_imports": sum(int(row["internal_imports"]) for row in by_chain.values()),
        "decl_span_lines": sum(int(row["decl_span_lines"]) for row in by_chain.values()),
        "decls_with_local_refs": sum(int(row["decls_with_local_refs"]) for row in by_chain.values()),
        "local_ref_occurrences": sum(int(row["local_ref_occurrences"]) for row in by_chain.values()),
    }
    totals["theorem_lemma_decls"] = totals["theorems"] + totals["lemmas"]
    totals["decls_with_local_refs_pct"] = totals["decls_with_local_refs"] / totals["theorem_lemma_decls"]
    return {
        "totals": totals,
        "by_chain": by_chain,
        "top_internal_references": local_ref_counter.most_common(20),
        "method_note": (
            "Static source proxy. Declaration spans run from each theorem/lemma line to the next top-level declaration. "
            "Internal-reference counts match local theorem/lemma names textually inside those spans, excluding short/common names. "
            "This is not a proof of mathematical novelty and does not measure Mathlib lemma calls directly."
        ),
    }


def extract_tarball(tarball: Path, destination: Path) -> Path:
    with tarfile.open(tarball, "r:gz") as tf:
        tf.extractall(destination, filter="data")
    candidates = [destination / "formalslt", destination / "formalslt_anonymized"]
    for candidate in candidates:
        if (candidate / "FormalSLT.lean").exists():
            return candidate
    for candidate in destination.iterdir():
        if candidate.is_dir() and (candidate / "FormalSLT.lean").exists():
            return candidate
    raise FileNotFoundError("could not find extracted FormalSLT root")


def main() -> None:
    parser = argparse.ArgumentParser(description="Static FormalSLT source-structure proxy for supplementary review.")
    parser.add_argument("--source", type=Path, help="Path to extracted formalslt root.")
    parser.add_argument("--tarball", type=Path, default=Path("formalslt_anonymized.tar.gz"), help="Path to formalslt_anonymized.tar.gz.")
    parser.add_argument("--json-out", type=Path, default=Path("SOURCE_STRUCTURE_ANALYSIS.json"))
    args = parser.parse_args()

    if args.source:
        source = args.source
        result = analyze(source)
    else:
        with tempfile.TemporaryDirectory() as tmp:
            source = extract_tarball(args.tarball, Path(tmp))
            result = analyze(source)

    args.json_out.write_text(json.dumps(result, indent=2, sort_keys=True) + "\n")
    totals = result["totals"]
    print(f"Files: {totals['files']}")
    print(f"Source lines: {totals['source_lines']}")
    print(f"Theorem/lemma declarations: {totals['theorem_lemma_decls']}")
    print(f"Mathlib import lines: {totals['mathlib_imports']}")
    print(f"Internal FormalSLT import lines: {totals['internal_imports']}")
    print(f"Declarations with local refs: {totals['decls_with_local_refs']} ({totals['decls_with_local_refs_pct']:.1%})")


if __name__ == "__main__":
    main()
