#!/usr/bin/env python3
"""
Preprocess BugBench-style datasets from HuggingFace repo IDs and register them.

This script:
1. Takes a list of HF repo IDs
2. Loads each dataset's splits (default: train, test, test_small)
3. Preprocesses each split into RLLM format
4. Registers each split with DatasetRegistry
5. Optionally pushes all processed splits back to HF (with corresponding push repo IDs)

NOTE: To push, you must be authenticated (e.g. `huggingface-cli login`).
"""

import json
import re
import argparse
from typing import Optional, List, Dict, Tuple

from datasets import load_dataset, DatasetDict

from rllm.data.dataset import DatasetRegistry
from rllm.data.utils import fetch_live_code_bench_system_prompt


def _truncate_instruct_prompt(instruct_prompt: str) -> str:
    pattern = r"You should write self-contained code starting with\s*:?\s*"
    truncated = re.split(pattern, instruct_prompt or "", maxsplit=1)[0]
    return truncated.rstrip()


def _ensure_fenced_python(code: str) -> str:
    s = (code or "").strip("\n")
    if "```" in s:
        return s
    return f"```python\n{s}\n```"


def _maybe_build_full_code(code_prompt: str, solution_or_body: str) -> str:
    cp = (code_prompt or "").strip("\n")
    sol = (solution_or_body or "").strip("\n")

    sol_lower = sol.lstrip().lower()
    if "```" in sol or sol_lower.startswith("import ") or sol_lower.startswith("from ") or "def " in sol:
        return sol

    if cp:
        return cp + "\n" + sol
    return sol


def _get_uid(ex: dict) -> str:
    return str(ex.get("uid") or ex.get("task_id") or "")


def preprocess_example(example: dict, idx: int, registry_name: str):
    """Preprocess a single example into RLLM format."""
    has_raw_bugbench_fields = any(
        k in example for k in ["instruct_prompt", "code_prompt", "canonical_solution", "test"]
    )
    has_preprocessed_fields = any(k in example for k in ["question", "reference_solution", "ground_truth"])

    # 1) Raw BugBench-style -> normalize
    if has_raw_bugbench_fields and not has_preprocessed_fields:
        instruct_prompt = example.get("instruct_prompt", "") or ""
        code_prompt = example.get("code_prompt", "") or ""

        truncated_instruct_prompt = _truncate_instruct_prompt(instruct_prompt)
        question = fetch_live_code_bench_system_prompt(truncated_instruct_prompt, code_prompt)

        canonical = example.get("canonical_solution", "") or ""
        buggy = example.get("buggy", "") or ""

        canonical_full = _maybe_build_full_code(code_prompt, canonical)
        buggy_full = _maybe_build_full_code(code_prompt, buggy)

        task_id = example.get("task_id", f"{registry_name}_{idx}")
        entry_point = example.get("entry_point", None)

        return {
            "question": question,
            "reference_solution": _ensure_fenced_python(canonical_full),
            "buggy_solution": _ensure_fenced_python(buggy_full),
            "ground_truth": example.get("test", ""),
            "data_source": str(registry_name),
            "uid": str(task_id),
            "index": int(idx),
            "starter_code": code_prompt,
            "instruct_prompt": instruct_prompt,
            "truncated_instruct_prompt": truncated_instruct_prompt,
            "complete_prompt": example.get("complete_prompt", None),
            "code_prompt": code_prompt,
            "entry_point": entry_point,
            "metadata": {"func_name": entry_point} if entry_point is not None else {},
        }

    # 2) Already-preprocessed BugBench-like -> normalize
    buggy_candidate = (
        example.get("buggy_solution", None)
        or example.get("buggy_sampled_solution", None)
        or example.get("buggy", None)
        or ""
    )
    reference_candidate = example.get("reference_solution", None) or example.get("canonical_solution", None) or ""

    if (example.get("question", None) or ""):
        question = example.get("question", "") or ""
        instruct_prompt = example.get("instruct_prompt", "") or ""
        code_prompt = example.get("code_prompt", "") or example.get("starter_code", "") or ""
        truncated_instruct_prompt = example.get("truncated_instruct_prompt", "") or (
            _truncate_instruct_prompt(instruct_prompt) if instruct_prompt else ""
        )
    else:
        instruct_prompt = example.get("instruct_prompt", "") or ""
        code_prompt = example.get("code_prompt", "") or example.get("starter_code", "") or ""
        truncated_instruct_prompt = _truncate_instruct_prompt(instruct_prompt)
        question = fetch_live_code_bench_system_prompt(truncated_instruct_prompt, code_prompt)

    task_id = example.get("uid", None) or example.get("task_id", None) or f"{registry_name}_{idx}"
    entry_point = example.get("entry_point", None)

    reference_full = _maybe_build_full_code(code_prompt, reference_candidate)
    buggy_full = _maybe_build_full_code(code_prompt, buggy_candidate)

    return {
        "question": question,
        "reference_solution": _ensure_fenced_python(reference_full),
        "buggy_solution": _ensure_fenced_python(buggy_full),
        "ground_truth": example.get("ground_truth", None) or example.get("test", None) or "",
        "data_source": str(example.get("data_source", None) or registry_name),
        "uid": str(task_id),
        "index": int(example.get("index", idx)),
        "starter_code": example.get("starter_code", None) or example.get("code_prompt", None) or "",
        "instruct_prompt": instruct_prompt,
        "truncated_instruct_prompt": truncated_instruct_prompt,
        "complete_prompt": example.get("complete_prompt", None),
        "code_prompt": code_prompt,
        "entry_point": entry_point,
        "metadata": example.get("metadata", None) or ({"func_name": entry_point} if entry_point is not None else {}),
    }


def process_and_register_split(
    repo_id: str,
    registry_name: str,
    split: str,
    excluded_indexes_file: Optional[str] = None,
) -> Optional[Tuple[object, object]]:
    """
    Load, preprocess, and register a single split from a dataset.
    
    Args:
        repo_id: HuggingFace repo ID (e.g., "anonymous/bugbench_new")
        registry_name: Name to use in DatasetRegistry (e.g., "bugbench")
        split: Split to load from HF dataset (e.g., "train", "test", "test_small")
        excluded_indexes_file: Optional path to excluded indexes JSON file
    
    Returns:
        Registered dataset object
    """
    print(f"\n  [Processing split] {split}")
    
    # Load excluded task_ids if provided
    excluded_task_ids = set()
    if excluded_indexes_file:
        try:
            with open(excluded_indexes_file, "r") as f:
                excluded_data = json.load(f)
                excluded_indexes = excluded_data.get("excluded_indexes", {}).get("bigcodebench", {}).get("v0.1.0_hf", [])
                excluded_task_ids = {f"BigCodeBench/{idx}" for idx in excluded_indexes}
        except FileNotFoundError:
            pass  # Already loaded in parent function
        except Exception as e:
            pass  # Already loaded in parent function
    
    # Load dataset from HF
    print(f"    [load] Loading {repo_id} split={split}...")
    try:
        raw = load_dataset(str(repo_id), split=split)
    except Exception as e:
        # Try loading as DatasetDict and get the split
        try:
            dataset_dict = load_dataset(str(repo_id))
            if split not in dataset_dict:
                print(f"    [SKIP] Split '{split}' not found in {repo_id}")
                return None, None
            raw = dataset_dict[split]
        except Exception as e2:
            print(f"    [SKIP] Failed to load {repo_id} split={split}: {e}")
            return None, None
    
    print(f"    [load] Loaded {len(raw)} examples")
    
    # Filter excluded task_ids if any
    if excluded_task_ids:
        before = len(raw)
        raw = raw.filter(lambda ex: _get_uid(ex) not in excluded_task_ids, num_proc=16)
        after = len(raw)
        print(f"    [filter] After excluding task_ids: {before} -> {after} examples")
    
    # Preprocess
    print(f"    [preprocess] Preprocessing {len(raw)} examples...")
    def preprocess_fn(example, idx):
        return preprocess_example(example, idx, registry_name)
    
    processed = raw.map(
        preprocess_fn,
        with_indices=True,
        writer_batch_size=10,
        num_proc=16,
        remove_columns=raw.column_names,
    )
    print(f"    [preprocess] Preprocessed {len(processed)} examples")
    
    # Register with DatasetRegistry
    print(f"    [register] Registering as {registry_name}/{split}...")
    registered = DatasetRegistry.register_dataset(str(registry_name), processed, split)
    print(f"    [register] Registered {len(registered)} examples")
    print(f"    [register] Path: {registered.get_data_path()}")
    
    return registered, processed


def process_and_register_dataset(
    repo_id: str,
    registry_name: str,
    splits: List[str] = ["train", "test", "test_small"],
    excluded_indexes_file: Optional[str] = None,
    push_to_hub: bool = False,
    push_repo_id: Optional[str] = None,
    num_train_examples: Optional[int] = None,
) -> Dict[str, object]:
    """
    Load, preprocess, and register all splits from a dataset.
    
    Args:
        repo_id: HuggingFace repo ID (e.g., "anonymous/bugbench_new")
        registry_name: Name to use in DatasetRegistry (e.g., "bugbench")
        splits: List of splits to load and register (default: ["train", "test", "test_small"])
        excluded_indexes_file: Optional path to excluded indexes JSON file
        push_to_hub: Whether to push processed dataset back to HF
        push_repo_id: HF repo ID to push to (defaults to repo_id)
        num_train_examples: Optional number of training examples for train_{num} split (matches bigcodebench_new/train_{num})
    
    Returns:
        Dict mapping split name -> registered dataset object
    """
    print(f"\n{'='*80}")
    print(f"Processing: {repo_id}")
    print(f"{'='*80}")
    
    # Load excluded task_ids if provided (once for all splits)
    excluded_task_ids = set()
    if excluded_indexes_file:
        try:
            with open(excluded_indexes_file, "r") as f:
                excluded_data = json.load(f)
                excluded_indexes = excluded_data.get("excluded_indexes", {}).get("bigcodebench", {}).get("v0.1.0_hf", [])
                excluded_task_ids = {f"BigCodeBench/{idx}" for idx in excluded_indexes}
                print(f"[excluded] Loaded {len(excluded_task_ids)} excluded task_ids")
        except FileNotFoundError:
            print(f"[WARN] excluded_indexes_file not found: {excluded_indexes_file} (skipping exclusions)")
        except Exception as e:
            print(f"[WARN] Failed to load excluded_indexes_file: {e} (skipping exclusions)")
    
    # Process each split
    registered_splits = {}
    processed_splits = {}
    
    for split in splits:
        try:
            result = process_and_register_split(
                repo_id=repo_id,
                registry_name=registry_name,
                split=split,
                excluded_indexes_file=excluded_indexes_file,
            )
            if result is not None and result[0] is not None:  # registered is not None
                registered, processed = result
                registered_splits[split] = registered
                processed_splits[split] = processed
        except Exception as e:
            print(f"    [ERROR] Failed to process {split}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Create custom train_{num_examples} split if requested (matching bigcodebench_new/train_{n} task_ids)
    if num_train_examples is not None and "train" in processed_splits:
        train_custom_name = f"train_{num_train_examples}"
        bcb_split_name = f"train_{num_train_examples}"
        print(f"\n  [Creating {train_custom_name}] Matching task_ids from bigcodebench_new/{bcb_split_name}...")
        
        # Load bigcodebench_new/train_{n} to get the task_ids
        try:
            bcb_dataset = DatasetRegistry.load_dataset("bigcodebench_new", bcb_split_name)
            if bcb_dataset is None:
                raise ValueError(f"Dataset bigcodebench_new/{bcb_split_name} not found in registry")
            bcb_task_ids = set(ex["task_id"] for ex in bcb_dataset)
            print(f"    Loaded {len(bcb_task_ids)} task_ids from bigcodebench_new/{bcb_split_name}")
        except Exception as e:
            print(f"    [ERROR] Failed to load bigcodebench_new/{bcb_split_name}: {e}")
            print(f"    [SKIP] Cannot create {train_custom_name} without bigcodebench reference split")
            bcb_task_ids = None
        
        if bcb_task_ids is not None:
            train_dataset = processed_splits["train"]
            
            # Filter to only include examples with matching task_ids
            # In bugbench, the uid field contains the task_id
            train_custom_dataset = train_dataset.filter(
                lambda ex: ex.get("uid") in bcb_task_ids or ex.get("task_id") in bcb_task_ids,
                num_proc=16
            )
            
            print(f"    Filtered: {len(train_custom_dataset)} examples (from {len(train_dataset)} train examples)")
            
            if len(train_custom_dataset) == 0:
                print(f"    [WARN] No matching task_ids found. Skipping {train_custom_name}.")
            else:
                # Register the custom train split
                print(f"    [register] Registering as {registry_name}/{train_custom_name}...")
                train_custom_registered = DatasetRegistry.register_dataset(str(registry_name), train_custom_dataset, train_custom_name)
                print(f"    [register] Registered {len(train_custom_registered)} examples")
                print(f"    [register] Path: {train_custom_registered.get_data_path()}")
                
                registered_splits[train_custom_name] = train_custom_registered
                processed_splits[train_custom_name] = train_custom_dataset
    
    # Push to HF if requested
    if push_to_hub and processed_splits:
        dst_repo = str(push_repo_id or repo_id)
        print(f"\n  [push_to_hub] Pushing all splits to {dst_repo}...")
        dd = DatasetDict(processed_splits)
        dd.push_to_hub(repo_id=dst_repo)
        print(f"  [push_to_hub] Done pushing {len(processed_splits)} splits to {dst_repo}")
    
    return registered_splits


def main():
    parser = argparse.ArgumentParser(
        description="Preprocess BugBench(-like) HF datasets into RLLM format and register them."
    )
    parser.add_argument(
        "--repo-ids",
        type=str,
        nargs="+",
        default=["anonymous/bugbench_v2", "anonymous/bugbench_human_new", "anonymous/bugbench_qwen7b_sampled_new", "anonymous/bugbench_gpt-oss-20b_sampled_new", "anonymous/bugbench_adversarial_new"],
        help="List of HuggingFace repo IDs (e.g., 'anonymous/bugbench_v2 anonymous/bugbench_human_new')",
    )
    parser.add_argument(
        "--registry-names",
        type=str,
        nargs="+",
        default=["bugbench", "bugbench_human", "bugbench_qwen7b_sampled", "bugbench_gpt-oss-20b_sampled", "bugbench_adversarial"],
        help="List of registry names (one per repo-id). If not provided, derived from repo IDs.",
    )
    parser.add_argument(
        "--splits",
        type=str,
        nargs="+",
        default=["train", "test", "test_small"],
        help="Splits to load from each HF dataset (default: ['train', 'test', 'test_small'])",
    )
    parser.add_argument(
        "--excluded-indexes-file",
        type=str,
        default=None,
        help="Optional path to excluded indexes JSON file",
    )
    parser.add_argument(
        "--no-push-to-hub",
        action="store_true",
        default=True,
        help="Don't push processed datasets to Hugging Face (default: push to hub).",
    )
    parser.add_argument(
        "--push-repo-ids",
        type=str,
        nargs="+",
        default=None,
        help="List of HF repo IDs to push to (one per repo-id). Defaults to --repo-ids.",
    )
    parser.add_argument(
        "--num-train-examples",
        type=int,
        default=None,
        help="Create a train_{num} split matching task_ids from bigcodebench_new/train_{num}",
    )
    
    args = parser.parse_args()
    
    repo_ids = args.repo_ids
    n_repos = len(repo_ids)
    
    # Validate and set registry names
    if args.registry_names:
        if len(args.registry_names) != n_repos:
            raise ValueError(
                f"Number of --registry-names ({len(args.registry_names)}) must match "
                f"number of --repo-ids ({n_repos})"
            )
        registry_names = args.registry_names
    else:
        # Derive registry names from repo IDs (take the part after the last '/')
        registry_names = [repo_id.split("/")[-1] for repo_id in repo_ids]
    
    # Validate push repo IDs
    if args.push_repo_ids:
        if len(args.push_repo_ids) != n_repos:
            raise ValueError(
                f"Number of --push-repo-ids ({len(args.push_repo_ids)}) must match "
                f"number of --repo-ids ({n_repos})"
            )
        push_repo_ids = args.push_repo_ids
    else:
        push_repo_ids = [None] * n_repos  # Will default to repo_id in function
    
    print("=" * 80)
    print("Preparing BugBench Datasets")
    print("=" * 80)
    print(f"Repositories: {repo_ids}")
    print(f"Registry names: {registry_names}")
    print(f"Splits: {args.splits}")
    if args.num_train_examples:
        print(f"Custom train split: train_{args.num_train_examples} (matching bigcodebench_new/train_{args.num_train_examples})")
    push_to_hub = not args.no_push_to_hub
    print(f"Push to hub: {push_to_hub}")
    if push_to_hub:
        print(f"Push repo IDs: {push_repo_ids if args.push_repo_ids else repo_ids}")
    print("=" * 80)
    
    # Process each dataset
    all_registered_splits = []
    for i, (repo_id, registry_name, push_repo_id) in enumerate(zip(repo_ids, registry_names, push_repo_ids)):
        try:
            registered_splits = process_and_register_dataset(
                repo_id=repo_id,
                registry_name=registry_name,
                splits=args.splits,
                excluded_indexes_file=args.excluded_indexes_file,
                push_to_hub=push_to_hub,
                push_repo_id=push_repo_id,
                num_train_examples=args.num_train_examples,
            )
            all_registered_splits.append(registered_splits)
        except Exception as e:
            print(f"\n[ERROR] Failed to process {repo_id}: {e}")
            import traceback
            traceback.print_exc()
            all_registered_splits.append({})
            continue
    
    # Summary
    print("\n" + "=" * 80)
    print("Summary")
    print("=" * 80)
    # Build list of splits to show in summary (include custom train split if created)
    summary_splits = list(args.splits)
    if args.num_train_examples:
        custom_train_name = f"train_{args.num_train_examples}"
        if custom_train_name not in summary_splits:
            summary_splits.append(custom_train_name)
    
    for repo_id, registry_name, registered_splits in zip(repo_ids, registry_names, all_registered_splits):
        if registered_splits:
            print(f"\n  {repo_id} -> {registry_name}:")
            for split_name in summary_splits:
                if split_name in registered_splits:
                    registered = registered_splits[split_name]
                    print(f"    {split_name}: {len(registered)} examples")
                    print(f"      Path: {registered.get_data_path()}")
                else:
                    print(f"    {split_name}: NOT REGISTERED")
    print("=" * 80)
    print("\n✅ Done!")


if __name__ == "__main__":
    main()
