import cv2
import numpy as np


def get_contour_points(pos, origin, size=20):
    """
    Calculate contour points for agent visualization arrow
    
    Args:
        pos: Agent position (x, y, orientation)
        origin: Origin offset coordinates
        size: Size of the arrow
        
    Returns:
        np.ndarray: Array of 4 contour points
    """
    x, y, o = pos
    pt1 = (int(x) + origin[0],
           int(y) + origin[1])
    pt2 = (int(x + size / 1.5 * np.cos(o + np.pi * 4 / 3)) + origin[0],
           int(y + size / 1.5 * np.sin(o + np.pi * 4 / 3)) + origin[1])
    pt3 = (int(x + size * np.cos(o)) + origin[0],
           int(y + size * np.sin(o)) + origin[1])
    pt4 = (int(x + size / 1.5 * np.cos(o - np.pi * 4 / 3)) + origin[0],
           int(y + size / 1.5 * np.sin(o - np.pi * 4 / 3)) + origin[1])

    return np.array([pt1, pt2, pt3, pt4])


def draw_line(start, end, mat, steps=25, w=1):
    """
    Draw a line between two points on a matrix
    
    Args:
        start: Starting coordinates (x,y)
        end: Ending coordinates (x,y)
        mat: Target matrix to draw on
        steps: Number of interpolation steps
        w: Line width
        
    Returns:
        np.ndarray: Matrix with drawn line
    """
    for i in range(steps + 1):
        x = int(np.rint(start[0] + (end[0] - start[0]) * i / steps))
        y = int(np.rint(start[1] + (end[1] - start[1]) * i / steps))
        mat[x - w:x + w, y - w:y + w] = 1
    return mat


def init_vis_image(goal_name, rgb_width, rgb_height, sem_map_size):
    """
    Initialize visualization image with dynamic layout
    
    Args:
        goal_name: Name of the target object
        rgb_width: Width of RGB observation
        rgb_height: Height of RGB observation  
        sem_map_size: Size of semantic map (square)
        
    Returns:
        np.ndarray: Initialized visualization image canvas
    """
    # Get layout parameters
    layout = calculate_vis_layout(rgb_width, rgb_height, sem_map_size)

    # Create canvas
    vis_image = np.ones(layout['canvas_size']).astype(np.uint8) * 255
    font = cv2.FONT_HERSHEY_SIMPLEX
    fontScale = 1
    color = (20, 20, 20)  # BGR
    thickness = 2

    # Draw RGB section title
    text = "Observations (Goal: {})".format(goal_name)
    textsize = cv2.getTextSize(text, font, fontScale, thickness)[0]
    rgb_text_x, rgb_text_y = layout['rgb_text_pos']
    textX = rgb_text_x - textsize[0] // 2
    textY = rgb_text_y + textsize[1] // 2
    vis_image = cv2.putText(vis_image, text, (textX, textY),
                            font, fontScale, color, thickness,
                            cv2.LINE_AA)

    # Draw semantic map section title
    text = "Predicted Semantic Map"
    textsize = cv2.getTextSize(text, font, fontScale, thickness)[0]
    sem_text_x, sem_text_y = layout['sem_text_pos']
    textX = sem_text_x - textsize[0] // 2
    textY = sem_text_y + textsize[1] // 2
    vis_image = cv2.putText(vis_image, text, (textX, textY),
                            font, fontScale, color, thickness,
                            cv2.LINE_AA)

    # Draw borders
    outline_color = [100, 100, 100]
    rgb_y1, rgb_y2, rgb_x1, rgb_x2 = layout['rgb_region']
    sem_y1, sem_y2, sem_x1, sem_x2 = layout['sem_region']

    # RGB region borders
    vis_image[rgb_y1-1, rgb_x1-1:rgb_x2+1] = outline_color  # Top border
    vis_image[rgb_y2, rgb_x1-1:rgb_x2+1] = outline_color    # Bottom border
    vis_image[rgb_y1-1:rgb_y2+1, rgb_x1-1] = outline_color  # Left border
    vis_image[rgb_y1-1:rgb_y2+1, rgb_x2] = outline_color    # Right border

    # Semantic map region borders
    vis_image[sem_y1-1, sem_x1-1:sem_x2+1] = outline_color  # Top border
    vis_image[sem_y2, sem_x1-1:sem_x2+1] = outline_color    # Bottom border
    vis_image[sem_y1-1:sem_y2+1, sem_x1-1] = outline_color  # Left border
    vis_image[sem_y1-1:sem_y2+1, sem_x2] = outline_color    # Right border

    return vis_image


def calculate_vis_layout(rgb_width, rgb_height, sem_map_size):
    """
    Dynamically calculate visualization layout parameters
    
    Args:
        rgb_width: Width of RGB observation
        rgb_height: Height of RGB observation
        sem_map_size: Size of semantic map (square)
        
    Returns:
        dict: Dictionary containing all layout parameters
    """
    margin = 15  # Outer margin
    gap = 15     # Gap between images
    top_margin = 50  # Top title area
    bottom_margin = 20  # Bottom margin

    # Calculate total canvas size
    canvas_width = margin + rgb_width + gap + sem_map_size + margin
    canvas_height = top_margin + max(rgb_height, sem_map_size) + bottom_margin

    # RGB image region
    rgb_x1 = margin
    rgb_x2 = margin + rgb_width
    rgb_y1 = top_margin
    rgb_y2 = top_margin + rgb_height

    # Semantic map region
    sem_x1 = rgb_x2 + gap
    sem_x2 = sem_x1 + sem_map_size
    sem_y1 = top_margin
    sem_y2 = top_margin + sem_map_size

    return {
        'canvas_size': (canvas_height, canvas_width, 3),
        'rgb_region': (rgb_y1, rgb_y2, rgb_x1, rgb_x2),
        'sem_region': (sem_y1, sem_y2, sem_x1, sem_x2),
        'arrow_origin': (sem_x1, sem_y1),
        'rgb_text_pos': (rgb_x1 + rgb_width // 2, top_margin // 2),
        'sem_text_pos': (sem_x1 + sem_map_size // 2, top_margin // 2)
    }