from typing import Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np

from ..base_dataset_generator import RectangularDatasetGenerator


class ModuloDatasetGenerator(RectangularDatasetGenerator):

    def __init__(
            self,
            root_path: str,
            num_samples: int,
            num_classes: int,
            bg_generator: Dict = dict(  # noqa
                type='uniform', channels=3, base_color=(0, 0, 0)),
            center: Optional[Tuple[int, int]] = (112, 112),
            length: int = 50,
            debug: bool = False,
    ) -> None:
        # make image first, and then make label
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=num_classes,
            make_img_first=True,
            bg_generator=bg_generator,
            center=center,
            length=length,
            debug=debug,
        )

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        img, rr, cc, gt_mask = self._get_rectangle_image()

        random_pixel_value = np.random.randint(2, size=(self.length, self.length))
        for i in range(3):
            img[rr, cc, i] = random_pixel_value * 255

            # this is a fast fix to preserve the current interface and functionalities
            gt_mask[rr, cc] = random_pixel_value * 255

        if self.debug:
            print("random pixel value", random_pixel_value)
            plt.imshow(img)
            plt.show()
        return img, gt_mask

    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        """ this dataset use the modulo of summed features as class """
        sum_val = int(np.sum(img.astype(np.float32)).item())

        # first division changes to 0, 1 bits, second division finds modulo result
        sum_val //= 255 * 3
        label = sum_val % self.num_classes

        if self.debug:
            print("label for this image is: ", label)
        return label
