#!/usr/bin/env python3
"""
Reformat `anonymous/lcb_bugbench` so that:

  buggy_solution := buggy_solution_weak   (if non-empty)
               else buggy_solution_strong (fallback)

And output the dataset in the **same raw column schema** as `anonymous/bugbench`:

  task_id, instruct_prompt, buggy, canonical_solution, test,
  complete_prompt, code_prompt, entry_point, doc_struct, libs

Usage:
  python examples/bugs/hf_dataset_formatting/reformat_lcb_bugbench_pick_weak.py \
    --source_repo anonymous/lcb_bugbench \
    --splits test \
    --repo_id anonymous/lcb_bugbench_weak \
    --bug_variant weak_then_strong

Auth:
- If you are not logged in via `huggingface-cli login`, pass `--token` or set
  `HF_TOKEN` / `HUGGINGFACEHUB_API_TOKEN`.
"""

from __future__ import annotations

import argparse
import os
from typing import Any, Dict, Iterable, Optional

from datasets import Dataset, DatasetDict, Features, Value, load_dataset


BUGBENCH_COLUMNS = [
    "task_id",
    "instruct_prompt",
    "buggy",
    "canonical_solution",
    "test",
    "complete_prompt",
    "code_prompt",
    "entry_point",
    "doc_struct",
    "libs",
]


def _parse_splits(s: str) -> Iterable[str]:
    # Accept comma-separated list.
    for part in (s or "").split(","):
        part = part.strip()
        if part:
            yield part


def _is_nullish(x: Any) -> bool:
    if x is None:
        return True
    if isinstance(x, str):
        return not x.strip()
    return False


def _extract_code_from_fence(text: Any) -> str:
    """
    Best-effort extraction of code from a ```python fenced block.
    If not fenced, returns the original string.
    """
    if text is None:
        return ""
    s = str(text)
    if "```" not in s:
        return s.strip()
    # handle ```python\n...\n``` and ```\n...\n```
    start = s.find("```")
    if start == -1:
        return s.strip()
    rest = s[start + 3 :]
    # strip optional language tag
    if rest.lstrip().startswith("python"):
        rest = rest.lstrip()[len("python") :]
    # strip leading newline
    rest = rest.lstrip("\n")
    end = rest.rfind("```")
    if end != -1:
        rest = rest[:end]
    return rest.strip("\n").rstrip()


def _choose_buggy_solution(ex: Dict[str, Any], bug_variant: str) -> str:
    weak = ex.get("buggy_solution_weak", None)
    strong = ex.get("buggy_solution_strong", None)
    current = ex.get("buggy_solution", None)
    bug_variant = str(bug_variant or "").strip().lower()
    if bug_variant == "weak":
        return "" if weak is None else str(weak)
    if bug_variant == "strong":
        return "" if strong is None else str(strong)
    if bug_variant == "weak_then_strong":
        if not _is_nullish(weak):
            return str(weak)
        if not _is_nullish(strong):
            return str(strong)
        return "" if current is None else str(current)
    raise ValueError("--bug_variant must be one of: weak, strong, weak_then_strong")


def _to_bugbench_rows(ds: Dataset, bug_variant: str) -> list[dict[str, Any]]:
    """
    Convert LCB bugbench rows into BugBench raw schema (string-typed columns).
    """
    rows: list[dict[str, Any]] = []
    for ex in ds:
        uid = str(ex.get("uid", "") or "").strip()
        task_id = uid or str(ex.get("task_id", "") or "").strip() or "lcb_bugbench_unknown"

        # Prefer human-readable prompt if present.
        instruct_prompt = str(ex.get("problem", "") or "").strip() or str(ex.get("question", "") or "").strip()

        buggy_solution = _choose_buggy_solution(ex, bug_variant)
        buggy_code = _extract_code_from_fence(buggy_solution)

        ref_solution = ex.get("reference_solution", None)
        canonical_solution = _extract_code_from_fence(ref_solution)

        # BugBench's "test" is normally unittest code; here we keep the LCB tests payload as-is.
        test = ex.get("ground_truth", "")
        test = "" if test is None else str(test)

        # Use the full formatted question as "complete_prompt" when available.
        complete_prompt = str(ex.get("question", "") or "").strip()

        starter_code = ex.get("starter_code", "")
        code_prompt = "" if starter_code is None else str(starter_code)

        metadata = ex.get("metadata", {}) or {}
        entry_point = ""
        if isinstance(metadata, dict):
            entry_point = str(metadata.get("func_name", "") or "")

        row = {
            "task_id": task_id,
            "instruct_prompt": instruct_prompt,
            "buggy": buggy_code,
            "canonical_solution": canonical_solution,
            "test": test,
            "complete_prompt": complete_prompt,
            "code_prompt": code_prompt,
            "entry_point": entry_point,
            "doc_struct": "",
            "libs": "[]",
        }
        rows.append(row)
    return rows


def _iter_bugbench_rows(examples: Iterable[Dict[str, Any]], bug_variant: str) -> Iterable[Dict[str, Any]]:
    """
    Streaming-friendly version of `_to_bugbench_rows` that yields rows one-by-one.
    """
    for ex in examples:
        uid = str(ex.get("uid", "") or "").strip()
        task_id = uid or str(ex.get("task_id", "") or "").strip() or "lcb_bugbench_unknown"

        instruct_prompt = str(ex.get("problem", "") or "").strip() or str(ex.get("question", "") or "").strip()

        buggy_solution = _choose_buggy_solution(ex, bug_variant)
        buggy_code = _extract_code_from_fence(buggy_solution)

        ref_solution = ex.get("reference_solution", None)
        canonical_solution = _extract_code_from_fence(ref_solution)

        test = ex.get("ground_truth", "")
        test = "" if test is None else str(test)

        complete_prompt = str(ex.get("question", "") or "").strip()

        starter_code = ex.get("starter_code", "")
        code_prompt = "" if starter_code is None else str(starter_code)

        metadata = ex.get("metadata", {}) or {}
        entry_point = ""
        if isinstance(metadata, dict):
            entry_point = str(metadata.get("func_name", "") or "")

        yield {
            "task_id": task_id,
            "instruct_prompt": instruct_prompt,
            "buggy": buggy_code,
            "canonical_solution": canonical_solution,
            "test": test,
            "complete_prompt": complete_prompt,
            "code_prompt": code_prompt,
            "entry_point": entry_point,
            "doc_struct": "",
            "libs": "[]",
        }


def _bugbench_features() -> Features:
    return Features({k: Value("string") for k in BUGBENCH_COLUMNS})


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--source_repo", default="anonymous/lcb_bugbench")
    ap.add_argument("--splits", default="test", help="comma-separated (e.g. train,validation,test)")
    ap.add_argument("--repo_id", required=True, help="Destination dataset repo id, e.g. anonymous/lcb_bugbench_weak")
    ap.add_argument(
        "--bug_variant",
        default="weak_then_strong",
        choices=["weak_then_strong", "weak", "strong"],
        help="Which buggy solution to use when constructing the BugBench-style `buggy` column.",
    )
    ap.add_argument("--config_name", default="default")
    ap.add_argument("--private", action="store_true")
    ap.add_argument(
        "--token",
        default=None,
        help="HF token; if omitted uses HF_TOKEN / HUGGINGFACEHUB_API_TOKEN or cached login.",
    )
    ap.add_argument("--dry_run", action="store_true", help="Process and print a sample, but don't push")
    args = ap.parse_args()

    token = (
        args.token
        or os.environ.get("HF_TOKEN")
        or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
        or None
    )

    out = {}
    for split in _parse_splits(str(args.splits)):
        ds = load_dataset(str(args.source_repo), split=str(split))
        # Drop heavy debug columns to avoid materializing huge nested structures in Python.
        needed_cols = {
            "uid",
            "task_id",
            "problem",
            "question",
            "buggy_solution",
            "buggy_solution_weak",
            "buggy_solution_strong",
            "reference_solution",
            "ground_truth",
            "starter_code",
            "metadata",
        }
        drop_cols = [c for c in ds.column_names if c not in needed_cols]
        if drop_cols:
            ds = ds.remove_columns(drop_cols)

        # Dry-run: avoid building the full output Dataset; just print the first converted row.
        if args.dry_run:
            first = next(iter(_iter_bugbench_rows(ds, str(args.bug_variant))), None)
            if first is not None:
                print(f"[{split}] cols={BUGBENCH_COLUMNS}")
                for k in ["task_id", "entry_point", "instruct_prompt", "buggy", "canonical_solution", "test"]:
                    v = str(first.get(k, ""))
                    if len(v) > 200:
                        v = v[:200] + "..."
                    print(f" - {k}: {v}")
            continue

        ds2 = Dataset.from_generator(
            lambda: _iter_bugbench_rows(ds, str(args.bug_variant)),
            features=_bugbench_features(),
        )
        out[str(split)] = ds2

        if len(ds2) > 0:
            r0 = ds2[0]
            print(f"[{split}] rows={len(ds2)} cols={ds2.column_names}")
            for k in ["task_id", "entry_point", "instruct_prompt", "buggy", "canonical_solution", "test"]:
                v = str(r0.get(k, ""))
                if len(v) > 200:
                    v = v[:200] + "..."
                print(f" - {k}: {v}")

    dd = DatasetDict(out)
    if args.dry_run:
        return

    dd.push_to_hub(
        str(args.repo_id),
        config_name=str(args.config_name),
        private=bool(args.private),
        token=token,
    )
    print(f"Pushed dataset to {args.repo_id!r} (config={args.config_name!r}).")


if __name__ == "__main__":
    main()


