"""Phase 2: Tiered quality filter + sha256-seeded selection.

Reads labeled_steps.jsonl, applies hard rejects + quality tiering, picks one
candidate per problem from the highest available tier.

Usage:
    python3 -m step_edit.clean_and_select \
        --input  step_edit/data/minif2f/labeled_steps.jsonl \
        --output step_edit/data/minif2f/selected_steps.jsonl
"""
from __future__ import annotations

import argparse
import hashlib
import json
import random
import re
import sys
from collections import Counter
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

# ── Configurable thresholds (CLI args) ──────────────────────────────────────

DEFAULT_MIN_REASONING_CHARS = 30
DEFAULT_MIN_OUTCOME_CHARS = 10

# ── Blacklists ──────────────────────────────────────────────────────────────

VACUOUS_PHRASES = {
    "so", "thus", "hence", "therefore", "clearly", "obviously",
    "then", "next", "first", "second", "finally", "now", "here",
    "note", "notice", "observe", "also", "further", "moreover",
    "indeed", "similarly", "and", "but", "however", "wlog",
    "consequently", "trivially", "simplifying",
    "we have", "we have that", "we get", "we obtain", "we see",
    "we see that", "we note", "we note that", "we know",
    "we know that", "we conclude",
    "in the diagram", "by this", "by that", "this gives us",
    "that means", "as a result", "it follows", "it follows that",
    "so that", "such that", "as before", "as above", "from this",
    "from the above", "by the above",
}

SUBSTANTIVE_VERB_RE = re.compile(
    r"\b(substitut|expand|factor|comput|derive|appl|rewrite|square|multipl|"
    r"divid|integrat|differentiat|sum|simplif|cancel|invert|transpos|"
    r"normaliz|equate|solve|rearrang|combine|isolate|reduce|evaluat|"
    r"subtract|add|multiply|square root|take the|let|set|define|"
    r"plug|insert|convert|transform|express|manipulat|"
    r"by.*theorem|by.*inequality|by.*lemma|by.*definition|"
    r"using|since|because|as)\w*\b",
    re.I,
)

DANGLING_FRAGMENT_RE = re.compile(
    r"^\s*(which\b|that\s+is\b|thereby\b|whereas\b|whereby\b|"
    r"so\s+that\b|for\s+a\s+total\b|meaning\s+that\b|implying\s+that\b)",
    re.I,
)
ANAPHORIC_START_RE = re.compile(
    r"^\s*(this|that|these|those)\s+(is|are|was|were|gives|means|implies|shows|proves|yields)\b",
    re.I,
)


# ── Quality assessment ──────────────────────────────────────────────────────

def classify_step(step: dict, min_reasoning: int, min_outcome: int, total_steps: int
                  ) -> tuple[str, str]:
    """Return (tier, reason). tier ∈ {gold, silver, bronze, reject}."""
    reasoning = (step.get("reasoning_text") or "").strip()
    outcome = (step.get("outcome_text") or "").strip()
    full_text = (step.get("full_text") or "").strip()

    if not reasoning or not outcome or not full_text:
        return "reject", "empty_field"

    # ── R1: Substantive reasoning (HARD) ────────────────────────────────
    norm_reasoning = re.sub(r"[\s.,;:!?]+", " ", reasoning.lower()).strip()
    if norm_reasoning in VACUOUS_PHRASES:
        return "reject", f"vacuous_phrase:{norm_reasoning[:30]}"
    if len(reasoning) < min_reasoning:
        has_math = "$" in reasoning or "\\(" in reasoning
        has_action = bool(SUBSTANTIVE_VERB_RE.search(reasoning))
        if not (has_math or has_action):
            return "reject", f"short_no_math_no_action:{len(reasoning)}"

    # ── R3: Dangling fragment (HARD) ────────────────────────────────────
    if DANGLING_FRAGMENT_RE.match(outcome):
        return "reject", "dangling_fragment"
    if ANAPHORIC_START_RE.match(outcome):
        return "reject", "anaphoric_start"

    # ── P0.2: outcome must be a contiguous substring of full_text (HARD) ─
    # Catches the case where Gemini paraphrases outcome_text (e.g., adds
    # punctuation/capitalization not present in full_text). Try three levels
    # of tolerance: raw → unescape → whitespace-normalized. If none match, reject.
    full_ue = full_text.replace("\\\\", "\\")
    outcome_ue = outcome.replace("\\\\", "\\")
    outcome_in_full = (
        outcome in full_text
        or outcome_ue in full_ue
        or re.sub(r"\s+", " ", outcome_ue).strip() in re.sub(r"\s+", " ", full_ue).strip()
    )
    if not outcome_in_full:
        return "reject", "outcome_not_substring_of_full"

    # ── R4: Reconstruction (HARD) ───────────────────────────────────────
    norm_full = re.sub(r"\s+", " ", full_text.lower()).strip()
    norm_join = re.sub(r"\s+", " ", (reasoning + " " + outcome).lower()).strip()
    if norm_join not in norm_full and norm_full not in norm_join:
        if len(norm_join) > 10 and len(norm_full) > 10:
            from difflib import SequenceMatcher
            ratio = SequenceMatcher(None, norm_join, norm_full).ratio()
            if ratio < 0.85:
                return "reject", f"split_inconsistent:ratio={ratio:.2f}"

    # ── R2 + R6: Outcome quality + step count → tier assignment ─────────
    if len(outcome) < min_outcome:
        return "reject", f"outcome_too_short:{len(outcome)}"

    # Check if outcome is a bare number/symbol (degenerate)
    stripped = re.sub(r"\$[^$]*\$", "", outcome).strip()
    stripped = re.sub(r"\\[a-zA-Z]+(\{[^}]*\})*", "", stripped).strip()
    stripped = re.sub(r"[\\\{\}_\^.,;:!?\s]", "", stripped)
    if not stripped and len(outcome) < 15 and "=" not in outcome and "<" not in outcome and ">" not in outcome:
        return "reject", "degenerate_outcome"

    # Tier assignment
    starts_upper = outcome[0].isupper() or outcome[0] == "\\"
    has_english = len([w for w in re.sub(r"\$[^$]*\$", "", outcome).split()
                       if w.isalpha() and len(w) >= 2]) >= 2
    ends_properly = outcome.rstrip().endswith((".", "$", ")"))
    has_equation = any(c in outcome for c in "=<>≤≥≠∈∀∃")

    if starts_upper and has_english and ends_properly:
        tier = "gold"
    elif has_equation or (starts_upper and ends_properly):
        tier = "silver"
    else:
        tier = "bronze"

    # R6 soft demotion: if proof has ≤ 2 steps, demote to bronze
    if total_steps <= 2:
        tier = "bronze"

    return tier, ""


def seeded_rng(problem_name: str) -> random.Random:
    # seed string kept as "step_v2" for idempotency with existing selected_steps.jsonl
    digest = hashlib.sha256(f"{problem_name}:step_v2".encode("utf-8")).hexdigest()
    return random.Random(int(digest[:16], 16))


def deterministic_sort_key(step: dict, problem_name: str, tier_rank: int) -> tuple:
    """Quality-ranked sort key (descending). Last component = sha256 tiebreak.

    Components (all "bigger is better"):
      - tier_rank: 2=gold, 1=silver, 0=bronze
      - reasoning_substance_score (0-4): 2*has_math + 2*has_substantive_verb
      - outcome_readability_score (0-3): starts_upper + has_english + ends_properly
      - len(reasoning_text): longer substantive reasoning preferred
      - sha256-derived int: deterministic random tiebreak (ensures stable
        picking even if all above are tied across steps)

    is_last is intentionally NOT in the sort key — see docs/exp4_step_delete.md P0.1.
    """
    reasoning = (step.get("reasoning_text") or "").strip()
    outcome = (step.get("outcome_text") or "").strip()

    has_math = "$" in reasoning or "\\(" in reasoning
    has_action = bool(SUBSTANTIVE_VERB_RE.search(reasoning))
    reasoning_substance = 2 * int(has_math) + 2 * int(has_action)

    starts_upper = bool(outcome) and (outcome[0].isupper() or outcome[0] == "\\")
    has_english = len([w for w in re.sub(r"\$[^$]*\$", "", outcome).split()
                       if w.isalpha() and len(w) >= 2]) >= 2
    ends_properly = outcome.rstrip().endswith((".", "$", ")"))
    outcome_readability = int(starts_upper) + int(has_english) + int(ends_properly)

    digest = hashlib.sha256(
        f"{problem_name}:step_v2:{step.get('step_idx', -1)}".encode("utf-8")
    ).hexdigest()
    tiebreak = int(digest[:16], 16)

    return (tier_rank, reasoning_substance, outcome_readability, len(reasoning), tiebreak)


# ── Main ────────────────────────────────────────────────────────────────────

def main() -> None:
    parser = argparse.ArgumentParser(description="Phase 2: filter + select")
    parser.add_argument("--input", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument("--min-reasoning-chars", type=int, default=DEFAULT_MIN_REASONING_CHARS)
    parser.add_argument("--min-outcome-chars", type=int, default=DEFAULT_MIN_OUTCOME_CHARS)
    args = parser.parse_args()

    rows = [json.loads(l) for l in open(args.input) if l.strip()]
    reject_reasons = Counter()
    tier_counts = Counter()
    n_selected = 0
    n_no_candidate = 0

    out_rows = []
    for row in rows:
        name = row.get("problem_name", "?")
        steps = row.get("steps", []) or []
        total_steps = len(steps)

        # Classify every step that Gemini marked deletable; keep tier alongside
        # so the sort key can use it.
        TIER_RANK = {"gold": 2, "silver": 1, "bronze": 0}
        classified_candidates = []  # list of (step, tier)
        for step in steps:
            if not step.get("deletable"):
                continue
            tier, reason = classify_step(step, args.min_reasoning_chars,
                                         args.min_outcome_chars, total_steps)
            if tier == "reject":
                reject_reasons[reason] += 1
            else:
                tier_counts[tier] += 1
                classified_candidates.append((step, tier))

        if not classified_candidates:
            n_no_candidate += 1
            out_rows.append({"problem_name": name, "selected_step": None})
            continue

        # P0.1: deterministic quality-ranked selection.
        # Sort candidates by the composite key (descending); pick the top.
        # sha256 tiebreak is the last component, so ordering is fully
        # deterministic across reviewers.
        classified_candidates.sort(
            key=lambda sc: deterministic_sort_key(sc[0], name, TIER_RANK[sc[1]]),
            reverse=True,
        )
        picked, picked_tier = classified_candidates[0]
        n_selected += 1
        out_rows.append({
            "problem_name": name,
            "selected_step": picked,
            "selected_tier": picked_tier,
        })

    Path(args.output).parent.mkdir(parents=True, exist_ok=True)
    with open(args.output, "w", encoding="utf-8") as f:
        for r in out_rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    print("=== Phase 2 Summary ===")
    print(f"Input problems:       {len(rows)}")
    print(f"Selected (has pick):  {n_selected}")
    print(f"No candidate:         {n_no_candidate}")
    print(f"\nTier distribution (all passing candidates across all problems):")
    for t in ("gold", "silver", "bronze"):
        print(f"  {t}: {tier_counts[t]}")
    print(f"\nReject reasons:")
    for reason, count in reject_reasons.most_common():
        print(f"  {reason}: {count}")
    print(f"\nOutput: {args.output}")


if __name__ == "__main__":
    main()
