import numpy as np
import pytest
from matplotlib import pyplot as plt
from tqdm import tqdm

from sde.dataset_generators import ModuloDatasetGenerator, UniformDatasetGenerator


@pytest.mark.parametrize('mode', ["modulo"])
def test_make_image(mode, tmp_path):
    """ test the class that makes gaussian distributed classes """
    print("make single image")
    print(mode)
    if mode == "modulo":
        generator = ModuloDatasetGenerator(
            str(tmp_path), num_samples=2, num_classes=100, center=(112, 112), length=50, debug=True)
    else:
        raise NotImplementedError

    generator._make_image(label=-1)
    generator.debug = False

    print("make a dataset of images")
    count_dict = {}

    for _ in tqdm(range(50000)):
        _, _, label = generator.make_data_point()
        if label not in count_dict:
            count_dict[label] = 1
        else:
            count_dict[label] += 1
    print("number of created classes", len(count_dict.keys()))
    print("class label and their samples", count_dict)
    plt.bar(count_dict.keys(), count_dict.values())
    plt.title("sample of classes")
    plt.show()


def test_image_center():
    # test random center generator
    n = 0
    old_center = (0, 0)
    while n < 100:
        center = next(ModuloDatasetGenerator.random_center_generator(image_size=(224, 224), patch_length=15))
        assert center[0] != old_center[0]
        assert center[1] != old_center[1]
        n += 1

    # test fixed center generator
    n = 0
    old_center = (153, 168)
    while n < 100:
        center = next(ModuloDatasetGenerator.fixed_center_generator(old_center))
        assert center[0] == old_center[0]
        assert center[1] == old_center[1]
        n += 1


def test_uniform_dataset_generation(tmp_path):
    """ This function tests the random pixel value of uniform generation """
    generator = UniformDatasetGenerator(
        str(tmp_path), num_classes=100, num_samples=10, center=(112, 112), length=50, debug=True)
    generator._make_image(label=0)
    generator._make_image(label=10)
    generator._make_image(label=99)


def test_gt_mask_modulo(tmp_path):
    generator = ModuloDatasetGenerator(
        str(tmp_path), num_samples=10, num_classes=100, center=(112, 112), length=50, debug=False)
    for i in range(10):
        img, gt_mask = generator._make_image(label=None)
        assert np.array_equal(
            gt_mask == 255,
            np.logical_and(np.logical_and(img[:, :, 0] == 255, img[:, :, 1] == 255), img[:, :, 2] == 255))
