import random
from typing import Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np

from ..base_dataset_generator import RectangularDatasetGenerator


# TODO rewrite this class later
class SingleColorModuloDatasetGenerator(RectangularDatasetGenerator):

    def __init__(
        self,
        root_path: str,
        num_samples: int,
        num_classes: int,
        bg_generator: Dict = dict(  # noqa
            type='uniform', channels=3, base_color=(0, 0, 0)),
        center: Optional[Tuple[int, int]] = None,
        length: int = 50,
        debug: bool = False,
        color_list: Tuple[Tuple[int]] = ((255, 255, 255),)) -> None:
        # make image first, and then make label
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=num_classes,
            make_img_first=True,
            bg_generator=bg_generator,
            center=center,
            length=length,
            debug=debug)
        self.color_list = color_list
        if len(color_list) > 1:
            self.classes_per_color = [self.num_classes // len(self.color_list) for _ in range(len(self.color_list) - 1)]
            self.classes_per_color.append(
                self.num_classes - self.num_classes // len(self.color_list) * (len(self.color_list) - 1))
        else:
            self.classes_per_color = [self.num_classes]

        # define a member variable to track the color idx of the current image
        self.current_color_idx = None

        # this term is to save the desired sum in _make_image, a fix only in single color patch
        self.desired_sum: int = 0

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        img, rr, cc, gt_mask = self._get_rectangle_image()

        random_pixel_value = np.random.randint(2, size=(self.length, self.length))
        # TODO can be simplified to random.choice(list(range(len(self.color_list))))
        #  or np.random.choice(len(self.color_list))
        self.current_color_idx = random.choice([i for i in range(len(self.color_list))])
        self.desired_sum = random_pixel_value.sum()
        for i in range(3):
            img[rr, cc, i] = random_pixel_value * self.color_list[self.current_color_idx][i] + \
                             (1 - random_pixel_value) * img[rr, cc, i]

            # this is a fast fix to preserve the current interface and functionalities
            gt_mask[rr, cc] = random_pixel_value * 255
        if self.debug:
            print("random pixel value", random_pixel_value)
            print(f"image shape is {img.shape}")
            plt.imshow(img)
            plt.show()
        return img, gt_mask

    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        """ this dataset use the modulo of summed features as class """
        sum_val = self.desired_sum
        assert self.current_color_idx is not None

        # finds modulo result
        label = sum_val % self.classes_per_color[self.current_color_idx]
        label += self.current_color_idx * self.classes_per_color[0]

        if self.debug:
            print("label for this image is: ", label)
        return label

    def make_data_point(self) -> Tuple[np.ndarray, np.ndarray, int]:
        """
        This method overwrites the method in base class
        It passes an additional gt_mask to _make_label
        """
        img, gt_mask = self._make_image(None)
        label = self._make_label(img, gt_mask=gt_mask)
        return img, gt_mask, label
