from __future__ import annotations

import argparse
import csv
import json
import math
import random
from pathlib import Path
from statistics import mean
from typing import Any


def load_jsonl(path: Path) -> list[dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]


def read_csv(path: Path) -> list[dict[str, str]]:
    with path.open("r", encoding="utf-8", newline="") as f:
        return list(csv.DictReader(f))


def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = sorted({key for row in rows for key in row})
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def hit_maps(records: list[dict[str, Any]], ks: list[int]) -> dict[str, dict[int, dict[str, bool]]]:
    grouped: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for record in records:
        grouped.setdefault((record["strategy"], record["query_id"]), []).append(record)
    strategies = sorted({strategy for strategy, _ in grouped})
    output: dict[str, dict[int, dict[str, bool]]] = {strategy: {k: {} for k in ks} for strategy in strategies}
    for (strategy, query_id), items in grouped.items():
        for k in ks:
            output[strategy][k][query_id] = any(item.get("accepted") is True for item in items if int(item["rank"]) <= k)
    return output


def bootstrap_ci(diffs: list[float], n_bootstrap: int, seed: int) -> tuple[float, float]:
    if not diffs:
        return 0.0, 0.0
    rng = random.Random(seed)
    n = len(diffs)
    samples = []
    for _ in range(n_bootstrap):
        samples.append(sum(diffs[rng.randrange(n)] for _ in range(n)) / n)
    samples.sort()
    return samples[int(0.025 * (n_bootstrap - 1))], samples[int(0.975 * (n_bootstrap - 1))]


def exact_mcnemar_pvalue(strategy_only: int, baseline_only: int) -> float:
    n = strategy_only + baseline_only
    if n == 0:
        return 1.0
    tail = min(strategy_only, baseline_only)
    prob = sum(math.comb(n, i) for i in range(tail + 1)) / (2**n)
    return min(1.0, 2.0 * prob)


def significance_rows(
    records: list[dict[str, Any]],
    baseline: str,
    ks: list[int],
    n_bootstrap: int,
    seed: int,
) -> list[dict[str, Any]]:
    maps = hit_maps(records, ks)
    if baseline not in maps:
        return []
    rows = []
    for strategy in sorted(maps):
        if strategy == baseline:
            continue
        for k in ks:
            base_hits = maps[baseline][k]
            strat_hits = maps[strategy][k]
            common = sorted(set(base_hits) & set(strat_hits))
            diffs = [float(strat_hits[q]) - float(base_hits[q]) for q in common]
            strategy_only = sum(1 for q in common if strat_hits[q] and not base_hits[q])
            baseline_only = sum(1 for q in common if base_hits[q] and not strat_hits[q])
            lo, hi = bootstrap_ci(diffs, n_bootstrap, seed + k)
            rows.append(
                {
                    "strategy": strategy,
                    "baseline": baseline,
                    "metric": f"accept_at_{k}",
                    "n_pairs": len(common),
                    "strategy_rate": sum(float(strat_hits[q]) for q in common) / len(common) if common else "",
                    "baseline_rate": sum(float(base_hits[q]) for q in common) / len(common) if common else "",
                    "diff": mean(diffs) if diffs else "",
                    "bootstrap_ci_low": lo,
                    "bootstrap_ci_high": hi,
                    "mcnemar_b_strategy_only": strategy_only,
                    "mcnemar_c_baseline_only": baseline_only,
                    "mcnemar_p": exact_mcnemar_pvalue(strategy_only, baseline_only),
                }
            )
    return rows


def metric(row: dict[str, str], name: str) -> float:
    value = row.get(name) or row.get(f"{name}_mean") or ""
    return float(value) if value != "" else float("nan")


def proxy_execution_gap_rows(search_rows: list[dict[str, str]], execution_rows: list[dict[str, str]]) -> list[dict[str, Any]]:
    search_by_strategy = {row["strategy"]: row for row in search_rows}
    execution_by_strategy = {row["strategy"]: row for row in execution_rows}
    rows = []
    for strategy in sorted(set(search_by_strategy) & set(execution_by_strategy)):
        search = search_by_strategy[strategy]
        execution = execution_by_strategy[strategy]
        proxy_exact_5 = metric(search, "exact_tactic_success_at_5")
        lean_accept_5 = metric(execution, "accept_at_5_all")
        rows.append(
            {
                "strategy": strategy,
                "proxy_exact_at_1": metric(search, "exact_tactic_success_at_1"),
                "proxy_exact_at_5": proxy_exact_5,
                "proxy_family_at_5": metric(search, "family_success_at_5"),
                "lean_accept_at_1": metric(execution, "accept_at_1_all"),
                "lean_accept_at_5": lean_accept_5,
                "lean_minus_proxy_exact_at_5": lean_accept_5 - proxy_exact_5,
                "candidate_execution_coverage": metric(execution, "candidate_execution_coverage"),
            }
        )
    return rows


def main() -> None:
    parser = argparse.ArgumentParser(description="Paired uncertainty tests for reconstructed Lean execution.")
    parser.add_argument("--classified-cache", type=Path, default=Path("results/execution/state_reconstruction_direct_sample500_k5_classified.jsonl"))
    parser.add_argument("--execution-table", type=Path, default=Path("results/tables/execution_audit_by_strategy.csv"))
    parser.add_argument("--search-table", type=Path, default=Path("results/tables/s4_search_validation_aggregate.csv"))
    parser.add_argument("--baseline", default="unguided")
    parser.add_argument("--ks", default="1,3,5")
    parser.add_argument("--bootstrap", type=int, default=5000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--significance-output", type=Path, default=Path("results/tables/execution_accept_significance.csv"))
    parser.add_argument("--gap-output", type=Path, default=Path("results/tables/trace_execution_gap.csv"))
    args = parser.parse_args()

    records = load_jsonl(args.classified_cache)
    ks = [int(value) for value in args.ks.split(",") if value.strip()]
    write_csv(
        significance_rows(records, args.baseline, ks, args.bootstrap, args.seed),
        args.significance_output,
    )
    write_csv(
        proxy_execution_gap_rows(read_csv(args.search_table), read_csv(args.execution_table)),
        args.gap_output,
    )
    print(f"Wrote execution significance table to: {args.significance_output}")
    print(f"Wrote trace/execution comparison table to: {args.gap_output}")


if __name__ == "__main__":
    main()
