"""Phase 3: Build step-delete unsound data with span-based verification.

Reads selected_steps.jsonl + original dataset.jsonl, produces:
  - step_delete_unsound.jsonl  (gold set: exact/unescape match)
  - step_delete_review.jsonl   (review pool: normalized/fuzzy match)

Usage:
    python3 -m step_edit.build_step_delete \
        --labels step_edit/data/minif2f/selected_steps.jsonl \
        --input  datasets_validation/minif2f/dataset.jsonl \
        --output_dir step_edit/data/minif2f
"""
from __future__ import annotations

import argparse
import json
import re
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))


def unescape_only(text: str) -> str:
    return text.replace("\\\\", "\\")


def normalize_both_sides(text: str) -> str:
    t = unescape_only(text)
    t = re.sub(r"\s+", " ", t).strip()
    return t


def find_span(proof: str, full_text: str) -> tuple[int, int, str, str]:
    """Find full_text in proof. Returns (start, end, method, inserted_outcome_basis).

    inserted_outcome_basis is either 'raw' or 'unescaped' — tells the caller how
    to prepare outcome_text for insertion.
    """
    # Strategy 1: Exact
    idx = proof.find(full_text)
    if idx >= 0:
        return idx, idx + len(full_text), "exact", "raw"

    # Strategy 2: Unescape only
    ft_ue = unescape_only(full_text)
    idx = proof.find(ft_ue)
    if idx >= 0:
        return idx, idx + len(ft_ue), "unescape", "unescaped"

    # Strategy 3: Normalized both sides
    proof_norm = normalize_both_sides(proof)
    ft_norm = normalize_both_sides(full_text)
    idx = proof_norm.find(ft_norm)
    if idx >= 0:
        return idx, idx + len(ft_norm), "normalized", "normalized"

    # Strategy 4: Fuzzy (first 30 + last 30 chars)
    if len(full_text) >= 60:
        head = full_text[:30]
        tail = full_text[-30:]
        hi = proof.find(head)
        ti = proof.rfind(tail)
        if hi >= 0 and ti >= 0 and ti > hi:
            return hi, ti + len(tail), "fuzzy", "raw"
        # Also try unescaped
        head_ue = unescape_only(head)
        tail_ue = unescape_only(tail)
        hi = proof.find(head_ue)
        ti = proof.rfind(tail_ue)
        if hi >= 0 and ti >= 0 and ti > hi:
            return hi, ti + len(tail_ue), "fuzzy", "unescaped"

    return -1, -1, "not_found", ""


def build_one(problem: dict, step: dict, tier: str) -> dict | None:
    """Build one edited row. Returns dict with match metadata or None on failure."""
    name = problem["name"]
    proof = problem.get("informal_proof", "") or ""
    full_text = (step.get("full_text") or "").strip()
    raw_outcome = (step.get("outcome_text") or "").strip()
    reasoning = (step.get("reasoning_text") or "").strip()

    if not full_text or not raw_outcome:
        return None

    # P0.2: verbatim gate — outcome must be a contiguous substring of full_text
    # (with unescape + whitespace-normalize tolerance). Phase 2's classify_step
    # already enforces this; this is the defensive layer in case Phase 2 rules
    # are ever loosened. Keeps build_step_delete correct on its own terms.
    full_ue = full_text.replace("\\\\", "\\")
    outcome_ue = raw_outcome.replace("\\\\", "\\")
    outcome_in_full = (
        raw_outcome in full_text
        or outcome_ue in full_ue
        or re.sub(r"\s+", " ", outcome_ue).strip()
           in re.sub(r"\s+", " ", full_ue).strip()
    )
    if not outcome_in_full:
        return None

    start, end, method, outcome_basis = find_span(proof, full_text)
    if method == "not_found":
        return None

    # Prepare inserted_outcome_text based on match method
    if outcome_basis == "raw":
        inserted = raw_outcome
    elif outcome_basis == "unescaped":
        inserted = unescape_only(raw_outcome)
    else:
        inserted = normalize_both_sides(raw_outcome)

    # Perform the replacement
    edited_proof = proof[:start] + inserted + proof[end:]

    # ── Span-based post-build verification ──────────────────────────────
    # 1. Edit applied
    if edited_proof == proof:
        return None

    # 2. Span replaced correctly
    actual = edited_proof[start:start + len(inserted)]
    if actual != inserted:
        return None

    # 3. Content outside span unchanged
    if edited_proof[:start] != proof[:start]:
        return None
    if edited_proof[start + len(inserted):] != proof[end:]:
        return None

    # 4. Proof shortened (1–80%)
    if len(proof) == 0:
        return None
    shrinkage = 1.0 - len(edited_proof) / len(proof)
    if not (0.005 <= shrinkage <= 0.80):
        return None

    # Build output row
    out = dict(problem)
    out["original_informal_proof"] = proof
    out["informal_proof"] = edited_proof
    out["step_edit_type"] = "step_delete"
    out["step_edit_target_step_idx"] = step.get("step_idx", -1)
    out["step_edit_target_full_text"] = full_text
    out["step_edit_target_outcome_text"] = raw_outcome
    out["step_edit_target_reasoning_text"] = reasoning
    out["step_edit_matched_by"] = method
    out["step_edit_match_confidence"] = {"exact": 1.0, "unescape": 0.9,
                                          "normalized": 0.7, "fuzzy": 0.5}[method]
    out["step_edit_is_last"] = bool(step.get("is_last", False))
    out["step_edit_tier"] = tier
    out["step_edit_span_start"] = start
    out["step_edit_span_end"] = end

    # P0.4: flag globally-repeated reasoning. Check whether the reasoning text
    # appears in the proof ANYWHERE OUTSIDE the deleted span. If yes, the scorer
    # may still see partial "reasoning" content even after span removal. We
    # keep the row in the gold set (span removal is still correct) and tag it
    # so analysis can stratify. Slicing out [start:end] avoids trivial overlap.
    remaining = proof[:start] + proof[end:]
    reasoning_ue = reasoning.replace("\\\\", "\\")
    out["step_edit_reasoning_reappears_elsewhere"] = bool(
        reasoning and (reasoning in remaining or reasoning_ue in remaining)
    )
    return out


def main() -> None:
    parser = argparse.ArgumentParser(description="Phase 3: build step-delete")
    parser.add_argument("--labels", required=True, help="selected_steps.jsonl")
    parser.add_argument("--input", required=True, help="Original dataset.jsonl")
    parser.add_argument("--output_dir", required=True)
    args = parser.parse_args()

    problems = {r["name"]: r for r in (json.loads(l) for l in open(args.input) if l.strip())}
    labels = [json.loads(l) for l in open(args.labels) if l.strip()]

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    gold_rows = []
    review_rows = []
    fail_reasons = {}

    for label_row in labels:
        name = label_row.get("problem_name")
        step = label_row.get("selected_step")
        tier = label_row.get("selected_tier", "bronze")
        if step is None or name not in problems:
            continue

        result = build_one(problems[name], step, tier)
        if result is None:
            fail_reasons[name] = "build_failed"
            continue

        method = result["step_edit_matched_by"]
        if method in ("exact", "unescape"):
            gold_rows.append(result)
        else:
            review_rows.append(result)

    # Write outputs
    gold_path = out_dir / "step_delete_unsound.jsonl"
    review_path = out_dir / "step_delete_review.jsonl"
    for path, rows in [(gold_path, gold_rows), (review_path, review_rows)]:
        with open(path, "w", encoding="utf-8") as f:
            for r in rows:
                f.write(json.dumps(r, ensure_ascii=False) + "\n")

    # Summary
    summary = {
        "input_problems": len(problems),
        "labeled_with_selection": sum(1 for l in labels if l.get("selected_step")),
        "gold_set": len(gold_rows),
        "review_pool": len(review_rows),
        "build_failed": len(fail_reasons),
        "gold_tiers": dict(Counter(r["step_edit_tier"] for r in gold_rows)),
        "gold_methods": dict(Counter(r["step_edit_matched_by"] for r in gold_rows)),
        "review_methods": dict(Counter(r["step_edit_matched_by"] for r in review_rows)),
    }
    with open(out_dir / "summary.json", "w") as f:
        json.dump(summary, f, indent=2)

    print("=== Phase 3 Summary ===")
    print(f"Input:        {summary['input_problems']} problems")
    print(f"With pick:    {summary['labeled_with_selection']}")
    print(f"Gold set:     {summary['gold_set']} (exact/unescape)")
    print(f"Review pool:  {summary['review_pool']} (normalized/fuzzy)")
    print(f"Build failed: {summary['build_failed']}")
    print(f"Gold tiers:   {summary['gold_tiers']}")
    print(f"Gold methods: {summary['gold_methods']}")
    print(f"Output:       {gold_path}")
    if review_rows:
        print(f"              {review_path}")


from collections import Counter

if __name__ == "__main__":
    main()
