import argparse
import csv
from collections import Counter
from pathlib import Path
from typing import Any

try:
    from proofstate_common import load_jsonl, normalize_entry, representation_text, theorem_split
    from run_experiments import build_models, ranked_from_estimator
except ImportError:
    from scripts.proofstate_common import load_jsonl, normalize_entry, representation_text, theorem_split
    from scripts.run_experiments import build_models, ranked_from_estimator


def truncate(value: str, limit: int = 180) -> str:
    value = " ".join(value.split())
    return value if len(value) <= limit else value[: limit - 3] + "..."


def analyze(
    rows: list[dict[str, Any]],
    representation: str,
    model_name: str,
    test_ratio: float,
    seed: int,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    train_rows, test_rows, split_meta = theorem_split(rows, test_ratio=test_ratio, seed=seed)
    model = build_models()[model_name]
    x_train = [representation_text(row, representation) for row in train_rows]
    y_train = [row["tactic_family"] for row in train_rows]
    x_test = [representation_text(row, representation) for row in test_rows]
    model.fit(x_train, y_train)
    ranked = ranked_from_estimator(model, x_test)

    confusions = Counter()
    errors = []
    for row, ranking in zip(test_rows, ranked):
        pred = ranking[0]
        gold = row["tactic_family"]
        confusions[(gold, pred)] += 1
        if pred != gold:
            errors.append(
                {
                    "seed": seed,
                    "representation": representation,
                    "model": model_name,
                    "theorem": row["theorem"],
                    "step_index": row["step_index"],
                    "gold": gold,
                    "predicted": pred,
                    "next_tactic": row["next_tactic"],
                    "main_goal": truncate(row["main_goal"]),
                    "local_context": truncate("; ".join(row["local_context"])),
                    "split_test_theorems": ",".join(split_meta["test_theorems"]),
                }
            )

    confusion_rows = [
        {"gold": gold, "predicted": pred, "count": count}
        for (gold, pred), count in confusions.most_common()
    ]
    return confusion_rows, errors


def write_csv(rows: list[dict[str, Any]], path: Path, fieldnames: list[str]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate confusion and error-analysis tables.")
    parser.add_argument("--data", type=Path, default=Path("data/leandojo_steps_checked.jsonl"))
    parser.add_argument("--representation", default="structured")
    parser.add_argument("--model", default="tfidf_linear_svm")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--confusion-output", type=Path, default=Path("results/tables/confusion_matrix.csv"))
    parser.add_argument("--errors-output", type=Path, default=Path("results/tables/error_examples.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    confusion_rows, errors = analyze(rows, args.representation, args.model, args.test_ratio, args.seed)
    write_csv(confusion_rows, args.confusion_output, ["gold", "predicted", "count"])
    write_csv(
        errors,
        args.errors_output,
        [
            "seed",
            "representation",
            "model",
            "theorem",
            "step_index",
            "gold",
            "predicted",
            "next_tactic",
            "main_goal",
            "local_context",
            "split_test_theorems",
        ],
    )
    print(f"Wrote confusion table to: {args.confusion_output}")
    print(f"Wrote error examples to: {args.errors_output}")


if __name__ == "__main__":
    main()
