#!/usr/bin/env python3
"""
Prepare new BigCodeBench dataset with splits train, test, test_all based on common_indices.

This script:
1. Loads common_indices from JSON files (common_indices_test_all.json, common_indices_test.json)
2. Loads BCB dataset from anonymous/bigcodebench split v0.1.0_hf
3. Loads excluded indexes from bigcodebench_excluded_indexes.json
4. Creates splits:
   - test_all: task_ids in common_indices_test_all.json
   - test: task_ids in common_indices_test.json
   - train: all BCB task_ids EXCEPT excluded ones AND ones in v0.1.0_hf split
   - train_large: all BCB task_ids EXCEPT excluded ones AND ones in test (smaller)
5. Registers all splits with DatasetRegistry

Usage:
    python prepare_bigcodebench_data_new.py                    # Register all splits
    python prepare_bigcodebench_data_new.py --only-train-large # Only register train_large
    python prepare_bigcodebench_data_new.py --num-train-examples 500  # Create train_500 split with 500 shuffled examples
"""
import argparse
import json
import os
import re
from pathlib import Path
from typing import Set

from datasets import load_dataset

from rllm.data.dataset import DatasetRegistry
from rllm.data.utils import fetch_live_code_bench_system_prompt


def preprocess_fn(example, idx):
    """
    Bug generation / fixing task format needs to have the following keys:
    - question: the problem description
    - reference_solution: the reference solution code
    - ground_truth: the test cases
    - data_source: the dataset name
    - uid: the unique identifier for the example
    - index: the index of the example
    - starter_code: the starter code
    - metadata: the metadata for the example
    """
    # Remove everything starting from "You should write self-contained code starting with:" from instruct_prompt
    # Use regex to handle variations in spacing and formatting
    instruct_prompt = example["instruct_prompt"]
    # Pattern matches "You should write self-contained code starting with" followed by optional colon and whitespace
    pattern = r"You should write self-contained code starting with\s*:?\s*"
    truncated_instruct_prompt = re.split(pattern, instruct_prompt, maxsplit=1)[0]
        
    # Strip trailing whitespace and newlines
    truncated_instruct_prompt = truncated_instruct_prompt.rstrip()
    code_prompt = example["code_prompt"]
    question = fetch_live_code_bench_system_prompt(truncated_instruct_prompt, code_prompt)
    reference_solution = "```python\n" + example["code_prompt"] + "\n" + example["canonical_solution"] + "\n```"
    
    # Prefer stable id based on task_id suffix (BigCodeBench/1132 -> bigcodebench_1132)
    task_id = example.get("task_id")
    suffix = None
    if isinstance(task_id, str) and "/" in task_id:
        suffix = task_id.split("/")[-1]
    if suffix is not None and suffix.isdigit():
        uid = f"bigcodebench_{suffix}"
        index = int(suffix)
    else:
        uid = f"bigcodebench_{idx}"
        index = int(idx)
    
    return {
        "question": question,
        "reference_solution": reference_solution,
        "ground_truth": example["test"],
        "data_source": "bigcodebench",
        "index": index,
        # Keep BigCodeBench fields so we can build the *native* instruct prompt at eval time.
        # (The `question` above is LCB-formatted for compatibility with other code tasks.)
        "starter_code": code_prompt,
        "instruct_prompt": example.get("instruct_prompt"),
        "truncated_instruct_prompt": truncated_instruct_prompt,
        "code_prompt": code_prompt,
        "entry_point": example["entry_point"],
        "metadata": {"func_name": example["entry_point"]},
        "task_id": task_id,
    }


def load_common_indices(json_path: str) -> Set[str]:
    """Load common indices from a JSON file."""
    with open(json_path, "r") as f:
        task_ids = json.load(f)
    return set(task_ids)


def load_excluded_task_ids(excluded_indexes_path: str, hf_dataset) -> Set[str]:
    """Load excluded task_ids from excluded_indexes.json by mapping indexes to task_ids."""
    with open(excluded_indexes_path, "r") as f:
        excluded_data = json.load(f)
    
    # Get excluded indexes from the nested structure
    excluded_indexes = excluded_data.get("excluded_indexes", {}).get("bigcodebench", {}).get("v0.1.0_hf", [])
    
    # Map indexes to task_ids
    excluded_task_ids = set()
    for idx in excluded_indexes:
        if idx < len(hf_dataset):
            task_id = hf_dataset[idx]["task_id"]
            excluded_task_ids.add(task_id)
    
    print(f"[Excluded] Found {len(excluded_indexes)} excluded indexes -> {len(excluded_task_ids)} excluded task_ids")
    return excluded_task_ids


def main():
    parser = argparse.ArgumentParser(description="Prepare BigCodeBench dataset with splits")
    parser.add_argument("--only-train-large", action="store_true",
                        help="Only create and register train_large split (skip other splits)")
    parser.add_argument("--num-train-examples", type=int, default=None,
                        help="Create a train_{num} split with this many shuffled examples")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for shuffling (default: 42)")
    args = parser.parse_args()
    
    only_train_large = args.only_train_large
    num_train_examples = args.num_train_examples
    shuffle_seed = args.seed
    
    script_dir = Path(__file__).parent
    data_dir = script_dir
    
    # File paths
    common_indices_test_all_path = data_dir / "common_indices_test_all.json"
    common_indices_test_path = data_dir / "common_indices_test.json"
    excluded_indexes_path = data_dir / "bigcodebench_excluded_indexes.json"
    
    print("=" * 80)
    print("Preparing BigCodeBench Dataset (New)")
    if only_train_large:
        print("  [MODE: only-train-large]")
    if num_train_examples:
        print(f"  [MODE: num-train-examples={num_train_examples}, seed={shuffle_seed}]")
    print("=" * 80)
    
    # Load common indices
    print("\n[Step 1] Loading common indices...")
    if not common_indices_test_all_path.exists():
        raise FileNotFoundError(f"common_indices_test_all.json not found at {common_indices_test_all_path}")
    if not common_indices_test_path.exists():
        raise FileNotFoundError(f"common_indices_test.json not found at {common_indices_test_path}")
    
    test_all_task_ids = load_common_indices(common_indices_test_all_path)
    test_task_ids = load_common_indices(common_indices_test_path)
    
    print(f"  test_all: {len(test_all_task_ids)} task_ids")
    print(f"  test: {len(test_task_ids)} task_ids")
    
    # Load BCB dataset
    print("\n[Step 2] Loading BigCodeBench dataset...")
    hf_dataset = load_dataset("anonymous/bigcodebench", split="v0.1.0_hf")
    print(f"  Loaded {len(hf_dataset)} examples from anonymous/bigcodebench v0.1.0_hf")
    
    # Get all task_ids in the dataset
    all_task_ids_in_dataset = set(hf_dataset["task_id"])
    print(f"  Total unique task_ids in dataset: {len(all_task_ids_in_dataset)}")
    
    # Load excluded task_ids
    print("\n[Step 3] Loading excluded task_ids...")
    if not excluded_indexes_path.exists():
        print(f"  Warning: {excluded_indexes_path} not found. No task_ids will be excluded.")
        excluded_task_ids = set()
    else:
        excluded_task_ids = load_excluded_task_ids(excluded_indexes_path, hf_dataset)
    
    # Get task_ids in v0.1.0_hf split (which is all task_ids in the loaded dataset)
    v0_1_0_hf_task_ids = all_task_ids_in_dataset
    print(f"  Task_ids in v0.1.0_hf split: {len(v0_1_0_hf_task_ids)}")
    
    # Create splits
    print("\n[Step 4] Creating splits...")
    
    test_all_dataset = None
    test_dataset = None
    train_dataset = None
    
    if not only_train_large:
        # test_all split: filter by common_indices_test_all.json
        print("\n  Creating test_all split...")
        test_all_filtered = hf_dataset.filter(
            lambda ex: ex["task_id"] in test_all_task_ids,
            num_proc=16
        )
        print(f"    Filtered: {len(test_all_filtered)} examples")
        
        test_all_dataset = test_all_filtered.map(
            preprocess_fn,
            with_indices=True,
            writer_batch_size=10,
            num_proc=16,
            remove_columns=test_all_filtered.column_names,
        )
        print(f"    Preprocessed: {len(test_all_dataset)} examples")
        
        # test split: filter by common_indices_test.json
        print("\n  Creating test split...")
        test_filtered = hf_dataset.filter(
            lambda ex: ex["task_id"] in test_task_ids,
            num_proc=16
        )
        print(f"    Filtered: {len(test_filtered)} examples")
        
        test_dataset = test_filtered.map(
            preprocess_fn,
            with_indices=True,
            writer_batch_size=10,
            num_proc=16,
            remove_columns=test_filtered.column_names,
        )
        print(f"    Preprocessed: {len(test_dataset)} examples")
        
        # train split: all BCB task_ids EXCEPT excluded ones AND ones in test_all (common indices)
        # According to requirement: "train is all bcb task_id's except those that are excluded 
        # and those in anonymous/bigcodebench split v0.1.0_hf"
        # We interpret "those in v0.1.0_hf" as task_ids in test_all (common indices used for testing)
        print("\n  Creating train split...")
        
        # Task_ids to exclude from train:
        # 1. Excluded task_ids (from excluded_indexes.json)
        # 2. Task_ids in test_all (common indices used for testing)
        
        train_task_ids = all_task_ids_in_dataset - excluded_task_ids - test_all_task_ids
        print(f"    Train task_ids: {len(train_task_ids)} (all: {len(all_task_ids_in_dataset)}, "
              f"excluded: {len(excluded_task_ids)}, test_all: {len(test_all_task_ids)})")
        
        train_filtered = hf_dataset.filter(
            lambda ex: ex["task_id"] in train_task_ids,
            num_proc=16
        )
        print(f"    Filtered: {len(train_filtered)} examples")
        
        train_dataset = train_filtered.map(
            preprocess_fn,
            with_indices=True,
            writer_batch_size=10,
            num_proc=16,
            remove_columns=train_filtered.column_names,
        )
        print(f"    Preprocessed: {len(train_dataset)} examples")
    
    # train_large split: all BCB task_ids EXCEPT excluded ones AND ones in test (smaller common indices)
    print("\n  Creating train_large split...")
    
    train_large_task_ids = all_task_ids_in_dataset - excluded_task_ids - test_task_ids
    print(f"    Train_large task_ids: {len(train_large_task_ids)} (all: {len(all_task_ids_in_dataset)}, "
          f"excluded: {len(excluded_task_ids)}, test: {len(test_task_ids)})")
    
    train_large_filtered = hf_dataset.filter(
        lambda ex: ex["task_id"] in train_large_task_ids,
        num_proc=16
    )
    print(f"    Filtered: {len(train_large_filtered)} examples")
    
    train_large_dataset = train_large_filtered.map(
        preprocess_fn,
        with_indices=True,
        writer_batch_size=10,
        num_proc=16,
        remove_columns=train_large_filtered.column_names,
    )
    print(f"    Preprocessed: {len(train_large_dataset)} examples")
    
    # Create custom train_{num_examples} split if requested
    train_custom_dataset = None
    train_custom_name = None
    if num_train_examples:
        print(f"\n  Creating train_{num_train_examples} split (shuffled, seed={shuffle_seed})...")
        
        # Shuffle the train_large dataset and select the first num_train_examples
        shuffled_dataset = train_large_dataset.shuffle(seed=shuffle_seed)
        
        if num_train_examples > len(shuffled_dataset):
            print(f"    Warning: Requested {num_train_examples} examples but only {len(shuffled_dataset)} available.")
            print(f"    Using all {len(shuffled_dataset)} examples.")
            train_custom_dataset = shuffled_dataset
        else:
            train_custom_dataset = shuffled_dataset.select(range(num_train_examples))
        
        train_custom_name = f"train_{num_train_examples}"
        print(f"    Created: {len(train_custom_dataset)} examples")
    
    # Register splits with DatasetRegistry
    print("\n[Step 5] Registering splits with DatasetRegistry...")
    registry_name = "bigcodebench_new"
    
    train_registered = None
    train_large_registered = None
    test_registered = None
    test_all_registered = None
    
    if not only_train_large:
        train_registered = DatasetRegistry.register_dataset(registry_name, train_dataset, "train")
        print(f"  Registered {registry_name}/train: {len(train_registered)} examples")
        print(f"    Path: {train_registered.get_data_path()}")
    
    train_large_registered = DatasetRegistry.register_dataset(registry_name, train_large_dataset, "train_large")
    print(f"  Registered {registry_name}/train_large: {len(train_large_registered)} examples")
    print(f"    Path: {train_large_registered.get_data_path()}")
    
    train_custom_registered = None
    if train_custom_dataset is not None:
        train_custom_registered = DatasetRegistry.register_dataset(registry_name, train_custom_dataset, train_custom_name)
        print(f"  Registered {registry_name}/{train_custom_name}: {len(train_custom_registered)} examples")
        print(f"    Path: {train_custom_registered.get_data_path()}")
    
    if not only_train_large:
        test_registered = DatasetRegistry.register_dataset(registry_name, test_dataset, "test_small")
        print(f"  Registered {registry_name}/test_small: {len(test_registered)} examples")
        print(f"    Path: {test_registered.get_data_path()}")
        
        test_all_registered = DatasetRegistry.register_dataset(registry_name, test_all_dataset, "test")
        print(f"  Registered {registry_name}/test: {len(test_all_registered)} examples")
        print(f"    Path: {test_all_registered.get_data_path()}")
    
    print("\n" + "=" * 80)
    print("Summary")
    print("=" * 80)
    print(f"Dataset: {registry_name}")
    if train_registered:
        print(f"  train: {len(train_registered)} examples")
    print(f"  train_large: {len(train_large_registered)} examples")
    if train_custom_registered:
        print(f"  {train_custom_name}: {len(train_custom_registered)} examples")
    if test_registered:
        print(f"  test_small: {len(test_registered)} examples")
    if test_all_registered:
        print(f"  test: {len(test_all_registered)} examples")
    print("=" * 80)
    print("\n✅ Done!")


if __name__ == "__main__":
    main()
