import enum
from copy import deepcopy

import numpy as np
from skimage import img_as_ubyte
from skimage.transform import rescale, resize
try:
    from detectron2 import model_zoo
    from detectron2.config import get_cfg
    from detectron2.engine import DefaultPredictor
    DETECTRON_INSTALLED = True
except:
    print("Detectron v2 is not installed")
    DETECTRON_INSTALLED = False

from .countless.countless2d import zero_corrected_countless


class ObjectMask():
    def __init__(self, mask):
        self.height, self.width = mask.shape
        (self.up, self.down), (self.left, self.right) = self._get_limits(mask)
        self.mask = mask[self.up:self.down, self.left:self.right].copy()

    @staticmethod
    def _get_limits(mask):
        def indicator_limits(indicator):
            lower = indicator.argmax()
            upper = len(indicator) - indicator[::-1].argmax()
            return lower, upper

        vertical_indicator = mask.any(axis=1)
        vertical_limits = indicator_limits(vertical_indicator)

        horizontal_indicator = mask.any(axis=0)
        horizontal_limits = indicator_limits(horizontal_indicator)

        return vertical_limits, horizontal_limits

    def _clean(self):
        self.up, self.down, self.left, self.right = 0, 0, 0, 0
        self.mask = np.empty((0, 0))

    def horizontal_flip(self, inplace=False):
        if not inplace:
            flipped = deepcopy(self)
            return flipped.horizontal_flip(inplace=True)

        self.mask = self.mask[:, ::-1]
        return self

    def vertical_flip(self, inplace=False):
        if not inplace:
            flipped = deepcopy(self)
            return flipped.vertical_flip(inplace=True)

        self.mask = self.mask[::-1, :]
        return self

    def image_center(self):
        y_center = self.up + (self.down - self.up) / 2
        x_center = self.left + (self.right - self.left) / 2
        return y_center, x_center

    def rescale(self, scaling_factor, inplace=False):
        if not inplace:
            scaled = deepcopy(self)
            return scaled.rescale(scaling_factor, inplace=True)

        scaled_mask = rescale(self.mask.astype(float), scaling_factor, order=0) > 0.5
        (up, down), (left, right) = self._get_limits(scaled_mask)
        self.mask = scaled_mask[up:down, left:right]

        y_center, x_center = self.image_center()
        mask_height, mask_width = self.mask.shape
        self.up = int(round(y_center - mask_height / 2))
        self.down = self.up + mask_height
        self.left = int(round(x_center - mask_width / 2))
        self.right = self.left + mask_width
        return self

    def crop_to_canvas(self, vertical=True, horizontal=True, inplace=False):
        if not inplace:
            cropped = deepcopy(self)
            cropped.crop_to_canvas(vertical=vertical, horizontal=horizontal, inplace=True)
            return cropped

        if vertical:
            if self.up >= self.height or self.down <= 0:
                self._clean()
            else:
                cut_up, cut_down = max(-self.up, 0), max(self.down - self.height, 0)
                if cut_up != 0:
                    self.mask = self.mask[cut_up:]
                    self.up = 0
                if cut_down != 0:
                    self.mask = self.mask[:-cut_down]
                    self.down = self.height

        if horizontal:
            if self.left >= self.width or self.right <= 0:
                self._clean()
            else:
                cut_left, cut_right = max(-self.left, 0), max(self.right - self.width, 0)
                if cut_left != 0:
                    self.mask = self.mask[:, cut_left:]
                    self.left = 0
                if cut_right != 0:
                    self.mask = self.mask[:, :-cut_right]
                    self.right = self.width

        return self

    def restore_full_mask(self, allow_crop=False):
        cropped = self.crop_to_canvas(inplace=allow_crop)
        mask = np.zeros((cropped.height, cropped.width), dtype=bool)
        mask[cropped.up:cropped.down, cropped.left:cropped.right] = cropped.mask
        return mask

    def shift(self, vertical=0, horizontal=0, inplace=False):
        if not inplace:
            shifted = deepcopy(self)
            return shifted.shift(vertical=vertical, horizontal=horizontal, inplace=True)

        self.up += vertical
        self.down += vertical
        self.left += horizontal
        self.right += horizontal
        return self

    def area(self):
        return self.mask.sum()


class RigidnessMode(enum.Enum):
    soft = 0
    rigid = 1


class SegmentationMask:
    def __init__(self, confidence_threshold=0.5, rigidness_mode=RigidnessMode.rigid,
                 max_object_area=0.3, min_mask_area=0.02, downsample_levels=6, num_variants_per_mask=4,
                 max_mask_intersection=0.5, max_foreground_coverage=0.5, max_foreground_intersection=0.5,
                 max_hidden_area=0.2, max_scale_change=0.25, horizontal_flip=True,
                 max_vertical_shift=0.1, position_shuffle=True):
        """
        :param confidence_threshold: float; threshold for confidence of the panoptic segmentator to allow for
        the instance.
        :param rigidness_mode: RigidnessMode object
            when soft, checks intersection only with the object from which the mask_object was produced
            when rigid, checks intersection with any foreground class object
        :param max_object_area: float; allowed upper bound for to be considered as mask_object.
        :param min_mask_area: float; lower bound for mask to be considered valid
        :param downsample_levels: int; defines width of the resized segmentation to obtain shifted masks;
        :param num_variants_per_mask: int; maximal number of the masks for the same object;
        :param max_mask_intersection: float; maximum allowed area fraction of intersection for 2 masks
        produced by horizontal shift of the same mask_object; higher value -> more diversity
        :param max_foreground_coverage: float; maximum allowed area fraction of intersection for foreground object to be
        covered by mask; lower value -> less the objects are covered
        :param max_foreground_intersection: float; maximum allowed area of intersection for the mask with foreground
        object; lower value -> mask is more on the background than on the objects
        :param max_hidden_area: upper bound on part of the object hidden by shifting object outside the screen area;
        :param max_scale_change: allowed scale change for the mask_object;
        :param horizontal_flip: if horizontal flips are allowed;
        :param max_vertical_shift: amount of vertical movement allowed;
        :param position_shuffle: shuffle
        """

        assert DETECTRON_INSTALLED, 'Cannot use SegmentationMask without detectron2'
        self.cfg = get_cfg()
        self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
        self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
        self.cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold
        self.predictor = DefaultPredictor(self.cfg)

        self.rigidness_mode = RigidnessMode(rigidness_mode)
        self.max_object_area = max_object_area
        self.min_mask_area = min_mask_area
        self.downsample_levels = downsample_levels
        self.num_variants_per_mask = num_variants_per_mask
        self.max_mask_intersection = max_mask_intersection
        self.max_foreground_coverage = max_foreground_coverage
        self.max_foreground_intersection = max_foreground_intersection
        self.max_hidden_area = max_hidden_area
        self.position_shuffle = position_shuffle

        self.max_scale_change = max_scale_change
        self.horizontal_flip = horizontal_flip
        self.max_vertical_shift = max_vertical_shift

    def get_segmentation(self, img):
        im = img_as_ubyte(img)
        panoptic_seg, segment_info = self.predictor(im)["panoptic_seg"]
        return panoptic_seg, segment_info

    @staticmethod
    def _is_power_of_two(n):
        return (n != 0) and (n & (n-1) == 0)

    def identify_candidates(self, panoptic_seg, segments_info):
        potential_mask_ids = []
        for segment in segments_info:
            if not segment["isthing"]:
                continue
            mask = (panoptic_seg == segment["id"]).int().detach().cpu().numpy()
            area = mask.sum().item() / np.prod(panoptic_seg.shape)
            if area >= self.max_object_area:
                continue
            potential_mask_ids.append(segment["id"])
        return potential_mask_ids

    def downsample_mask(self, mask):
        height, width = mask.shape
        if not (self._is_power_of_two(height) and self._is_power_of_two(width)):
            raise ValueError("Image sides are not power of 2.")

        num_iterations = width.bit_length() - 1 - self.downsample_levels
        if num_iterations < 0:
            raise ValueError(f"Width is lower than 2^{self.downsample_levels}.")

        if height.bit_length() - 1 < num_iterations:
            raise ValueError("Height is too low to perform downsampling")

        downsampled = mask
        for _ in range(num_iterations):
            downsampled = zero_corrected_countless(downsampled)

        return downsampled

    def _augmentation_params(self):
        scaling_factor = np.random.uniform(1 - self.max_scale_change, 1 + self.max_scale_change)
        if self.horizontal_flip:
            horizontal_flip = bool(np.random.choice(2))
        else:
            horizontal_flip = False
        vertical_shift = np.random.uniform(-self.max_vertical_shift, self.max_vertical_shift)

        return {
            "scaling_factor": scaling_factor,
            "horizontal_flip": horizontal_flip,
            "vertical_shift": vertical_shift
        }

    def _get_intersection(self, mask_array, mask_object):
        intersection = mask_array[
            mask_object.up:mask_object.down, mask_object.left:mask_object.right
        ] & mask_object.mask
        return intersection

    def _check_masks_intersection(self, aug_mask, total_mask_area, prev_masks):
        for existing_mask in prev_masks:
            intersection_area = self._get_intersection(existing_mask, aug_mask).sum()
            intersection_existing = intersection_area / existing_mask.sum()
            intersection_current = 1 - (aug_mask.area() - intersection_area) / total_mask_area
            if (intersection_existing > self.max_mask_intersection) or \
               (intersection_current > self.max_mask_intersection):
                return False
        return True

    def _check_foreground_intersection(self, aug_mask, foreground):
        for existing_mask in foreground:
            intersection_area = self._get_intersection(existing_mask, aug_mask).sum()
            intersection_existing = intersection_area / existing_mask.sum()
            if intersection_existing > self.max_foreground_coverage:
                return False
            intersection_mask = intersection_area / aug_mask.area()
            if intersection_mask > self.max_foreground_intersection:
                return False
        return True

    def _move_mask(self, mask, foreground):
        # Obtaining properties of the original mask_object:
        orig_mask = ObjectMask(mask)

        chosen_masks = []
        chosen_parameters = []
        # to fix the case when resizing gives mask_object consisting only of False
        scaling_factor_lower_bound = 0.

        for var_idx in range(self.num_variants_per_mask):
            # Obtaining augmentation parameters and applying them to the downscaled mask_object
            augmentation_params = self._augmentation_params()
            augmentation_params["scaling_factor"] = min([
                augmentation_params["scaling_factor"],
                2 * min(orig_mask.up, orig_mask.height - orig_mask.down) / orig_mask.height + 1.,
                2 * min(orig_mask.left, orig_mask.width - orig_mask.right) / orig_mask.width + 1.
            ])
            augmentation_params["scaling_factor"] = max([
                augmentation_params["scaling_factor"], scaling_factor_lower_bound
            ])

            aug_mask = deepcopy(orig_mask)
            aug_mask.rescale(augmentation_params["scaling_factor"], inplace=True)
            if augmentation_params["horizontal_flip"]:
                aug_mask.horizontal_flip(inplace=True)
            total_aug_area = aug_mask.area()
            if total_aug_area == 0:
                scaling_factor_lower_bound = 1.
                continue

            # Fix if the element vertical shift is too strong and shown area is too small:
            vertical_area = aug_mask.mask.sum(axis=1) / total_aug_area  # share of area taken by rows
            # number of rows which are allowed to be hidden from upper and lower parts of image respectively
            max_hidden_up = np.searchsorted(vertical_area.cumsum(), self.max_hidden_area)
            max_hidden_down = np.searchsorted(vertical_area[::-1].cumsum(), self.max_hidden_area)
            # correcting vertical shift, so not too much area will be hidden
            augmentation_params["vertical_shift"] = np.clip(
                augmentation_params["vertical_shift"],
                -(aug_mask.up + max_hidden_up) / aug_mask.height,
                (aug_mask.height - aug_mask.down + max_hidden_down) / aug_mask.height
            )
            # Applying vertical shift:
            vertical_shift = int(round(aug_mask.height * augmentation_params["vertical_shift"]))
            aug_mask.shift(vertical=vertical_shift, inplace=True)
            aug_mask.crop_to_canvas(vertical=True, horizontal=False, inplace=True)

            # Choosing horizontal shift:
            max_hidden_area = self.max_hidden_area - (1 - aug_mask.area() / total_aug_area)
            horizontal_area = aug_mask.mask.sum(axis=0) / total_aug_area
            max_hidden_left = np.searchsorted(horizontal_area.cumsum(), max_hidden_area)
            max_hidden_right = np.searchsorted(horizontal_area[::-1].cumsum(), max_hidden_area)
            allowed_shifts = np.arange(-max_hidden_left, aug_mask.width -
                                      (aug_mask.right - aug_mask.left) + max_hidden_right + 1)
            allowed_shifts = - (aug_mask.left - allowed_shifts)

            if self.position_shuffle:
                np.random.shuffle(allowed_shifts)

            mask_is_found = False
            for horizontal_shift in allowed_shifts:
                aug_mask_left = deepcopy(aug_mask)
                aug_mask_left.shift(horizontal=horizontal_shift, inplace=True)
                aug_mask_left.crop_to_canvas(inplace=True)

                prev_masks = [mask] + chosen_masks
                is_mask_suitable = self._check_masks_intersection(aug_mask_left, total_aug_area, prev_masks) & \
                                   self._check_foreground_intersection(aug_mask_left, foreground)
                if is_mask_suitable:
                    aug_draw = aug_mask_left.restore_full_mask()
                    chosen_masks.append(aug_draw)
                    augmentation_params["horizontal_shift"] = horizontal_shift / aug_mask_left.width
                    chosen_parameters.append(augmentation_params)
                    mask_is_found = True
                    break

            if not mask_is_found:
                break

        return chosen_parameters

    def _prepare_mask(self, mask):
        height, width = mask.shape
        target_width = width if self._is_power_of_two(width) else (1 << width.bit_length())
        target_height = height if self._is_power_of_two(height) else (1 << height.bit_length())

        return resize(mask.astype('float32'), (target_height, target_width), order=0, mode='edge').round().astype('int32')

    def get_masks(self, im, return_panoptic=False):
        panoptic_seg, segments_info = self.get_segmentation(im)
        potential_mask_ids = self.identify_candidates(panoptic_seg, segments_info)

        panoptic_seg_scaled = self._prepare_mask(panoptic_seg.detach().cpu().numpy())
        downsampled = self.downsample_mask(panoptic_seg_scaled)
        scene_objects = []
        for segment in segments_info:
            if not segment["isthing"]:
                continue
            mask = downsampled == segment["id"]
            if not np.any(mask):
                continue
            scene_objects.append(mask)

        mask_set = []
        for mask_id in potential_mask_ids:
            mask = downsampled == mask_id
            if not np.any(mask):
                continue

            if self.rigidness_mode is RigidnessMode.soft:
                foreground = [mask]
            elif self.rigidness_mode is RigidnessMode.rigid:
                foreground = scene_objects
            else:
                raise ValueError(f'Unexpected rigidness_mode: {rigidness_mode}')

            masks_params = self._move_mask(mask, foreground)

            full_mask = ObjectMask((panoptic_seg == mask_id).detach().cpu().numpy())

            for params in masks_params:
                aug_mask = deepcopy(full_mask)
                aug_mask.rescale(params["scaling_factor"], inplace=True)
                if params["horizontal_flip"]:
                    aug_mask.horizontal_flip(inplace=True)

                vertical_shift = int(round(aug_mask.height * params["vertical_shift"]))
                horizontal_shift = int(round(aug_mask.width * params["horizontal_shift"]))
                aug_mask.shift(vertical=vertical_shift, horizontal=horizontal_shift, inplace=True)
                aug_mask = aug_mask.restore_full_mask().astype('uint8')
                if aug_mask.mean() <= self.min_mask_area:
                    continue
                mask_set.append(aug_mask)

        if return_panoptic:
            return mask_set, panoptic_seg.detach().cpu().numpy()
        else:
            return mask_set


def propose_random_square_crop(mask, min_overlap=0.5):
    height, width = mask.shape
    mask_ys, mask_xs = np.where(mask > 0.5)  # mask==0 is known fragment and mask==1 is missing

    if height < width:
        crop_size = height
        obj_left, obj_right = mask_xs.min(), mask_xs.max()
        obj_width = obj_right - obj_left
        left_border = max(0, min(width - crop_size - 1, obj_left + obj_width * min_overlap - crop_size))
        right_border = max(left_border + 1, min(width - crop_size, obj_left + obj_width * min_overlap))
        start_x = np.random.randint(left_border, right_border)
        return start_x, 0, start_x + crop_size, height
    else:
        crop_size = width
        obj_top, obj_bottom = mask_ys.min(), mask_ys.max()
        obj_height = obj_bottom - obj_top
        top_border = max(0, min(height - crop_size - 1, obj_top + obj_height * min_overlap - crop_size))
        bottom_border = max(top_border + 1, min(height - crop_size, obj_top + obj_height * min_overlap))
        start_y = np.random.randint(top_border, bottom_border)
        return 0, start_y, width, start_y + crop_size
