"""Recompute leaderboard metrics from a system's per-topic summary JSON.

Usage:
    python -m ckm_benchmark.recompute --summary results/lite_summary.json
    python -m ckm_benchmark.recompute --summary results/lite_summary.json results/batch_summary.json

Prints the row that would appear on the leaderboard.
"""

import argparse
import json
import statistics as st
from pathlib import Path

try:
    from scipy.stats import wilcoxon

    _SCIPY_AVAILABLE = True
except ImportError:
    _SCIPY_AVAILABLE = False


def summary_row(label: str, data: list[dict]) -> dict:
    yields = [r.get("yield", 0) for r in data]
    hit_rates = [r.get("hit_rate", 0.0) for r in data]
    best_match = [r.get("best_match_score", 0.0) for r in data]
    tokens = [
        r.get("total_generation_tokens", r.get("init_tokens", 0) + r.get("evolution_tokens", 0))
        for r in data
    ]
    durations = [r.get("duration", 0.0) for r in data]
    coverage = sum(1 for r in data if r.get("hit_rate", 0) > 0)

    return {
        "label": label,
        "n_topics": len(data),
        "yield_mean": sum(yields) / max(1, len(yields)),
        "hit_rate_mean": sum(hit_rates) / max(1, len(hit_rates)),
        "best_match_mean": sum(best_match) / max(1, len(best_match)),
        "tokens_mean_M": sum(tokens) / max(1, len(tokens)) / 1e6,
        "duration_median_min": st.median(durations) / 60 if durations else 0,
        "duration_total_h": sum(durations) / 3600,
        "coverage": f"{coverage}/{len(data)}",
        "coverage_pct": 100 * coverage / max(1, len(data)),
    }


def print_row(row: dict) -> None:
    print(f"\n=== {row['label']} (n={row['n_topics']}) ===")
    print(f"  yield/topic   : {row['yield_mean']:.2f}")
    print(f"  hit rate %    : {row['hit_rate_mean']:.3f}")
    print(f"  coverage      : {row['coverage']} = {row['coverage_pct']:.1f}%")
    print(f"  best match    : {row['best_match_mean']:.3f}")
    print(f"  tokens (M)    : {row['tokens_mean_M']:.3f}")
    print(f"  duration mdn  : {row['duration_median_min']:.1f} min")
    print(f"  total wallclk : {row['duration_total_h']:.1f} h")


def maybe_paired_wilcoxon(a: list[dict], b: list[dict], label_a: str, label_b: str) -> None:
    if not _SCIPY_AVAILABLE:
        print("\n(scipy not installed; skipping paired Wilcoxon test)")
        return

    a_by_slug = {r["slug"]: r for r in a}
    b_by_slug = {r["slug"]: r for r in b}
    common = sorted(set(a_by_slug) & set(b_by_slug))

    a_hr = [a_by_slug[s].get("hit_rate", 0) for s in common]
    b_hr = [b_by_slug[s].get("hit_rate", 0) for s in common]
    nonzero = [(x, y) for x, y in zip(a_hr, b_hr) if x != y]
    if len(nonzero) < 5:
        print(f"\n{label_a} vs {label_b}: too few non-zero diffs for Wilcoxon")
        return

    stat, p = wilcoxon([x for x, _ in nonzero], [y for _, y in nonzero])
    diff = sum(a_hr) / len(a_hr) - sum(b_hr) / len(b_hr)
    print(
        f"\nWilcoxon paired test on hit rates:"
        f"\n  {label_a} vs {label_b}: diff={diff:+.3f}pp, p={p:.4f}"
        f" (n_nonzero={len(nonzero)}/{len(common)})"
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="Recompute CKM Benchmark leaderboard metrics.")
    parser.add_argument(
        "--summary",
        nargs="+",
        type=Path,
        required=True,
        help="One or more per-topic summary JSON files.",
    )
    args = parser.parse_args()

    rows = []
    for path in args.summary:
        with open(path) as fh:
            data = json.load(fh)
        row = summary_row(path.stem, data)
        print_row(row)
        rows.append((path.stem, data))

    # Pairwise Wilcoxon for hit rate when 2+ summaries provided
    if len(rows) >= 2:
        for i in range(len(rows)):
            for j in range(i + 1, len(rows)):
                maybe_paired_wilcoxon(rows[i][1], rows[j][1], rows[i][0], rows[j][0])


if __name__ == "__main__":
    main()
