import math
import matplotlib
import numpy as np
from home_robot.core.interfaces import Observations
from typing import Any, Dict, List, Optional, Tuple
from sklearn.cluster import DBSCAN
import cv2
import matplotlib.pyplot as plt
import torch
import imageio
from vlfm.utils.img_utils import reorient_rescale_map
import matplotlib.patches as mpatches
import time

def resize_images(images: List[np.ndarray], match_dimension: str = "height", use_max: bool = True, idx_to_use: int = 0) -> List[np.ndarray]:
    """
    Resize images to match either their heights or their widths.

    Args:
        images (List[np.ndarray]): List of NumPy images.
        match_dimension (str): Specify 'height' to match heights, or 'width' to match
            widths.

    Returns:
        List[np.ndarray]: List of resized images.
    """
    if len(images) == 1:
        return images

    if match_dimension == "height":
        new_height = images[idx_to_use].shape[0]
        resized_images = [
            cv2.resize(img, (int(img.shape[1] * new_height / img.shape[0]), new_height)) for img in images
        ]
    elif match_dimension == "width":
        if use_max:
            new_width = max(img.shape[1] for img in images)
        else:
            new_width = min(img.shape[1] for img in images)
        resized_images = [cv2.resize(img, (new_width, int(img.shape[0] * new_width / img.shape[1]))) for img in images]
    else:
        raise ValueError("Invalid 'match_dimension' argument. Use 'height' or 'width'.")

    return resized_images

def apply_mask(semantic_img, masks):
    semantic_img[:,:,0][masks[1]] //= 2
    semantic_img[:,:,1][masks[1]] //= 2
    semantic_img[:,:,1][masks[1]] += 127
    semantic_img[:,:,2][masks[1]] //= 2
    semantic_img[:,:,0][masks[0]] //= 2
    semantic_img[:,:,0][masks[0]] += 127
    semantic_img[:,:,1][masks[0]] //= 2
    semantic_img[:,:,2][masks[0]] //= 2
    semantic_img[:,:,0][masks[2]] //= 2
    semantic_img[:,:,1][masks[2]] //= 2
    semantic_img[:,:,2][masks[2]] //= 2
    semantic_img[:,:,2][masks[2]] += 127

    return semantic_img

def apply_mask_abs(semantic_img, masks):
    semantic_img[:,:,0][masks[0]] = 255
    semantic_img[:,:,1][masks[0]] = 0
    semantic_img[:,:,2][masks[0]] = 0
    semantic_img[:,:,0][masks[1]] = 0
    semantic_img[:,:,1][masks[1]] = 255
    semantic_img[:,:,2][masks[1]] = 0
    semantic_img[:,:,0][masks[2]] = 0
    semantic_img[:,:,1][masks[2]] = 0
    semantic_img[:,:,2][masks[2]] = 255

    return semantic_img


def process_and_write_frame(save_loc, obs, image_dict: dict, text_dict: dict, n_repeats: int=0):
    write_frame(save_loc, obs, image_dict, text_dict)

    if n_repeats>0:
        for i in range(n_repeats):
            write_frame(f'{save_loc[:-4]}_{i+1}.png', obs, image_dict, text_dict)

def write_frame(save_loc, obs, image_dict, text_dict):
    semantic_img = image_dict['rgb'].copy()
    
    if 'instance_map' in obs.task_observations.keys():
        mask_0 = np.full(semantic_img.shape[:2],False)
        mask_1 = np.full(semantic_img.shape[:2],False)
        mask_2 = np.full(semantic_img.shape[:2],False)
        masks = [mask_0, mask_1, mask_2]
        texts = []
        positions = []
        for i_instance, (instance_class, instance_score) in enumerate(zip(obs.task_observations['instance_classes'], obs.task_observations['instance_scores'])):
            if instance_class == 4:
                instance_class = 0  # set misc and others as same class
            if instance_class > 0:
                mask = obs.task_observations['instance_map'] == i_instance
                masks[instance_class-1] += mask

                #get right-most and top-most point
                nz = np.nonzero(mask)
                if len(nz[1])>0:
                    idx = np.argmax(nz[1])
                    i = max(nz[0][idx],80)
                    j = min(nz[1][idx],semantic_img.shape[1]-100)

                    texts += [str(f'{instance_score:.4f}')]
                    positions += [(j,i)]

        if len(texts) > 0:
            semantic_img = apply_mask(semantic_img, masks)
            for i in range(len(texts)):
                cv2.putText(semantic_img, texts[i], positions[i], cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 4, cv2.LINE_AA)
                cv2.putText(semantic_img, texts[i], positions[i], cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)

    else:
        semantic_img = apply_mask(semantic_img, [obs.semantic == obs.task_observations['object_goal'],
        obs.semantic == obs.task_observations['start_recep_goal'], obs.semantic == obs.task_observations['end_recep_goal']])


    image_dict['rgb_mask'] = semantic_img

    keys_row1 = ['depth', 'rgb_mask', '3dgs_seg', '3dgs_instances', '3dgs_uncertainty']
    keys_row2 = ['3dgs_depth', 'third_person', 'occ', 'value']

    row1 = [image_dict[k] for k in keys_row1]
    row1 = resize_images(row1, match_dimension="height", idx_to_use=0)
    row2 = [reorient_rescale_map(image_dict[k]) if k in ['occ', 'value'] else image_dict[k] for k in keys_row2]
    row2 = resize_images(row2, match_dimension="height", idx_to_use=0)
    for i_img, img, label in zip(
        range(len(row1)),
        row1,
        ['Depth', 'Segmented RGB', '3DGS Segmented', '3DGS Instance', 'Uncertainty']
    ):
        color = (255, 255, 255)
        img = cv2.putText(img, label, (10, 45), cv2.FONT_HERSHEY_SIMPLEX, 1.6, color, 2, cv2.LINE_AA)

    for i_img, img, label in zip(
        range(len(row1)),
        row2,
        ['3DGS Depth', '', 'Occupancy', f"Value for {text_dict['value_map']}"]
    ):
        if i_img == 0:
            color = (255, 255, 255)
        else:
            color = (0, 0, 0)
        img = cv2.putText(img, label, (10, 45), cv2.FONT_HERSHEY_SIMPLEX, 1.6, color, 2, cv2.LINE_AA)
    row1c = np.concatenate(row1,axis=1)
    row2c = np.concatenate(row2,axis=1)

    if row1c.shape[1] < row2c.shape[1]:
        row1c = np.concatenate([row1c,np.zeros((row1c.shape[0], row2c.shape[1]-row1c.shape[1], 3),dtype=np.uint8)],axis=1)
    elif row2c.shape[1] < row1c.shape[1]:
        row2c = np.concatenate([row2c,np.zeros((row2c.shape[0], row1c.shape[1]-row2c.shape[1], 3),dtype=np.uint8)],axis=1)
    
    frame = np.concatenate([
        row1c,
        row2c
    ], axis=0)

    frame_shape_0 = frame.shape[0]

    # Add text to the top of the frame
    frame = add_text_to_image(frame, text_dict['gaze_target'], top=True)
    frame = add_text_to_image(frame, text_dict['current_skill'], top=True)
    frame = add_text_to_image(frame, text_dict['instruction'], top=True)

    header_height = frame.shape[0]-frame_shape_0

    # blank_img = (255*np.ones((header_height,frame.shape[1]//2))).astype(np.uint8)

    factor = 40
    fig= plt.figure(figsize=(frame.shape[1]//(2*factor),header_height//factor))
    # plt.imshow(blank_img)
    # plt.gcf().set_facecolor("white")
    plt.axis('off')

    legend_dict = {text_dict['goal_obj'] : (1.0,0.0,0.0), text_dict['start_rec']: (0.0,1.0,0.0), text_dict['end_rec'] : (0.0,0.0,1.0) }
    patchList = []
    for key in legend_dict:
        data_key = mpatches.Patch(color=legend_dict[key], label=key)
        patchList.append(data_key)

    plt.legend(handles=patchList, ncol=3, fontsize=70)

    # plt.savefig('tmp_legend.png', bbox_inches='tight')
    # plt.close()

    plt.tight_layout()

    fig.canvas.draw()
    legend_img = np.array(fig.canvas.renderer.buffer_rgba())

    # legend_img = imageio.imread('tmp_legend.png')
    legend_img = cv2.resize(legend_img[:,:,:3], (frame.shape[1]//2, header_height))

    frame[:header_height, frame.shape[1]-legend_img.shape[1]:, :] = legend_img

    imageio.imwrite(save_loc, frame)

    plt.close()

    # save instance visualization
    # semantics_and_instances = np.concatenate([
    #     image_dict['rgb_mask'],
    #     image_dict['3dgs_seg'],
    #     image_dict['3dgs_instances']
    # ], axis=1)
    # imageio.imwrite(save_loc.replace('combined_', 'instance_'), semantics_and_instances)


#From VLFM -- modify font sizes
def add_text_to_image(image: np.ndarray, text: str, top: bool = False) -> np.ndarray:
    """
    Adds text to the given image.

    Args:
        image (np.ndarray): Input image.
        text (str): Text to be added.
        top (bool, optional): Whether to add the text to the top or bottom of the image.

    Returns:
        np.ndarray: Image with text added.
    """
    width = image.shape[1]
    text_image = generate_text_image(width, text)
    if top:
        combined_image = np.vstack([text_image, image])
    else:
        combined_image = np.vstack([image, text_image])

    return combined_image


def generate_text_image(width: int, text: str) -> np.ndarray:
    """
    Generates an image of the given text with line breaks, honoring given width.

    Args:
        width (int): Width of the image.
        text (str): Text to be drawn.

    Returns:
        np.ndarray: Text drawn on white image with the given width.
    """
    # Define the parameters for the text
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 2.0
    font_thickness = 2
    line_spacing = 20  # Spacing between lines in pixels

    # Calculate the maximum width and height of the text
    text_size, _ = cv2.getTextSize(text, font, font_scale, font_thickness)
    max_width = width//2 #- 20  # Allow some padding
    max_height = text_size[1] + line_spacing

    # Split the text into words
    words = text.split()

    # Initialize variables for text positioning
    x = 10
    y = text_size[1] + 10

    to_draw = []

    # Iterate over the words and add them to the image
    num_rows = 1
    for word in words:
        # Get the size of the word
        word_size, _ = cv2.getTextSize(word, font, font_scale, font_thickness)

        # Check if adding the word exceeds the maximum width
        if x + word_size[0] > max_width:
            # Add a line break before the word
            y += max_height
            x = 10
            num_rows += 1

        # Draw the word on the image
        to_draw.append((word, x, y))

        # Update the position for the next word
        x += word_size[0] + 5  # Add some spacing between words

    # Create a blank white image with the calculated dimensions
    image = 255 * np.ones((max_height * num_rows, width, 3), dtype=np.uint8)
    for word, x, y in to_draw:
        cv2.putText(
            image,
            word,
            (x, y),
            font,
            font_scale,
            (0, 0, 0),
            font_thickness,
            cv2.LINE_AA,
        )

    return image

def generate_times_image(times: np.ndarray) -> np.ndarray:
    colormap = matplotlib.colormaps['viridis']
    times = times.astype(np.float32)
    times = np.ma.masked_values(times, np.inf)
    times = np.ma.filled(times, np.max(times) + 1)
    times /= np.max(times)
    times_image = (colormap(times)[..., :3] * 255).astype(np.uint8)
    return times_image