import argparse
import csv
import glob
import json
import random
from pathlib import Path
from statistics import mean, stdev
from typing import Any


def load_runs(pattern: str) -> list[dict[str, Any]]:
    runs = []
    for path in sorted(glob.glob(pattern)):
        with open(path, "r", encoding="utf-8") as f:
            obj = json.load(f)
        if "per_query" not in obj:
            continue
        obj["_path"] = path
        runs.append(obj)
    return runs


def seed_of(run: dict[str, Any]) -> int:
    split = run.get("split", {})
    return int(split.get("seed", run.get("seed", -1)))


def aggregate_metrics(runs: list[dict[str, Any]]) -> list[dict[str, Any]]:
    by_strategy: dict[str, list[dict[str, Any]]] = {}
    for run in runs:
        by_strategy.setdefault(run["strategy"], []).append(run)

    rows: list[dict[str, Any]] = []
    metric_names = sorted({name for run in runs for name in run.get("metrics", {})})
    for strategy, items in sorted(by_strategy.items()):
        row: dict[str, Any] = {"strategy": strategy, "n_seeds": len(items)}
        weights = sorted({str(item.get("family_weight", "")) for item in items})
        row["family_weights"] = ",".join(weight for weight in weights if weight != "")
        for metric in metric_names:
            values = [float(item["metrics"][metric]) for item in items if metric in item.get("metrics", {})]
            if not values:
                continue
            row[f"{metric}_mean"] = mean(values)
            row[f"{metric}_std"] = stdev(values) if len(values) > 1 else 0.0
        rows.append(row)
    return rows


def hit_map(run: dict[str, Any], metric: str) -> dict[str, bool]:
    seed = seed_of(run)
    return {f"{seed}:{item['query_id']}": bool(item.get(metric)) for item in run.get("per_query", [])}


def bootstrap_ci(diffs: list[float], n_bootstrap: int, seed: int) -> tuple[float, float]:
    if not diffs:
        return 0.0, 0.0
    rng = random.Random(seed)
    n = len(diffs)
    samples = []
    for _ in range(n_bootstrap):
        samples.append(sum(diffs[rng.randrange(n)] for _ in range(n)) / n)
    samples.sort()
    lo = samples[int(0.025 * (n_bootstrap - 1))]
    hi = samples[int(0.975 * (n_bootstrap - 1))]
    return lo, hi


def mcnemar_pvalue(b: int, c: int) -> float:
    n = b + c
    if n == 0:
        return 1.0
    try:
        from scipy.stats import binomtest

        return float(binomtest(b, n=n, p=0.5, alternative="two-sided").pvalue)
    except Exception:
        # Continuity-corrected chi-square fallback.
        import math

        stat = (abs(b - c) - 1) ** 2 / n if n else 0.0
        return float(math.erfc(math.sqrt(stat / 2)))


def significance_rows(
    runs: list[dict[str, Any]],
    baseline: str,
    metrics: list[str],
    n_bootstrap: int,
    seed: int,
) -> list[dict[str, Any]]:
    by_strategy: dict[str, list[dict[str, Any]]] = {}
    for run in runs:
        by_strategy.setdefault(run["strategy"], []).append(run)
    baseline_runs = by_strategy.get(baseline, [])
    if not baseline_runs:
        return []

    output = []
    for metric in metrics:
        base_hits = {}
        for run in baseline_runs:
            base_hits.update(hit_map(run, metric))
        for strategy, items in sorted(by_strategy.items()):
            if strategy == baseline:
                continue
            strat_hits = {}
            for run in items:
                strat_hits.update(hit_map(run, metric))
            common = sorted(set(base_hits) & set(strat_hits))
            diffs = [float(strat_hits[q]) - float(base_hits[q]) for q in common]
            b = sum(1 for q in common if strat_hits[q] and not base_hits[q])
            c = sum(1 for q in common if base_hits[q] and not strat_hits[q])
            lo, hi = bootstrap_ci(diffs, n_bootstrap, seed)
            output.append(
                {
                    "strategy": strategy,
                    "baseline": baseline,
                    "metric": metric,
                    "n_pairs": len(common),
                    "strategy_rate": sum(float(strat_hits[q]) for q in common) / len(common) if common else "",
                    "baseline_rate": sum(float(base_hits[q]) for q in common) / len(common) if common else "",
                    "diff": mean(diffs) if diffs else "",
                    "bootstrap_ci_low": lo,
                    "bootstrap_ci_high": hi,
                    "mcnemar_b_strategy_only": b,
                    "mcnemar_c_baseline_only": c,
                    "mcnemar_p": mcnemar_pvalue(b, c),
                }
            )
    return output


def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = sorted({key for row in rows for key in row})
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> None:
    parser = argparse.ArgumentParser(description="Aggregate search runs and compute paired uncertainty tests.")
    parser.add_argument("--search-glob", required=True)
    parser.add_argument("--baseline", default="unguided")
    parser.add_argument("--metrics", default="exact_hit_at_1,exact_hit_at_5,family_hit_at_5")
    parser.add_argument("--bootstrap", type=int, default=1000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--aggregate-output", type=Path, default=Path("results/tables/search_aggregate.csv"))
    parser.add_argument("--significance-output", type=Path, default=Path("results/tables/search_significance.csv"))
    args = parser.parse_args()

    runs = load_runs(args.search_glob)
    aggregate = aggregate_metrics(runs)
    significance = significance_rows(
        runs,
        args.baseline,
        [metric.strip() for metric in args.metrics.split(",") if metric.strip()],
        args.bootstrap,
        args.seed,
    )
    write_csv(aggregate, args.aggregate_output)
    write_csv(significance, args.significance_output)
    print(f"Wrote aggregate search table to: {args.aggregate_output}")
    print(f"Wrote paired significance table to: {args.significance_output}")


if __name__ == "__main__":
    main()
