import numpy as np
import cv2
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import h5py
import hdf5plugin
import os
from tqdm import tqdm
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


def train_kmeans(X, n_clusters=5, random_state=None, init_method="k-means++", n_init="auto"):
    """Trains a K-means clustering model on the input data with optional scaling.

    Args:
        X (numpy.ndarray): Input data matrix with samples as rows and features as columns.
        n_clusters (int, optional): Number of clusters to form. Defaults to 5.
        random_state (int or None, optional): Random seed for reproducibility. Defaults to None.
        init_method (str, optional): Initialization method for cluster centers. Defaults to "k-means++".
        n_init (str or int, optional): Number of initializations to perform. Defaults to "auto".

    Returns:
        tuple: A tuple containing cluster sizes and original-scale cluster centers.
            - cluster_sizes (numpy.ndarray): Array of integers representing the size of each cluster.
            - centers_original (numpy.ndarray): Cluster centers in the original data scale.
    """
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    kmeans = KMeans(
        n_clusters=n_clusters,
        init=init_method,
        n_init=n_init,
        max_iter=300,
        random_state=random_state
    )

    kmeans.fit(X_scaled)

    labels = kmeans.labels_
    centers = kmeans.cluster_centers_
    centers_original = scaler.inverse_transform(centers)
    cluster_sizes = np.bincount(labels)
    return cluster_sizes, centers_original



def split_areas(mask):
    """Splits a mask into separate regions based on external contours.

    Args:
        mask (numpy.ndarray): Input binary mask where regions of interest have non-zero values.

    Returns:
        list: List of numpy arrays, each representing a separate region from the original mask.
            Regions are sorted by area in descending order.
    """
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key=lambda c: cv2.contourArea(c), reverse=True)
    sub_masks = []
    for contour in contours:
        instance_mask = np.zeros_like(mask, dtype=np.uint8)
        cv2.drawContours(instance_mask, [contour], -1, 255, -1)
        sub_masks.append(instance_mask)
    return sub_masks



def get_alpha_box(alpha, margin=0):
    """Calculates the bounding box coordinates around non-zero pixels in an alpha mask.

    Args:
        alpha (numpy.ndarray): Input alpha mask with non-zero values indicating regions of interest.
        margin (int, optional): Additional margin to add around the calculated bounding box. Defaults to 0.

    Returns:
        tuple: Bounding box coordinates (minx, miny, maxx, maxy) as integers.
    """
    contours, _ = cv2.findContours(alpha, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    miny, minx = alpha.shape[:2]
    maxx, maxy = 0, 0
    for contour in contours:
        if cv2.contourArea(contour) <= 3:
            continue
        minx = max(min(contour[..., 0].min() - margin, minx), 0)
        miny = max(min(contour[..., 1].min() - margin, miny), 0)
        maxx = min(max(contour[..., 0].max() + margin, maxx), alpha.shape[1])
        maxy = min(max(contour[..., 1].max() + margin, maxy), alpha.shape[0])
    return minx, miny, maxx, maxy



def prob2logits(prob, eps=1e-5):
    """Converts probability values to logit values using a sigmoid-like transformation.

    Args:
        prob (numpy.ndarray or float): Probability values in the range [0, 1].
        eps (float, optional): Small epsilon value to prevent division by zero. Defaults to 1e-5.

    Returns:
        numpy.ndarray or float: Converted logit values.
    """
    logits = - np.log(1 / (prob + eps) - 1 + eps)
    return logits



def get_points_logits(mask, num_points, input_point=None, logits=None):
    """Generates point coordinates and logits from a mask for model prediction.

    Args:
        mask (numpy.ndarray): Binary mask indicating regions of interest.
        num_points (int): Number of points to generate for the mask.
        input_point (numpy.ndarray or None, optional): Pre-defined point coordinates.
            If None, points are determined using K-means clustering. Defaults to None.
        logits (numpy.ndarray or None, optional): Pre-defined logits.
            If None, logits are calculated from the mask. Defaults to None.

    Returns:
        tuple: A tuple containing logits and point coordinates.
            - logits (numpy.ndarray): Logit values derived from the mask.
            - input_point (numpy.ndarray): Coordinates of points within the mask regions.
    """
    np_points = np.stack(np.where(mask > 0)[::-1], -1)
    if input_point is None:
        _, input_point = train_kmeans(np_points, num_points, random_state=0)
    input_mask = cv2.resize(mask, (256, 256))
    input_mask = input_mask.astype(float)[None] / 255
    if logits is None:
        logits = prob2logits(input_mask)
    return logits, input_point



def remove_mask_noise_points(mask, size=3):
    """Removes noise from a mask using erosion and dilation operations.

    Args:
        mask (numpy.ndarray): Input mask with values in [0, 1] range.
        size (int, optional): Size parameter for the structuring element. Defaults to 3.

    Returns:
        numpy.ndarray: Noise-reduced mask with values in [0, 1] range.
    """
    mask = (mask * 255).astype(np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size * 2 + 1, size * 2 + 1))
    eroded = cv2.erode(mask, kernel)
    dilated = cv2.dilate(eroded, kernel)
    dilated = cv2.dilate(dilated, kernel)
    eroded = cv2.erode(dilated, kernel)
    mask = eroded.astype(float) / 255
    return mask


class SamRefiner:
    """Refines segmentation masks using the SAM (Segment Anything Model) architecture.

    This class provides functionality to load a SAM2 model and use it to refine
    segmentation masks iteratively, improving their quality and accuracy.
    Sets up the SAM2 model with default configuration and loads pre-trained weights.
    Automatically selects CUDA device if available, otherwise uses CPU.
    """

    def __init__(self):
        checkpoint = "./sam2.1_hiera_large.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
        device = "cuda" if torch.cuda.is_available() else "cpu"
        sam2 = build_sam2(model_cfg, checkpoint, device=device)
        model = SAM2ImagePredictor(sam2)
        self.model = model
        self.masks = None

    def set_image(self, image):
        """Sets the image for the SAM model to process.

        Args:
            image (numpy.ndarray): Input image to be processed by the SAM model.
        """
        self.model.set_image(image)

    def run(self, mask, max_iter=15, thresh=0.99, num_points=10, dilate_size=3):
        """Runs iterative refinement on a mask using the SAM model.

        Args:
            mask (numpy.ndarray): Initial mask to refine.
            max_iter (int, optional): Maximum number of refinement iterations. Defaults to 15.
            thresh (float, optional): IoU threshold to stop refinement early. Defaults to 0.99.
            num_points (int, optional): Number of points to sample from the mask. Defaults to 10.
            dilate_size (int, optional): Size parameter for noise removal. Defaults to 3.

        Returns:
            numpy.ndarray: Refined mask after iterative processing.
        """
        logits, input_point = get_points_logits(mask, num_points)
        last_mask = np.zeros((1,) + mask.shape[:2])
        for i in range(max_iter):
            # np.random.shuffle(np_points)
            input_label = np.ones(len(input_point))
            mask_draw = np.zeros_like(mask, dtype=np.uint8)
            for x, y in input_point.astype(int):
                cv2.circle(mask_draw, (int(x), int(y)), 2, 255, -1)
            masks, scores, logits = self.model.predict(
                point_coords=input_point,
                point_labels=input_label,
                # box=input_box,
                mask_input=logits,
                multimask_output=False,
            )
            inter = np.logical_and(last_mask > 0.5, masks > 0.5).sum()
            union = np.logical_or(last_mask > 0.5, masks > 0.5).sum()
            iou = inter / union
            # print(f"iou={iou}")
            # cv2.imshow("mask_point", (masks[0] * 255).astype(np.uint8))
            if dilate_size > 0:
                masks[0] = remove_mask_noise_points(masks[0], size=dilate_size)
            logits, input_point = get_points_logits(masks[0], num_points, input_point, logits)
            # cv2.imshow("point", mask_draw)
            # cv2.imshow("mask_point_remove_noise", (masks[0] * 255).astype(np.uint8))
            # cv2.waitKey(0)
            if iou > thresh:
                break
            last_mask = masks
        return last_mask[0]

    def refine(self, image, mask, margin=20):
        """Refines a segmentation mask by processing each detected region separately.

        Args:
            image (numpy.ndarray): Input image corresponding to the mask.
            mask (numpy.ndarray): Segmentation mask to refine.
            margin (int, optional): Margin size around each detected region. Defaults to 20.

        Returns:
            numpy.ndarray: Final refined mask with values in [0, 255] range.
        """
        height, width = image.shape[:2]
        mask_list = split_areas(mask)

        mask_final = np.zeros_like(mask, dtype=int)
        total_area = mask.size
        for sub_mask in mask_list:
            if (sub_mask > 1).sum() / total_area < 0.002:
                mask_final = mask_final + (sub_mask > 1).astype(int)
                continue
            minx, miny, maxx, maxy = get_alpha_box(sub_mask, 10)
            margin_x = min(int((maxx - minx) * 0.1), margin)
            margin_y = min(int((maxy - miny) * 0.1), margin)
            x1 = max(minx - margin_x, 0)
            y1 = max(miny - margin_y, 0)
            x2 = min(maxx + margin_x, width)
            y2 = min(maxy + margin_y, height)
            # print(f"minx, miny, maxx, maxy: {minx, miny, maxx, maxy}")
            # print(f"x1, y1, x2, y2: {x1, y1, x2, y2}")
            crop_image = image[y1: y2, x1: x2].copy()
            crop_mask = sub_mask[y1: y2, x1: x2].copy()
            # print(f"crop_image: {crop_image.shape}")
            self.set_image(crop_image)
            # print(f"set_image耗时:{round(time.time() - start, 1)}s")
            if min(crop_image.shape[:2]) < 200:
                dilate_size = 0
            else:
                dilate_size = 3
            crop_mask_new = self.run(crop_mask, max_iter=10, thresh=0.999, num_points=20, dilate_size=dilate_size)
            mask_final[y1: y2, x1: x2] = mask_final[y1: y2, x1: x2] + crop_mask_new
            crop_mask_new = np.clip(crop_mask_new * 255, 0, 255).astype(np.uint8)
            # cv2.imshow("crop_mask_new", crop_mask_new)
            # cv2.imshow("crop_image_new", draw(crop_image, crop_mask_new))
            # cv2.waitKey()
        mask_final = ((mask_final > 0) * 255).astype(np.uint8)
        return mask_final



def draw(image, mask):
    """Draws contours of a mask onto an image for visualization.

    Args:
        image (numpy.ndarray): Input image to draw on.
        mask (numpy.ndarray): Mask containing contours to visualize.

    Returns:
        numpy.ndarray: Image with drawn contours overlaid.
    """
    image_draw = image.copy()
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(image_draw, contours, -1, (0, 255, 0), 1)
    return image_draw


if __name__ == "__main__":
    sam = SamRefiner()
    hdf5_read_path = "./image_masks.h5"
    clevel = 6

    with h5py.File(hdf5_read_path, 'a') as h5_file:
        klist = sorted(list(h5_file.keys()))
        # random.seed(2025)
        # random.shuffle(klist)
        for group_name in tqdm(klist):
            # group_name = "65"
            group = h5_file[group_name]
            # print(group_name, group.attrs["width"], group.attrs["height"])
            # print(list(group["image/image"].attrs.items()))
            if "mode" in group["image/image"].attrs:
                mode = group["image/image"].attrs["mode"]
            else:
                mode = "BGR"
            image = group["image"]["image"][:]
            if mode == "RGB":
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            group_mask = group["mask"]
            if "mask_refine" not in group:
                group_mask_refine = group.create_group("mask_refine")
            else:
                try:
                    group_mask_refine = group["mask_refine"]
                except BaseException as e:
                    print(f"{group_name} mask_refine not exist")
                    del group["mask_refine"]
                    group_mask_refine = group.create_group("mask_refine")
            if len(group_mask) == 0:
                print(f"{group_name} is empty")
            for mask_caption in group_mask:
                # print("mask_caption:", mask_caption)
                mask_caption = mask_caption.lower()
                if mask_caption in group_mask_refine:
                    continue
                if isinstance(group_mask[mask_caption], h5py.Dataset):
                    mask = group_mask[mask_caption][:]
                    try:
                        mask_refine = sam.refine(image, mask)
                    except RuntimeError:
                        print(f"{group_name} {mask_caption} refine failed")
                        mask_refine = np.array(mask).copy()
                    # if mask_caption in group_mask_refine:
                    #     continue
                    group_mask_refine.create_dataset(
                        mask_caption,
                        data=mask_refine,
                        **hdf5plugin.Zstd(clevel=clevel),
                        chunks=True,
                    )
                    # mask_compare = np.concatenate([mask, mask_refine], 1)
                    # image_compare = np.concatenate(([draw(image, mask), draw(image, mask_refine)]), 1)
                    # cv2.imwrite("mask_compare.png", mask_compare)
                    # cv2.imwrite("image_compare.png", image_compare)
                    # exit()
