#!/usr/bin/env python3
"""
Test to verify MBPP adapter splits have no data leakage.
"""

from supervised_code.data_generation.dataset_adapters import MBPPAdapter


def test_mbpp_adapter_no_data_leakage():
    """Test that train split has no overlap with valid/test splits."""
    adapter = MBPPAdapter()

    # Load all splits
    train_dataset = adapter.load_dataset("train")
    valid_dataset = adapter.load_dataset("valid")
    test_dataset = adapter.load_dataset("test")

    # Extract task IDs
    train_ids = set(adapter.extract_problem_name(example) for example in train_dataset)
    valid_ids = set(adapter.extract_problem_name(example) for example in valid_dataset)
    test_ids = set(adapter.extract_problem_name(example) for example in test_dataset)

    print(f"Train split: {len(train_ids)} examples")
    print(f"Valid split: {len(valid_ids)} examples")
    print(f"Test split: {len(test_ids)} examples")

    # Check for overlaps
    train_valid_overlap = train_ids.intersection(valid_ids)
    train_test_overlap = train_ids.intersection(test_ids)
    valid_test_overlap = valid_ids.intersection(test_ids)

    print(f"\nOverlap analysis:")
    print(f"Train ∩ Valid: {len(train_valid_overlap)} examples")
    print(f"Train ∩ Test: {len(train_test_overlap)} examples")
    print(f"Valid ∩ Test: {len(valid_test_overlap)} examples")

    if train_valid_overlap:
        print(f"Train/Valid overlap IDs: {sorted(train_valid_overlap)}")
    if train_test_overlap:
        print(f"Train/Test overlap IDs: {sorted(train_test_overlap)}")
    if valid_test_overlap:
        print(f"Valid/Test overlap IDs: {sorted(valid_test_overlap)}")

    # Assertions
    assert (
        len(train_valid_overlap) == 0
    ), f"Train and valid splits overlap: {train_valid_overlap}"
    assert (
        len(train_test_overlap) == 0
    ), f"Train and test splits overlap: {train_test_overlap}"

    print("\n✅ All assertions passed - no data leakage detected!")


if __name__ == "__main__":
    test_mbpp_adapter_no_data_leakage()
