import argparse
import csv
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

try:
    from proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, representation_text, theorem_split
    from run_experiments import build_models, ranked_from_estimator
except ImportError:
    from scripts.proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, representation_text, theorem_split
    from scripts.run_experiments import build_models, ranked_from_estimator


def prf(tp: int, fp: int, fn: int) -> tuple[float, float, float]:
    precision = tp / (tp + fp) if tp + fp else 0.0
    recall = tp / (tp + fn) if tp + fn else 0.0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
    return precision, recall, f1


def frequency_bucket(count: int) -> str:
    if count >= 50:
        return "head"
    if count >= 10:
        return "mid"
    return "rare"


def evaluate(
    rows: list[dict[str, Any]],
    representation: str,
    model_name: str,
    test_ratio: float,
    seed: int,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    train_rows, test_rows, split_meta = theorem_split(rows, test_ratio=test_ratio, seed=seed)
    train_rows, test_rows = prepare_representation_rows(train_rows, test_rows, representation)
    model = build_models()[model_name]
    x_train = [representation_text(row, representation) for row in train_rows]
    y_train = [row["tactic_family"] for row in train_rows]
    x_test = [representation_text(row, representation) for row in test_rows]
    y_test = [row["tactic_family"] for row in test_rows]
    model.fit(x_train, y_train)
    ranked = ranked_from_estimator(model, x_test)
    y_pred = [items[0] if items else "" for items in ranked]

    train_counts = Counter(y_train)
    labels = sorted(set(y_test) | set(y_pred))
    per_family = []
    bucket_values: dict[str, list[float]] = defaultdict(list)
    for label in labels:
        tp = sum(1 for y, p in zip(y_test, y_pred) if y == label and p == label)
        fp = sum(1 for y, p in zip(y_test, y_pred) if y != label and p == label)
        fn = sum(1 for y, p in zip(y_test, y_pred) if y == label and p != label)
        precision, recall, f1 = prf(tp, fp, fn)
        bucket = frequency_bucket(train_counts[label])
        bucket_values[bucket].append(f1)
        per_family.append(
            {
                "family": label,
                "train_count": train_counts[label],
                "test_count": sum(1 for y in y_test if y == label),
                "bucket": bucket,
                "precision": precision,
                "recall": recall,
                "f1": f1,
            }
        )

    grouped = []
    for bucket in ["head", "mid", "rare"]:
        values = bucket_values.get(bucket, [])
        grouped.append(
            {
                "bucket": bucket,
                "n_families": len(values),
                "macro_f1": sum(values) / len(values) if values else 0.0,
                "split_strategy": split_meta["strategy"],
                "seed": seed,
            }
        )
    return per_family, grouped


def write_csv(rows: list[dict[str, Any]], path: Path, fieldnames: list[str]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> None:
    parser = argparse.ArgumentParser(description="Report per-family and rare/mid/head grouped F1.")
    parser.add_argument("--data", type=Path, required=True)
    parser.add_argument("--representation", default="state_only")
    parser.add_argument("--model", default="tfidf_linear_svm")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--family-output", type=Path, default=Path("results/tables/per_family_f1.csv"))
    parser.add_argument("--group-output", type=Path, default=Path("results/tables/grouped_family_f1.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    per_family, grouped = evaluate(rows, args.representation, args.model, args.test_ratio, args.seed)
    write_csv(per_family, args.family_output, ["family", "train_count", "test_count", "bucket", "precision", "recall", "f1"])
    write_csv(grouped, args.group_output, ["bucket", "n_families", "macro_f1", "split_strategy", "seed"])
    print(f"Wrote per-family F1 to: {args.family_output}")
    print(f"Wrote grouped F1 to: {args.group_output}")


if __name__ == "__main__":
    main()
