from __future__ import annotations

import argparse
import csv
import json
from collections import defaultdict
from pathlib import Path
from typing import Any


STRATEGY_ORDER = ["unguided", "family_soft", "family_guided", "family_rrf", "family_top_m"]


def load_jsonl(path: Path) -> list[dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]


def write_csv(rows: list[dict[str, Any]], path: Path, fieldnames: list[str] | None = None) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    if fieldnames is None:
        fieldnames = sorted({key for row in rows for key in row})
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
        writer.writeheader()
        writer.writerows(rows)


def top_k_groups(records: list[dict[str, Any]], k: int) -> dict[tuple[str, str], list[dict[str, Any]]]:
    groups: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list)
    for record in records:
        if int(record["rank"]) <= k:
            groups[(record["strategy"], record["query_id"])].append(record)
    for items in groups.values():
        items.sort(key=lambda row: int(row["rank"]))
    return groups


def summarize_alternatives(records: list[dict[str, Any]], k: int) -> list[dict[str, Any]]:
    groups = top_k_groups(records, k)
    by_strategy: dict[str, list[list[dict[str, Any]]]] = defaultdict(list)
    for (strategy, _query_id), items in groups.items():
        by_strategy[strategy].append(items)

    rows = []
    for strategy in sorted(by_strategy, key=lambda value: (STRATEGY_ORDER.index(value) if value in STRATEGY_ORDER else 99, value)):
        query_groups = by_strategy[strategy]
        n_queries = len(query_groups)
        exact_hits = 0
        accept_hits = 0
        accepted_not_gold = 0
        accept_without_exact = 0
        exact_without_accept = 0
        for items in query_groups:
            gold_tactic = str(items[0].get("gold_tactic", ""))
            exact = any(str(item.get("candidate_tactic", "")) == gold_tactic for item in items)
            accept = any(item.get("accepted") is True for item in items)
            alt_accept = any(
                item.get("accepted") is True and str(item.get("candidate_tactic", "")) != gold_tactic
                for item in items
            )
            exact_hits += int(exact)
            accept_hits += int(accept)
            accepted_not_gold += int(alt_accept)
            accept_without_exact += int(accept and not exact)
            exact_without_accept += int(exact and not accept)
        rows.append(
            {
                "strategy": strategy,
                "queries": n_queries,
                "k": k,
                "proxy_exact_at_5_on_sample": exact_hits / n_queries if n_queries else 0.0,
                "lean_accept_at_5": accept_hits / n_queries if n_queries else 0.0,
                "accepted_not_gold_at_5": accepted_not_gold / n_queries if n_queries else 0.0,
                "accept_without_exact_at_5": accept_without_exact / n_queries if n_queries else 0.0,
                "exact_without_accept_at_5": exact_without_accept / n_queries if n_queries else 0.0,
                "proxy_exact_at_5_count": exact_hits,
                "lean_accept_at_5_count": accept_hits,
                "accepted_not_gold_at_5_count": accepted_not_gold,
                "accept_without_exact_at_5_count": accept_without_exact,
                "exact_without_accept_at_5_count": exact_without_accept,
            }
        )
    return rows


def accepted_not_gold_examples(records: list[dict[str, Any]], k: int, max_examples: int) -> list[dict[str, Any]]:
    groups = top_k_groups(records, k)
    ordered_groups = sorted(
        groups.items(),
        key=lambda item: (
            STRATEGY_ORDER.index(item[0][0]) if item[0][0] in STRATEGY_ORDER else 99,
            item[0][1],
        ),
    )
    rows = []
    seen_queries: set[tuple[str, str]] = set()
    for (strategy, query_id), items in ordered_groups:
        if (strategy, query_id) in seen_queries:
            continue
        gold_tactic = str(items[0].get("gold_tactic", ""))
        for item in items:
            candidate = str(item.get("candidate_tactic", ""))
            if item.get("accepted") is True and candidate != gold_tactic:
                rows.append(
                    {
                        "strategy": strategy,
                        "query_id": query_id,
                        "rank": item.get("rank", ""),
                        "gold_family": item.get("gold_family", ""),
                        "candidate_family": item.get("candidate_family", ""),
                        "gold_tactic": gold_tactic,
                        "accepted_candidate": candidate,
                    }
                )
                seen_queries.add((strategy, query_id))
                break
        if len(rows) >= max_examples:
            break
    return rows


def main() -> None:
    parser = argparse.ArgumentParser(description="Analyze Lean-accepted candidates that differ from the traced gold tactic.")
    parser.add_argument("--classified-cache", type=Path, default=Path("results/execution/state_reconstruction_direct_sample500_k5_classified.jsonl"))
    parser.add_argument("--k", type=int, default=5)
    parser.add_argument("--summary-output", type=Path, default=Path("results/tables/execution_accepted_alternatives.csv"))
    parser.add_argument("--examples-output", type=Path, default=Path("results/tables/execution_accepted_not_gold_examples.csv"))
    parser.add_argument("--max-examples", type=int, default=8)
    args = parser.parse_args()

    records = load_jsonl(args.classified_cache)
    summary = summarize_alternatives(records, args.k)
    examples = accepted_not_gold_examples(records, args.k, args.max_examples)
    write_csv(summary, args.summary_output)
    write_csv(
        examples,
        args.examples_output,
        ["strategy", "query_id", "rank", "gold_family", "candidate_family", "gold_tactic", "accepted_candidate"],
    )
    print(f"Wrote accepted-alternative summary to: {args.summary_output}")
    print(f"Wrote accepted-not-gold examples to: {args.examples_output}")


if __name__ == "__main__":
    main()
