"""Build unsound symbol-edit datasets from Gemini's direct picks.

Input format (one JSON line per problem):
  {"problem_name": ..., "statement_edit": {...} | null, "proof_edit": {...} | null}

Each non-null pick is {symbol, family, char_offset_start, char_offset_end, context}.
Perturbation is a fixed swap table in perturb_symbol (>/<, +/-, ×/÷, ...).
"""

from __future__ import annotations

import argparse
import json
from collections import Counter
from pathlib import Path
import sys
from typing import Dict, List, Optional

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from symbol_edit.common import (
    EDIT_TYPES,
    ensure_dir,
    load_jsonl,
    perturb_symbol,
    replace_symbol,
    write_jsonl,
)


SOURCE_OF = {"statement_edit": "statement", "proof_edit": "proof"}


def build_edited_record(problem: dict, pick: dict, edit_type: str) -> Optional[dict]:
    source = SOURCE_OF[edit_type]
    old_symbol = str(pick.get("symbol", "")).strip()
    start = int(pick.get("char_offset_start", -1))
    end = int(pick.get("char_offset_end", -1))
    context = pick.get("context", "")

    if not old_symbol:
        return None

    new_symbol = perturb_symbol(old_symbol)
    if new_symbol is None or new_symbol == old_symbol:
        return None

    if source == "statement":
        text = str(problem.get("informal_statement", ""))
        edited = replace_symbol(text, start, end, old_symbol, new_symbol, context)
        if edited is None:
            return None
        edited_statement, edited_proof = edited, str(problem.get("informal_proof", ""))
    else:
        text = str(problem.get("informal_proof", ""))
        edited = replace_symbol(text, start, end, old_symbol, new_symbol, context)
        if edited is None:
            return None
        edited_statement, edited_proof = str(problem.get("informal_statement", "")), edited

    if (edited_statement == str(problem.get("informal_statement", ""))
            and edited_proof == str(problem.get("informal_proof", ""))):
        return None

    record = dict(problem)
    record["original_informal_statement"] = str(problem.get("informal_statement", ""))
    record["original_informal_proof"] = str(problem.get("informal_proof", ""))
    record["informal_statement"] = edited_statement
    record["informal_proof"] = edited_proof
    record["symbol_edit_type"] = edit_type
    record["symbol_edit_source"] = source
    record["symbol_edit_family"] = pick.get("family")
    record["symbol_edit_old_symbol"] = old_symbol
    record["symbol_edit_new_symbol"] = new_symbol
    record["symbol_edit_context"] = context
    record["symbol_edit_char_offset_start"] = start
    record["symbol_edit_char_offset_end"] = end
    return record


def main() -> None:
    parser = argparse.ArgumentParser(description="Build unsound symbol-edit datasets")
    parser.add_argument("--input", default="./datasets_validation/minif2f/dataset.jsonl")
    parser.add_argument("--labels", default="./symbol_edit/data/selected_symbols.jsonl")
    parser.add_argument("--output_dir", default="./symbol_edit/data")
    parser.add_argument("--limit", type=int, default=0)
    args = parser.parse_args()

    output_dir = ensure_dir(args.output_dir)
    problems = load_jsonl(args.input)
    labels = load_jsonl(args.labels)
    if args.limit > 0:
        problems = problems[: args.limit]

    labels_by_name: Dict[str, dict] = {row["problem_name"]: row for row in labels}

    outputs: Dict[str, List[dict]] = {et: [] for et in EDIT_TYPES}
    skipped: Counter = Counter()

    for problem in problems:
        name = problem["name"]
        label_record = labels_by_name.get(name)
        if label_record is None:
            skipped["missing_labels"] += 1
            continue

        for edit_type in EDIT_TYPES:
            pick = label_record.get(edit_type)
            if not pick:
                skipped[f"null_{edit_type}"] += 1
                continue
            edited = build_edited_record(problem, pick, edit_type)
            if edited is None:
                skipped[f"skip_{edit_type}"] += 1
                continue
            outputs[edit_type].append(edited)

    for edit_type in EDIT_TYPES:
        write_jsonl(output_dir / f"{edit_type}_unsound.jsonl", outputs[edit_type])

    summary = {
        "input_size": len(problems),
        "output_sizes": {k: len(v) for k, v in outputs.items()},
        "skipped": dict(skipped),
    }
    with (output_dir / "summary.json").open("w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)

    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
