import argparse
import csv
from pathlib import Path
from typing import Any

try:
    from proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, representation_text, theorem_split
    from run_search import candidate_rankings, probability_maps_from_estimator, ranked_from_estimator, train_family_model
except ImportError:
    from scripts.proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, representation_text, theorem_split
    from scripts.run_search import candidate_rankings, probability_maps_from_estimator, ranked_from_estimator, train_family_model


def first_family_rank(row: dict[str, Any], candidates: list[dict[str, Any]]) -> int:
    for idx, candidate in enumerate(candidates, start=1):
        if candidate["tactic_family"] == row["tactic_family"]:
            return idx
    return len(candidates) + 1


def family_hit_at(row: dict[str, Any], candidates: list[dict[str, Any]], k: int) -> bool:
    return any(candidate["tactic_family"] == row["tactic_family"] for candidate in candidates[:k])


def bucket_name(index: int, n: int) -> str:
    if index < n / 3:
        return "low"
    if index < 2 * n / 3:
        return "mid"
    return "high"


def analyze(
    rows: list[dict[str, Any]],
    representation: str,
    family_model: str,
    family_weight: float,
    test_ratio: float,
    seed: int,
) -> list[dict[str, Any]]:
    train_rows, test_rows, _ = theorem_split(rows, test_ratio=test_ratio, seed=seed)
    train_rows, test_rows = prepare_representation_rows(train_rows, test_rows, representation)
    x_test = [representation_text(row, representation) for row in test_rows]
    model = train_family_model(train_rows, representation, family_model)
    ranked_families = ranked_from_estimator(model, x_test)
    probabilities = probability_maps_from_estimator(model, x_test)
    unguided = candidate_rankings(train_rows, test_rows, representation, "unguided", family_model=family_model)
    hard = candidate_rankings(train_rows, test_rows, representation, "family_guided", family_model=family_model)
    soft = candidate_rankings(
        train_rows,
        test_rows,
        representation,
        "family_soft",
        family_weight=family_weight,
        family_model=family_model,
    )

    examples = []
    for idx, row in enumerate(test_rows):
        predicted = ranked_families[idx][0] if ranked_families[idx] else ""
        confidence = probabilities[idx].get(predicted, 0.0)
        examples.append(
            {
                "row": row,
                "pred_correct": predicted == row["tactic_family"],
                "confidence": confidence,
                "unguided_rank": first_family_rank(row, unguided[idx]),
                "hard_rank": first_family_rank(row, hard[idx]),
                "soft_rank": first_family_rank(row, soft[idx]),
                "unguided_f5": family_hit_at(row, unguided[idx], 5),
                "hard_f5": family_hit_at(row, hard[idx], 5),
                "soft_f5": family_hit_at(row, soft[idx], 5),
            }
        )

    ordered = sorted(range(len(examples)), key=lambda i: examples[i]["confidence"])
    buckets: dict[tuple[str, bool], list[dict[str, Any]]] = {}
    for rank, idx in enumerate(ordered):
        item = examples[idx]
        buckets.setdefault((bucket_name(rank, len(examples)), bool(item["pred_correct"])), []).append(item)

    rows_out = []
    for bucket in ["low", "mid", "high"]:
        for correct in [True, False]:
            group = buckets.get((bucket, correct), [])
            if not group:
                continue
            n = len(group)
            rows_out.append(
                {
                    "confidence_bucket": bucket,
                    "prediction": "correct" if correct else "wrong",
                    "n": n,
                    "mean_confidence": sum(item["confidence"] for item in group) / n,
                    "unguided_rank": sum(item["unguided_rank"] for item in group) / n,
                    "hard_delta_rank": sum(item["hard_rank"] - item["unguided_rank"] for item in group) / n,
                    "soft_delta_rank": sum(item["soft_rank"] - item["unguided_rank"] for item in group) / n,
                    "hard_delta_family_at_5": sum(int(item["hard_f5"]) - int(item["unguided_f5"]) for item in group) / n,
                    "soft_delta_family_at_5": sum(int(item["soft_f5"]) - int(item["unguided_f5"]) for item in group) / n,
                }
            )
    return rows_out


def main() -> None:
    parser = argparse.ArgumentParser(description="Confidence-stratified exposure analysis for family guidance.")
    parser.add_argument("--data", type=Path, required=True)
    parser.add_argument("--representation", default="state_only")
    parser.add_argument("--family-model", default="logistic_regression")
    parser.add_argument("--family-weight", type=float, default=0.25)
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output", type=Path, default=Path("results/tables/confidence_exposure.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    output_rows = analyze(rows, args.representation, args.family_model, args.family_weight, args.test_ratio, args.seed)
    args.output.parent.mkdir(parents=True, exist_ok=True)
    with args.output.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "confidence_bucket",
                "prediction",
                "n",
                "mean_confidence",
                "unguided_rank",
                "hard_delta_rank",
                "soft_delta_rank",
                "hard_delta_family_at_5",
                "soft_delta_family_at_5",
            ],
        )
        writer.writeheader()
        writer.writerows(output_rows)
    print(f"Wrote confidence analysis to: {args.output}")


if __name__ == "__main__":
    main()
