#!/usr/bin/env python3
"""
Reformat `anonymous/adversarial_bugbench_bcb` into the same *column schema* as `anonymous/bugbench`
and push to the Hugging Face Hub.

Why this exists:
- `anonymous/bugbench` has a specific raw schema used by downstream tools/scripts.
- `anonymous/adversarial_bugbench_bcb` contains mutation/solver metadata and code snippets, but
  doesn't have BugBench's raw columns.
- This script joins against `anonymous/bugbench` by `task_id` to recover prompt/test metadata,
  then replaces `buggy` with the mutated buggy body and `canonical_solution` with the best solver fix.

Usage:
  python examples/bugs/reformat_bugbench.py \
    --source_repo anonymous/adversarial_bugbench_bcb \
    --bugbench_repo anonymous/bugbench \
    --split train \
    --repo_id anonymous/adversarial_bugbench_bcb_bugbench_format

Auth:
- If you are not logged in via `huggingface-cli login`, pass `--token` or set
  `HF_TOKEN` / `HUGGINGFACEHUB_API_TOKEN`.
"""

from __future__ import annotations

import argparse
import ast
import json
import os
import textwrap
from typing import Any, Dict, List, Optional, Tuple

from datasets import Dataset, DatasetDict, Features, Value, load_dataset


BUGBENCH_COLUMNS = [
    "task_id",
    "instruct_prompt",
    "buggy",
    "canonical_solution",
    "test",
    "complete_prompt",
    "code_prompt",
    "entry_point",
    "doc_struct",
    "libs",
]


def _strip_code_fences(s: str) -> str:
    t = (s or "").strip()
    if t.startswith("```"):
        # Handle ```python ...``` or ``` ...```
        t = t.strip("`")
        # After stripping backticks we may have "python\n..."
        if "\n" in t:
            first, rest = t.split("\n", 1)
            if first.strip().lower() in {"python", "py"}:
                return rest.strip().rstrip("`").strip()
        return t.strip().rstrip("`").strip()
    return t


def _try_parse_json_list(s: Any) -> Optional[List[Any]]:
    if s is None:
        return None
    if isinstance(s, list):
        return s
    if not isinstance(s, str):
        return None
    txt = s.strip()
    if not txt:
        return None
    # Most fields are JSON-encoded strings like '["..."]'
    try:
        out = json.loads(txt)
        return out if isinstance(out, list) else None
    except Exception:
        pass
    # Some could be python-literal-like; fall back to ast.literal_eval
    try:
        out = ast.literal_eval(txt)
        return out if isinstance(out, list) else None
    except Exception:
        return None


def _extract_function_body_from_code(code: str, func_name: str) -> Optional[str]:
    """
    Return the function body (as a string with 4-space indentation) for `func_name`
    from full Python source `code`. Returns None if not found or unparsable.
    """
    code = (code or "").strip("\n")
    if not code or not func_name:
        return None

    try:
        tree = ast.parse(code)
    except Exception:
        tree = None

    if tree is not None:
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func_name:
                if not getattr(node, "body", None):
                    return ""
                lines = code.splitlines()
                start = int(node.body[0].lineno) - 1
                end = int(getattr(node, "end_lineno", node.body[-1].end_lineno or node.body[-1].lineno))
                body_lines = lines[start:end]
                dedented = textwrap.dedent("\n".join(body_lines)).rstrip()
                # Re-indent to match BugBench convention: body-only with 4-space indent.
                return "\n".join(("    " + ln if ln.strip() else ln) for ln in dedented.splitlines()).rstrip()

    # Fallback: naive text slicing after "def func_name"
    marker = f"def {func_name}"
    idx = code.find(marker)
    if idx == -1:
        return None
    after = code[idx:]
    # Find first newline after def line
    nl = after.find("\n")
    if nl == -1:
        return None
    body_text = after[nl + 1 :].rstrip()
    dedented = textwrap.dedent(body_text).rstrip()
    return "\n".join(("    " + ln if ln.strip() else ln) for ln in dedented.splitlines()).rstrip()


def _normalize_body_like_bugbench(text: str, entry_point: str) -> str:
    """
    Attempt to normalize a code snippet into BugBench's `buggy`/`canonical_solution` format:
    function body only, 4-space indented.
    """
    raw = _strip_code_fences(text)
    raw = raw.strip("\n")
    if not raw:
        return ""

    # If it's full code containing the function, extract the body.
    body = _extract_function_body_from_code(raw, entry_point)
    if body is not None:
        return body

    # Otherwise, assume it's already a body; normalize indentation.
    dedented = textwrap.dedent(raw).rstrip()
    return "\n".join(("    " + ln if ln.strip() else ln) for ln in dedented.splitlines()).rstrip()


def _pick_best_solution(
    *,
    solutions: Any,
    solution_scores: Any,
    entry_point: str,
) -> str:
    sols = _try_parse_json_list(solutions) or []
    scores = _try_parse_json_list(solution_scores) or []

    if not sols:
        return ""

    # Default: first solution.
    best_idx = 0
    # If we have numeric-ish scores, pick argmax.
    if len(scores) == len(sols) and scores:
        try:
            best_idx = max(range(len(scores)), key=lambda i: float(scores[i]))
        except Exception:
            best_idx = 0

    return _normalize_body_like_bugbench(str(sols[best_idx]), entry_point)


def _build_bugbench_index(bugbench_ds: Dataset) -> Dict[str, Dict[str, Any]]:
    idx: Dict[str, Dict[str, Any]] = {}
    for ex in bugbench_ds:
        tid = ex.get("task_id", None)
        if tid is None:
            continue
        idx[str(tid)] = dict(ex)
    return idx


def _convert_one(
    *,
    adv_ex: Dict[str, Any],
    bugbench_by_task_id: Dict[str, Dict[str, Any]],
    bcb_by_task_id: Dict[str, Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
    orig_task_id = str(adv_ex.get("task_id", "") or "").strip()
    if not orig_task_id:
        return None

    # Prefer BugBench (already filtered/curated), fall back to upstream BigCodeBench.
    meta = bugbench_by_task_id.get(orig_task_id) or bcb_by_task_id.get(orig_task_id)
    if meta is None:
        return None

    # Always keep the original task_id (matches BugBench / BigCodeBench conventions).
    out_task_id = orig_task_id

    entry_point = str(meta.get("entry_point", "") or "").strip()
    if not entry_point:
        # BugBench expects entry_point, so drop if missing.
        return None

    buggy_body = _normalize_body_like_bugbench(str(adv_ex.get("response", "") or ""), entry_point)
    # Keep BugBench's canonical solution *verbatim* so the output remains fully compatible
    # with the original `anonymous/bugbench` format (it's already body-only with correct indentation).
    canonical_body = str(meta.get("canonical_solution", "") or "").rstrip()

    out = {
        "task_id": out_task_id,
        "instruct_prompt": str(meta.get("instruct_prompt", "") or ""),
        "buggy": str(buggy_body or ""),
        "canonical_solution": str(canonical_body or ""),
        "test": str(meta.get("test", "") or ""),
        "complete_prompt": str(meta.get("complete_prompt", "") or ""),
        "code_prompt": str(meta.get("code_prompt", "") or ""),
        "entry_point": str(entry_point),
        "doc_struct": str(meta.get("doc_struct", "") or ""),
        "libs": str(meta.get("libs", "") or ""),
    }

    # Ensure all required columns exist (and no extras).
    return {k: out.get(k, "") for k in BUGBENCH_COLUMNS}


def _bugbench_features() -> Features:
    return Features({k: Value("string") for k in BUGBENCH_COLUMNS})


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--source_repo", default="anonymous/adversarial_bugbench_bcb")
    ap.add_argument("--bugbench_repo", default="anonymous/bugbench")
    ap.add_argument("--bcb_repo", default="anonymous/bigcodebench", help="Fallback metadata source")
    ap.add_argument("--bcb_split", default="v0.1.0_hf", help="Split to use from --bcb_repo")
    ap.add_argument("--split", default="train")
    ap.add_argument("--repo_id", default="anonymous/adversarial_bugbench_bcb_bugbench_format")
    ap.add_argument("--config_name", default="default")
    ap.add_argument("--private", action="store_true")
    ap.add_argument(
        "--token",
        default=None,
        help="HF token; if omitted uses HF_TOKEN / HUGGINGFACEHUB_API_TOKEN or cached login.",
    )
    ap.add_argument("--dry_run", action="store_true")
    ap.add_argument("--max_rows", type=int, default=0, help="0 = all (useful for quick smoke tests)")
    args = ap.parse_args()

    token = (
        args.token
        or os.environ.get("HF_TOKEN")
        or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
        or None
    )

    bugbench = load_dataset(args.bugbench_repo, split=str(args.split))
    bcb = load_dataset(str(args.bcb_repo), split=str(args.bcb_split))
    adv = load_dataset(args.source_repo, split=str(args.split))

    bugbench_by_task_id = _build_bugbench_index(bugbench)
    bcb_by_task_id = _build_bugbench_index(bcb)

    rows: List[Dict[str, Any]] = []
    dropped_missing_join = 0
    recovered_from_bcb = 0
    dropped_missing_entry = 0
    limit = int(args.max_rows) if int(args.max_rows) > 0 else None

    for i, ex in enumerate(adv):
        if limit is not None and i >= limit:
            break
        exd = dict(ex)
        orig_task_id = str(exd.get("task_id", "") or "").strip()
        meta = bugbench_by_task_id.get(orig_task_id) or bcb_by_task_id.get(orig_task_id)
        if meta is None:
            dropped_missing_join += 1
            continue
        if orig_task_id not in bugbench_by_task_id and orig_task_id in bcb_by_task_id:
            recovered_from_bcb += 1
        if not str(meta.get("entry_point", "") or "").strip():
            dropped_missing_entry += 1
            continue
        out = _convert_one(
            adv_ex=exd,
            bugbench_by_task_id=bugbench_by_task_id,
            bcb_by_task_id=bcb_by_task_id,
        )
        if out is not None:
            rows.append(out)

    print(
        f"Loaded adv={len(adv)} bugbench={len(bugbench)} bcb={len(bcb)} | "
        f"output_rows={len(rows)} dropped_missing_join={dropped_missing_join} recovered_from_bcb={recovered_from_bcb} "
        f"dropped_missing_entry={dropped_missing_entry}"
    )

    ds_out = Dataset.from_list(rows, features=_bugbench_features())
    print("Output columns:", ds_out.column_names)
    print("Output features:", ds_out.features)
    if len(ds_out) > 0:
        print("\nSample output row (truncated):")
        r0 = ds_out[0]
        for k in BUGBENCH_COLUMNS:
            s = str(r0.get(k, ""))
            if len(s) > 240:
                s = s[:240] + "..."
            print(f" - {k}: {s}")

    if args.dry_run:
        return

    dd = DatasetDict({str(args.split): ds_out})
    dd.push_to_hub(
        str(args.repo_id),
        config_name=str(args.config_name),
        private=bool(args.private),
        token=token,
    )
    print(f"Pushed reformatted dataset to {args.repo_id!r} (config={args.config_name!r}, split={args.split!r}).")


if __name__ == "__main__":
    main()