import random
from os import path as osp
from typing import Dict, List, Optional, Tuple

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from skimage.draw import disk, ellipse, polygon, rectangle
from tqdm import tqdm

from sde.utils import mkdir_or_exist
from ..base_dataset_generator import RectangularDatasetGenerator


class MultiPatchSumDataset(RectangularDatasetGenerator):

    def __init__(
        self,
        root_path: str,
        num_samples: int,
        color_list: List[List[int]],
        bg_generator: Dict = dict(  # noqa
            type='uniform', channels=3, base_color=(0, 0, 0)),
        length: int = 50,
        debug: bool = False,
    ) -> None:
        # make image first, and then make label
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=len(color_list),
            make_img_first=True,
            bg_generator=bg_generator,
            length=length,
            debug=debug,
        )
        # get color list, sanity check, and get number of patches
        self.color_list = color_list
        for color in color_list:
            self._check_valid_color(color)
        self.num_patches = len(self.color_list)

        # overwrite center generator to enable multi-center generation
        self.center_generator = self.multi_rectangle_center_generator(
            self.num_patches, patch_length=length, patch_width=length)

    @staticmethod
    def _check_valid_color(color):
        for i in color:
            assert 0 <= i <= 255, f'color value invalid, expected to be in 0-255, get {i}'

    def multi_rectangle_center_generator(self, num_patches, patch_length, patch_width):
        while True:
            yield self._generate_non_overlapping_points(num_patches, patch_length, patch_width)

    def _generate_non_overlapping_points(self, num_patches, length, width, image_width=224, image_height=224):
        points = []
        count, tolerance = 0, 1000

        # return None if reach tolerance
        while len(points) < num_patches and count < tolerance:
            x = random.randint(0, image_width)
            y = random.randint(0, image_height)
            point = (x, y)
            if (0 <= x <= image_width - length) \
                    and (0 <= y <= image_height - width) \
                    and not any(self._rectangle_intersect(point, length, width, p) for p in points):
                points.append(point)
            count += 1
        if count < tolerance:
            return points
        else:
            return None

    @staticmethod
    def _rectangle_intersect(point, length, width, other_point):
        dx = abs(point[0] - other_point[0])
        dy = abs(point[1] - other_point[1])
        return (dx < length and dy < width) or (dx < width and dy < length)

    def _get_multi_patch_image(self) -> Tuple[np.ndarray, List[Dict]]:
        img = self.bg_generator.get_background()
        points = next(self.center_generator)
        patches = []
        for point in points:
            rr, cc = self._get_random_shape(img, point, self.length, self.length)
            gt_mask = self.get_mask(img, rr, cc)
            patches.append({'rr': rr, 'cc': cc})
        # TODO define gt_mask before implement
        return img, patches

    def _get_random_shape(self, img, point, bbox_length, bbox_width) -> Tuple[np.ndarray, ...]:
        # TODO need to test
        random_idx = random.randrange(4)
        if random_idx == 0:
            # get rectangle
            return self._get_rectangle(point, bbox_length, bbox_width)
        elif random_idx == 1:
            # get ellipse
            min_length = min(bbox_length, bbox_width)
            point = (point[0] + min_length // 2, point[1] + min_length // 2)
            r_radius = random.randint(min_length // 4, min_length // 2)
            c_radius = random.randint(min_length // 4, min_length // 2)
            return self._get_ellipse(point, r_radius, c_radius)
        elif random_idx == 2:
            # get triangle
            point_1 = (point[0], point[1] + random.randrange(bbox_width))
            point_2 = (point[0] + random.randrange(bbox_length), point[1])
            random_triangle_vertex_id = random.randrange(2)
            if random_triangle_vertex_id == 0:
                point_3 = (point[0] + random.randrange(bbox_length), point[1] + bbox_width)
            else:
                point_3 = (point[0] + bbox_length, point[1] + random.randrange(bbox_width))
            return self._get_triangle([point_1, point_2, point_3])
        else:
            # get rectangle
            min_length = min(bbox_length, bbox_width)
            point = (point[0] + min_length // 2, point[1] + min_length // 2)
            return self._get_disk(point, min_length // 2, img.shape)

    @staticmethod
    def _get_rectangle(point, length, width) -> Tuple[np.ndarray, ...]:
        return rectangle(point, extent=(length, width))

    @staticmethod
    def _get_disk(point, radius, img_shape) -> Tuple[np.ndarray, ...]:
        return disk(point, radius, shape=img_shape)

    @staticmethod
    def _get_ellipse(point, r_radius, c_radius) -> Tuple[np.ndarray, ...]:
        return ellipse(point[0], point[1], r_radius, c_radius)

    @staticmethod
    def _get_triangle(vertices: List) -> Tuple[np.ndarray, ...]:
        assert len(vertices) == 3
        r = [vertex[0] for vertex in vertices]
        c = [vertex[1] for vertex in vertices]
        return polygon(r, c)

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        img, patches = self._get_multi_patch_image()

        # generate ground truth masks, there are self.num_patches + 1 GT masks
        # for each color, and one overall
        gt_masks = [None for _ in range(self.num_patches + 1)]

        # each color is a class, so randomly assign color to patches
        random_ids = [i for i in range(len(self.color_list))]
        random.shuffle(random_ids)
        for random_id, patch in zip(random_ids, patches):
            rr = patch['rr']
            cc = patch['cc']
            mask = self.get_mask(img, rr, cc)
            random_pixel_value = np.random.randint(2, size=rr.shape)
            for i in range(3):
                # in debug mode (meant for test cases, pixel inside a patch is always 1)
                img[rr, cc, i] = random_pixel_value * self.color_list[random_id][i] + \
                                 (1 - random_pixel_value) * img[rr, cc, i] if not self.debug else 255
                if not self.debug:
                    mask[rr, cc] = random_pixel_value * 255

            # the ordering of GT mask list is the same as self.color_list,
            # and the last one in addition is the overall GT
            gt_masks[random_id] = mask

        # we know masks are not overlapping and binary, so just add
        gt_masks[-1] = np.array(gt_masks[:-1]).sum(0)
        if self.debug:
            plt.imshow(img)
            plt.show()
        return img, np.array(gt_masks).astype(np.uint8)

    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        color_sum_list = []
        for color in self.color_list:
            color_sum_list.append(
                np.logical_and(
                    np.logical_and(img[:, :, 0] == color[0], img[:, :, 1] == color[1]), img[:, :, 2] == color[2]).sum())
        return color_sum_list.index(max(color_sum_list))

    @staticmethod
    def _merge_mask(gt_masks: Optional[np.ndarray], label, **kwargs) -> np.ndarray:
        """
        Given masks, merge them into one mask. For the label index, assign mask value 255,
        for non-label index, assign mask value 127
        """
        merged_mask = gt_masks[-1].copy() // 2
        merged_mask += np.ceil(gt_masks[label] / 2).astype(np.uint8)
        return merged_mask

    def make_dataset(self) -> None:
        """ create synthetic dataset"""

        # create image folder
        mkdir_or_exist(self.root_path)
        for class_label in range(self.num_classes):
            mkdir_or_exist(osp.join(self.root_path, 'images', str(class_label)))
            mkdir_or_exist(osp.join(self.root_path, 'gt_masks', str(class_label)))

        for img_index in tqdm(range(self.num_samples)):
            img, gt_masks, label = self.make_data_point()
            img = Image.fromarray(img)
            img.save(osp.join(self.root_path, 'images', str(label), f"{img_index}.png"))

            # return merged mask for multi color sum
            gt_mask = Image.fromarray(self._merge_mask(gt_masks, label), mode='L')
            gt_mask.save(osp.join(self.root_path, 'gt_masks', str(label), f"{img_index}.png"))
