import argparse
import csv
import json
from collections import Counter
from pathlib import Path
from statistics import mean, median
from typing import Any

try:
    from proofstate_common import load_jsonl, normalize_entry, tokenize
except ImportError:
    from scripts.proofstate_common import load_jsonl, normalize_entry, tokenize


def summarize(rows: list[dict[str, Any]]) -> dict[str, Any]:
    labels = Counter(row["tactic_family"] for row in rows)
    theorem_counts = Counter(row["theorem"] for row in rows)
    goal_tokens = [len(tokenize(row["main_goal"])) for row in rows]
    context_sizes = [len(row["local_context"]) for row in rows]
    return {
        "n_steps": len(rows),
        "n_theorems": len(theorem_counts),
        "n_files": len({row["file"] for row in rows}),
        "n_labels": len(labels),
        "most_common_label": labels.most_common(1)[0][0] if labels else "",
        "most_common_label_count": labels.most_common(1)[0][1] if labels else 0,
        "mean_steps_per_theorem": mean(theorem_counts.values()) if theorem_counts else 0.0,
        "median_steps_per_theorem": median(theorem_counts.values()) if theorem_counts else 0.0,
        "mean_goal_tokens": mean(goal_tokens) if goal_tokens else 0.0,
        "median_goal_tokens": median(goal_tokens) if goal_tokens else 0.0,
        "mean_context_size": mean(context_sizes) if context_sizes else 0.0,
        "median_context_size": median(context_sizes) if context_sizes else 0.0,
        "label_distribution": dict(labels.most_common()),
    }


def write_summary_csv(summary: dict[str, Any], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    scalar_items = [(key, value) for key, value in summary.items() if key != "label_distribution"]
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["metric", "value"])
        writer.writerows(scalar_items)


def write_label_csv(summary: dict[str, Any], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["label", "count"])
        for label, count in summary["label_distribution"].items():
            writer.writerow([label, count])


def main() -> None:
    parser = argparse.ArgumentParser(description="Summarize proof-step datasets.")
    parser.add_argument("--data", type=Path, default=Path("data/leandojo_steps_checked.jsonl"))
    parser.add_argument("--output-json", type=Path, default=Path("results/dataset/dataset_summary.json"))
    parser.add_argument("--summary-csv", type=Path, default=Path("results/tables/dataset_summary.csv"))
    parser.add_argument("--label-csv", type=Path, default=Path("results/tables/label_distribution.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    summary = summarize(rows)
    args.output_json.parent.mkdir(parents=True, exist_ok=True)
    args.output_json.write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8")
    write_summary_csv(summary, args.summary_csv)
    write_label_csv(summary, args.label_csv)
    print(f"Wrote dataset summary to: {args.output_json}")
    print(f"Wrote dataset tables to: {args.summary_csv} and {args.label_csv}")


if __name__ == "__main__":
    main()
