import io
import time

from PIL import Image
from PIL import ImageColor
from PIL import ImageDraw
from PIL import ImageFont
from absl import logging

from veoplace.utils.constants import FONT_SIZE
from veoplace.utils.constants import GRID_IMG_HEIGHT
from veoplace.utils.constants import GRID_IMG_WIDTH
from veoplace.utils.constants import MARGIN_BOTTOM
from veoplace.utils.constants import MARGIN_LEFT
from veoplace.utils.constants import MARGIN_RIGHT
from veoplace.utils.constants import MARGIN_TOP
from veoplace.utils.constants import MAX_FONT_SIZE
from veoplace.utils.constants import MIN_FONT_SIZE
from veoplace.utils.constants import STROKE_WIDTH


def _compute_grid_img_size(max_width, max_height, grid_img_width, grid_img_height):
    """Compute grid image size preserving aspect ratio of canvas."""
    if max_width <= 0 or max_height <= 0:
        return grid_img_width, grid_img_height
    # Only adjust if using default square size
    if grid_img_width == GRID_IMG_WIDTH and grid_img_height == GRID_IMG_HEIGHT:
        if max_width >= max_height:
            # Canvas is wider than tall - keep width at default, scale height
            grid_img_height = max(1, int(round(GRID_IMG_WIDTH * max_height / max_width)))
        else:
            # Canvas is taller than wide - keep height at default, scale width
            grid_img_width = max(1, int(round(GRID_IMG_HEIGHT * max_width / max_height)))
    return grid_img_width, grid_img_height


def render(node_pos,
        max_width,
        max_height,
        color_config,
        grid,
        grid_img_width=GRID_IMG_WIDTH,
        grid_img_height=GRID_IMG_HEIGHT,
        margin_left=MARGIN_LEFT,
        margin_right=MARGIN_RIGHT,
        margin_top=MARGIN_TOP,
        margin_bottom=MARGIN_BOTTOM,
        return_bytes=True,
        highlight_nodes=None,
        node_name_to_short_name=None,
        # DREAMPlace full node positions (optional, for rendering stdcells)
        dreamplace_node_x=None,
        dreamplace_node_y=None,
        dreamplace_node_size_x=None,
        dreamplace_node_size_y=None,
        dreamplace_num_movable=None,
        dreamplace_num_terminals=None,
        macro_alpha=1.0):
    # Adjust grid image size to preserve canvas aspect ratio
    grid_img_width, grid_img_height = _compute_grid_img_size(
            max_width, max_height, grid_img_width, grid_img_height)

    # Total image size includes the grid area plus the margins
    img_width = grid_img_width + margin_left + margin_right
    img_height = grid_img_height + margin_top + margin_bottom

    # Create a blank white image for the entire canvas
    # Use RGBA if we need transparency for macros
    use_alpha = macro_alpha < 1.0
    img = Image.new('RGBA' if use_alpha else 'RGB', (img_width, img_height), color='white')
    draw = ImageDraw.Draw(img)

    # Scale factors for the grid drawing area only
    scale_x = grid_img_width / max_width
    scale_y = grid_img_height / max_height

    # =========================================================================
    # FIRST: Draw DREAMPlace stdcells/terminals as background (if provided)
    # =========================================================================
    if dreamplace_node_x is not None and dreamplace_num_movable is not None:
        num_nodes = len(dreamplace_node_x)
        num_movable = dreamplace_num_movable
        num_terminals = dreamplace_num_terminals or 0

        # Stdcell color (light blue for movable stdcells)
        stdcell_color = (140, 170, 220)  # medium blue
        # Terminal/IO color (light orange for fixed terminals)
        terminal_color = (100, 100, 100)  # dark gray

        # Draw all nodes (stdcells first, then terminals)
        for i in range(num_nodes):
            x = float(dreamplace_node_x[i])
            y = float(dreamplace_node_y[i])
            w = float(dreamplace_node_size_x[i])
            h = float(dreamplace_node_size_y[i])

            # Skip if size is too small (likely filler or invalid)
            if w < 1 or h < 1:
                continue

            # Determine color based on node type
            if i < num_movable:
                fill_color = stdcell_color
            elif i < num_movable + num_terminals:
                fill_color = terminal_color
            else:
                continue  # Skip terminal_NIs for now

            # Scale to image coordinates
            x_real = margin_left + x * scale_x
            y_real = margin_top + grid_img_height - (y * scale_y) - (h * scale_y)
            w_real = w * scale_x
            h_real = h * scale_y

            # Draw rectangle (no outline for performance with many nodes)
            draw.rectangle(
                    [(x_real, y_real), (x_real + w_real, y_real + h_real)],
                    fill=fill_color
            )

    # Draw grid lines within the grid drawing area
    # Dynamic step: aim for ~32 grid lines regardless of grid size
    # e.g., grid=84 -> step=3 (29 lines), grid=512 -> step=16 (33 lines)
    grid_line_step = max(1, grid // 32)
    cell_width = grid_img_width / grid
    cell_height = grid_img_height / grid
    for i in range(0, grid + 1, grid_line_step):
        # Horizontal grid line: from left margin to (left margin + grid width)
        # Flip y-coordinate: grid_img_height - y_pos instead of y_pos
        y_pos = margin_top + grid_img_height - (i * cell_height)
        draw.line([(margin_left, y_pos), (margin_left + grid_img_width, y_pos)],
                  fill='lightgray', width=2, joint='curve')
        # Vertical grid line: from top of grid area to bottom of grid area
        x_pos = i * cell_width
        draw.line([(margin_left + x_pos, margin_top),
                   (margin_left + x_pos, margin_top + grid_img_height)],
                  fill='lightgray', width=2, joint='curve')

    # Draw numeric labels at grid line positions (same step as grid lines)
    try:
        font = ImageFont.truetype("arial.ttf", FONT_SIZE)  # Regular font
        bold_font = ImageFont.truetype("arialbd.ttf", FONT_SIZE)  # Bold font
    except IOError:
        # Fallback to default font if Arial isn't available
        font = ImageFont.load_default().font_variant(size=FONT_SIZE)
        # Create slightly larger font for "bold" effect
        bold_font = ImageFont.load_default().font_variant(size=FONT_SIZE + 5)

    label_step = grid_line_step  # Match label step to grid line step
    for i in range(0, grid + 1, label_step):
        label = str(i)
        # Get text bounding box for centering
        text_bbox = draw.textbbox((0, 0), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]

        # x-axis label: centered on the grid line, below the grid area
        x_line_pos = margin_left + i * cell_width
        x_label = x_line_pos - text_width / 2
        draw.text((x_label, margin_top + grid_img_height + 5), label, fill='black', font=font)

        # y-axis label: right-aligned in the left margin, centered on grid line
        y_line_pos = margin_top + grid_img_height - i * cell_height
        y_label = y_line_pos - text_height / 2
        x_label_y = margin_left - text_width - 8  # Right-align with padding
        draw.text((x_label_y, y_label), label, fill='black', font=font)

    # Store text rendering information for later
    text_to_render = []

    # FIRST PASS: Draw all macros (rectangles) inside the grid drawing area
    for node_name, (
                nx, ny, grid_size_x, grid_size_y, real_size_x,
                real_size_y) in node_pos.items():
        # With logging debug print out the grid size and real sizes
        # but only once

        logging.debug('Node %s: grid_size=(%d, %d), real_size=(%.2f, %.2f)',
                      node_name, grid_size_x, grid_size_y, real_size_x,
                      real_size_y)

        width_real = real_size_x * scale_x
        height_real = real_size_y * scale_y

        # Convert coordinates: x stays the same, y is flipped
        x_real = margin_left + nx * scale_x
        # Flip y-coordinate to match matplotlib's bottom-left origin
        y_real = margin_top + grid_img_height - (ny * scale_y) - height_real

        # Get color using Pillow's built-in parser for flexibility
        if node_name not in color_config:
            # use gray if node_name not in color_config
            color_str = '#808080'  # Default gray color
        else:
            color_str = color_config[node_name]

        if isinstance(color_str, str):
            rgb_color = ImageColor.getrgb(color_str)
        else:
            rgb_color = color_str

        # Draw the rectangle (with optional transparency)
        if use_alpha:
            # Create RGBA fill color with alpha
            alpha_value = int(255 * macro_alpha)
            rgba_color = (*rgb_color[:3], alpha_value)
            # Draw on overlay and composite for proper transparency
            overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
            overlay_draw = ImageDraw.Draw(overlay)
            overlay_draw.rectangle(
                    [(x_real, y_real),
                     (x_real + width_real, y_real + height_real)],
                    fill=rgba_color,
                    outline='black'
            )
            img = Image.alpha_composite(img, overlay)
            draw = ImageDraw.Draw(img)
        else:
            draw.rectangle(
                    [(x_real, y_real),
                     (x_real + width_real, y_real + height_real)],
                    fill=rgb_color,
                    outline='black'
            )

        # If this node should be highlighted, store information for later drawing
        if highlight_nodes is not None and node_name in highlight_nodes:
            # Check macro area in grid units
            macro_grid_width = real_size_x / (max_width / grid)
            macro_grid_height = real_size_y / (max_height / grid)

            # Only proceed with text if the macro isn't a 1x1 cell
            if macro_grid_width * macro_grid_height > 1:  # Skip 1x1 macros
                # Use short_name if node_name_to_short_name is provided and has the node
                display_name = node_name
                if node_name_to_short_name is not None and node_name in node_name_to_short_name:
                    display_name = node_name_to_short_name[node_name]

                # Store information for later text rendering
                text_to_render.append({
                        'text': display_name,
                        'rect': [(x_real, y_real),
                                 (x_real + width_real, y_real + height_real)],
                        'width': width_real,
                        'height': height_real,
                        'x': x_real,
                        'y': y_real
                })

    # SECOND PASS: Draw all red outlines for highlighted nodes
    if highlight_nodes is not None:
        for info in text_to_render:
            # Draw the red outline around the rectangle
            draw.rectangle(
                    info['rect'],
                    fill=None,  # No fill, just outline
                    outline='red',
                    width=4
            )

    for info in text_to_render:
        display_name = info['text']
        x_real = info['x']
        y_real = info['y']
        width_real = info['width']
        height_real = info['height']

        # First try horizontal text with default font size
        try:
            default_font = ImageFont.truetype("arialbd.ttf", MAX_FONT_SIZE)
        except IOError:
            default_font = ImageFont.load_default().font_variant(
                    size=MAX_FONT_SIZE)

        # Check if default font size fits horizontally
        text_bbox = draw.textbbox((0, 0), display_name, font=default_font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]

        # Available space (use 90% to leave some padding)
        available_width = width_real * 0.9
        available_height = height_real * 0.9

        # If text fits horizontally with default font size, use that
        if text_width <= available_width and text_height <= available_height:
            # Center text in rectangle
            text_x = x_real + (width_real - text_width) / 2
            text_y = y_real + (height_real - text_height) / 2

            # Draw the text
            draw.text((text_x, text_y), display_name, fill='black',
                      font=default_font, stroke_width=STROKE_WIDTH)
        else:
            # Default font doesn't fit, determine best approach

            # Determine if we should try rotating based on macro shape
            rotate_text = width_real < height_real

            if rotate_text:
                # Try vertical orientation
                vertical_avail_width = height_real * 0.9
                vertical_avail_height = width_real * 0.9

                # Find best font size for vertical text
                font_size = MAX_FONT_SIZE
                fits = False

                while font_size >= MIN_FONT_SIZE and not fits:
                    try:
                        test_font = ImageFont.truetype("arialbd.ttf", font_size)
                    except IOError:
                        test_font = ImageFont.load_default().font_variant(
                                size=font_size)

                    text_bbox = draw.textbbox((0, 0), display_name,
                                              font=test_font)
                    text_width = text_bbox[2] - text_bbox[0]
                    text_height = text_bbox[3] - text_bbox[1]

                    # For vertical text, width becomes height and vice versa
                    fits = (
                            text_width <= vertical_avail_width and text_height <= vertical_avail_height)

                    if not fits:
                        font_size -= 1

                # If we found a size that fits or reached minimum font size
                try:
                    actual_font = ImageFont.truetype("arialbd.ttf", font_size)
                except IOError:
                    actual_font = ImageFont.load_default().font_variant(
                            size=font_size)

                # Get final text dimensions
                text_bbox = draw.textbbox((0, 0), display_name,
                                          font=actual_font)
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]

                # Create a temporary image with padding for the rotated text
                padding = 10
                txt_img = Image.new('RGBA', (
                        text_width + padding * 2, text_height + padding * 2),
                                    (255, 255, 255, 0))
                txt_draw = ImageDraw.Draw(txt_img)
                txt_draw.text((padding, padding), display_name, fill='black',
                              font=actual_font, stroke_width=STROKE_WIDTH)

                # Rotate the text image 90 degrees counterclockwise
                rotated_txt = txt_img.rotate(90, expand=True)

                # Center the rotated text in the macro - fixed positioning calculation
                # After rotation: original height becomes width, original width becomes height
                rotated_width = rotated_txt.width
                rotated_height = rotated_txt.height

                text_x = x_real + (
                        width_real - rotated_width) / 2 - padding  # Subtract padding/2 to shift left
                text_y = y_real + (height_real - rotated_height) / 2

                # Paste the rotated text onto the main image
                img.paste(rotated_txt, (int(text_x), int(text_y)), rotated_txt)
            else:
                # Keep horizontal but find best fitting font size
                font_size = MAX_FONT_SIZE
                fits = False

                while font_size >= MIN_FONT_SIZE and not fits:
                    try:
                        test_font = ImageFont.truetype("arialbd.ttf", font_size)
                    except IOError:
                        test_font = ImageFont.load_default().font_variant(
                                size=font_size)

                    text_bbox = draw.textbbox((0, 0), display_name,
                                              font=test_font)
                    text_width = text_bbox[2] - text_bbox[0]
                    text_height = text_bbox[3] - text_bbox[1]

                    fits = (
                            text_width <= available_width and text_height <= available_height)

                    if not fits:
                        font_size -= 1

                # Use the font that fits, or the minimum size if none fit
                try:
                    actual_font = ImageFont.truetype("arialbd.ttf", font_size)
                except IOError:
                    actual_font = ImageFont.load_default().font_variant(
                            size=font_size)

                # Get final text dimensions
                text_bbox = draw.textbbox((0, 0), display_name,
                                          font=actual_font)
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]

                # Center text
                text_x = x_real + (width_real - text_width) / 2
                text_y = y_real + (height_real - text_height) / 2

                # Draw the text
                draw.text((text_x, text_y), display_name, fill='black',
                          font=actual_font, stroke_width=STROKE_WIDTH)
    end_time = time.time()
    # logging.info("Rendered image in %.2f s", end_time - start_time)
    # Convert image to bytes or return the PIL image
    if return_bytes:
        buffer = io.BytesIO()
        img.save(buffer, format="PNG", dpi=(300, 300))
        buffer.seek(0)
        return buffer.getvalue()
    else:
        return img


def render_full_canvas(node_pos,
        max_width,
        max_height,
        color_config,
        grid,
        dreamplace_node_x,
        dreamplace_node_y,
        dreamplace_node_size_x,
        dreamplace_node_size_y,
        dreamplace_num_movable,
        dreamplace_num_terminals,
        grid_img_width=GRID_IMG_WIDTH,
        grid_img_height=GRID_IMG_HEIGHT,
        margin_left=MARGIN_LEFT,
        margin_right=MARGIN_RIGHT,
        margin_top=MARGIN_TOP,
        margin_bottom=MARGIN_BOTTOM,
        return_bytes=True,
        highlight_nodes=None,
        node_name_to_short_name=None,
        macro_alpha=1.0):
    """
    Render full canvas with stdcells. DREAMPlace arrays are required.
    """
    if dreamplace_node_x is None or dreamplace_node_y is None:
        raise ValueError("render_full_canvas requires dreamplace_node_x/y")
    if dreamplace_node_size_x is None or dreamplace_node_size_y is None:
        raise ValueError("render_full_canvas requires dreamplace_node_size_x/y")
    if dreamplace_num_movable is None or dreamplace_num_terminals is None:
        raise ValueError("render_full_canvas requires dreamplace_num_movable/terminals")

    return render(
        node_pos=node_pos,
        max_width=max_width,
        max_height=max_height,
        color_config=color_config,
        grid=grid,
        grid_img_width=grid_img_width,
        grid_img_height=grid_img_height,
        margin_left=margin_left,
        margin_right=margin_right,
        margin_top=margin_top,
        margin_bottom=margin_bottom,
        return_bytes=return_bytes,
        highlight_nodes=highlight_nodes,
        node_name_to_short_name=node_name_to_short_name,
        dreamplace_node_x=dreamplace_node_x,
        dreamplace_node_y=dreamplace_node_y,
        dreamplace_node_size_x=dreamplace_node_size_x,
        dreamplace_node_size_y=dreamplace_node_size_y,
        dreamplace_num_movable=dreamplace_num_movable,
        dreamplace_num_terminals=dreamplace_num_terminals,
        macro_alpha=macro_alpha,
    )


def render_all_suggestions(
        node_pos,
        ratio_x,
        ratio_y,
        max_width,
        max_height,
        color_config,
        grid,
        all_suggestions,
        highlight_nodes=None,
        grid_img_width=GRID_IMG_WIDTH,
        grid_img_height=GRID_IMG_HEIGHT,
        margin_left=MARGIN_LEFT,
        margin_right=MARGIN_RIGHT,
        margin_bottom=MARGIN_BOTTOM,
        node_name_to_short_name=None,
        return_bytes=True,
        macro_alpha=0.7):
    """
    Render multiple suggestions on the same image with colors matching their color groups.

    Args:
        node_pos: Current node positions
        ratio: Ratio for scaling
        max_width, max_height: Canvas dimensions
        color_config: Color configuration mapping nodes to colors
        grid: Grid size
        all_suggestions: Dictionary of node_name -> region coordinates
        highlight_nodes: List of nodes to highlight
        grid_img_width, grid_img_height: Image dimensions
        margin_left, margin_right, margin_bottom: Margins
        return_bytes: Whether to return bytes or PIL image

    Returns:
        Image with all suggestions rendered
    """
    # Adjust grid image size to preserve canvas aspect ratio
    grid_img_width, grid_img_height = _compute_grid_img_size(
            max_width, max_height, grid_img_width, grid_img_height)

    # Get the base image (with semi-transparent macros so suggestion boxes show through)
    placement_pil_image = render(
            node_pos,
            max_width=max_width,
            max_height=max_height,
            color_config=color_config,
            grid=grid,
            grid_img_width=grid_img_width,
            grid_img_height=grid_img_height,
            margin_left=margin_left,
            margin_right=margin_right,
            margin_bottom=margin_bottom,
            highlight_nodes=highlight_nodes,
            node_name_to_short_name=node_name_to_short_name,
            return_bytes=False,
            macro_alpha=macro_alpha
    )

    # Scale factors for the grid drawing area only
    scale_x = grid_img_width / max_width
    scale_y = grid_img_height / max_height

    # Draw each suggestion
    draw = ImageDraw.Draw(placement_pil_image)

    # Loop through all suggestions and draw them
    for node_name, suggestion in all_suggestions.items():
        if suggestion is None:
            continue
        short_name = node_name_to_short_name[node_name]
        hex_color = color_config[node_name]
        # Transform the coordinates
        adjusted_coords = []
        for sx, sy in suggestion:
            x_real = margin_left + sx * ratio_x * scale_x
            # Flip y-coordinate to match matplotlib's bottom-left origin
            y_real = grid_img_height - (sy * ratio_y * scale_y)
            adjusted_coords.append((x_real, y_real))

        # Find the rectangle dimensions
        min_x = min(adjusted_coords[0][0], adjusted_coords[2][0])
        max_x = max(adjusted_coords[0][0], adjusted_coords[2][0])
        min_y = min(adjusted_coords[0][1], adjusted_coords[2][1])
        max_y = max(adjusted_coords[0][1], adjusted_coords[2][1])

        # Draw the rectangle
        draw.rectangle([(min_x, min_y), (max_x, max_y)], outline=hex_color,
                       width=4)

        # Add a label with the node name - scale font to fit box
        label_text = f"{short_name}"
        box_width = max_x - min_x
        box_height = max_y - min_y

        # Scale font to fit in the suggestion box (use 80% of box size)
        available_width = box_width * 0.8
        available_height = box_height * 0.8

        font_size = MAX_FONT_SIZE
        fits = False
        while font_size >= MIN_FONT_SIZE and not fits:
            try:
                test_font = ImageFont.truetype("arialbd.ttf", font_size)
            except IOError:
                test_font = ImageFont.load_default().font_variant(size=font_size)

            text_bbox = draw.textbbox((0, 0), label_text, font=test_font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]

            fits = text_width <= available_width and text_height <= available_height
            if not fits:
                font_size -= 1

        # Use the font that fits
        try:
            font = ImageFont.truetype("arialbd.ttf", font_size)
        except IOError:
            font = ImageFont.load_default().font_variant(size=font_size)

        # Position the label at the top left of the rectangle
        text_x = min_x + 5
        text_y = min_y + 5

        # Draw text background for better visibility
        text_bbox = draw.textbbox((text_x, text_y), label_text, font=font)
        draw.rectangle([text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2,
                        text_bbox[3] + 2], fill="#FFFFFF")

        # Draw the text
        draw.text((text_x, text_y), label_text, fill=hex_color, font=font)

    # Return as requested
    if return_bytes:
        buffer = io.BytesIO()
        placement_pil_image.save(buffer, format="PNG", dpi=(300, 300))
        buffer.seek(0)
        return buffer.getvalue()
    else:
        return placement_pil_image


def render_suggestion(node_pos,
        ratio_x,
        ratio_y,
        max_width,
        max_height,
        color_config,
        grid,
        suggestion,
        highlight_node=None,
        hex_color="#FF0000",
        grid_img_width=GRID_IMG_WIDTH,
        grid_img_height=GRID_IMG_HEIGHT,
        margin_left=MARGIN_LEFT,
        margin_right=MARGIN_RIGHT,
        margin_bottom=MARGIN_BOTTOM,
        return_bytes=True):
    # Adjust grid image size to preserve canvas aspect ratio
    grid_img_width, grid_img_height = _compute_grid_img_size(
            max_width, max_height, grid_img_width, grid_img_height)

    # Get the base image
    placement_pil_image = render(node_pos,
                                 max_width=max_width,
                                 max_height=max_height,
                                 color_config=color_config,
                                 grid=grid,
                                 grid_img_width=grid_img_width,
                                 grid_img_height=grid_img_height,
                                 margin_left=margin_left,
                                 margin_right=margin_right,
                                 margin_bottom=margin_bottom,
                                 highlight_nodes=[highlight_node],
                                 return_bytes=False)

    # Now add the suggestion rectangle overlay
    draw = ImageDraw.Draw(placement_pil_image)

    # Scale factors for the grid drawing area only
    scale_x = grid_img_width / max_width
    scale_y = grid_img_height / max_height

    # Transform the coordinates
    adjusted_coords = []
    for sx, sy in suggestion:
        x_real = margin_left + sx * ratio_x * scale_x
        # Flip y-coordinate to match matplotlib's bottom-left origin
        y_real = grid_img_height - (sy * ratio_y * scale_y)
        adjusted_coords.append((x_real, y_real))

    # Draw the rectangle
    # Note: For a proper rectangle, we need to ensure the coordinates are ordered correctly
    min_x = min(adjusted_coords[0][0], adjusted_coords[2][0])
    max_x = max(adjusted_coords[0][0], adjusted_coords[2][0])
    min_y = min(adjusted_coords[0][1], adjusted_coords[2][1])
    max_y = max(adjusted_coords[0][1], adjusted_coords[2][1])

    # draw.rectangle([(min_x, min_y), (max_x, max_y)], outline=hex_color, width=2)
    # make the rectangle more bold
    draw.rectangle([(min_x, min_y), (max_x, max_y)], outline=hex_color, width=6)
    if return_bytes:
        buffer = io.BytesIO()
        placement_pil_image.save(buffer, format="PNG")
        buffer.seek(0)
        return buffer.getvalue()
    else:
        return placement_pil_image
