from __future__ import annotations

import argparse
import csv
import json
import re
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any

try:
    from execution_audit_common import (
        EXECUTED_ACCEPT,
        INFRASTRUCTURE_ERROR,
        PARSE_ERROR,
        RECON_ELABORATION,
        RECON_UNKNOWN_IDENTIFIER,
        TIMEOUT,
        classify_execution_error,
        is_executed_class,
    )
except ImportError:
    from scripts.execution_audit_common import (
        EXECUTED_ACCEPT,
        INFRASTRUCTURE_ERROR,
        PARSE_ERROR,
        RECON_ELABORATION,
        RECON_UNKNOWN_IDENTIFIER,
        TIMEOUT,
        classify_execution_error,
        is_executed_class,
    )


ERROR_CLASSES = [
    EXECUTED_ACCEPT,
    "executed_reject",
    RECON_UNKNOWN_IDENTIFIER,
    RECON_ELABORATION,
    PARSE_ERROR,
    TIMEOUT,
    INFRASTRUCTURE_ERROR,
]


PATH_REPLACEMENTS = [
    (re.compile(r"(?:/|[A-Za-z]:[/\\])(?:[^\s\"']*/)*ai-for-theorem-proving"), "<project-root>"),
    (re.compile(r"(?:/|[A-Za-z]:[/\\])(?:[^\s\"']*/)*\.elan"), "<elan-home>"),
]


def scrub_environment_paths(text: str) -> str:
    for pattern, replacement in PATH_REPLACEMENTS:
        text = pattern.sub(replacement, text)
    return text


def load_jsonl(path: Path) -> list[dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]


def write_jsonl(rows: list[dict[str, Any]], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")


def write_csv(rows: list[dict[str, Any]], path: Path, fieldnames: list[str] | None = None) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    if fieldnames is None:
        fieldnames = sorted({key for row in rows for key in row})
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
        writer.writeheader()
        writer.writerows(rows)


def classify_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
    output = []
    for record in records:
        error_class = classify_execution_error(
            str(record.get("status", "")),
            record.get("accepted"),
            str(record.get("error", "")),
        )
        row = dict(record)
        if "error" in row:
            row["error"] = scrub_environment_paths(str(row["error"]))
        row["error_class"] = error_class
        row["audit_status"] = "executed" if is_executed_class(error_class) else "non_executable"
        output.append(row)
    return output


def hit_at(records: list[dict[str, Any]], k: int) -> bool:
    return any(record.get("accepted") is True for record in records if int(record["rank"]) <= k)


def eligible_at(records: list[dict[str, Any]], k: int) -> bool:
    top = [record for record in records if int(record["rank"]) <= k]
    return bool(top) and all(is_executed_class(str(record["error_class"])) for record in top)


def summarize_by_strategy(records: list[dict[str, Any]], ks: list[int]) -> list[dict[str, Any]]:
    by_strategy: dict[str, list[dict[str, Any]]] = defaultdict(list)
    for record in records:
        by_strategy[record["strategy"]].append(record)

    rows = []
    for strategy in sorted(by_strategy):
        items = by_strategy[strategy]
        groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
        for item in items:
            groups[item["query_id"]].append(item)
        class_counts = Counter(item["error_class"] for item in items)
        executed = sum(1 for item in items if is_executed_class(item["error_class"]))
        candidates = len(items)
        row: dict[str, Any] = {
            "strategy": strategy,
            "queries": len(groups),
            "candidates": candidates,
            "executed_candidates": executed,
            "candidate_execution_coverage": executed / candidates if candidates else 0.0,
            "query_coverage_at_5": sum(1 for group in groups.values() if eligible_at(group, 5)) / len(groups),
            "reconstruction_failures": class_counts[RECON_UNKNOWN_IDENTIFIER]
            + class_counts[RECON_ELABORATION]
            + class_counts[PARSE_ERROR],
            "reconstruction_failure_rate": (
                class_counts[RECON_UNKNOWN_IDENTIFIER]
                + class_counts[RECON_ELABORATION]
                + class_counts[PARSE_ERROR]
            )
            / candidates
            if candidates
            else 0.0,
            "timeouts": class_counts[TIMEOUT],
            "infrastructure_errors": class_counts[INFRASTRUCTURE_ERROR],
        }
        for error_class in ERROR_CLASSES:
            row[error_class] = class_counts[error_class]
        for k in ks:
            all_hits = sum(1 for group in groups.values() if hit_at(group, k))
            eligible_groups = [group for group in groups.values() if eligible_at(group, k)]
            eligible_hits = sum(1 for group in eligible_groups if hit_at(group, k))
            row[f"accept_at_{k}_all"] = all_hits / len(groups) if groups else 0.0
            row[f"executable_queries_at_{k}"] = len(eligible_groups)
            row[f"accept_at_{k}_executable"] = eligible_hits / len(eligible_groups) if eligible_groups else ""
        rows.append(row)
    return rows


def summarize_overall(records: list[dict[str, Any]], by_strategy: list[dict[str, Any]], args: argparse.Namespace) -> list[dict[str, Any]]:
    class_counts = Counter(record["error_class"] for record in records)
    query_ids = {record["query_id"] for record in records}
    executed = sum(1 for record in records if is_executed_class(record["error_class"]))
    rows = [
        {"metric": "input_path", "value": str(args.input)},
        {"metric": "dataset_path", "value": str(args.dataset_path)},
        {"metric": "dataset_rows", "value": args.dataset_rows},
        {"metric": "seed", "value": args.seed},
        {"metric": "sample_size", "value": args.sample_size},
        {"metric": "max_k", "value": args.max_k},
        {"metric": "lean_version", "value": args.lean_version},
        {"metric": "runtime", "value": args.runtime},
        {"metric": "job_id", "value": args.job_id},
        {"metric": "walltime", "value": args.walltime},
        {"metric": "n_queries", "value": len(query_ids)},
        {"metric": "n_candidates", "value": len(records)},
        {"metric": "executed_candidates", "value": executed},
        {"metric": "candidate_execution_coverage", "value": executed / len(records) if records else 0.0},
    ]
    for error_class in ERROR_CLASSES:
        rows.append({"metric": error_class, "value": class_counts[error_class]})
    rows.append(
        {
            "metric": "best_accept_at_5_all",
            "value": max(by_strategy, key=lambda row: row["accept_at_5_all"])["strategy"] if by_strategy else "",
        }
    )
    return rows


def error_examples(records: list[dict[str, Any]], max_per_class: int) -> list[dict[str, Any]]:
    counts: Counter[str] = Counter()
    rows = []
    for record in records:
        error_class = record["error_class"]
        if is_executed_class(error_class):
            continue
        if counts[error_class] >= max_per_class:
            continue
        counts[error_class] += 1
        rows.append(
            {
                "error_class": error_class,
                "strategy": record.get("strategy", ""),
                "query_id": record.get("query_id", ""),
                "rank": record.get("rank", ""),
                "candidate_tactic": record.get("candidate_tactic", ""),
                "error": str(record.get("error", "")).replace("\n", " ")[:800],
            }
        )
    return rows


def main() -> None:
    parser = argparse.ArgumentParser(description="Audit reconstructed-state Lean execution cache.")
    parser.add_argument("--input", type=Path, default=Path("results/execution/state_reconstruction_direct_sample500_k5.jsonl"))
    parser.add_argument("--classified-output", type=Path, default=Path("results/execution/state_reconstruction_direct_sample500_k5_classified.jsonl"))
    parser.add_argument("--summary-output", type=Path, default=Path("results/tables/execution_audit_summary.csv"))
    parser.add_argument("--by-strategy-output", type=Path, default=Path("results/tables/execution_audit_by_strategy.csv"))
    parser.add_argument("--examples-output", type=Path, default=Path("results/tables/execution_error_examples.csv"))
    parser.add_argument("--dataset-path", default="data/mathlib_subset_s4_submission_main_steps_checked.jsonl")
    parser.add_argument("--dataset-rows", default="3723")
    parser.add_argument("--seed", default="42")
    parser.add_argument("--sample-size", default="500")
    parser.add_argument("--max-k", type=int, default=5)
    parser.add_argument("--lean-version", default="Lean 4.28.0")
    parser.add_argument("--runtime", default="Lean 4.28 direct checking with cached build artifacts")
    parser.add_argument("--job-id", default="redacted")
    parser.add_argument("--walltime", default="redacted")
    parser.add_argument("--max-examples-per-class", type=int, default=5)
    args = parser.parse_args()

    records = classify_records(load_jsonl(args.input))
    by_strategy = summarize_by_strategy(records, [1, 3, 5])
    write_jsonl(records, args.classified_output)
    write_csv(
        summarize_overall(records, by_strategy, args),
        args.summary_output,
        ["metric", "value"],
    )
    write_csv(by_strategy, args.by_strategy_output)
    write_csv(
        error_examples(records, args.max_examples_per_class),
        args.examples_output,
        ["error_class", "strategy", "query_id", "rank", "candidate_tactic", "error"],
    )
    print(f"Wrote classified cache to: {args.classified_output}")
    print(f"Wrote strategy audit to: {args.by_strategy_output}")


if __name__ == "__main__":
    main()
