import argparse
import json
import random
from pathlib import Path
from typing import Any

try:
    from proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, theorem_split
    from run_search import candidate_rankings, parse_strategies
except ImportError:
    from scripts.proofstate_common import load_jsonl, normalize_entry, prepare_representation_rows, theorem_split
    from scripts.run_search import candidate_rankings, parse_strategies


def theorem_rows_before(rows: list[dict[str, Any]], query: dict[str, Any]) -> list[dict[str, Any]]:
    theorem = query["theorem"]
    file_name = query["file"]
    step_index = int(query["step_index"])
    prior = [
        row
        for row in rows
        if row["theorem"] == theorem and row["file"] == file_name and int(row["step_index"]) < step_index
    ]
    return sorted(prior, key=lambda row: int(row["step_index"]))


def write_jsonl(rows: list[dict[str, Any]], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")


def run_with_leandojo(
    repo_path: str,
    commit: str,
    all_rows: list[dict[str, Any]],
    query: dict[str, Any],
    tactic: str,
    timeout: int,
    build_deps: bool,
) -> dict[str, Any]:
    try:
        from lean_dojo import Dojo, LeanGitRepo, Theorem
    except ImportError as exc:
        return {"status": "error", "accepted": False, "error": f"LeanDojo import failed: {exc}"}

    try:
        repo = LeanGitRepo(repo_path, commit)
        theorem = Theorem(repo, query["file"], query["theorem"])
        prior_rows = theorem_rows_before(all_rows, query)
        with Dojo(theorem, timeout=timeout, build_deps=build_deps) as (dojo, state):
            current_state = state
            for prior in prior_rows:
                result = dojo.run_tac(current_state, prior["next_tactic"])
                if not hasattr(result, "state"):
                    return {
                        "status": "replay_failed",
                        "accepted": False,
                        "error": str(result),
                        "failed_step_index": prior["step_index"],
                    }
                current_state = result.state
            result = dojo.run_tac(current_state, tactic)
            accepted = hasattr(result, "state")
            return {"status": "executed", "accepted": accepted, "error": "" if accepted else str(result)}
    except Exception as exc:
        return {"status": "error", "accepted": False, "error": str(exc)}


def build_execution_records(
    rows: list[dict[str, Any]],
    strategies: list[str],
    representation: str,
    test_ratio: float,
    seed: int,
    max_k: int,
    sample_size: int,
    family_weight: float,
    family_model: str,
    execute: bool,
    repo_path: str,
    commit: str,
    timeout: int,
    build_deps: bool,
) -> list[dict[str, Any]]:
    train_rows, test_rows, split_meta = theorem_split(rows, test_ratio=test_ratio, seed=seed)
    rng = random.Random(seed)
    sampled_test = list(test_rows)
    rng.shuffle(sampled_test)
    sampled_test = sampled_test[:sample_size] if sample_size else sampled_test

    records: list[dict[str, Any]] = []
    for strategy in strategies:
        prepared_train, prepared_test = prepare_representation_rows(train_rows, sampled_test, representation)
        ranked = candidate_rankings(
            prepared_train,
            prepared_test,
            representation,
            strategy,
            family_weight=family_weight,
            family_model=family_model,
        )
        for query, candidates in zip(prepared_test, ranked):
            query_id = f"{query.get('file', '')}:{query.get('theorem', '')}:{query.get('step_index', 0)}"
            for rank, candidate in enumerate(candidates[:max_k], start=1):
                record = {
                    "query_id": query_id,
                    "strategy": strategy,
                    "rank": rank,
                    "k": max_k,
                    "representation": representation,
                    "seed": seed,
                    "split_strategy": split_meta["strategy"],
                    "query_file": query["file"],
                    "query_theorem": query["theorem"],
                    "query_step_index": query["step_index"],
                    "gold_family": query["tactic_family"],
                    "gold_tactic": query["next_tactic"],
                    "candidate_family": candidate["tactic_family"],
                    "candidate_tactic": candidate["next_tactic"],
                }
                if execute:
                    record.update(
                        run_with_leandojo(
                            repo_path,
                            commit,
                            rows,
                            query,
                            candidate["next_tactic"],
                            timeout,
                            build_deps,
                        )
                    )
                else:
                    record.update({"status": "queued", "accepted": None, "error": ""})
                records.append(record)
    return records


def summarize(records: list[dict[str, Any]], max_k: int) -> list[dict[str, Any]]:
    grouped: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for record in records:
        grouped.setdefault((record["strategy"], record["query_id"]), []).append(record)
    strategies = sorted({record["strategy"] for record in records})
    output = []
    for strategy in strategies:
        query_groups = [items for (name, _), items in grouped.items() if name == strategy]
        row: dict[str, Any] = {"strategy": strategy, "queries": len(query_groups)}
        for k in [1, 3, 5, max_k]:
            hits = 0
            executable = 0
            for items in query_groups:
                top = [item for item in items if int(item["rank"]) <= k]
                if any(item.get("accepted") is not None for item in top):
                    executable += 1
                if any(item.get("accepted") is True for item in top):
                    hits += 1
            row[f"lean_accept_at_{k}"] = hits / executable if executable else ""
        output.append(row)
    return output


def main() -> None:
    parser = argparse.ArgumentParser(description="Create or execute LeanDojo top-k tactic execution queues.")
    parser.add_argument("--data", type=Path, required=True)
    parser.add_argument("--strategy", default="unguided,family_guided,family_soft,family_top_m,family_rrf")
    parser.add_argument("--representation", default="state_only")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--sample-size", type=int, default=500)
    parser.add_argument("--max-k", type=int, default=5)
    parser.add_argument("--family-weight", type=float, default=0.25)
    parser.add_argument("--family-model", default="logistic_regression")
    parser.add_argument("--execute", action="store_true")
    parser.add_argument("--repo", default=".")
    parser.add_argument("--commit", default="HEAD")
    parser.add_argument("--timeout", type=int, default=600)
    parser.add_argument(
        "--no-build-deps",
        action="store_true",
        help="Ask LeanDojo to fetch dependency cache before building the traced repo.",
    )
    parser.add_argument("--cache-output", type=Path, default=Path("results/execution/lean_execution_cache.jsonl"))
    parser.add_argument("--summary-output", type=Path, default=Path("results/tables/lean_execution_summary.jsonl"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    records = build_execution_records(
        rows,
        parse_strategies(args.strategy),
        args.representation,
        args.test_ratio,
        args.seed,
        args.max_k,
        args.sample_size,
        args.family_weight,
        args.family_model,
        args.execute,
        args.repo,
        args.commit,
        args.timeout,
        not args.no_build_deps,
    )
    write_jsonl(records, args.cache_output)
    write_jsonl(summarize(records, args.max_k), args.summary_output)
    action = "executed" if args.execute else "queued"
    print(f"{action} {len(records)} candidate executions in {args.cache_output}")
    print(f"Wrote summary to: {args.summary_output}")


if __name__ == "__main__":
    main()
