import argparse
import csv
import json
from collections import Counter
from pathlib import Path
from typing import Any

try:
    from proofstate_common import (
        REPRESENTATIONS,
        ALL_REPRESENTATIONS,
        accuracy,
        class_frequency_ranking,
        heuristic_predict,
        load_jsonl,
        macro_f1,
        normalize_entry,
        prepare_representation_rows,
        ranked_labels_from_scores,
        representation_text,
        theorem_split,
        top_k_accuracy,
    )
except ImportError:
    from scripts.proofstate_common import (
        REPRESENTATIONS,
        ALL_REPRESENTATIONS,
        accuracy,
        class_frequency_ranking,
        heuristic_predict,
        load_jsonl,
        macro_f1,
        normalize_entry,
        prepare_representation_rows,
        ranked_labels_from_scores,
        representation_text,
        theorem_split,
        top_k_accuracy,
    )


def sklearn_imports() -> dict[str, Any]:
    try:
        from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
        from sklearn.linear_model import LogisticRegression
        from sklearn.naive_bayes import MultinomialNB
        from sklearn.pipeline import Pipeline
        from sklearn.svm import LinearSVC
    except ImportError as exc:
        raise SystemExit(
            "scikit-learn is required. Install dependencies with `pip install -r requirements.txt`."
        ) from exc

    return {
        "CountVectorizer": CountVectorizer,
        "TfidfVectorizer": TfidfVectorizer,
        "LogisticRegression": LogisticRegression,
        "MultinomialNB": MultinomialNB,
        "Pipeline": Pipeline,
        "LinearSVC": LinearSVC,
    }


def parse_representations(value: str) -> list[str]:
    if value == "all":
        return list(REPRESENTATIONS)
    if value == "all_with_oracle":
        return list(ALL_REPRESENTATIONS)
    reps = [item.strip() for item in value.split(",") if item.strip()]
    unknown = sorted(set(reps) - set(ALL_REPRESENTATIONS))
    if unknown:
        raise SystemExit(
            f"Unknown representations: {unknown}. Expected one of {ALL_REPRESENTATIONS}, all, or all_with_oracle."
        )
    return reps


def ranked_from_estimator(model: Any, x_test: list[str]) -> list[list[str]]:
    classes = [str(x) for x in model.classes_]
    if hasattr(model, "predict_proba"):
        probabilities = model.predict_proba(x_test)
        return [ranked_labels_from_scores(classes, list(row)) for row in probabilities]
    if hasattr(model, "decision_function"):
        scores = model.decision_function(x_test)
        if len(classes) == 2 and getattr(scores, "ndim", 1) == 1:
            return [ranked_labels_from_scores(classes, [-float(score), float(score)]) for score in scores]
        return [ranked_labels_from_scores(classes, list(row)) for row in scores]
    predictions = [str(x) for x in model.predict(x_test)]
    return [[prediction] + [label for label in classes if label != prediction] for prediction in predictions]


def metrics_from_ranked(y_true: list[str], ranked: list[list[str]]) -> dict[str, float]:
    y_pred = [items[0] if items else "" for items in ranked]
    return {
        "accuracy": accuracy(y_true, y_pred),
        "macro_f1": macro_f1(y_true, y_pred),
        "top1_accuracy": top_k_accuracy(y_true, ranked, 1),
        "top3_accuracy": top_k_accuracy(y_true, ranked, 3),
        "top5_accuracy": top_k_accuracy(y_true, ranked, 5),
    }


def build_models() -> dict[str, Any]:
    skl = sklearn_imports()
    Pipeline = skl["Pipeline"]
    CountVectorizer = skl["CountVectorizer"]
    TfidfVectorizer = skl["TfidfVectorizer"]
    MultinomialNB = skl["MultinomialNB"]
    LogisticRegression = skl["LogisticRegression"]
    LinearSVC = skl["LinearSVC"]

    return {
        "text_naive_bayes": Pipeline(
            [("vectorizer", CountVectorizer(tokenizer=str.split, token_pattern=None, lowercase=False)), ("model", MultinomialNB())]
        ),
        "tfidf_logistic_regression": Pipeline(
            [
                ("vectorizer", TfidfVectorizer(tokenizer=str.split, token_pattern=None, lowercase=False, min_df=1)),
                ("model", LogisticRegression(max_iter=1000, class_weight="balanced")),
            ]
        ),
        "tfidf_linear_svm": Pipeline(
            [
                ("vectorizer", TfidfVectorizer(tokenizer=str.split, token_pattern=None, lowercase=False, min_df=1)),
                ("model", LinearSVC(class_weight="balanced", dual="auto")),
            ]
        ),
    }


def evaluate_representation(
    rows: list[dict[str, Any]],
    representation: str,
    test_ratio: float,
    seed: int,
    retrieved_premise_top_n: int = 8,
) -> dict[str, Any]:
    train_rows, test_rows, split_meta = theorem_split(rows, test_ratio=test_ratio, seed=seed)
    if not train_rows or not test_rows:
        raise ValueError("Split produced empty train/test set. Expand data or adjust test_ratio.")
    train_rows, test_rows = prepare_representation_rows(
        train_rows,
        test_rows,
        representation,
        retrieved_premise_top_n=retrieved_premise_top_n,
    )

    y_train = [row["tactic_family"] for row in train_rows]
    y_test = [row["tactic_family"] for row in test_rows]
    x_train = [representation_text(row, representation) for row in train_rows]
    x_test = [representation_text(row, representation) for row in test_rows]

    label_ranking = class_frequency_ranking(y_train)
    majority_ranked = [label_ranking for _ in test_rows]
    fallback = label_ranking[0]

    heuristic_ranked = []
    for row in test_rows:
        prediction = heuristic_predict(row, fallback)
        heuristic_ranked.append([prediction] + [label for label in label_ranking if label != prediction])

    baselines: dict[str, Any] = {
        "majority_class": metrics_from_ranked(y_test, majority_ranked),
        "keyword_heuristic": metrics_from_ranked(y_test, heuristic_ranked),
    }

    for model_name, model in build_models().items():
        model.fit(x_train, y_train)
        ranked = ranked_from_estimator(model, x_test)
        baselines[model_name] = metrics_from_ranked(y_test, ranked)

    return {
        "representation": representation,
        "split": split_meta,
        "retrieved_premise_top_n": retrieved_premise_top_n if representation == "retrieved_premise" else None,
        "leakage_policy": (
            "oracle-only next-tactic premise/AST metadata"
            if representation in {"oracle_premise", "premise"}
            else "state-visible inputs only"
        ),
        "label_distribution": dict(Counter(row["tactic_family"] for row in rows)),
        "baselines": baselines,
    }


def write_summary(results: list[dict[str, Any]], output_path: Path) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "representation",
                "model",
                "accuracy",
                "macro_f1",
                "top1_accuracy",
                "top3_accuracy",
                "top5_accuracy",
            ],
        )
        writer.writeheader()
        for result in results:
            for model, metrics in result["baselines"].items():
                row = {"representation": result["representation"], "model": model}
                row.update(metrics)
                writer.writerow(row)


def main() -> None:
    parser = argparse.ArgumentParser(description="Run representation ablations for tactic-family prediction.")
    parser.add_argument("--data", type=Path, default=Path("data/pilot_pairs_checked.jsonl"))
    parser.add_argument("--representations", default="all")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--retrieved-premise-top-n", type=int, default=8)
    parser.add_argument("--output-dir", type=Path, default=Path("results/classification"))
    parser.add_argument("--table-output", type=Path, default=Path("results/tables/classification_summary.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    representations = parse_representations(args.representations)
    args.output_dir.mkdir(parents=True, exist_ok=True)

    results = []
    dataset_name = args.data.stem
    for representation in representations:
        result = evaluate_representation(
            rows,
            representation,
            args.test_ratio,
            args.seed,
            retrieved_premise_top_n=args.retrieved_premise_top_n,
        )
        results.append(result)
        output_path = args.output_dir / f"{dataset_name}_{representation}_seed{args.seed}.json"
        with output_path.open("w", encoding="utf-8") as f:
            json.dump(result, f, indent=2, ensure_ascii=False, sort_keys=True)
        print(f"{representation}: wrote {output_path}")

    write_summary(results, args.table_output)
    print(f"Wrote summary table to: {args.table_output}")


if __name__ == "__main__":
    main()
