import csv
import argparse
from pathlib import Path

try:
    from proofstate_common import load_jsonl, normalize_entry, validate_entry, write_jsonl
except ImportError:
    from scripts.proofstate_common import load_jsonl, normalize_entry, validate_entry, write_jsonl

INPUT_FILE = Path("data/pilot_pairs.jsonl")
OUTPUT_JSONL = Path("data/pilot_pairs_checked.jsonl")
OUTPUT_CSV = Path("data/pilot_pairs_checked.csv")

def main() -> None:
    parser = argparse.ArgumentParser(description="Validate and normalize the pilot JSONL dataset.")
    parser.add_argument("--input", type=Path, default=INPUT_FILE, help="Raw JSONL input path.")
    parser.add_argument("--output-jsonl", type=Path, default=OUTPUT_JSONL, help="Cleaned JSONL output path.")
    parser.add_argument("--output-csv", type=Path, default=OUTPUT_CSV, help="CSV inspection output path.")
    args = parser.parse_args()

    if not args.input.exists():
        print(f"Error: {args.input} does not exist.")
        return

    output_entries = []
    all_errors = []

    try:
        raw_entries = load_jsonl(args.input)
    except ValueError as exc:
        print(f"Error: {exc}")
        return

    for line_no, entry in enumerate(raw_entries, start=1):
        all_errors.extend(validate_entry(entry, line_no))
        normalized = normalize_entry(entry)
        if normalized["tactic_family"] == "unknown":
            all_errors.append(f"Line {line_no}: could not infer tactic_family")
        output_entries.append(normalized)

    if all_errors:
        print("Found problems:")
        for err in all_errors:
            print(" -", err)
    else:
        print("No formatting problems found.")

    write_jsonl(output_entries, args.output_jsonl)

    args.output_csv.parent.mkdir(parents=True, exist_ok=True)
    with args.output_csv.open("w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(
            [
                "file",
                "theorem",
                "step_index",
                "main_goal",
                "local_context",
                "next_tactic",
                "tactic_family",
            ]
        )
        for entry in output_entries:
            writer.writerow(
                [
                    entry["file"],
                    entry["theorem"],
                    entry["step_index"],
                    entry["main_goal"],
                    "[" + "; ".join(entry["local_context"]) + "]",
                    entry["next_tactic"],
                    entry["tactic_family"],
                ]
            )

    print(f"Wrote cleaned file to: {args.output_jsonl}")
    print(f"Wrote CSV snapshot to: {args.output_csv}")


if __name__ == "__main__":
    main()
