import numpy as np
import cv2
from shapely.geometry import Polygon
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt

def load_data(bbox_file, scores_file, masks_file, image_file):
    bboxes = np.load(bbox_file)
    scores = np.load(scores_file)
    masks = np.load(masks_file)
    image = cv2.imread(image_file)
    return bboxes, scores, masks, image

def calculate_area(bbox):
    return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])

def calculate_mbr(bboxes):
    x_min = np.min(bboxes[:, 0])
    y_min = np.min(bboxes[:, 1])
    x_max = np.max(bboxes[:, 2])
    y_max = np.max(bboxes[:, 3])
    return [x_min, y_min, x_max, y_max]

def calculate_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = calculate_area(box1)
    area2 = calculate_area(box2)

    iou = intersection / (area1 + area2 - intersection)
    conver_ratio = intersection / area2
    return iou, conver_ratio

def filter_bboxes(bboxes, scores, masks):
    areas = np.array([calculate_area(bbox) for bbox in bboxes])
    sorted_indices = np.argsort(areas)[::-1]

    filtered_indices = []
    for i in range(len(sorted_indices)):
        current_bbox = bboxes[sorted_indices[i]]
        remaining_bboxes = bboxes[sorted_indices[i+1:]]

        if len(remaining_bboxes) == 0:
            filtered_indices.append(sorted_indices[i])
            break

        mbr = calculate_mbr(remaining_bboxes)
        iou, cover_ratio = calculate_iou(current_bbox, mbr)
        if cover_ratio <= 0.67:
            filtered_indices.extend(sorted_indices[i:])
            break

    return bboxes[filtered_indices], scores[filtered_indices], masks[filtered_indices], filtered_indices

def mask_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    return intersection / union

def nms_masks(masks, scores, iou_threshold=0.4):
    order = scores.argsort()[::-1]
    keep = []

    while order.size > 0:
        i = order[0]
        keep.append(i)

        ious = np.array([mask_iou(masks[i, 0], masks[j, 0]) for j in order[1:]])
        inds = np.where(ious <= iou_threshold)[0]
        order = order[inds + 1]

    return np.array(keep)

def draw_masks(image, masks, alpha=0.5):
    colored_masks = np.zeros_like(image)
    for i, mask in enumerate(masks):
        color = np.random.randint(0, 255, size=3)
        colored_masks[mask[0] > 0.5] = color

    return cv2.addWeighted(image, 1, colored_masks, alpha, 0)

def suppress_redundant_masks(bboxes, scores, masks):
    filtered_bboxes, filtered_scores, filtered_masks, filtered_indices = filter_bboxes(bboxes, scores, masks)
    keep_indices = nms_masks(filtered_masks, filtered_scores)
    final_masks = filtered_masks[keep_indices]
    final_masks = rectify_intersection(final_masks)
    final_masks = erode_masks(final_masks.squeeze(axis=1))
    # Calculate nms_indices
    nms_indices = np.array(filtered_indices)[keep_indices]
    return final_masks, nms_indices


def filter_overlapping_masks(self, masks, area_threshold=0.90, reverse_area_threshold = 0.20):#thresh was 0.90 and rev was 0.50
    """
    Filters overlapping masks based on the area of overlap.
    Optimized for boolean NumPy arrays.

    Args:
        masks: A list of boolean masks (numpy arrays) representing detected objects.
        area_threshold: A float value between 0 and 1. If the ratio of the
                        intersection area of two masks to the smaller mask's area
                        exceeds this threshold, the larger mask is removed.

    Returns:
        A list of filtered masks.
    """

    sorted_masks = sorted(masks, key=np.sum, reverse=True)

    # 3. Iterate and remove completely overlapping masks
    filtered_masks = []
    filtered_masks.append(sorted_masks.pop(0))
    while sorted_masks:
        largest_mask = sorted_masks.pop(0)

        # Keep track of masks to remove (indices)
        masks_to_remove = []
        # keep track of reverse cover
        reverse_cover = False
        for mask_index, mask in enumerate(filtered_masks):
            # Calculate intersection area
            intersection_area = np.sum(largest_mask & mask)
            # Calculate smaller mask's area
            smaller_area = min(np.sum(largest_mask), np.sum(mask))
            # Calculate larger mask's area
            larger_area = max(np.sum(largest_mask), np.sum(mask))
            # If the overlap is below the threshold, keep the mask
            if intersection_area / smaller_area >= area_threshold:
                if intersection_area / larger_area >= reverse_area_threshold:
                    reverse_cover = True
                else:
                    # Add the index of the mask to remove to the list
                    masks_to_remove.append(mask_index)

        # Remove masks from filtered_masks in reverse order to avoid index shifting
        for index in sorted(masks_to_remove, reverse=True):
            filtered_masks.pop(index)
        if not reverse_cover:
            filtered_masks.append(largest_mask)

    return filtered_masks

def rectify_intersection(masks):
    m, _, height, width = masks.shape
    # Flatten the masks for easier comparison
    flat_masks = masks.reshape(m, height, width)

    for i in range(m):
        for j in range(i + 1, m):
            mask_i = flat_masks[i]
            mask_j = flat_masks[j]

            # Find intersection
            intersection = mask_i & mask_j

            if np.any(intersection):
                # Calculate the area of the masks
                area_i = np.sum(mask_i)
                area_j = np.sum(mask_j)

                # Remove intersection area from the smaller mask
                if area_i < area_j:
                    flat_masks[i] = mask_i & ~intersection
                else:
                    flat_masks[j] = mask_j & ~intersection

    # Reshape the masks back to original shape
    return flat_masks.reshape(m, 1, height, width)

def erode_masks(masks):
    kernel = np.ones((2, 2), np.uint8)
    eroded_masks = np.zeros_like(masks, dtype=np.uint8)  # Specify dtype here
    for i in range(masks.shape[0]):
        eroded_masks[i] = cv2.erode(masks[i].astype(np.uint8), kernel)  # Convert to uint8

    eroded_masks = eroded_masks.astype(bool)
    return eroded_masks

def remove_1pixel_border(masks):
    masks[:, 0, :] = 0  # Top border
    masks[:, -1, :] = 0  # Bottom border
    masks[:, :, 0] = 0  # Left border
    masks[:, :, -1] = 0  # Right border
    return masks
