from typing import Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np

from ..base_dataset_generator import RectangularDatasetGenerator


class UniformDatasetGenerator(RectangularDatasetGenerator):
    """This dataset generated uniform number of each class
    the class is simply determined by sum of values of patches
    a patch has two types of pixels, one is the base pixel,
    to be distinguishable to the background
    then is the value pixel, which is the 1 pixel
    """

    def __init__(
            self,
            root_path: str,
            num_samples: int,
            num_classes: int,
            bg_generator: Dict = dict(type='uniform', channels=3, base_color=(0, 0, 0)),
            center: Optional[Tuple[int, int]] = (112, 112),
            length: int = 50,
            debug: bool = False,
            discrete: bool = True) -> None:
        # make label first, and then make image
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=num_classes,
            bg_generator=bg_generator,
            make_img_first=False,
            center=center,
            length=length,
            debug=debug)

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        img, rr, cc, gt_mask = self._get_rectangle_image()

        # assign based on probability
        number_interval = (
            int(self.length * self.length * (label / self.num_classes)),
            int(self.length * self.length * ((label + 1) / self.num_classes)))
        random_pixel_value = np.random.randn(self.length, self.length).flatten()
        desired_sum = np.random.randint(number_interval[0], number_interval[1])
        rate = desired_sum / 50176
        for idx, value in enumerate(random_pixel_value):
            random_pixel_value[idx] = 1 if value < rate else 0

        # refine the result to make sure the image has exactly the desired sum
        current_sum = random_pixel_value.sum()
        diff = current_sum - desired_sum
        for idx, value in enumerate(random_pixel_value):
            if diff > 0 and value == 1:
                random_pixel_value[idx] = 0
                diff -= 1
            elif diff < 0 and value == 0:
                random_pixel_value[idx] = 1
                diff += 1
            if diff == 0:
                break
        assert random_pixel_value.sum() == desired_sum

        # always shuffle the random vec and reshape
        np.random.shuffle(random_pixel_value)
        random_pixel_value = np.reshape(random_pixel_value, (self.length, -1))
        random_pixel_value[random_pixel_value < 1] = 127

        for i in range(3):
            img[rr, cc, i] = random_pixel_value * 255

        if self.debug:
            print("random pixel value", random_pixel_value)
            plt.imshow(img)
            plt.show()
        return img, gt_mask

    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        label = np.random.randint(self.num_classes)
        return label
