import numpy as np

from generalization_study.modify_datasets import \
    get_fixed_random_mask


def test_get_fixed_random_mask():
    a_shuffled = get_fixed_random_mask(0.5, 100)
    b_shuffled = get_fixed_random_mask(0.5, 100)
    assert np.all(a_shuffled == b_shuffled)


if __name__ == '__main__':
    print('running tests')
    print('test_get_fixed_random_mask')
    test_get_fixed_random_mask()
