import random
from typing import Optional, Tuple

import numpy as np
from skimage.draw import disk, rectangle

from ..base_dataset_generator import BaseDatasetGenerator


class CircleRectDatasetGenerator(BaseDatasetGenerator):

    def __init__(
        self,
        root_path: str,
        num_samples: int,
        center: bool,
        non_center: bool,
        length: int = 40,
    ) -> None:
        # first get label, then create image based on the label
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=2,
            make_img_first=False,
            bg_generator=dict(type='uniform', channels=3, base_color=(0, 0, 0)))
        self.length = length
        assert center or non_center, "no object in the image!"
        self.center = center
        self.non_center = non_center

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        """
        Make image that has one rectangular or one circle in the center
        Also, a smaller rect or circle on either top left or lower right
        depending on the class
        """
        img = self.bg_generator.get_background()

        # construct center object
        if self.center:
            center = (112, 112)
            start = (center[0] - self.length // 2, center[1] - self.length // 2)
            if label == 0:
                rr, cc = rectangle(start, extent=(self.length, self.length))
            else:
                rr, cc = disk(center, self.length)
            for i in range(3):
                img[rr, cc, i] = 255

        # construct non-centered object
        if self.non_center:
            length = 10
            if label == 0:
                center = (20, 20)
                start = (center[0] - length // 2, center[1] - length // 2)
                rr, cc = rectangle(start, extent=(length, length))
            else:
                center = (204, 204)
                rr, cc = disk(center, length)
            for i in range(3):
                img[rr, cc, i] = 255

        # return 3D image and 2D gt mask
        return img, img[:, :, 0]

    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        return random.randint(0, 1)
