import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim

def compute_iou(image1, image2, threshold=128):
    """
    Computes the Intersection over Union (IoU) between two images.

    Args:
        image1 (ndarray): First grayscale or binary image.
        image2 (ndarray): Second grayscale or binary image.
        threshold (int): Threshold to binarize the images.

    Returns:
        float: IoU score.
    """
    assert image1.shape == image2.shape, f"Shape mismatch: {image1.shape} vs {image2.shape}"

    # Binarize images explicitly
    bin1 = image1 > threshold
    bin2 = image2 > threshold

    intersection = np.logical_and(bin1, bin2).sum()
    union = np.logical_or(bin1, bin2).sum()

    print(f"Intersection: {intersection}, Union: {union}")
    return intersection / union if union != 0 else 0


def compute_mse(image1, image2):
    """
    Computes the Mean Squared Error (MSE) between two images.
    """
    assert image1.shape == image2.shape, f"Shape mismatch: {image1.shape} vs {image2.shape}"
    return np.mean((image1.astype(np.float32) - image2.astype(np.float32)) ** 2)

def compute_ssim(image1, image2):
    """
    Computes the Structural Similarity Index (SSIM) between two images.
    """
    assert image1.shape == image2.shape, f"Shape mismatch: {image1.shape} vs {image2.shape}"
    if image1.ndim == 3 and image1.shape[2] == 3:
        gray1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)
    else:
        gray1 = image1
        gray2 = image2
    score, _ = ssim(gray1, gray2, full=True)
    return score

def compute_feature_match_score(image1, image2):
    """
    Computes a feature-based similarity score using ORB keypoints.
    Returns the number of good matches as a rough similarity score.
    """
    # Convert to grayscale if needed
    if image1.ndim == 3:
        image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    if image2.ndim == 3:
        image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)

    orb = cv2.ORB_create(nfeatures=500)
    kp1, des1 = orb.detectAndCompute(image1, None)
    kp2, des2 = orb.detectAndCompute(image2, None)

    if des1 is None or des2 is None:
        return 0  # No features to match

    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    matches = bf.match(des1, des2)

    return len(matches)

def isolate_red_shapes_from_rgb(image_rgb: np.ndarray, background_color=(255, 255, 255)) -> np.ndarray:
    """
    Keeps only red shapes in an RGB image and replaces everything else with a solid background color.

    Args:
        image_rgb (np.ndarray): Input image in RGB format.
        background_color (tuple): RGB tuple for background (default: white).

    Returns:
        np.ndarray: Image with only red parts retained, rest filled with background color.
    """
    if image_rgb is None or not isinstance(image_rgb, np.ndarray):
        raise ValueError("Input must be a valid NumPy image array in RGB format.")

    # Convert RGB to BGR (OpenCV uses BGR internally)
    image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

    # Convert to HSV
    hsv = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2HSV)

    # Define red in HSV (both low and high hue ranges)
    lower_red1 = np.array([0, 100, 100])
    upper_red1 = np.array([10, 255, 255])
    lower_red2 = np.array([160, 100, 100])
    upper_red2 = np.array([180, 255, 255])

    # Combine both red masks
    mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
    red_mask = cv2.bitwise_or(mask1, mask2)

    # Prepare background in RGB
    result_rgb = np.full_like(image_rgb, background_color, dtype=np.uint8)
    result_rgb[red_mask > 0] = image_rgb[red_mask > 0]

    return result_rgb
