"""Build unsound number-edit datasets from Gemini's direct picks.

Input format (one JSON line per problem):
  {"problem_name": ..., "statement_edit": {...} | null, "proof_edit": {...} | null}

Each non-null pick is {old_value, char_offset_start, char_offset_end, context}.
Perturbation is deterministic via SHA256 in perturb_numeric_string.
"""

from __future__ import annotations

import argparse
import json
from collections import Counter
from pathlib import Path
import sys
from typing import Dict, List, Optional

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from number_edit.common import (
    ensure_dir,
    load_jsonl,
    perturb_numeric_string,
    replace_span,
    write_jsonl,
)


EDIT_TYPES = ("statement_edit", "proof_edit")
SOURCE_OF = {"statement_edit": "statement", "proof_edit": "proof"}


def build_edited_record(problem: dict, pick: dict, edit_type: str) -> Optional[dict]:
    source = SOURCE_OF[edit_type]
    old_value = str(pick.get("old_value", "")).strip()
    start = int(pick.get("char_offset_start", -1))
    end = int(pick.get("char_offset_end", -1))

    if not old_value:
        return None

    new_value = perturb_numeric_string(
        old_value,
        problem_name=problem.get("name", ""),
        source=source,
    )
    if new_value is None or new_value == old_value:
        return None

    if source == "statement":
        text = str(problem.get("informal_statement", ""))
        edited = replace_span(text, start, end, old_value, new_value)
        if edited is None:
            return None
        edited_statement, edited_proof = edited, str(problem.get("informal_proof", ""))
    else:
        text = str(problem.get("informal_proof", ""))
        edited = replace_span(text, start, end, old_value, new_value)
        if edited is None:
            return None
        edited_statement, edited_proof = str(problem.get("informal_statement", "")), edited

    record = dict(problem)
    record["original_informal_statement"] = str(problem.get("informal_statement", ""))
    record["original_informal_proof"] = str(problem.get("informal_proof", ""))
    record["informal_statement"] = edited_statement
    record["informal_proof"] = edited_proof
    record["number_edit_type"] = edit_type
    record["number_edit_source"] = source
    record["number_edit_old_value"] = old_value
    record["number_edit_new_value"] = new_value
    record["number_edit_context"] = pick.get("context", "")
    record["number_edit_char_offset_start"] = start
    record["number_edit_char_offset_end"] = end
    return record


def main() -> None:
    parser = argparse.ArgumentParser(description="Build unsound number-edit datasets")
    parser.add_argument("--input", default="./datasets_validation/minif2f/dataset.jsonl")
    parser.add_argument("--labels", default="./number_edit/data/selected_numbers.jsonl")
    parser.add_argument("--output_dir", default="./number_edit/data")
    parser.add_argument("--limit", type=int, default=0)
    args = parser.parse_args()

    output_dir = ensure_dir(args.output_dir)
    problems = load_jsonl(args.input)
    labels = load_jsonl(args.labels)
    if args.limit > 0:
        problems = problems[: args.limit]

    labels_by_name: Dict[str, dict] = {row["problem_name"]: row for row in labels}

    outputs: Dict[str, List[dict]] = {et: [] for et in EDIT_TYPES}
    skipped: Counter = Counter()

    for problem in problems:
        name = problem["name"]
        label_record = labels_by_name.get(name)
        if label_record is None:
            skipped["missing_labels"] += 1
            continue

        for edit_type in EDIT_TYPES:
            pick = label_record.get(edit_type)
            if not pick:
                skipped[f"null_{edit_type}"] += 1
                continue
            edited = build_edited_record(problem, pick, edit_type)
            if edited is None:
                skipped[f"skip_{edit_type}"] += 1
                continue
            outputs[edit_type].append(edited)

    for et in EDIT_TYPES:
        write_jsonl(output_dir / f"{et}_unsound.jsonl", outputs[et])

    summary = {
        "input_size": len(problems),
        "output_sizes": {k: len(v) for k, v in outputs.items()},
        "skipped": dict(skipped),
    }
    with (output_dir / "summary.json").open("w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)

    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
