#!/usr/bin/env python3
"""
Generate a synthetic BugBench dataset by prompting Qwen2.5-Coder-7B-Instruct to generate bugs.

This script:
1. Loads bugbench_human test and test_small splits
2. For each task_id, uses the BugGenerator component to generate bugs
3. Validates the bugs (must fail at least one test without compile/runtime errors)
4. Creates equivalent bugbench_synth test and test_small splits

Example:
  # Start vLLM server first:
  python -m vllm.entrypoints.openai.api_server \
    --model Qwen/Qwen2.5-Coder-7B-Instruct \
    --port 8000 \
    --tensor-parallel-size 1

  # Run the script:
  python examples/bugs/data_processing/generate_bugbench_synth.py \
    --source-dataset bugbench_human \
    --output-repo anonymous/bugbench_synth \
    --base-url http://localhost:8000/v1 \
    --model Qwen/Qwen2.5-Coder-7B-Instruct \
    --splits test test_small

  # For a quick test run with limited examples:
  python examples/bugs/data_processing/generate_bugbench_synth.py \
    --output-repo anonymous/bugbench_synth \
    --base-url http://localhost:8000/v1 \
    --splits test_small \
    --limit 10 \
    --dry-run \
    --verbose
"""

from __future__ import annotations

import argparse
import asyncio
import contextlib
import io
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple

try:
    import httpx
    _HAS_HTTPX = True
except ImportError:
    _HAS_HTTPX = False

from datasets import Dataset as HFDataset
from datasets import DatasetDict as HFDatasetDict
from datasets import load_dataset

from examples.bugs.prompts import _build_bug_generator_prompt
from rllm.data.dataset import DatasetRegistry
from rllm.engine.rollout import OpenAIEngine
from rllm.rewards.code_reward import RewardCodeFn
from rllm.rewards.reward_types import RewardConfig, RewardOutput


def check_bug_validity(
    bug_meta: Dict[str, Any],
    bug_reward_output: Any,
    compile_errors_invalid: bool = True,
) -> Tuple[bool, bool]:
    """
    A valid bug must:
      1) Have no compilation errors (if compile_errors_invalid=True)
      2) Fail at least one unit test (passed_tests < total_tests OR all_passed=False)
    """
    total_tests = bug_meta.get("total_tests")
    passed_tests = bug_meta.get("passed_tests")
    all_passed = bug_meta.get("all_passed", False)

    test_results = bug_meta.get("test_results", [])
    has_compile_error = False

    if isinstance(test_results, list):
        for test in test_results:
            msg = str(test.get("error_message", "") or "")
            msg_l = msg.lower()
            if not msg:
                continue

            if "error during testing:" in msg_l:
                has_compile_error = True
                break

            if "wrong answer" not in msg_l:
                compile_patterns = [
                    "syntax", "syntaxerror", "compilation", "compile error",
                    "cannot compile", "indentation", "invalid syntax",
                    "unexpected", "eof", "unterminated", "was never closed",
                    "nameerror", "typeerror", "attributeerror", "import error",
                    "module not found", "indentationerror",
                ]
                if any(p in msg_l for p in compile_patterns):
                    has_compile_error = True
                    break

    # Also check output string for common error patterns (BigCodeBench style)
    output_str = str(bug_meta.get("output", "") or "").lower()
    if output_str and not has_compile_error:
        # Don't flag assertion errors as compile errors - those are valid test failures
        if "assertionerror" not in output_str:
            error_patterns = [
                "syntaxerror", "indentationerror", "invalid syntax",
                "was never closed", "unterminated", "eof while scanning",
                "importerror", "modulenotfounderror", "nameerror",
                "typeerror", "attributeerror",
            ]
            if any(p in output_str for p in error_patterns):
                has_compile_error = True

    if compile_errors_invalid and has_compile_error:
        bug_valid = False
    elif total_tests is not None and passed_tests is not None and total_tests > 0:
        bug_valid = passed_tests < total_tests
    elif all_passed is False:
        bug_valid = True
    else:
        # Fallback: use correctness signal if metadata is incomplete.
        bug_valid = not bool(getattr(bug_reward_output, "is_correct", False))
        if compile_errors_invalid:
            bug_valid = bug_valid and not has_compile_error

    return bug_valid, has_compile_error


_FENCE_RE = re.compile(r"```(?:python)?\n(.*?)```", re.DOTALL | re.IGNORECASE)
_THINK_TAG_RE = re.compile(r"</?\s*think\s*>", re.IGNORECASE)


def _resolve_api_key(base_url: str, explicit_api_key: Optional[str]) -> str:
    """Resolve API key, allowing dummy keys for non-OpenAI endpoints."""
    api_key = (explicit_api_key or os.getenv("OPENAI_API_KEY", "")) or ""
    api_key = str(api_key).strip()
    is_openai_api = "api.openai.com" in str(base_url)
    if is_openai_api and not api_key:
        raise ValueError(
            "base_url points to the OpenAI API, but OPENAI_API_KEY is missing/empty. "
            "Please export OPENAI_API_KEY or pass --api-key explicitly."
        )
    return api_key or "EMPTY"


def _extract_code_maybe_fenced(text: str) -> str:
    """Extract code from fenced blocks if present."""
    if text is None:
        return ""
    m = _FENCE_RE.findall(str(text))
    if m:
        return (m[-1] or "").strip()
    return str(text).strip()


def _strip_think_tokens(text: str) -> str:
    """Remove <think> tokens from a model response."""
    s = str(text or "")
    has_any_think = bool(_THINK_TAG_RE.search(s))
    if not has_any_think:
        return s

    had_closer = "</think>" in s.lower()
    if had_closer:
        s = s.rsplit("</think>", 1)[-1]
    s = _THINK_TAG_RE.sub("", s)
    return s.strip()


def _to_fenced_python(code_or_text: str) -> str:
    """Wrap code in Python fence if not already fenced."""
    s = (code_or_text or "").strip("\n")
    if s.startswith("```"):
        return s
    return f"```python\n{s}\n```"


def _maybe_prepend_starter(code_prompt: str, code_or_body: str) -> str:
    """If model outputs a body-only snippet, prepend starter code (imports + signature)."""
    cp = (code_prompt or "").strip("\n")
    sol = (code_or_body or "").strip("\n")
    sol_lower = sol.lstrip().lower()
    if "```" in sol or sol_lower.startswith("import ") or sol_lower.startswith("from ") or "def " in sol or "class " in sol:
        return sol
    if cp:
        return cp + "\n" + sol
    return sol


def _get_uid(example: Dict[str, Any]) -> str:
    """Extract uid/task_id from example."""
    return str(example.get("uid") or example.get("task_id") or "")


def _get_problem(task: Dict[str, Any]) -> str:
    """Extract problem description from task."""
    for key in ("question", "instruct_prompt", "truncated_instruct_prompt", "complete_prompt", "prompt", "problem"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    return ""


def _get_reference_solution(task: Dict[str, Any]) -> str:
    """Extract reference solution from task."""
    for key in ("reference_solution", "canonical_solution", "solution", "code"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    return ""


def _get_ground_truth(task: Dict[str, Any]) -> str:
    """Extract ground truth (tests) from task."""
    for key in ("ground_truth", "test"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    return ""


def _get_entry_point(task: Dict[str, Any]) -> str:
    """Extract entry point from task."""
    ep = task.get("entry_point")
    if isinstance(ep, str) and ep.strip():
        return ep.strip()
    metadata = task.get("metadata", {})
    if isinstance(metadata, dict):
        ep = metadata.get("func_name") or metadata.get("entry_point")
        if isinstance(ep, str) and ep.strip():
            return ep.strip()
    return ""


def _get_starter_code(task: Dict[str, Any]) -> str:
    """Extract starter code from task."""
    for key in ("starter_code", "code_prompt"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    return ""


async def _quiet_model_call(engine: OpenAIEngine, messages: List[Dict[str, str]], verbose: bool = False, **kwargs) -> str:
    """Suppress noisy stdout from OpenAIEngine (unless verbose)."""
    if verbose:
        # Don't suppress output in verbose mode for debugging
        out = await engine.get_model_response(messages, **kwargs)
        return out.content or out.text or ""
    else:
        buf = io.StringIO()
        with contextlib.redirect_stdout(buf), contextlib.redirect_stderr(buf):
            out = await engine.get_model_response(messages, **kwargs)
        return out.content or out.text or ""


def _reward_bugbench(
    reward_fn: RewardCodeFn,
    *,
    test_code: str,
    code: str,
    entry_point: str,
    uid: str,
    index: int,
) -> RewardOutput:
    """Run BigCodeBench/BugBench unit tests on the given code."""
    task_info = {
        "data_source": "bugbench",
        "ground_truth": test_code,
        "entry_point": entry_point,
        "uid": uid,
        "index": index,
        "metadata": {"func_name": entry_point},
    }
    return reward_fn(task_info=task_info, action=_to_fenced_python(code))


async def _reward_bugbench_async(
    *,
    reward_fn: RewardCodeFn,
    executor: ThreadPoolExecutor,
    test_code: str,
    code: str,
    entry_point: str,
    uid: str,
    index: int,
) -> RewardOutput:
    """Run the unit-test execution off the event loop to avoid blocking asyncio."""
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(
        executor,
        lambda: _reward_bugbench(
            reward_fn,
            test_code=test_code,
            code=code,
            entry_point=entry_point,
            uid=uid,
            index=index,
        ),
    )


async def _generate_bugs(
    *,
    engine: OpenAIEngine,
    reward_fn: RewardCodeFn,
    reward_executor: ThreadPoolExecutor,
    problem: str,
    reference_solution: str,
    starter_code: str,
    test_code: str,
    entry_point: str,
    uid: str,
    index: int,
    n_bugs: int,
    k: int,
    temperature: float,
    top_p: float,
    max_tokens: int,
    max_completion_tokens: int,
    verbose: bool = False,
) -> Tuple[List[str], Dict[str, Any]]:
    """
    Generate up to n_bugs valid bugs, trying up to k total attempts.
    Returns a list of valid buggy solutions (may be empty if none found).
    """
    debug: Dict[str, Any] = {
        "attempts": 0,
        "bugs_found": 0,
        "bug_valid_at": [],
        "last_meta": None,
        "last_has_compile_error": None,
    }

    valid_bugs: List[str] = []

    # Build the bug generator prompt
    reference_fenced = _to_fenced_python(reference_solution)
    prompt = _build_bug_generator_prompt(problem=problem, correct_code=reference_fenced)
    messages = [{"role": "user", "content": prompt}]

    for attempt in range(1, int(k) + 1):
        # Stop early if we have enough bugs
        if len(valid_bugs) >= n_bugs:
            break

        debug["attempts"] = attempt
        token_kwargs: Dict[str, Any]
        if int(max_completion_tokens) > 0:
            token_kwargs = {"max_completion_tokens": int(max_completion_tokens)}
        else:
            token_kwargs = {"max_tokens": int(max_tokens)}

        try:
            raw = await _quiet_model_call(
                engine,
                messages,
                verbose=verbose,
                temperature=float(temperature),
                top_p=float(top_p),
                **token_kwargs,
            )
        except Exception as e:
            if verbose:
                print(f"[uid={uid}] Model call failed on attempt {attempt}/{k}: {e}")
                import traceback
                traceback.print_exc()
            continue

        if verbose:
            print(f"[uid={uid}] Attempt {attempt} raw response: {raw[:500]}...")

        cleaned = _strip_think_tokens(raw)
        code = _extract_code_maybe_fenced(cleaned)
        if not code.strip():
            if verbose:
                print(f"[uid={uid}] Attempt {attempt}: empty code extracted")
            continue

        # Prepend starter code if needed
        full_code = _maybe_prepend_starter(starter_code, code)

        # Validate the bug
        try:
            ro = await _reward_bugbench_async(
                reward_fn=reward_fn,
                executor=reward_executor,
                test_code=test_code,
                code=full_code,
                entry_point=entry_point,
                uid=uid,
                index=index,
            )
        except Exception as e:
            if verbose:
                print(f"[uid={uid}] Reward evaluation failed on attempt {attempt}/{k}: {e}")
            continue

        meta = ro.metadata or {}
        debug["last_meta"] = meta
        bug_valid, has_compile_error = check_bug_validity(
            bug_meta=meta, bug_reward_output=ro, compile_errors_invalid=True
        )
        debug["last_has_compile_error"] = bool(has_compile_error)

        if bool(bug_valid) and (not bool(ro.is_correct)):
            debug["bug_valid_at"].append(attempt)
            debug["bugs_found"] = len(valid_bugs) + 1
            if verbose:
                print(f"[uid={uid}] Valid bug #{len(valid_bugs)+1} found at attempt {attempt}")
            valid_bugs.append(_to_fenced_python(full_code))
        else:
            if verbose:
                why = "compile/runtime error" if has_compile_error else ("passed all tests" if ro.is_correct else "invalid")
                print(f"[uid={uid}] Attempt {attempt}: invalid bug ({why})")

    debug["bugs_found"] = len(valid_bugs)
    return valid_bugs, debug


async def process_split(
    *,
    source_dataset: str,
    split: str,
    engine: OpenAIEngine,
    args: argparse.Namespace,
) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
    """Process a single split from the source dataset."""
    print(f"\n{'='*80}")
    print(f"Processing split: {split}")
    print(f"{'='*80}")

    # Load the source dataset split
    print(f"Loading {source_dataset} split={split}...")
    try:
        ds = DatasetRegistry.load_dataset(source_dataset, split)
        if ds is None:
            # Fallback: try loading from HuggingFace
            hf_repo = f"anonymous/{source_dataset}"
            print(f"  Registry load failed, trying HuggingFace: {hf_repo}")
            hf_ds = load_dataset(hf_repo, split=split)
            tasks = [dict(ex) for ex in hf_ds]
        else:
            tasks = ds.get_data()
            print(f"  Loaded from registry: {ds.get_data_path()}")
    except Exception as e:
        print(f"  Failed to load dataset: {e}")
        import traceback
        traceback.print_exc()
        return [], {"total": 0, "success": 0, "failed": 0}

    print(f"Loaded {len(tasks)} tasks")
    if tasks and args.verbose:
        print(f"  First task keys: {list(tasks[0].keys())}")

    # Apply start/end/limit filters
    start = max(0, int(args.start))
    end = len(tasks) if args.end is None else min(len(tasks), int(args.end))
    if end <= start:
        print(f"  Invalid slice: start={start} end={end}")
        return [], {"total": 0, "success": 0, "failed": 0}

    tasks = tasks[start:end]
    if args.limit is not None and args.limit > 0:
        tasks = tasks[: int(args.limit)]

    print(f"Processing {len(tasks)} tasks (slice {start}:{end}, limit={args.limit})")

    # Set up reward function and executor
    reward_fn = RewardCodeFn(RewardConfig())
    executor = ThreadPoolExecutor(max_workers=max(1, min(32, int(args.n_parallel))))

    stats = {
        "total": len(tasks),
        "success": 0,
        "failed": 0,
    }
    out_rows: List[Dict[str, Any]] = []

    sem = asyncio.Semaphore(int(args.n_parallel))

    async def process_one(task: Dict[str, Any], idx: int) -> List[Dict[str, Any]]:
        """Process one task and return a list of rows (one per valid bug found)."""
        async with sem:
            uid = _get_uid(task)
            problem = _get_problem(task)
            reference_solution = _get_reference_solution(task)
            ground_truth = _get_ground_truth(task)
            entry_point = _get_entry_point(task)
            starter_code = _get_starter_code(task)

            if not reference_solution.strip():
                if args.verbose:
                    print(f"[{idx}] uid={uid}: missing reference solution, skipping")
                return []

            if not ground_truth.strip():
                if args.verbose:
                    print(f"[{idx}] uid={uid}: missing ground truth, skipping")
                return []

            buggy_solutions, debug = await _generate_bugs(
                engine=engine,
                reward_fn=reward_fn,
                reward_executor=executor,
                problem=problem,
                reference_solution=reference_solution,
                starter_code=starter_code,
                test_code=ground_truth,
                entry_point=entry_point,
                uid=uid,
                index=idx,
                n_bugs=int(args.n_bugs),
                k=int(args.k),
                temperature=float(args.temperature),
                top_p=float(args.top_p),
                max_tokens=int(args.max_tokens),
                max_completion_tokens=int(args.max_completion_tokens),
                verbose=bool(args.verbose),
            )

            if not buggy_solutions:
                if args.verbose:
                    print(f"[{idx}] uid={uid}: no valid bugs generated after {debug.get('attempts', 0)} attempts")
                return []

            # Build output rows - one per valid bug found
            rows = []
            for bug_idx, buggy_solution in enumerate(buggy_solutions):
                row = {
                    "question": task.get("question", problem),
                    "reference_solution": _to_fenced_python(reference_solution),
                    "buggy_solution": buggy_solution,
                    "ground_truth": ground_truth,
                    "data_source": task.get("data_source", "bugbench"),
                    "uid": uid,
                    "index": idx,
                    "bug_index": bug_idx,  # Which bug this is for this task
                    "starter_code": starter_code,
                    "entry_point": entry_point,
                    "metadata": task.get("metadata", {"func_name": entry_point}),
                    # Preserve additional fields from source
                    "instruct_prompt": task.get("instruct_prompt", ""),
                    "truncated_instruct_prompt": task.get("truncated_instruct_prompt", ""),
                    "complete_prompt": task.get("complete_prompt", ""),
                    "code_prompt": task.get("code_prompt", starter_code),
                    # Debug info
                    "bug_generation_debug": debug,
                }
                rows.append(row)
            
            if args.verbose:
                print(f"[{idx}] uid={uid}: generated {len(rows)} valid bug(s)")
            return rows

    # Process all tasks in batches
    completed = 0
    tasks_with_bugs = 0
    total_bugs = 0
    batch_size = max(int(args.batch_size), 1)
    
    for batch_start in range(0, len(tasks), batch_size):
        batch_end = min(batch_start + batch_size, len(tasks))
        batch_tasks = tasks[batch_start:batch_end]
        
        results = await asyncio.gather(
            *[process_one(task, batch_start + i) for i, task in enumerate(batch_tasks)],
            return_exceptions=True,
        )
        
        for r in results:
            if isinstance(r, Exception):
                print(f"  Exception during processing: {r}")
                stats["failed"] += 1
            elif not r:  # Empty list or None
                stats["failed"] += 1
            else:
                # r is a list of rows (one per bug found for this task)
                out_rows.extend(r)
                stats["success"] += 1
                tasks_with_bugs += 1
                total_bugs += len(r)
            completed += 1
        
        if completed % max(int(args.log_every), 1) == 0 or completed == len(tasks):
            pct = 100.0 * tasks_with_bugs / completed if completed > 0 else 0
            print(f"  Progress: {completed}/{len(tasks)} tasks | {tasks_with_bugs}/{completed} with bugs ({pct:.1f}%) | {total_bugs} total bugs")

    executor.shutdown(wait=False, cancel_futures=True)

    pct = 100.0 * tasks_with_bugs / stats['total'] if stats['total'] > 0 else 0
    print(f"\nSplit {split} complete:")
    print(f"  Tasks with valid bugs: {tasks_with_bugs}/{stats['total']} ({pct:.1f}%)")
    print(f"  Total bugs generated: {total_bugs}")
    stats["tasks_with_bugs"] = tasks_with_bugs
    stats["total_bugs"] = total_bugs
    return out_rows, stats


async def main_async() -> None:
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    # Source and output
    parser.add_argument(
        "--source-dataset",
        type=str,
        default="bugbench_human",
        help="Source dataset name in DatasetRegistry (default: bugbench_human)",
    )
    parser.add_argument(
        "--output-repo",
        type=str,
        required=True,
        help="HuggingFace dataset repo id for output, e.g. anonymous/bugbench_synth",
    )
    parser.add_argument(
        "--splits",
        type=str,
        nargs="+",
        default=["test", "test_small"],
        help="Splits to process (default: test test_small)",
    )
    parser.add_argument("--private", action="store_true", help="Push as a private dataset")
    parser.add_argument(
        "--hf-token",
        type=str,
        default=None,
        help="HuggingFace token (or set HF_TOKEN/HUGGINGFACEHUB_API_TOKEN)",
    )

    # Model configuration
    parser.add_argument(
        "--model",
        type=str,
        default="Qwen/Qwen2.5-Coder-7B-Instruct",
        help="Model name (default: Qwen/Qwen2.5-Coder-7B-Instruct)",
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default="http://localhost:30000/v1",
        help="OpenAI-compatible base URL (default: http://localhost:8000/v1)",
    )
    parser.add_argument(
        "--api-key",
        type=str,
        default=None,
        help="API key (or set OPENAI_API_KEY)",
    )

    # Generation parameters
    parser.add_argument(
        "--n-bugs",
        type=int,
        default=1,
        help="Number of valid bugs to generate per task (default: 1)",
    )
    parser.add_argument(
        "--k",
        type=int,
        default=8,
        help="Max total attempts to generate bugs per task (default: 8)",
    )
    parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature (default: 0.6)")
    parser.add_argument("--top-p", type=float, default=0.95, help="Top-p sampling (default: 0.95)")
    parser.add_argument("--max-tokens", type=int, default=2048, help="Max tokens (default: 2048)")
    parser.add_argument(
        "--max-completion-tokens",
        type=int,
        default=0,
        help="If >0, use this instead of max_tokens (for reasoning models)",
    )

    # Processing parameters
    parser.add_argument("--n-parallel", type=int, default=16, help="Parallel tasks (default: 16)")
    parser.add_argument("--batch-size", type=int, default=32, help="Batch size for gather (default: 32)")
    parser.add_argument("--log-every", type=int, default=25, help="Log progress every N completions")
    parser.add_argument("--verbose", action="store_true", help="Print detailed logs")

    # Slice/limit
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive)")
    parser.add_argument("--end", type=int, default=None, help="End index (exclusive)")
    parser.add_argument("--limit", type=int, default=None, help="Max number of tasks to process per split")

    # Persistence
    parser.add_argument("--dry-run", action="store_true", help="Generate but do not push to HuggingFace")
    parser.add_argument(
        "--save-local",
        type=str,
        default=None,
        help="Optional: save datasets to local directory",
    )
    parser.add_argument(
        "--register-locally",
        action="store_true",
        help="Register generated datasets in local DatasetRegistry",
    )
    parser.add_argument(
        "--local-registry-name",
        type=str,
        default="bugbench_synth",
        help="Name for local DatasetRegistry (default: bugbench_synth)",
    )

    args = parser.parse_args()

    # Resolve HF token
    hf_token = (
        args.hf_token
        or os.environ.get("HF_TOKEN")
        or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
        or None
    )

    print("=" * 80)
    print("Synthetic BugBench Generation")
    print("=" * 80)
    print(f"Source dataset: {args.source_dataset}")
    print(f"Output repo: {args.output_repo}")
    print(f"Splits: {args.splits}")
    print(f"Model: {args.model}")
    print(f"Base URL: {args.base_url}")
    print(f"n_bugs (bugs per task): {args.n_bugs}")
    print(f"k (max attempts per task): {args.k}")
    print(f"Temperature: {args.temperature}")
    print(f"N parallel: {args.n_parallel}")
    print("=" * 80)

    # Test connection to the server
    print("\nTesting connection to vLLM server...")
    if _HAS_HTTPX:
        try:
            resp = httpx.get(f"{args.base_url}/models", timeout=10.0)
            if resp.status_code == 200:
                models = resp.json()
                print(f"  Connected! Available models: {[m.get('id') for m in models.get('data', [])]}")
            else:
                print(f"  Warning: Server returned status {resp.status_code}")
        except httpx.ConnectError as e:
            print(f"  ERROR: Cannot connect to {args.base_url}")
            print(f"  Make sure the vLLM server is running and accessible.")
            print(f"  Error details: {e}")
            return
        except Exception as e:
            print(f"  Warning: Connection test failed: {e}")
            print("  Proceeding anyway...")
    else:
        print("  (httpx not installed, skipping connection test)")

    # Initialize the engine
    max_response_length = (
        int(args.max_completion_tokens) if int(args.max_completion_tokens) > 0 else int(args.max_tokens)
    )
    engine = OpenAIEngine(
        model=str(args.model),
        tokenizer=None,
        max_prompt_length=8192,
        max_response_length=max_response_length,
        base_url=str(args.base_url),
        api_key=_resolve_api_key(str(args.base_url), args.api_key),
        sampling_params={},
        api_retries=3,
        verbose=False,
    )

    # Process each split
    all_splits: Dict[str, HFDataset] = {}
    total_stats = {"total": 0, "success": 0, "failed": 0, "tasks_with_bugs": 0, "total_bugs": 0}

    t0 = time.time()

    for split in args.splits:
        rows, stats = await process_split(
            source_dataset=args.source_dataset,
            split=split,
            engine=engine,
            args=args,
        )

        total_stats["total"] += stats["total"]
        total_stats["success"] += stats["success"]
        total_stats["failed"] += stats["failed"]
        total_stats["tasks_with_bugs"] += stats.get("tasks_with_bugs", 0)
        total_stats["total_bugs"] += stats.get("total_bugs", 0)

        if rows:
            all_splits[split] = HFDataset.from_list(rows)

            # Optionally register locally
            if args.register_locally:
                try:
                    DatasetRegistry.register_dataset(args.local_registry_name, rows, split)
                    print(f"Registered {args.local_registry_name}/{split} with {len(rows)} examples")
                except Exception as e:
                    print(f"Failed to register locally: {e}")

    elapsed = time.time() - t0
    rate = total_stats["total"] / max(elapsed, 1e-6)

    print("\n" + "=" * 80)
    print("Summary")
    print("=" * 80)
    print(f"Total time: {elapsed:.1f}s ({rate:.2f} tasks/s)")
    print(f"Total tasks processed: {total_stats['total']}")
    pct_with_bugs = 100.0 * total_stats['tasks_with_bugs'] / total_stats['total'] if total_stats['total'] > 0 else 0
    print(f"Tasks with at least 1 valid bug: {total_stats['tasks_with_bugs']}/{total_stats['total']} ({pct_with_bugs:.1f}%)")
    print(f"Total bugs generated: {total_stats['total_bugs']}")
    print(f"Tasks with no valid bugs: {total_stats['failed']}")
    print()
    print("Output splits:")
    for split, ds in all_splits.items():
        print(f"  {split}: {len(ds)} bug examples")
    print("=" * 80)

    if not all_splits:
        print("No rows produced; nothing to push.")
        return

    # Save locally if requested
    if args.save_local:
        os.makedirs(args.save_local, exist_ok=True)
        for split, ds in all_splits.items():
            path = os.path.join(args.save_local, f"{split}.parquet")
            ds.to_parquet(path)
            print(f"Saved {split} to {path}")

    # Push to HuggingFace
    if args.dry_run:
        print("\nDry run mode - not pushing to HuggingFace")
        if all_splits:
            sample_split = list(all_splits.keys())[0]
            sample_row = all_splits[sample_split][0]
            print(f"\nSample row keys: {list(sample_row.keys())}")
            print(f"Sample uid: {sample_row.get('uid')}")
        return

    if not hf_token and not args.dry_run:
        print("\nWarning: No HF token found. Set HF_TOKEN or pass --hf-token to push to HuggingFace.")
        print("Use --dry-run to generate without pushing.")
        return

    print(f"\nPushing to HuggingFace: {args.output_repo}")
    dsd = HFDatasetDict(all_splits)
    dsd.push_to_hub(
        str(args.output_repo),
        private=bool(args.private),
        token=hf_token,
    )
    print(f"Done! Dataset available at: https://huggingface.co/datasets/{args.output_repo}")


def main() -> None:
    asyncio.run(main_async())


if __name__ == "__main__":
    main()
