import argparse
import csv
import json
from collections import Counter
from pathlib import Path
from typing import Any

try:
    from proofstate_common import (
        class_frequency_ranking,
        heuristic_predict,
        load_jsonl,
        normalize_entry,
        prepare_representation_rows,
        ranked_labels_from_scores,
        representation_text,
        theorem_split,
        train_val_test_split,
    )
except ImportError:
    from scripts.proofstate_common import (
        class_frequency_ranking,
        heuristic_predict,
        load_jsonl,
        normalize_entry,
        prepare_representation_rows,
        ranked_labels_from_scores,
        representation_text,
        theorem_split,
        train_val_test_split,
    )


def sklearn_imports() -> dict[str, Any]:
    try:
        from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
        from sklearn.linear_model import LogisticRegression
        from sklearn.naive_bayes import MultinomialNB
        from sklearn.metrics.pairwise import cosine_similarity
        from sklearn.pipeline import Pipeline
    except ImportError as exc:
        raise SystemExit(
            "scikit-learn is required. Install dependencies with `pip install -r requirements.txt`."
        ) from exc
    return {
        "CountVectorizer": CountVectorizer,
        "TfidfVectorizer": TfidfVectorizer,
        "LogisticRegression": LogisticRegression,
        "MultinomialNB": MultinomialNB,
        "Pipeline": Pipeline,
        "cosine_similarity": cosine_similarity,
    }


def ranked_from_estimator(model: Any, x_test: list[str]) -> list[list[str]]:
    classes = [str(x) for x in model.classes_]
    if hasattr(model, "predict_proba"):
        probabilities = model.predict_proba(x_test)
        return [ranked_labels_from_scores(classes, list(row)) for row in probabilities]
    predictions = [str(x) for x in model.predict(x_test)]
    return [[prediction] + [label for label in classes if label != prediction] for prediction in predictions]


def probability_maps_from_estimator(model: Any, x_test: list[str]) -> list[dict[str, float]]:
    classes = [str(x) for x in model.classes_]
    if hasattr(model, "predict_proba"):
        probabilities = model.predict_proba(x_test)
        return [
            {label: float(score) for label, score in zip(classes, row)}
            for row in probabilities
        ]
    ranked = ranked_from_estimator(model, x_test)
    return [
        {label: 1.0 / (1.0 + rank) for rank, label in enumerate(labels)}
        for labels in ranked
    ]


def train_family_model(
    train_rows: list[dict[str, Any]], representation: str, family_model: str
) -> Any:
    skl = sklearn_imports()
    Pipeline = skl["Pipeline"]
    CountVectorizer = skl["CountVectorizer"]
    TfidfVectorizer = skl["TfidfVectorizer"]
    LogisticRegression = skl["LogisticRegression"]
    MultinomialNB = skl["MultinomialNB"]
    if family_model == "naive_bayes":
        model = Pipeline(
            [
                ("vectorizer", CountVectorizer(tokenizer=str.split, token_pattern=None, lowercase=False)),
                ("model", MultinomialNB()),
            ]
        )
    elif family_model == "logistic_regression":
        model = Pipeline(
            [
                ("vectorizer", TfidfVectorizer(tokenizer=str.split, token_pattern=None, lowercase=False, min_df=1)),
                ("model", LogisticRegression(max_iter=1000, class_weight="balanced")),
            ]
        )
    else:
        raise ValueError(f"Unknown family model: {family_model}")
    x_train = [representation_text(row, representation) for row in train_rows]
    y_train = [row["tactic_family"] for row in train_rows]
    model.fit(x_train, y_train)
    return model


def candidate_rankings(
    train_rows: list[dict[str, Any]],
    test_rows: list[dict[str, Any]],
    representation: str,
    strategy: str,
    family_weight: float = 0.25,
    family_model: str = "logistic_regression",
    top_m: int = 3,
    rrf_k: int = 60,
) -> list[list[dict[str, Any]]]:
    skl = sklearn_imports()
    vectorizer = skl["TfidfVectorizer"](tokenizer=str.split, token_pattern=None, lowercase=False, min_df=1)
    cosine_similarity = skl["cosine_similarity"]

    train_texts = [representation_text(row, representation) for row in train_rows]
    test_texts = [representation_text(row, representation) for row in test_rows]
    train_matrix = vectorizer.fit_transform(train_texts)
    test_matrix = vectorizer.transform(test_texts)
    similarities = cosine_similarity(test_matrix, train_matrix)

    label_ranking = class_frequency_ranking(row["tactic_family"] for row in train_rows)
    family_rankings: list[list[str]]
    family_probabilities: list[dict[str, float]]
    if strategy in {"family_guided", "family_top_m", "family_rrf"}:
        model = train_family_model(train_rows, representation, family_model)
        family_rankings = ranked_from_estimator(model, test_texts)
        family_probabilities = [{} for _ in test_rows]
    elif strategy == "family_soft":
        model = train_family_model(train_rows, representation, family_model)
        family_rankings = ranked_from_estimator(model, test_texts)
        family_probabilities = probability_maps_from_estimator(model, test_texts)
    elif strategy == "oracle_family":
        family_rankings = [[row["tactic_family"]] + [label for label in label_ranking if label != row["tactic_family"]] for row in test_rows]
        family_probabilities = [{row["tactic_family"]: 1.0} for row in test_rows]
    else:
        family_rankings = [label_ranking for _ in test_rows]
        family_probabilities = [{} for _ in test_rows]

    ranked_candidates = []
    for row_idx, test_row in enumerate(test_rows):
        if strategy == "family_top_m":
            top_families = family_rankings[row_idx][:top_m]
            buckets: dict[str, list[tuple[float, dict[str, Any]]]] = {family: [] for family in top_families}
            rest: list[tuple[float, dict[str, Any]]] = []
            for candidate_idx, train_row in enumerate(train_rows):
                similarity = float(similarities[row_idx, candidate_idx])
                family = train_row["tactic_family"]
                if family in buckets:
                    buckets[family].append((similarity, train_row))
                else:
                    rest.append((similarity, train_row))
            for bucket in buckets.values():
                bucket.sort(key=lambda item: item[0], reverse=True)
            rest.sort(key=lambda item: item[0], reverse=True)
            ranked: list[dict[str, Any]] = []
            offset = 0
            while True:
                added = False
                for family in top_families:
                    bucket = buckets[family]
                    if offset < len(bucket):
                        ranked.append(bucket[offset][1])
                        added = True
                if not added:
                    break
                offset += 1
            ranked.extend(candidate for _, candidate in rest)
            ranked_candidates.append(ranked)
            continue

        row_scores = []
        family_priority = {family: rank for rank, family in enumerate(family_rankings[row_idx])}
        similarity_order = sorted(
            range(len(train_rows)),
            key=lambda idx: float(similarities[row_idx, idx]),
            reverse=True,
        )
        similarity_priority = {candidate_idx: rank for rank, candidate_idx in enumerate(similarity_order)}
        for candidate_idx, train_row in enumerate(train_rows):
            similarity = float(similarities[row_idx, candidate_idx])
            if strategy in {"family_guided", "oracle_family"}:
                priority = family_priority.get(train_row["tactic_family"], len(family_priority) + 1)
                score = (1.0 / (1.0 + priority), similarity)
            elif strategy == "family_soft":
                family_probability = family_probabilities[row_idx].get(train_row["tactic_family"], 0.0)
                score = (similarity + family_weight * family_probability, similarity)
            elif strategy == "family_rrf":
                family_rank = family_priority.get(train_row["tactic_family"], len(family_priority)) + 1
                sim_rank = similarity_priority[candidate_idx] + 1
                fused = 1.0 / (rrf_k + sim_rank) + 1.0 / (rrf_k + family_rank)
                score = (fused, similarity)
            else:
                score = (similarity,)
            row_scores.append((score, train_row))
        row_scores.sort(key=lambda item: item[0], reverse=True)
        ranked_candidates.append([candidate for _, candidate in row_scores])
    return ranked_candidates


def evaluate_search(
    rows: list[dict[str, Any]],
    strategy: str,
    representation: str,
    test_ratio: float,
    seed: int,
    max_k: int,
    family_weight: float = 0.25,
    family_model: str = "logistic_regression",
    retrieved_premise_top_n: int = 8,
    top_m: int = 3,
    rrf_k: int = 60,
) -> dict[str, Any]:
    train_rows, test_rows, split_meta = theorem_split(rows, test_ratio=test_ratio, seed=seed)
    train_rows, test_rows = prepare_representation_rows(
        train_rows,
        test_rows,
        representation,
        retrieved_premise_top_n=retrieved_premise_top_n,
    )
    ranked_candidates = candidate_rankings(
        train_rows,
        test_rows,
        representation,
        strategy,
        family_weight=family_weight,
        family_model=family_model,
        top_m=top_m,
        rrf_k=rrf_k,
    )

    metrics: dict[str, float] = {}
    per_query: list[dict[str, Any]] = []
    for k in [1, 3, 5, 10, max_k]:
        k = min(k, max_k)
        family_hits = 0
        exact_hits = 0
        for test_row, candidates in zip(test_rows, ranked_candidates):
            top = candidates[:k]
            if any(candidate["tactic_family"] == test_row["tactic_family"] for candidate in top):
                family_hits += 1
            if any(candidate["next_tactic"] == test_row["next_tactic"] for candidate in top):
                exact_hits += 1
        metrics[f"family_success_at_{k}"] = family_hits / len(test_rows) if test_rows else 0.0
        metrics[f"exact_tactic_success_at_{k}"] = exact_hits / len(test_rows) if test_rows else 0.0

    for test_row, candidates in zip(test_rows, ranked_candidates):
        query = {
            "query_id": f"{test_row.get('file', '')}:{test_row.get('theorem', '')}:{test_row.get('step_index', 0)}",
            "theorem": test_row.get("theorem", ""),
            "file": test_row.get("file", ""),
            "step_index": test_row.get("step_index", 0),
            "gold_family": test_row["tactic_family"],
            "gold_tactic": test_row["next_tactic"],
        }
        for k in [1, 3, 5, 10, max_k]:
            k = min(k, max_k)
            top = candidates[:k]
            query[f"family_hit_at_{k}"] = any(candidate["tactic_family"] == test_row["tactic_family"] for candidate in top)
            query[f"exact_hit_at_{k}"] = any(candidate["next_tactic"] == test_row["next_tactic"] for candidate in top)
        if candidates:
            query["top_family"] = candidates[0]["tactic_family"]
            query["top_tactic"] = candidates[0]["next_tactic"]
        per_query.append(query)

    ranks = []
    for test_row, candidates in zip(test_rows, ranked_candidates):
        rank = None
        for idx, candidate in enumerate(candidates, start=1):
            if candidate["tactic_family"] == test_row["tactic_family"]:
                rank = idx
                break
        if rank is not None:
            ranks.append(rank)
    metrics["mean_family_rank"] = sum(ranks) / len(ranks) if ranks else 0.0

    return {
        "strategy": strategy,
        "representation": representation,
        "family_weight": family_weight if strategy == "family_soft" else None,
        "family_model": family_model if strategy in {"family_guided", "family_soft", "family_top_m", "family_rrf"} else None,
        "top_m": top_m if strategy == "family_top_m" else None,
        "rrf_k": rrf_k if strategy == "family_rrf" else None,
        "retrieved_premise_top_n": retrieved_premise_top_n if representation == "retrieved_premise" else None,
        "split": split_meta,
        "max_k": max_k,
        "n_train_candidates": len(train_rows),
        "label_distribution": dict(Counter(row["tactic_family"] for row in rows)),
        "metrics": metrics,
        "per_query": per_query,
        "note": (
            "This is a retrieval-search proxy. family_soft ranks candidates by "
            "cosine similarity plus family_weight times predicted family probability."
        ),
    }


def select_family_weight(
    train_rows: list[dict[str, Any]],
    val_rows: list[dict[str, Any]],
    representation: str,
    seed: int,
    max_k: int,
    family_weights: list[float],
    family_model: str,
    retrieved_premise_top_n: int,
) -> tuple[float, list[dict[str, Any]]]:
    del seed
    val_results = []
    for weight in family_weights:
        prepared_train, prepared_val = prepare_representation_rows(
            train_rows,
            val_rows,
            representation,
            retrieved_premise_top_n=retrieved_premise_top_n,
        )
        ranked = candidate_rankings(
            prepared_train,
            prepared_val,
            representation,
            "family_soft",
            family_weight=weight,
            family_model=family_model,
        )
        exact5 = 0
        exact1 = 0
        family5 = 0
        for row, candidates in zip(prepared_val, ranked):
            top5 = candidates[: min(5, max_k)]
            top1 = candidates[:1]
            exact5 += any(candidate["next_tactic"] == row["next_tactic"] for candidate in top5)
            exact1 += any(candidate["next_tactic"] == row["next_tactic"] for candidate in top1)
            family5 += any(candidate["tactic_family"] == row["tactic_family"] for candidate in top5)
        n = len(prepared_val) or 1
        val_results.append(
            {
                "family_weight": weight,
                "exact_tactic_success_at_5": exact5 / n,
                "exact_tactic_success_at_1": exact1 / n,
                "family_success_at_5": family5 / n,
            }
        )
    best = max(
        val_results,
        key=lambda row: (
            row["exact_tactic_success_at_5"],
            row["exact_tactic_success_at_1"],
            row["family_success_at_5"],
            -row["family_weight"],
        ),
    )
    return float(best["family_weight"]), val_results


def evaluate_search_with_validation(
    rows: list[dict[str, Any]],
    strategy: str,
    representation: str,
    test_ratio: float,
    val_ratio: float,
    seed: int,
    max_k: int,
    family_weights: list[float],
    family_model: str = "logistic_regression",
    retrieved_premise_top_n: int = 8,
    top_m: int = 3,
    rrf_k: int = 60,
) -> dict[str, Any]:
    train_rows, val_rows, test_rows, split_meta = train_val_test_split(
        rows,
        test_ratio=test_ratio,
        val_ratio=val_ratio,
        seed=seed,
    )
    selected_weight = family_weights[0] if family_weights else 0.0
    validation_results: list[dict[str, Any]] = []
    if strategy == "family_soft":
        selected_weight, validation_results = select_family_weight(
            train_rows,
            val_rows,
            representation,
            seed,
            max_k,
            family_weights,
            family_model,
            retrieved_premise_top_n,
        )
    prepared_train, prepared_test = prepare_representation_rows(
        train_rows,
        test_rows,
        representation,
        retrieved_premise_top_n=retrieved_premise_top_n,
    )
    ranked_candidates = candidate_rankings(
        prepared_train,
        prepared_test,
        representation,
        strategy,
        family_weight=selected_weight,
        family_model=family_model,
        top_m=top_m,
        rrf_k=rrf_k,
    )
    metrics: dict[str, float] = {}
    per_query: list[dict[str, Any]] = []
    for k in [1, 3, 5, 10, max_k]:
        k = min(k, max_k)
        family_hits = 0
        exact_hits = 0
        for test_row, candidates in zip(prepared_test, ranked_candidates):
            top = candidates[:k]
            if any(candidate["tactic_family"] == test_row["tactic_family"] for candidate in top):
                family_hits += 1
            if any(candidate["next_tactic"] == test_row["next_tactic"] for candidate in top):
                exact_hits += 1
        metrics[f"family_success_at_{k}"] = family_hits / len(prepared_test) if prepared_test else 0.0
        metrics[f"exact_tactic_success_at_{k}"] = exact_hits / len(prepared_test) if prepared_test else 0.0
    ranks = []
    for test_row, candidates in zip(prepared_test, ranked_candidates):
        for idx, candidate in enumerate(candidates, start=1):
            if candidate["tactic_family"] == test_row["tactic_family"]:
                ranks.append(idx)
                break
    metrics["mean_family_rank"] = sum(ranks) / len(ranks) if ranks else 0.0
    for test_row, candidates in zip(prepared_test, ranked_candidates):
        query = {
            "query_id": f"{test_row.get('file', '')}:{test_row.get('theorem', '')}:{test_row.get('step_index', 0)}",
            "theorem": test_row.get("theorem", ""),
            "file": test_row.get("file", ""),
            "step_index": test_row.get("step_index", 0),
            "gold_family": test_row["tactic_family"],
            "gold_tactic": test_row["next_tactic"],
            "top_family": candidates[0]["tactic_family"] if candidates else "",
            "top_tactic": candidates[0]["next_tactic"] if candidates else "",
        }
        for k in [1, 3, 5, 10, max_k]:
            k = min(k, max_k)
            top = candidates[:k]
            query[f"family_hit_at_{k}"] = any(candidate["tactic_family"] == test_row["tactic_family"] for candidate in top)
            query[f"exact_hit_at_{k}"] = any(candidate["next_tactic"] == test_row["next_tactic"] for candidate in top)
        per_query.append(query)
    return {
        "strategy": strategy,
        "representation": representation,
        "family_weight": selected_weight if strategy == "family_soft" else None,
        "family_model": family_model if strategy in {"family_guided", "family_soft", "family_top_m", "family_rrf"} else None,
        "top_m": top_m if strategy == "family_top_m" else None,
        "rrf_k": rrf_k if strategy == "family_rrf" else None,
        "retrieved_premise_top_n": retrieved_premise_top_n if representation == "retrieved_premise" else None,
        "split": split_meta,
        "max_k": max_k,
        "n_train_candidates": len(prepared_train),
        "label_distribution": dict(Counter(row["tactic_family"] for row in rows)),
        "metrics": metrics,
        "per_query": per_query,
        "validation_selection": {
            "selected_family_weight": selected_weight if strategy == "family_soft" else None,
            "validation_results": validation_results,
            "selection_metric": "exact_tactic_success_at_5, then exact@1, then family@5",
        },
        "note": (
            "Validation mode keeps test labels out of family_weight selection; "
            "family_soft uses the selected validation weight on the test split."
        ),
    }


def write_summary(results: list[dict[str, Any]], output_path: Path) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    metric_names = sorted({name for result in results for name in result["metrics"]})
    with output_path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=["strategy", "representation", "family_model", "family_weight", "top_m", "rrf_k", *metric_names],
        )
        writer.writeheader()
        for result in results:
            row = {
                "strategy": result["strategy"],
                "representation": result["representation"],
                "family_model": result.get("family_model"),
                "family_weight": result.get("family_weight"),
                "top_m": result.get("top_m"),
                "rrf_k": result.get("rrf_k"),
            }
            row.update(result["metrics"])
            writer.writerow(row)


def parse_strategies(value: str) -> list[str]:
    choices = ["unguided", "family_guided", "family_soft", "family_top_m", "family_rrf", "oracle_family"]
    if value == "all":
        return choices
    strategies = [item.strip() for item in value.split(",") if item.strip()]
    unknown = sorted(set(strategies) - set(choices))
    if unknown:
        raise SystemExit(f"Unknown strategies: {unknown}. Expected one of {choices} or all.")
    return strategies


def parse_family_weights(value: str) -> list[float]:
    weights = []
    for item in value.split(","):
        item = item.strip()
        if not item:
            continue
        weight = float(item)
        if weight < 0:
            raise SystemExit("--family-weights must be non-negative.")
        weights.append(weight)
    if not weights:
        raise SystemExit("--family-weights did not contain any numeric values.")
    return weights


def parse_family_model(value: str) -> str:
    choices = {"logistic_regression", "naive_bayes"}
    if value not in choices:
        raise SystemExit(f"Unknown --family-model {value!r}. Expected one of {sorted(choices)}.")
    return value


def main() -> None:
    parser = argparse.ArgumentParser(description="Run lightweight retrieval-search experiments.")
    parser.add_argument("--data", type=Path, default=Path("data/pilot_pairs_checked.jsonl"))
    parser.add_argument("--strategy", default="family_guided")
    parser.add_argument("--representation", default="normalized")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--val-ratio", type=float, default=0.1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-k", type=int, default=20)
    parser.add_argument("--use-validation", action="store_true")
    parser.add_argument("--retrieved-premise-top-n", type=int, default=8)
    parser.add_argument("--top-m", type=int, default=3)
    parser.add_argument("--rrf-k", type=int, default=60)
    parser.add_argument(
        "--family-weights",
        default="0.25",
        help="Comma-separated weights for family_soft scoring.",
    )
    parser.add_argument(
        "--family-model",
        default="logistic_regression",
        help="Family predictor used by family_guided and family_soft: logistic_regression or naive_bayes.",
    )
    parser.add_argument("--output-dir", type=Path, default=Path("results/search"))
    parser.add_argument("--table-output", type=Path, default=Path("results/tables/search_summary.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    strategies = parse_strategies(args.strategy)
    family_weights = parse_family_weights(args.family_weights)
    family_model = parse_family_model(args.family_model)
    args.output_dir.mkdir(parents=True, exist_ok=True)

    results = []
    for strategy in strategies:
        weights = family_weights if strategy == "family_soft" else [0.0]
        for family_weight in weights:
            result = evaluate_search(
                rows,
                strategy,
                args.representation,
                args.test_ratio,
                args.seed,
                args.max_k,
                family_weight=family_weight,
                family_model=family_model,
                retrieved_premise_top_n=args.retrieved_premise_top_n,
                top_m=args.top_m,
                rrf_k=args.rrf_k,
            ) if not args.use_validation else evaluate_search_with_validation(
                rows,
                strategy,
                args.representation,
                args.test_ratio,
                args.val_ratio,
                args.seed,
                args.max_k,
                family_weights=family_weights,
                family_model=family_model,
                retrieved_premise_top_n=args.retrieved_premise_top_n,
                top_m=args.top_m,
                rrf_k=args.rrf_k,
            )
            results.append(result)
            selected_weight = result.get("family_weight", family_weight)
            weight_suffix = f"_w{selected_weight:g}" if strategy == "family_soft" else ""
            validation_suffix = "_val" if args.use_validation else ""
            model_suffix = f"_{family_model}" if strategy in {"family_guided", "family_soft", "family_top_m", "family_rrf"} else ""
            output_path = args.output_dir / f"{args.data.stem}_{strategy}{model_suffix}{weight_suffix}{validation_suffix}_{args.representation}_seed{args.seed}.json"
            with output_path.open("w", encoding="utf-8") as f:
                json.dump(result, f, indent=2, ensure_ascii=False, sort_keys=True)
            print(f"{strategy}{weight_suffix}: wrote {output_path}")
            if args.use_validation:
                break

    write_summary(results, args.table_output)
    print(f"Wrote search summary to: {args.table_output}")


if __name__ == "__main__":
    main()
