import random

from torch.utils.data import SequentialSampler

from src.dataset.torch_batch_sampler import BatchSamplerSyncRandom


def test_random_sync():
    loaded_data = list(
        BatchSamplerSyncRandom(
            SequentialSampler(range(10)),
            batch_size=3,
            drop_last=False,
            random_generator=lambda: random.uniform(0.5, 1),
        )
    )
    for batch in loaded_data:
        assert all(
            batch[0][1] == k[1] for k in batch
        ), "Random number not the same in batch"

    assert (
        loaded_data[0][0][1] != loaded_data[1][0][1]
    ), "Random number between two batches is the same"
