import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

import cv2

# Define the color map
digit_to_word_plus = ["#000000", "#0074D9", "#FF4136", "#2ECC40", "#FFDC00", "#AAAAAA", "#F012BE", "#FF851B", "#7FDBFF", "#870C25"]
arc_color_map = LinearSegmentedColormap.from_list(
    name='arc_colors',
    colors=digit_to_word_plus,
)

def generate_input_type_ids_multi(grid, visualize=False):
    """
    Generates input_type_ids for the given grid based on detected object contours.
    Automatically handles multiple object colors.

    Parameters:
    - grid: 2D numpy array representing the grid.
    - visualize: If True, intermediate steps will be visualized.

    Returns:
    - input_type_ids: 2D numpy array with the same shape as grid, where each object
                      is assigned a unique ID starting from 1.
    """
    input_type_ids = np.zeros_like(grid, dtype=np.int32)
    current_object_id = 1  # Start object IDs from 1

    unique_values = np.unique(grid)
    background_value = np.bincount(grid.flatten()).argmax()  # Assume most frequent value is the background
    unique_values = unique_values[unique_values != background_value]  # Exclude background

    for value in unique_values:
        # Convert grid to binary image for the current object value
        binary_image = np.where(grid == value, 255, 0).astype(np.uint8)

        if visualize:
            plt.figure()
            plt.title(f'Binary Image for Object Value {value}')
            plt.imshow(binary_image, cmap='gray')
            plt.show()

        # Find contours
        contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        if visualize:
            # Draw contours on a blank image for visualization
            contour_image = np.zeros_like(binary_image)
            cv2.drawContours(contour_image, contours, -1, (255, 255, 255), thickness=cv2.FILLED)

            plt.figure()
            plt.title(f'Contours for Object Value {value}')
            plt.imshow(contour_image, cmap='gray')
            plt.show()

        # Assign unique IDs to the objects
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            #input_type_ids[y:y+h, x:x+w] = current_object_id
            # Ensure the bounding rectangle does not exceed the grid dimensions
            x_end = min(x + w, input_type_ids.shape[1])
            y_end = min(y + h, input_type_ids.shape[0])

            input_type_ids[y:y_end, x:x_end] = current_object_id
            current_object_id += 1  # Increment the object ID for the next object

    if visualize:
        plt.figure()
        plt.title('Input Type IDs')
        plt.imshow(input_type_ids, cmap='tab20')
        plt.colorbar()
        plt.show()

    return input_type_ids

def paint_grid_with_boxes(grid, input_type_ids):
    """
    Paint the grid with bounding boxes drawn around detected objects.

    Parameters:
    - grid: 2D numpy array representing the grid.
    - input_type_ids: 2D numpy array of the same shape as grid, containing object IDs.
    """
    # Convert the grid to a DataFrame for plotting
    df = pd.DataFrame(grid)

    # Plot the grid
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(df, annot=True, fmt="d", linewidths=.5, xticklabels=False, yticklabels=False, cbar=False, cmap=arc_color_map, vmin=0, vmax=9)

    # Get the unique object IDs (excluding the background, which is usually 0)
    object_ids = np.unique(input_type_ids)
    object_ids = object_ids[object_ids != 0]  # Exclude background

    for obj_id in object_ids:
        # Get the positions of the current object
        positions = np.argwhere(input_type_ids == obj_id)

        # Determine the bounding box for the current object
        min_row, min_col = positions.min(axis=0)
        max_row, max_col = positions.max(axis=0)

        # Add a rectangle (bounding box) to the plot
        rect = patches.Rectangle((min_col, min_row), max_col - min_col + 1, max_row - min_row + 1,
                                 linewidth=2, edgecolor=digit_to_word_plus[obj_id % len(digit_to_word_plus)], facecolor='none')
        ax.add_patch(rect)

    plt.show()

def test_opencv_contour():
    # Example usage
    grid = np.array([
        [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
        [5, 5, 6, 6, 5, 8, 5, 5, 5, 5, 5, 5],
        [5, 5, 6, 8, 8, 8, 5, 8, 5, 5, 5, 5],
        [5, 5, 8, 5, 8, 5, 5, 8, 8, 8, 5, 5],
        [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
        [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
    ])

    input_type_ids = generate_input_type_ids_multi(grid, visualize=True)
    paint_grid_with_boxes(grid, input_type_ids)


