import pytest
import torch
import warnings
from src.datasets.dataset_utils.dataset_splitting import split_by_class_samples  # adjust import path if needed


def test_basic_split_reproducibility():
    """Test that split is deterministic for a given seed."""
    y = torch.tensor([1, 2, 0, 0, 2, 2, 1, 2, 1])
    t1, v1 = split_by_class_samples(y, samples_per_class=2, seed=123)
    t2, v2 = split_by_class_samples(y, samples_per_class=2, seed=123)
    assert t1 == t2
    assert v1 == v2


def test_number_of_train_samples_per_class():
    """Each class should have exactly samples_per_class in the training set (if available)."""
    y = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])
    samples_per_class = 2
    train_idx, test_idx = split_by_class_samples(y, samples_per_class)

    for cls in torch.unique(y):
        train_cls = [i for i in train_idx if y[i] == cls]
        assert len(train_cls) == samples_per_class or len(train_cls) == sum(y == cls)


def test_all_indices_are_used_once():
    """All dataset indices must be present exactly once in either train or test."""
    y = torch.tensor([0, 1, 1, 2, 2, 2])
    train_idx, test_idx = split_by_class_samples(y, 2)
    all_indices = sorted(train_idx + test_idx)
    assert all_indices == list(range(len(y)))
    assert set(train_idx).isdisjoint(test_idx)


def test_invalid_dim_raises():
    """Raise ValueError when y is not 1D."""
    y = torch.tensor([[1, 2], [3, 4]])
    with pytest.raises(ValueError):
        split_by_class_samples(y, 1)


def test_invalid_samples_per_class_raises():
    """Raise ValueError when samples_per_class <= 0."""
    y = torch.tensor([0, 1, 1, 2])
    with pytest.raises(ValueError):
        split_by_class_samples(y, 0)


def test_class_with_few_samples_warns():
    """Warn when class has fewer samples than requested."""
    y = torch.tensor([0, 1, 1, 1])  # class 0 only has 1 sample
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        train_idx, test_idx = split_by_class_samples(y, 2)
        assert any("no samples of class" in str(wi.message).lower() for wi in w)
    # All class 0 samples go to train
    assert set(train_idx).issuperset({0})
