import numpy as np
import pytest

from sde.dataset_generators import MultiPatchSumDataset, SingleColorModuloDatasetGenerator


@pytest.mark.parametrize(
    "bg_generator",
    [
        dict(type='uniform', channels=3, base_color=(0, 0, 0)),
        dict(type='random', channels=3, exclude_colors=((255, 255, 255),))
    ])
def test_single_color_make_image(bg_generator, tmp_path):
    """
    This function tests the random pixel value of uniform generation
    Test is done by visual inspection
    """
    generator = SingleColorModuloDatasetGenerator(
        str(tmp_path), num_classes=100, num_samples=50, length=50, debug=True, bg_generator=bg_generator)
    img, gt_mask = generator._make_image(label=None)
    generated_label = generator._make_label(img, gt_mask=gt_mask)

    # get the true class from the image
    img[np.logical_not(gt_mask)] = 0.
    sum_val = img.sum() // (255 * 3)
    calculated_label = sum_val % 100
    assert calculated_label == generated_label

    # test GT mask
    generator.debug = False
    for i in range(10):
        img, gt_mask = generator._make_image(label=None)
        assert np.array_equal(
            gt_mask == 255,
            np.logical_and(np.logical_and(img[:, :, 0] == 255, img[:, :, 1] == 255), img[:, :, 2] == 255))


@pytest.mark.parametrize(
    "color_list, length",
    [
        # yapf: disable
        ([[250, 240, 215]], 50),
        ([[250, 240, 215], [5, 124, 62]], 50),
        ([[250, 240, 215], [5, 124, 62], [62, 125, 7], [72, 100, 100]], 50),
        ([[250, 240, 215], [5, 124, 62], [62, 125, 7], [72, 100, 100], [36, 97, 0], [221, 65, 99]], 30)
        # yapf: enable
    ])
def test_multi_patch_make_image(color_list, length):
    # instantiate with dummy params
    num_patches = len(color_list)
    generator = MultiPatchSumDataset('', num_samples=1, length=length, color_list=color_list, debug=True)

    points = generator._generate_non_overlapping_points(num_patches, length=length, width=length)
    for point in points:
        assert isinstance(point[0], int) and isinstance(point[1], int)

    # check generated points match our setting
    assert len(points) == num_patches

    img, _ = generator._make_image(None)

    # in debug mode, all pixels inside patches are 1, so we can check overlapping by checking sum
    # because the patch is now not a rectangular, so the sum should be smaller than N rectangular
    assert img.sum() <= 3 * 255 * num_patches * length * length

    # assert overall gt_mask correctness
    img, gt_masks = generator._make_image(None)
    assert img.sum() == gt_masks[-1].sum() * 3

    # assert single gt_mask correctness by masking the desired color
    generator.debug = False
    img, gt_masks = generator._make_image(None)
    for color, gt_mask in zip(color_list, gt_masks):
        for i in range(3):
            img[:, :, i] *= (gt_mask == 0)
        assert not np.array_equal(
            gt_mask == 255,
            np.logical_and(
                np.logical_and(img[:, :, 0] == color[0], img[:, :, 1] == color[1]), img[:, :, 2] == color[2]))


@pytest.mark.parametrize(
    "color_list, length",
    [
        # yapf: disable
        ([[250, 240, 215]], 50),
        ([[250, 240, 215], [5, 124, 62]], 50),
        ([[250, 240, 215], [5, 124, 62], [62, 125, 7], [72, 100, 100]], 50),
        ([[250, 240, 215], [5, 124, 62], [62, 125, 7], [72, 100, 100], [36, 97, 0], [221, 65, 99]], 30)
        # yapf: enable
    ])
def test_multi_patch_make_label(color_list, length):
    """
    This test function just check if _make_label can run
    """
    generator = MultiPatchSumDataset('', num_samples=1, length=length, color_list=color_list, debug=False)
    img, gt_masks = generator._make_image(None)
    label = generator._make_label(img)
    assert isinstance(label, int)
