import os.path as osp
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Optional, Tuple

import numpy as np
from PIL import Image
from skimage.draw import rectangle
from tqdm import tqdm

from ..utils import mkdir_or_exist
from .background import RandomBackgroundGenerator, UniformBackgroundGenerator


class BaseDatasetGenerator(ABC):
    """Base class of a dataset generator.

    Args:
        root_path: The root of the data.
        num_samples: Number of samples to be generated.
        num_classes: Number of classes contained in the dataset.
        make_img_first: If True, first generate the image and gt mask, then assign the
            label. Otherwise, first generate the gt mask, and then generate the image
            and gt mask.

    """

    def __init__(
            self,
            root_path: str,
            num_samples: int,
            num_classes: int,
            make_img_first: bool,
            bg_generator: Dict = dict(type='uniform', channels=3, base_color=(0, 0, 0)),
    ) -> None:
        self.root_path = root_path
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.make_img_first = make_img_first

        bg_generator = deepcopy(bg_generator)
        bg_generator_type = bg_generator.pop('type')
        if bg_generator_type.lower() == 'uniform':
            self.bg_generator = UniformBackgroundGenerator(**bg_generator)
        elif bg_generator_type.lower() == 'random':
            self.bg_generator = RandomBackgroundGenerator(**bg_generator)
        else:
            raise TypeError(f'Unsupported type specifier for background generator: {bg_generator_type}')

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        """Make image and the gt mask.

        Args:
            label: Class label. It is not necessarily used in this method.

        Returns:
            A tuple of image and gt mask.
        """
        pass

    @abstractmethod
    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        """Make label.

        Args:
            img: Synthesized image. It is not necessarily used in this method.

        Returns:
            Class label.
        """
        pass

    def make_data_point(self) -> Tuple[np.ndarray, np.ndarray, int]:
        if self.make_img_first:
            img, gt_mask = self._make_image(None)
            label = self._make_label(img)
        else:
            label = self._make_label(None)
            img, gt_mask = self._make_image(label)
        return img, gt_mask, label

    def make_dataset(self) -> None:
        """ create synthetic dataset"""

        # create image folder
        mkdir_or_exist(self.root_path)
        for class_label in range(self.num_classes):
            mkdir_or_exist(osp.join(self.root_path, 'images', str(class_label)))
            mkdir_or_exist(osp.join(self.root_path, 'gt_masks', str(class_label)))

        for img_index in tqdm(range(self.num_samples)):
            img, gt_mask, label = self.make_data_point()
            img = Image.fromarray(img)
            img.save(osp.join(self.root_path, 'images', str(label), f"{img_index}.png"))

            gt_mask = Image.fromarray(gt_mask, mode='L')
            gt_mask.save(osp.join(self.root_path, 'gt_masks', str(label), f"{img_index}.png"))


class RectangularDatasetGenerator(BaseDatasetGenerator, ABC):

    def __init__(
            self,
            root_path: str,
            num_samples: int,
            num_classes: int,
            make_img_first: bool,
            bg_generator: Dict = dict(  # noqa
                type='uniform', channels=3, base_color=(0, 0, 0)),
            center: Optional[Tuple[int, int]] = (112, 112),
            length: int = 50,
            debug: bool = False,
    ) -> None:
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=num_classes,
            make_img_first=make_img_first,
            bg_generator=bg_generator)
        if center is None:
            self.center_generator = self.random_center_generator((224, 224), patch_length=length)
        else:
            self.center_generator = self.fixed_center_generator(center)
        self.length = length
        self.debug = debug

        # variable to track ground truth mask
        self.current_center = None
        self.ground_truth_mask_list = {}

    @staticmethod
    def random_center_generator(image_size: Tuple[int, int], patch_length: int):
        assert patch_length is not None, "Please provide a valid patch length!"
        while True:
            center_x = np.random.randint(0 + patch_length, image_size[0] - patch_length)
            center_y = np.random.randint(0 + patch_length, image_size[1] - patch_length)
            yield center_x, center_y

    @staticmethod
    def fixed_center_generator(center: Tuple[int, int]):
        while True:
            yield center

    @staticmethod
    def get_mask(img, rr, cc):
        """
        Generate GT mask, input img doesn't need to be a black image
        """
        mask = np.zeros_like(img, dtype=bool)
        mask = mask[:, :, 0]
        mask[rr, cc] = 1
        # convert binary mask to single-channel unit8
        mask = mask.astype(np.uint8) * 255
        return mask

    def _get_rectangle_image(self) -> Tuple[np.ndarray, ...]:
        img = self.bg_generator.get_background()
        center = next(self.center_generator)
        self.current_center = center
        start = (center[0] - self.length // 2, center[1] - self.length // 2)
        rr, cc = rectangle(start, extent=(self.length, self.length))

        gt_mask = self.get_mask(img, rr, cc)
        return img, rr, cc, gt_mask
