"""
Sanity check that all BugBench datasets are split correctly based on BigCodeBench task_ids.

For each BugBench dataset (from HuggingFace):
- Load train, test, train_large, test_small, test_all splits
- Check that train task_ids are NOT in BigCodeBench test_all
- Check that train_large task_ids are NOT in BigCodeBench test
- Check that test task_ids are NOT in BigCodeBench train
- Check that test_small task_ids are NOT in BigCodeBench train_large
"""
from datasets import load_dataset
from rllm.data.dataset import DatasetRegistry


def _get_task_id(example: dict) -> str:
    """Extract task_id from example, preferring 'task_id' then 'uid'."""
    task_id = example.get("task_id") or example.get("uid")
    if task_id is None:
        return ""
    task_id_str = str(task_id).strip()
    # Filter out "None" string and empty strings
    if task_id_str == "None" or not task_id_str:
        return ""
    return task_id_str


def _get_repo_id(dataset_name: str) -> str:
    """Convert dataset name to HuggingFace repo ID."""
    return f"anonymous/{dataset_name}"


def _load_bigcodebench_task_ids(dataset_name: str, split: str) -> set:
    """Load task_ids from BigCodeBench dataset in DatasetRegistry."""
    try:
        ds = DatasetRegistry.load_dataset(dataset_name, split)
        if ds is None:
            print(f"  [WARN] BigCodeBench {dataset_name}/{split} not found in registry")
            return set()
        task_ids = set()
        for entry in ds:
            task_id = _get_task_id(entry)
            if task_id and task_id != "None":
                task_ids.add(task_id)
        return task_ids
    except Exception as e:
        print(f"  [WARN] Failed to load BigCodeBench {dataset_name}/{split}: {e}")
        return set()


def _load_hf_dataset_task_ids(repo_id: str, split: str) -> set:
    """Load task_ids from HuggingFace dataset."""
    try:
        ds = load_dataset(repo_id, split=split)
        task_ids = set()
        for example in ds:
            task_id = _get_task_id(example)
            if task_id and task_id != "None":
                task_ids.add(task_id)
        return task_ids
    except Exception as e:
        # Don't print error for missing splits - they might not exist
        return set()


def check_no_overlap(
    split_name: str,
    task_ids: set,
    forbidden_set: set,
    forbidden_name: str,
) -> bool:
    """Check that task_ids do not overlap with forbidden set. Returns True if no overlap."""
    if not task_ids:
        print(f"    {split_name}: (empty or not found) - SKIP")
        return True
    
    overlap = task_ids & forbidden_set
    if overlap:
        print(f"    {split_name}: ❌ {len(overlap)} task_ids overlap with {forbidden_name}!")
        print(f"      Sample: {sorted(list(overlap))[:10]}")
        return False
    else:
        print(f"    {split_name}: ✅ No overlap with {forbidden_name} ({len(task_ids)} task_ids)")
        return True


def main():
    # BugBench datasets to check (from HuggingFace)
    bugbench_datasets = [
        "bugbench_new",
        "bugbench_human_new",
        "bugbench_qwen7b_sampled_new",
        "bugbench_gpt-oss-20b_sampled_new",
        "bugbench_adversarial_new",
    ]
    
    print("=" * 80)
    print("Sanity Check: BugBench Dataset Splits (No Train/Test Overlap)")
    print("=" * 80)
    
    # Load BigCodeBench task_ids for all splits
    print("\n[Loading BigCodeBench splits from registry]")
    
    bcb_splits = {}
    for split in ["train", "train_large", "test", "test_small", "test_all"]:
        bcb_splits[split] = _load_bigcodebench_task_ids("bigcodebench_new", split)
        print(f"  bigcodebench_new/{split}: {len(bcb_splits[split])} task_ids")
    
    # Check each BugBench dataset
    print("\n" + "=" * 80)
    print("Checking BugBench Datasets")
    print("=" * 80)
    
    all_passed = True
    
    for dataset_name in bugbench_datasets:
        repo_id = _get_repo_id(dataset_name)
        print(f"\n{'='*60}")
        print(f"[{dataset_name}] {repo_id}")
        print(f"{'='*60}")
        
        # Load all splits
        print(f"  Loading splits...")
        bugbench_splits = {}
        for split in ["train", "train_large", "test", "test_small", "test_all"]:
            bugbench_splits[split] = _load_hf_dataset_task_ids(repo_id, split)
        
        dataset_passed = True
        
        # Check train split: should NOT overlap with BCB test_all
        passed = check_no_overlap(
            split_name="train",
            task_ids=bugbench_splits["train"],
            forbidden_set=bcb_splits["test_all"],
            forbidden_name="BCB test_all",
        )
        dataset_passed = dataset_passed and passed
        
        # Check train_large split: should NOT overlap with BCB test
        passed = check_no_overlap(
            split_name="train_large",
            task_ids=bugbench_splits["train_large"],
            forbidden_set=bcb_splits["test"],
            forbidden_name="BCB test",
        )
        dataset_passed = dataset_passed and passed
        
        # Check test split: should NOT overlap with BCB train
        passed = check_no_overlap(
            split_name="test",
            task_ids=bugbench_splits["test"],
            forbidden_set=bcb_splits["train"],
            forbidden_name="BCB train",
        )
        dataset_passed = dataset_passed and passed
        
        # Check test_small split: should NOT overlap with BCB train_large
        passed = check_no_overlap(
            split_name="test_small",
            task_ids=bugbench_splits["test_small"],
            forbidden_set=bcb_splits["train_large"],
            forbidden_name="BCB train_large",
        )
        dataset_passed = dataset_passed and passed
        
        # Check test_all split: should NOT overlap with BCB train
        passed = check_no_overlap(
            split_name="test_all",
            task_ids=bugbench_splits["test_all"],
            forbidden_set=bcb_splits["train"],
            forbidden_name="BCB train",
        )
        dataset_passed = dataset_passed and passed
        
        if dataset_passed:
            print(f"\n  ✅ [{dataset_name}] PASSED")
        else:
            print(f"\n  ❌ [{dataset_name}] FAILED")
            all_passed = False
    
    print("\n" + "=" * 80)
    if all_passed:
        print("✅ All datasets passed validation!")
    else:
        print("❌ Some datasets failed validation. See errors above.")
    print("=" * 80)


if __name__ == "__main__":
    main()
