"""Retroactively classify failure types across completed runs.

Walks over Autocomp run directories (`output/.../eval-results-iter-*/`) and
iterative-baseline directories (`output/.../chain_*/turn_*/`), reads the
stored eval result JSONs, inspects the error/stdout text, and reports
per-run failure-type breakdowns.

Also produces a combined table across runs.

Usage:
    python -m autocomp.baselines.classify_failures \
        output/baselines/iterative \
        output/baselines/best_of_n \
        output/jaxbench-sweep

Flags:
    --format text|json    # output format (default text)
    --per-file            # also write per-run classifier JSON next to summary.json
"""
from __future__ import annotations

import argparse
import json
import pathlib
import re
from collections import Counter, defaultdict


# ---------------------------------------------------------------------------
# Classifier
# ---------------------------------------------------------------------------

def classify_error(stdout: str, error: str = "") -> str:
    """Classify a failed eval into one of a small set of failure types.

    Returns one of:
        "success"             — latency was extracted, no error
        "import_error"        — syntax/import failure before workload ran
        "missing_workload"    — file didn't define workload()
        "correctness_error"   — code ran but output mismatched reference
        "pallas_api_error"    — runtime error: wrong Pallas API usage
        "pallas_unsupported"  — NotImplementedError from Pallas TPU lowering
        "compile_error"       — JIT/compile-time error (TracerError, ShapeError, …)
        "timeout"             — per-impl timeout (120s)
        "oom"                 — out-of-memory
        "runtime_error"       — any other runtime exception
        "no_output"           — runner produced no result (crash, SSH failure)
        "no_code_extracted"   — LLM produced no code to evaluate
    """
    text = (stdout or "") + "\n" + (error or "")
    if not text.strip():
        return "no_output"

    lower = text.lower()

    # --- structural markers first ---
    if "[no code extracted]" in lower or "no code extracted" in lower:
        return "no_code_extracted"
    if "does not define a workload" in lower:
        return "missing_workload"
    if "correctness check failed" in lower:
        return "correctness_error"

    # --- timeouts / resource ---
    if "timed out" in lower or "timeout" in lower:
        return "timeout"
    if "resource_exhausted" in lower or "out of memory" in lower or "oom" in lower:
        return "oom"

    # --- Pallas-specific patterns (before generic runtime_error) ---
    # Unimplemented / unsupported primitive in Pallas TPU lowering
    if re.search(r"(?i)unimplemented primitive in pallas", text):
        return "pallas_unsupported"
    if re.search(r"(?i)not.*supported.*pallas", text):
        return "pallas_unsupported"

    # Wrong Pallas API usage (attribute errors, missing args on pallas_call, etc.)
    if re.search(r"module 'jax\.experimental\.pallas'.* has no attribute", text):
        return "pallas_api_error"
    if re.search(r"pallas_call\(\) (missing|got).*argument", text):
        return "pallas_api_error"
    if "BlockSpec" in text and re.search(r"(?i)(block_shape|index_map|error)", text):
        if "pallas" in lower and ("typeerror" in lower or "valueerror" in lower
                                  or "attributeerror" in lower):
            return "pallas_api_error"

    # --- compile / tracing errors ---
    compile_signals = [
        "TracerIntegerConversionError",
        "ConcretizationTypeError",
        "ShapeError",
        "MosaicError",
        "XlaRuntimeError",
        "LoweringError",
        "TypeError: JVP",
        "could not broadcast",
        "incompatible shapes",
    ]
    if any(s.lower() in lower for s in compile_signals):
        return "compile_error"

    # --- parse-time python errors ---
    if "SyntaxError" in text and "File \"/tmp/autocomp_jaxbench" in text:
        return "import_error"
    if "IndentationError" in text:
        return "import_error"
    if re.search(r"ImportError:|ModuleNotFoundError:", text):
        return "import_error"

    # --- generic runtime exceptions ---
    if re.search(r"\b(TypeError|ValueError|AttributeError|RuntimeError|"
                 r"NotImplementedError|IndexError|NameError|KeyError|AssertionError)\b",
                 text):
        return "runtime_error"

    return "runtime_error"


# ---------------------------------------------------------------------------
# Collectors
# ---------------------------------------------------------------------------

def _load_result_json(p: pathlib.Path) -> dict | None:
    try:
        return json.loads(p.read_text())
    except Exception:
        return None


def _classify_result(res: dict) -> str:
    if res.get("correct") and res.get("latency") is not None:
        return "success"
    ft = res.get("failure_type")
    if ft and ft not in ("None", "null"):
        return ft
    return classify_error(res.get("stdout", "") or "", res.get("error", "") or "")


def collect_iterative_run(run_dir: pathlib.Path) -> list[str]:
    """Return a list of failure-type labels (one per eval) from an iterative run."""
    labels: list[str] = []
    for chain_dir in sorted(run_dir.glob("chain_*")):
        if not chain_dir.is_dir():
            continue
        for turn_dir in sorted(chain_dir.glob("turn_*")):
            rp = turn_dir / "result.json"
            if not rp.exists():
                continue
            res = _load_result_json(rp)
            if res is None:
                labels.append("no_output")
                continue
            labels.append(_classify_result(res))
    return labels


def collect_best_of_n_run(run_dir: pathlib.Path) -> list[str]:
    """Best-of-N runs store per-candidate results under eval/."""
    labels: list[str] = []
    eval_dir = run_dir / "eval"
    if not eval_dir.is_dir():
        return labels
    for p in sorted(eval_dir.glob("code_*_result.txt")):
        if "_full" in p.name:
            continue
        res = _load_result_json(p)
        if res is None:
            labels.append("no_output")
            continue
        labels.append(_classify_result(res))
    return labels


def collect_autocomp_run(run_dir: pathlib.Path) -> list[str]:
    """Autocomp runs store results under eval-results-iter-*/code_*_result.txt."""
    labels: list[str] = []
    for it_dir in sorted(run_dir.glob("eval-results-iter-*")):
        for p in sorted(it_dir.glob("code_*_result.txt")):
            if "_full" in p.name:
                continue
            res = _load_result_json(p)
            if res is None:
                labels.append("no_output")
                continue
            labels.append(_classify_result(res))
    return labels


# ---------------------------------------------------------------------------
# Driver
# ---------------------------------------------------------------------------

def _looks_like_autocomp_run(p: pathlib.Path) -> bool:
    return any(p.glob("eval-results-iter-*"))


def _looks_like_iterative_run(p: pathlib.Path) -> bool:
    return any(p.glob("chain_*/turn_*/result.json"))


def _looks_like_best_of_n_run(p: pathlib.Path) -> bool:
    return (p / "eval").is_dir() and any((p / "eval").glob("code_*_result.txt"))


def collect(root: pathlib.Path) -> dict[str, list[str]]:
    """Walk `root` and classify every run under it.

    Returns: {run_name: [label, label, ...]}
    """
    out: dict[str, list[str]] = {}
    if not root.is_dir():
        return out

    for run_dir in sorted(p for p in root.iterdir() if p.is_dir()):
        if _looks_like_autocomp_run(run_dir):
            out[str(run_dir)] = collect_autocomp_run(run_dir)
        elif _looks_like_iterative_run(run_dir):
            out[str(run_dir)] = collect_iterative_run(run_dir)
        elif _looks_like_best_of_n_run(run_dir):
            out[str(run_dir)] = collect_best_of_n_run(run_dir)

    # If root itself is a run, handle that too.
    if _looks_like_autocomp_run(root):
        out[str(root)] = collect_autocomp_run(root)
    elif _looks_like_iterative_run(root):
        out[str(root)] = collect_iterative_run(root)
    elif _looks_like_best_of_n_run(root):
        out[str(root)] = collect_best_of_n_run(root)

    return out


def format_report(all_results: dict[str, list[str]]) -> str:
    """Pretty-print a per-run + combined summary table."""
    lines = []
    all_types: list[str] = []
    for labels in all_results.values():
        all_types.extend(labels)
    type_order = [k for k, _ in Counter(all_types).most_common()]

    header = f"{'run':<60s} {'n':>5s} {'ok':>5s}  " + "  ".join(f"{t:>10s}" for t in type_order)
    lines.append(header)
    lines.append("-" * len(header))

    agg: Counter = Counter()
    for run_name, labels in sorted(all_results.items()):
        c = Counter(labels)
        agg.update(c)
        n = len(labels)
        row = (
            f"{pathlib.Path(run_name).name[:60]:<60s} "
            f"{n:>5d} "
            f"{c.get('success', 0):>5d}  "
            + "  ".join(f"{c.get(t, 0):>10d}" for t in type_order)
        )
        lines.append(row)

    lines.append("-" * len(header))
    total = sum(agg.values())
    lines.append(
        f"{'TOTAL':<60s} "
        f"{total:>5d} "
        f"{agg.get('success', 0):>5d}  "
        + "  ".join(f"{agg.get(t, 0):>10d}" for t in type_order)
    )
    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
    parser.add_argument("roots", nargs="+", help="Directories to scan recursively.")
    parser.add_argument("--format", choices=["text", "json"], default="text")
    parser.add_argument("--per-file", action="store_true",
                        help="Also write failure_breakdown.json next to each run's summary.json.")
    args = parser.parse_args()

    all_results: dict[str, list[str]] = {}
    for root in args.roots:
        all_results.update(collect(pathlib.Path(root)))

    if args.per_file:
        for run_dir, labels in all_results.items():
            out = {"n": len(labels), "breakdown": dict(Counter(labels))}
            (pathlib.Path(run_dir) / "failure_breakdown.json").write_text(
                json.dumps(out, indent=2)
            )

    if args.format == "json":
        print(json.dumps(
            {r: dict(Counter(ls)) for r, ls in all_results.items()}, indent=2
        ))
    else:
        print(format_report(all_results))


if __name__ == "__main__":
    main()
