"""Phase 4: Independent audit of step_delete_unsound.jsonl.

Re-runs all quality checks on the final built data. Reports pass/fail per metric.

Usage:
    python3 -m step_edit.audit_step \
        step_edit/data/minif2f/step_delete_unsound.jsonl \
        step_edit/data/math500/step_delete_unsound.jsonl
"""
import json
import re
import sys
from collections import Counter

VACUOUS_PHRASES = {
    "so", "thus", "hence", "therefore", "clearly", "obviously",
    "then", "next", "first", "second", "finally", "now", "here",
    "note", "notice", "observe", "also", "further", "moreover",
    "indeed", "similarly", "and", "but", "however", "wlog",
    "consequently", "trivially", "simplifying",
    "we have", "we have that", "we get", "we obtain", "we see",
    "we see that", "we note", "we note that", "we know",
    "we know that", "we conclude",
    "in the diagram", "by this", "by that", "this gives us",
    "that means", "as a result", "it follows", "it follows that",
    "so that", "such that", "as before", "as above", "from this",
    "from the above", "by the above",
}

DANGLING_RE = re.compile(
    r"^\s*(which\b|that\s+is\b|thereby\b|whereas\b|whereby\b|"
    r"so\s+that\b|for\s+a\s+total\b|meaning\s+that\b|implying\s+that\b)",
    re.I,
)
ANAPHORIC_RE = re.compile(
    r"^\s*(this|that|these|those)\s+(is|are|was|were|gives|means|implies|shows|proves|yields)\b",
    re.I,
)


def unescape_only(s):
    return s.replace("\\\\", "\\") if s else s


def audit(path: str) -> bool:
    rows = [json.loads(l) for l in open(path) if l.strip()]
    n = len(rows)
    if n == 0:
        print(f"\n=== {path}: EMPTY (0 rows) ===")
        return False

    fails = Counter()
    tier_counts = Counter()
    method_counts = Counter()
    is_last_count = 0
    reasoning_lens = []
    outcome_lens = []

    for r in rows:
        reasoning = (r.get("step_edit_target_reasoning_text") or "").strip()
        outcome = (r.get("step_edit_target_outcome_text") or "").strip()
        orig = r.get("original_informal_proof", "")
        edited = r.get("informal_proof", "")
        start = r.get("step_edit_span_start")
        end = r.get("step_edit_span_end")
        tier = r.get("step_edit_tier", "?")
        method = r.get("step_edit_matched_by", "?")
        is_last = r.get("step_edit_is_last", False)

        tier_counts[tier] += 1
        method_counts[method] += 1
        if is_last:
            is_last_count += 1
        reasoning_lens.append(len(reasoning))
        outcome_lens.append(len(outcome))

        # ── Hard zero targets ───────────────────────────────────────────
        # 1. Vacuous reasoning
        norm = re.sub(r"[\s.,;:!?]+", " ", reasoning.lower()).strip()
        if norm in VACUOUS_PHRASES:
            fails["vacuous_reasoning"] += 1

        # 2. Degenerate outcome
        if len(outcome) < 10:
            stripped = re.sub(r"\$[^$]*\$", "", outcome).strip()
            stripped = re.sub(r"[\\\{\}_\^.,;:!?\s]", "", stripped)
            if not stripped:
                fails["degenerate_outcome"] += 1

        # 3. Dangling fragment
        if DANGLING_RE.match(outcome):
            fails["dangling_fragment"] += 1
        if ANAPHORIC_RE.match(outcome):
            fails["anaphoric_start"] += 1

        # 4. Edit applied
        if orig == edited:
            fails["edit_not_applied"] += 1

        # 5. Span verification (for exact/unescape matches only)
        if start is not None and end is not None and method in ("exact", "unescape"):
            inserted = outcome
            if method == "unescape":
                inserted = unescape_only(outcome)
            actual = edited[start:start + len(inserted)] if start + len(inserted) <= len(edited) else ""
            if actual != inserted:
                fails["outcome_not_at_span"] += 1
            if edited[:start] != orig[:start]:
                fails["prefix_changed"] += 1
            if start + len(inserted) <= len(edited) and end <= len(orig):
                if edited[start + len(inserted):] != orig[end:]:
                    fails["suffix_changed"] += 1

        # 6. Proof shortened
        if len(orig) > 0:
            shrink = 1.0 - len(edited) / len(orig)
            if not (0.005 <= shrink <= 0.80):
                fails[f"bad_shrinkage:{shrink:.1%}"] += 1

        # 7. Match confidence (gold set should be >= 0.9)
        conf = r.get("step_edit_match_confidence", 0)
        if conf < 0.9:
            fails["low_confidence_in_gold"] += 1

    # ── Report ──────────────────────────────────────────────────────────
    print(f"\n{'='*60}")
    print(f"  AUDIT: {path}")
    print(f"{'='*60}")
    print(f"Total rows: {n}")
    print(f"\nTier distribution:")
    for t in ("gold", "silver", "bronze"):
        c = tier_counts.get(t, 0)
        print(f"  {t}: {c} ({100*c/n:.0f}%)")
    print(f"\nMatch methods:")
    for m, c in method_counts.most_common():
        print(f"  {m}: {c}")
    print(f"\nis_last=true: {is_last_count} ({100*is_last_count/n:.0f}%)")
    reasoning_lens.sort()
    outcome_lens.sort()
    print(f"Reasoning length: median={reasoning_lens[n//2]} p10={reasoning_lens[n//10]} min={min(reasoning_lens)}")
    print(f"Outcome length:   median={outcome_lens[n//2]} p10={outcome_lens[n//10]} min={min(outcome_lens)}")

    print(f"\n--- HARD-ZERO TARGETS ---")
    hard_zero = ["vacuous_reasoning", "degenerate_outcome", "dangling_fragment",
                 "anaphoric_start", "edit_not_applied", "outcome_not_at_span",
                 "prefix_changed", "suffix_changed", "low_confidence_in_gold"]
    all_pass = True
    for key in hard_zero:
        c = fails.get(key, 0)
        status = "✅ PASS" if c == 0 else f"❌ FAIL ({c})"
        print(f"  {key}: {status}")
        if c > 0:
            all_pass = False

    # Other fails (shrinkage etc.)
    other_fails = {k: v for k, v in fails.items() if k not in hard_zero}
    if other_fails:
        print(f"\n--- OTHER ISSUES ---")
        for k, v in sorted(other_fails.items()):
            print(f"  {k}: {v}")

    # Soft quality metrics
    starts_upper = sum(1 for r in (json.loads(l) for l in open(path) if l.strip())
                       if ((r.get("step_edit_target_outcome_text") or "")[0:1]).isupper())
    print(f"\n--- SOFT QUALITY ---")
    print(f"  outcome starts uppercase: {starts_upper}/{n} ({100*starts_upper/n:.0f}%)")

    if n >= 150:
        print(f"\n  Total rows ≥ 150: ✅ PASS")
    else:
        print(f"\n  Total rows ≥ 150: ❌ FAIL ({n})")
        all_pass = False

    overall = "✅ ALL TARGETS MET" if all_pass else "❌ SOME TARGETS FAILED"
    print(f"\n{'='*60}")
    print(f"  OVERALL: {overall}")
    print(f"{'='*60}")
    return all_pass


def main():
    if len(sys.argv) < 2:
        print("Usage: python -m step_edit.audit_step <file1.jsonl> [file2.jsonl ...]")
        sys.exit(1)

    all_pass = True
    for path in sys.argv[1:]:
        if not audit(path):
            all_pass = False

    sys.exit(0 if all_pass else 1)


if __name__ == "__main__":
    main()
