from typing import Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np

from ..base_dataset_generator import RectangularDatasetGenerator


class GaussianDatasetGenerator(RectangularDatasetGenerator):

    def __init__(
            self,
            root_path: str,
            num_samples: int,
            num_classes: int,
            bg_generator: Dict = dict(  # noqa
                type='uniform', channels=3, base_color=(0, 0, 0)),
            center: Optional[Tuple[int, int]] = (112, 112),
            length: int = 50,
            debug: bool = False) -> None:
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=num_classes,
            make_img_first=True,
            bg_generator=bg_generator,
            center=center,
            length=length,
            debug=debug)

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        img, rr, cc, gt_mask = self._get_rectangle_image()

        # make sure the pixel value distributed uniformly in 0-255
        random_pixel_value = np.random.randn(self.length, self.length)

        # shift to 2
        random_pixel_value += 2
        # random_pixel_value = np.clip(random_pixel_value, 0, 4)
        if self.debug:
            print("random pixel value", random_pixel_value)
        img[rr, cc] = random_pixel_value * 64
        if self.debug:
            plt.imshow(img)
            plt.show()
        return img, gt_mask

    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        sum_val = np.sum(img.astype(np.float32) / 64)
        unbiased_uniform_sum_val = (sum_val / (self.length * self.length) - 2)
        label = int(unbiased_uniform_sum_val * 10 * self.num_classes) + self.num_classes // 2
        label = min(self.num_classes - 1, max(0, label))
        if self.debug:
            print("label for this image is: ", label)
        return label