from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from scipy.special import binom
from skimage.draw import polygon2mask

from ..background import RandomBackgroundGenerator
from ..base_dataset_generator import BaseDatasetGenerator


class BezierDatasetGenerator(BaseDatasetGenerator):

    def __init__(
        self,
        root_path: str,
        num_samples: int,
        num_classes: int,
        bg_generator: Dict = dict(type='uniform', channels=3, base_color=(0, 0, 0)),
        grid_size: Tuple[int, int] = (2, 2),
        scale: float = 0.8,
        rad: float = 0.2,
        edgy: float = 0.0,
        max_pos_pixels: int = 4900,
    ) -> None:
        super().__init__(
            root_path=root_path,
            num_samples=num_samples,
            num_classes=num_classes,
            make_img_first=True,
            bg_generator=bg_generator)
        self.grid_size = grid_size
        self.scale = scale
        self.rad = rad
        self.edgy = edgy
        self.max_pos_pixels = max_pos_pixels

        if isinstance(bg_generator, RandomBackgroundGenerator):
            raise TypeError('RandomBackgroundGenerator is not supported.')

    @staticmethod
    def ccw_sort(points: np.ndarray) -> np.ndarray:
        d = points - np.mean(points, axis=0)
        s = np.arctan2(d[:, 0], d[:, 1])
        return points[np.argsort(s), :]

    @staticmethod
    def get_random_points(
            num_points: int,
            tl_loc: Union[np.ndarray, Tuple[int, int]],
            rect_size: Union[np.ndarray, Tuple[int, int]],
            scale: float = 0.8,
            num_trials: int = 0) -> np.ndarray:
        tl_loc = np.asarray(tl_loc)
        min_dst = 0.7 / num_points
        points = np.random.rand(num_points, 2)
        dst = np.sqrt(np.sum(np.diff(BezierDatasetGenerator.ccw_sort(points), axis=0), axis=1)**2)
        if np.all(dst >= min_dst) or num_trials >= 200:
            return tl_loc + points * scale * rect_size
        else:
            return BezierDatasetGenerator.get_random_points(
                num_points, tl_loc, rect_size, scale=scale, num_trials=num_trials + 1)

    @staticmethod
    def bernstein(n: int, k: int, t: np.ndarray) -> np.ndarray:
        return binom(n, k) * t**k * (1. - t)**(n - k)

    @staticmethod
    def bezier(control_points: np.ndarray, num_seg_points: int = 200) -> np.ndarray:
        num_ct_points = len(control_points)
        t = np.linspace(0, 1, num_seg_points)
        curve = np.zeros((num_seg_points, 2))
        for i in range(num_ct_points):
            curve += np.outer(BezierDatasetGenerator.bernstein(num_ct_points - 1, i, t), control_points[i])
        return curve

    @staticmethod
    def get_segment(
            p1: np.ndarray,
            p2: np.ndarray,
            angle1: np.ndarray,
            angle2: np.ndarray,
            rad: float = 0.3,
            num_seg_points: int = 100) -> np.ndarray:
        r = np.sqrt((rad * (p2 - p1)**2).sum())
        control_points = np.zeros((4, 2))
        control_points[0, :] = p1[:]
        control_points[3, :] = p2[:]

        control_points[1, :] = p1 + np.array([r * np.cos(angle1), r * np.sin(angle1)])
        control_points[2, :] = p2 + np.array([r * np.cos(angle2 + np.pi), r * np.sin(angle2 + np.pi)])
        return BezierDatasetGenerator.bezier(control_points, num_seg_points)

    @staticmethod
    def get_bezier_curve(
            points: np.ndarray, rad: float = 0.2, edgy: float = 0, num_seg_points: int = 100) -> np.ndarray:
        p = np.arctan(edgy) / np.pi + .5
        points = BezierDatasetGenerator.ccw_sort(points)
        points = np.append(points, np.atleast_2d(points[0, :]), axis=0)

        d = np.diff(points, axis=0)
        ang = np.arctan2(d[:, 1], d[:, 0])
        ang = (ang >= 0) * ang + (ang < 0) * (ang + 2 * np.pi)
        ang1 = ang
        ang2 = np.roll(ang, 1)
        ang = p * ang1 + (1 - p) * ang2 + (np.abs(ang2 - ang1) > np.pi) * np.pi
        ang = np.append(ang, [ang[0]])

        points = np.append(points, np.atleast_2d(ang).T, axis=1)

        segments = []
        for i in range(len(points) - 1):
            seg = BezierDatasetGenerator.get_segment(
                points[i, :2],
                points[i + 1, :2],
                points[i, 2],
                points[i + 1, 2],
                rad=rad,
                num_seg_points=num_seg_points)
            segments.append(seg)
        curve = np.concatenate(segments)

        return curve

    @staticmethod
    def brezier_grids(
            img_size: Tuple[int, int] = (224, 224),
            grid_size: Tuple[int, int] = (2, 2),
            scale: float = 0.8,
            rad: float = 0.2,
            edgy: float = 0,
            num_ct_points: int = 5,
            num_seg_points: int = 100) -> List[np.ndarray]:
        img_h, img_w = img_size
        num_rows, num_cols = grid_size
        rect_size = np.asarray([img_h // num_rows, img_w // num_cols])

        tl_ys = np.arange(0, img_h - rect_size[0] + 1, rect_size[0])
        tl_xs = np.arange(0, img_w - rect_size[1] + 1, rect_size[1])

        yv, xv = np.meshgrid(tl_ys, tl_xs, indexing='ij')
        yv = yv.flatten()
        xv = xv.flatten()
        tl_locs = np.stack([yv, xv], axis=1)

        curves = []
        for tl_loc in tl_locs:
            points = BezierDatasetGenerator.get_random_points(num_ct_points, tl_loc, rect_size, scale=scale)
            curve = BezierDatasetGenerator.get_bezier_curve(points, rad=rad, edgy=edgy, num_seg_points=num_seg_points)
            br_loc = tl_loc + rect_size
            curve = np.clip(curve, a_min=tl_loc, a_max=br_loc - 1)
            curves.append(curve)

        return curves

    def _make_image(self, label: Optional[int]) -> Tuple[np.ndarray, np.ndarray]:
        img_size = (224, 224)
        curves = self.brezier_grids(
            img_size,
            grid_size=self.grid_size,
            scale=self.scale,
            rad=self.rad,
            edgy=self.edgy,
            num_ct_points=5,
            num_seg_points=100)
        # convert curves into masks and merge them
        mask_list = [polygon2mask(img_size, curve) for curve in curves]
        mask = np.any(np.stack(mask_list, 0), axis=0)

        pos_y, pos_x = np.nonzero(mask)
        if mask.sum() > self.max_pos_pixels:
            # shrink the size of RoI to max_pos_pixels
            selected_inds = np.random.choice(len(pos_y), size=self.max_pos_pixels, replace=False)
            pos_y = pos_y[selected_inds]
            pos_x = pos_x[selected_inds]
            # remake mask
            mask = np.zeros_like(mask, dtype=bool)
            mask[pos_y, pos_x] = 1
            # sanity check
            assert mask.sum() == self.max_pos_pixels
        # select the positive pixels from the RoI.
        num_pos = np.random.randint(0, len(pos_y))
        selected_inds = np.random.choice(len(pos_y), size=num_pos, replace=False)
        pos_y = pos_y[selected_inds]
        pos_x = pos_x[selected_inds]

        img = self.bg_generator.get_background()
        img[pos_y, pos_x] = 255
        # remake mask again
        mask = np.zeros_like(mask, dtype=bool)
        mask[pos_y, pos_x] = 1
        # convert binary mask to single-channel unit8
        mask = mask.astype(np.uint8) * 255
        return img, mask

    def _make_label(self, img: Optional[np.ndarray], **kwargs) -> int:
        if img.ndim == 2:
            num_channels = 1
        elif img.ndim == 3:
            num_channels = img.shape[2]
        else:
            raise ValueError(f'Invalid image shape: {img.shape}')

        sum_val = img.sum()
        sum_val = int(sum_val // (num_channels * 255))
        return sum_val % self.num_classes
