import matplotlib.pyplot as plt
import numpy as np
import pytest

from sde.dataset_generators import RandomBackgroundGenerator, UniformBackgroundGenerator


@pytest.mark.parametrize("base_color", [(255, 255, 231), (53, 132, 76), (0, 21, 51)])
def test_uniform_background(base_color, plot=False):
    # no need to test single channel
    generator = UniformBackgroundGenerator(channels=3, base_color=base_color)
    background = generator.get_background()

    # check shape
    assert background.shape == (224, 224, 3)

    # check value
    for idx, color_value in enumerate(base_color):
        assert np.all(background[:, :, idx] == color_value)

    # plot to check visually if needed
    if plot:
        plt.imshow(background)
        plt.show()


@pytest.mark.parametrize("exclude_colors", [[(255, 255, 231), (53, 132, 76)]])
def test_random_background(exclude_colors, plot=False):
    generator = RandomBackgroundGenerator(channels=3, exclude_colors=exclude_colors)
    background = generator.get_background()

    # check shape
    assert background.shape == (224, 224, 3)

    # check value
    for color in exclude_colors:
        assert not np.any(np.all(background[:, :] == color, axis=-1))

    # plot to check visually if needed
    if plot:
        plt.imshow(background)
        plt.show()
