
import matplotlib.pyplot as plt
import numpy as np

from slot_attention.model.model_utils import to_rgb_from_np_array
# def visualize_segmentation_masks(img, true_masks):
#     """
#     Visualizes segmentation masks on top of an input image and returns the Figure object.

#     Args:
#         img (numpy.ndarray): Input image with shape (C, H, W), where C is the number of channels, and H, W are height and width.
#         true_masks (numpy.ndarray): Stack of true segmentation masks with shape (n_masks, H, W).

#     Returns:
#         matplotlib.figure.Figure: The Figure object containing the visualization.
#     """
#     C, H, W = img.shape
#     img = img.transpose(1, 2, 0)
#     n_masks, H, W = true_masks.shape
#     print(f"img.shape: {img.shape}")
#     print(f"true_masks.shape: {true_masks.shape}")

#     # Create a blank canvas for visualization
#     canvas = img.copy()

#     # Assign a unique color to each mask for visualization
#     colors = plt.cm.get_cmap('tab20', n_masks)

#     # Overlay each mask on the canvas with transparency
#     mask_alpha = 0.7  # Adjust the transparency level
#     for i in range(n_masks):
#         mask_color = colors(i)
#         mask = true_masks[i]
#         print(f"mask.shape: {mask.shape}")
#         print(f"mask_color: {mask_color}")
#         if sum(mask.flatten()) == 0:
#             continue
#         # canvas[mask > 0] = (1 - mask_alpha) * canvas[mask > 0] + mask_alpha * (mask_color[:3] * 255)
#         canvas[mask > 0] = (1 - mask_alpha) * canvas[mask > 0]
#         canvas[mask > 0] = mask_alpha * (mask_color[:3] * 255)

#     # Create a figure and display the input image with overlaid masks
#     fig, ax = plt.subplots()
#     ax.imshow(canvas)
#     ax.axis('off')

#     return fig

def visualize_segmentation_masks(img, true_masks):
    """
    Visualizes segmentation masks on top of an input image and returns the Figure object.

    Args:
        img (numpy.ndarray): Input image with shape (H, W, C), where C is the number of channels, and H, W are height and width.
        true_masks (numpy.ndarray): Stack of true segmentation masks with shape (n_masks, H, W).

    Returns:
        matplotlib.figure.Figure: The Figure object containing the visualization.
    """
    mask_channels, H, W = true_masks.shape
    img = to_rgb_from_np_array(img.transpose(1, 2, 0))
    n_masks = 0
    for i_plot in range(mask_channels):
        mask = true_masks[i_plot]
        if sum(mask.flatten()) > 0:
            n_masks += 1
    
    n_masks = 1 if n_masks == 0 else n_masks
        
    # Calculate the number of rows and columns for the subplot grid
    num_rows = int(np.ceil(np.sqrt(n_masks + 1)))
    num_cols = int(np.ceil((n_masks + 1) / num_rows))

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 15))

    # Plot the original image
    axes[0, 0].imshow(img)
    axes[0, 0].set_title("Original Image")

    # Plot each segmentation mask
    i_plot = 1
    for j in range(n_masks):
        mask = true_masks[j]
        sum_mask = sum(mask.flatten())
        if sum_mask == 0:
            continue
        
        # print('i_plot', i_plot)
        # print('j', j)
        # print('sum_mask', sum_mask)
        
        row = i_plot // num_cols
        col = i_plot % num_cols
        
        axes[row, col].imshow(img)
        axes[row, col].imshow(mask, alpha=0.5, cmap='Reds')
        axes[row, col].set_title(f"Mask {i_plot + 1}")
        i_plot += 1

    # Remove empty subplots
    for i_plot in range(n_masks + 1, num_rows * num_cols):
        fig.delaxes(axes.flatten()[i_plot])

    # Adjust subplot spacing
    plt.tight_layout()

    return fig
