#!/usr/bin/env python3
"""Minimal GRPO-MVP reward utilities for the supplementary artifact."""

from __future__ import annotations

import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, Iterable, List


def strip_isabelle_comments(text: str) -> str:
    pattern = re.compile(r"\(\*.*?\*\)", re.DOTALL)
    previous = None
    cleaned = str(text)
    while previous != cleaned:
        previous = cleaned
        cleaned = pattern.sub(" ", cleaned)
    return cleaned


def contains_forbidden_placeholder(text: str, placeholders: Iterable[str] = ("sorry", "oops")) -> bool:
    cleaned = strip_isabelle_comments(text)
    pattern = r"\b(?:%s)\b" % "|".join(re.escape(token) for token in placeholders)
    return re.search(pattern, cleaned, flags=re.IGNORECASE) is not None


def proof_level_mvp(step_results: Iterable[Dict[str, Any]]) -> float:
    """Compute Minimal Valid Progress from PISA proof-level traces.

    Each step result may contain `proof_level_before` and `proof_level_after`.
    MVP rewards failed completions only when they both open and close local
    proof contexts, which is a verifier-grounded proxy for partial proof
    progress.
    """

    open_magnitude = 0.0
    close_magnitude = 0.0
    for step in step_results:
        try:
            before = float(step.get("proof_level_before", 0.0))
            after = float(step.get("proof_level_after", before))
        except Exception:
            continue
        delta = after - before
        if delta > 0:
            open_magnitude += delta
        elif delta < 0:
            close_magnitude += -delta
    return min(open_magnitude, close_magnitude) / max(1.0, open_magnitude)


def grpo_mvp_reward(
    verification: Dict[str, Any],
    *,
    lambda_mvp: float = 0.2,
    pass_reward: float = 1.0,
    fail_reward: float = 0.0,
) -> Dict[str, float]:
    success = bool(verification.get("success"))
    if success:
        return {"reward": float(pass_reward), "verifier": float(pass_reward), "mvp": 0.0}
    mvp = proof_level_mvp(verification.get("step_results") or [])
    reward = float(fail_reward) + float(lambda_mvp) * float(mvp)
    return {"reward": reward, "verifier": float(fail_reward), "mvp": float(mvp)}


def difficulty_score(row: Dict[str, Any]) -> int:
    lean = row.get("lean") or {}
    return len(str(lean.get("formal_proof") or "")) + len(str(lean.get("isabelle_statement") or ""))


def order_by_difficulty(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    return sorted(rows, key=lambda row: (difficulty_score(row), str(row.get("id") or "")))


def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")


def main() -> None:
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="command", required=True)
    order_cmd = subparsers.add_parser("order")
    order_cmd.add_argument("--input", type=Path, required=True)
    order_cmd.add_argument("--output", type=Path, required=True)
    args = parser.parse_args()
    if args.command == "order":
        rows = read_jsonl(args.input)
        ordered = order_by_difficulty(rows)
        write_jsonl(args.output, ordered)


if __name__ == "__main__":
    main()
