import argparse
import csv
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

try:
    from proofstate_common import load_jsonl, normalize_entry, normalize_formula_text
except ImportError:
    from scripts.proofstate_common import load_jsonl, normalize_entry, normalize_formula_text


def normalized_goal_key(row: dict[str, Any]) -> str:
    context = " ; ".join(normalize_formula_text(str(item)) for item in row.get("local_context", []))
    goal = normalize_formula_text(str(row.get("main_goal", "")))
    return f"goal={goal} context={context}"


def summarize(rows: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    goal_groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
    tactic_counts = Counter(row["next_tactic"] for row in rows)
    family_counts = Counter(row["tactic_family"] for row in rows)
    file_counts = Counter(row["file"] for row in rows)
    theorem_counts = Counter(row["theorem"] for row in rows)
    for row in rows:
        goal_groups[normalized_goal_key(row)].append(row)

    duplicate_goal_rows = [group for group in goal_groups.values() if len(group) > 1]
    summary = [
        {"metric": "n_rows", "value": len(rows)},
        {"metric": "n_unique_normalized_state", "value": len(goal_groups)},
        {"metric": "n_duplicate_state_groups", "value": len(duplicate_goal_rows)},
        {"metric": "rows_in_duplicate_state_groups", "value": sum(len(group) for group in duplicate_goal_rows)},
        {"metric": "n_unique_next_tactic", "value": len(tactic_counts)},
        {"metric": "n_repeated_next_tactic", "value": sum(1 for count in tactic_counts.values() if count > 1)},
        {"metric": "n_unique_family", "value": len(family_counts)},
        {"metric": "n_files", "value": len(file_counts)},
        {"metric": "n_theorems", "value": len(theorem_counts)},
    ]

    examples = []
    for key, group in sorted(goal_groups.items(), key=lambda item: len(item[1]), reverse=True):
        if len(group) <= 1:
            continue
        families = Counter(row["tactic_family"] for row in group)
        files = Counter(row["file"] for row in group)
        examples.append(
            {
                "normalized_state": key[:240],
                "count": len(group),
                "families": ";".join(f"{label}:{count}" for label, count in families.most_common()),
                "files": ";".join(f"{label}:{count}" for label, count in files.most_common(5)),
                "example_theorems": ";".join(str(row["theorem"]) for row in group[:5]),
            }
        )
    return summary, examples


def write_csv(rows: list[dict[str, Any]], path: Path, fieldnames: list[str]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> None:
    parser = argparse.ArgumentParser(description="Analyze near-duplicate proof states and repeated tactics.")
    parser.add_argument("--data", type=Path, required=True)
    parser.add_argument("--summary-output", type=Path, default=Path("results/tables/duplicate_summary.csv"))
    parser.add_argument("--examples-output", type=Path, default=Path("results/tables/duplicate_examples.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    summary, examples = summarize(rows)
    write_csv(summary, args.summary_output, ["metric", "value"])
    write_csv(examples, args.examples_output, ["normalized_state", "count", "families", "files", "example_theorems"])
    print(f"Wrote duplicate summary to: {args.summary_output}")
    print(f"Wrote duplicate examples to: {args.examples_output}")


if __name__ == "__main__":
    main()
