import argparse
import csv
import json
from collections import Counter
from pathlib import Path
from typing import Any

try:
    from proofstate_common import (
        ALL_REPRESENTATIONS,
        class_frequency_ranking,
        grouped_split,
        heuristic_predict,
        load_jsonl,
        normalize_entry,
        prepare_representation_rows,
        representation_text,
    )
    from run_experiments import build_models, metrics_from_ranked, parse_representations, ranked_from_estimator
except ImportError:
    from scripts.proofstate_common import (
        ALL_REPRESENTATIONS,
        class_frequency_ranking,
        grouped_split,
        heuristic_predict,
        load_jsonl,
        normalize_entry,
        prepare_representation_rows,
        representation_text,
    )
    from scripts.run_experiments import build_models, metrics_from_ranked, parse_representations, ranked_from_estimator


def evaluate_group_split(
    rows: list[dict[str, Any]],
    representation: str,
    group_field: str,
    test_ratio: float,
    seed: int,
) -> dict[str, Any]:
    if representation not in ALL_REPRESENTATIONS:
        raise ValueError(f"Unknown representation: {representation}")
    train_rows, test_rows, split_meta = grouped_split(rows, group_field, test_ratio, seed)
    train_rows, test_rows = prepare_representation_rows(train_rows, test_rows, representation)
    y_train = [row["tactic_family"] for row in train_rows]
    y_test = [row["tactic_family"] for row in test_rows]
    x_train = [representation_text(row, representation) for row in train_rows]
    x_test = [representation_text(row, representation) for row in test_rows]

    label_ranking = class_frequency_ranking(y_train)
    majority_ranked = [label_ranking for _ in test_rows]
    fallback = label_ranking[0] if label_ranking else "unknown"
    heuristic_ranked = []
    for row in test_rows:
        prediction = heuristic_predict(row, fallback)
        heuristic_ranked.append([prediction] + [label for label in label_ranking if label != prediction])

    baselines: dict[str, Any] = {
        "majority_class": metrics_from_ranked(y_test, majority_ranked),
        "keyword_heuristic": metrics_from_ranked(y_test, heuristic_ranked),
    }
    for model_name, model in build_models().items():
        model.fit(x_train, y_train)
        baselines[model_name] = metrics_from_ranked(y_test, ranked_from_estimator(model, x_test))
    return {
        "representation": representation,
        "split": split_meta,
        "label_distribution": dict(Counter(row["tactic_family"] for row in rows)),
        "baselines": baselines,
    }


def write_summary(results: list[dict[str, Any]], output_path: Path) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=["representation", "model", "accuracy", "macro_f1", "top1_accuracy", "top3_accuracy", "top5_accuracy"],
        )
        writer.writeheader()
        for result in results:
            for model, metrics in result["baselines"].items():
                row = {"representation": result["representation"], "model": model}
                row.update(metrics)
                writer.writerow(row)


def main() -> None:
    parser = argparse.ArgumentParser(description="Run representation ablations with file/module-level splits.")
    parser.add_argument("--data", type=Path, required=True)
    parser.add_argument("--representations", default="state_only,state_meta,retrieved_premise")
    parser.add_argument("--group-field", default="file")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output-dir", type=Path, default=Path("results/classification"))
    parser.add_argument("--table-output", type=Path, default=Path("results/tables/file_split_classification_summary.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    results = []
    args.output_dir.mkdir(parents=True, exist_ok=True)
    for representation in parse_representations(args.representations):
        result = evaluate_group_split(rows, representation, args.group_field, args.test_ratio, args.seed)
        results.append(result)
        output = args.output_dir / f"{args.data.stem}_{args.group_field}_split_{representation}_seed{args.seed}.json"
        output.write_text(json.dumps(result, indent=2, ensure_ascii=False, sort_keys=True), encoding="utf-8")
        print(f"{representation}: wrote {output}")
    write_summary(results, args.table_output)
    print(f"Wrote file/module split summary to: {args.table_output}")


if __name__ == "__main__":
    main()
