"""Utilities for layout generation and visualization."""
import os

import matplotlib as mpl
import matplotlib.lines as lines
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from google import genai
from google.genai import types

GEMINI_2DOT5_PRO_PREVIEW_05_06 = "gemini-2.5-pro-preview-05-06"

TOP = 0
BOTTOM = 1
LEFT = 2
RIGHT = 3

EDGES_COLOR = 0
FILL_COLOR = 1

WALL_FLAGS = [TOP, BOTTOM, LEFT, RIGHT]
WALL_NAMES = ["top", "bottom", "left", "right"]


def draw_room_shape(ax, objects, width, depth, wall_flags):
    """Draw the outline of a room and any L-shaped cut-outs."""
    for i, wall_name in enumerate(WALL_NAMES):
        line_style = "solid" if wall_flags[i] else "dashed"
        line_width = 1 if wall_flags[i] else 0.5

        if i == TOP:
            ax.add_line(
                lines.Line2D(
                    [0, width], [0, 0],
                    linewidth=line_width, linestyle=line_style, color="black"
                )
            )
        elif i == BOTTOM:
            ax.add_line(
                lines.Line2D(
                    [0, width], [depth, depth],
                    linewidth=line_width, linestyle=line_style, color="black"
                )
            )
        elif i == LEFT:
            ax.add_line(
                lines.Line2D(
                    [0, 0], [0, depth],
                    linewidth=line_width, linestyle=line_style, color="black"
                )
            )
        elif i == RIGHT:
            ax.add_line(
                lines.Line2D(
                    [width, width], [0, depth],
                    linewidth=line_width, linestyle=line_style, color="black"
                )
            )

    # Handle cut-outs for L-shaped layouts
    for obj in objects:
        if obj["label"] == "cutout_area":
            y1, x1, y2, x2 = obj["bbox"]
            cutout = patches.Rectangle(
                (x1, y1), x2 - x1, y2 - y1, fill=False, hatch="///"
            )
            ax.add_patch(cutout)


def draw_living_room_layout(room_info):
    """Quick (experimental) top-left-origin plan view for a single living room."""
    room_width = room_info["room"]["width"]
    room_depth = room_info["room"]["depth"]
    objects = room_info["objects"]

    fig, ax = plt.subplots(figsize=(10, 8))

    # Outer room rectangle
    ax.add_patch(
        patches.Rectangle(
            (0, 0), room_width, room_depth,
            linewidth=2, edgecolor="black", facecolor="none"
        )
    )

    for obj in objects:
        if "bbox" in obj and obj["bbox"]:
            y1, x1, y2, x2 = obj["bbox"]

            # Nudge zero-depth/width windows a little so the rectangle is visible
            if obj["label"] == "window":
                if y1 == y2:
                    y1 += 0.1
                    y2 -= 0.1
                elif x1 == x2:
                    x1 += 0.1
                    x2 -= 0.1

            rect_width = x2 - x1
            rect_height = y2 - y1

            ax.add_patch(
                patches.Rectangle(
                    (x1, y1),
                    rect_width,
                    rect_height,
                    linewidth=1,
                    edgecolor="blue",
                    facecolor="lightblue",
                    alpha=0.5,
                )
            )

            ax.text(
                x1 + rect_width / 2,
                y1 + rect_height / 2,
                obj["label"],
                ha="center",
                va="center",
                fontsize=8,
            )

    ax.set_xlim(0, room_width)
    ax.set_ylim(0, room_depth)
    ax.set_xlabel("Width (m)")
    ax.set_ylabel("Depth (m)")
    ax.set_title("Living Room Layout (Top-Left Origin)")
    ax.set_aspect("equal")
    ax.invert_yaxis()
    plt.grid(True)
    plt.show()


def draw_room(room_info, png_file_path: str | None = None, print_to_screen: bool = True):
    """General-purpose plan-view renderer for a (possibly L-shaped) room."""
    width = room_info["room"]["width"]
    depth = room_info["room"]["depth"]
    room_type = room_info["room"]["shape"]
    objects = room_info["objects"]

    wall_flags = np.ones(len(WALL_FLAGS), dtype=bool)

    if room_type == "open":
        shape_desc = room_info["room"]["shape_description"].lower()
        for name in WALL_NAMES:
            if name in shape_desc:
                wall_flags[WALL_NAMES.index(name)] = False

    # Matplotlib tweaks
    mpl.rcParams["figure.dpi"] = 300
    mpl.rcParams["hatch.linewidth"] = 0.25

    fig, ax = plt.subplots()
    ax.set_xlim(-0.5, width + 0.5)
    ax.set_ylim(depth + 0.5, -0.5)

    # Basic colour palettes
    glass = ["royalblue", "aliceblue"]
    lamp = ["darkgoldenrod", "papayawhip"]
    wood = ["brown", "linen"]
    sofa = ["chocolate", "seashell"]

    color_dict = {
        "armchair": sofa,
        "bed": sofa,
        "bench": sofa,
        "bookshelf": wood,
        "chair": sofa,
        "closet_alcove": wood,
        "coffee_table": wood,
        "cutout_area": ["black", "white"],
        "desk": wood,
        "door": wood,
        "dresser": wood,
        "fireplace": ["maroon", "mistyrose"],
        "floor": ["goldenrod", "cornsilk"],
        "floor_lamp": lamp,
        "loveseat": sofa,
        "mirror": glass,
        "nightstand": wood,
        "ottoman": sofa,
        "plant": ["darkgreen", "mintcream"],
        "rug": ["navy", "lavender"],
        "side_table": wood,
        "sofa": sofa,
        "table_lamp": lamp,
        "television": ["dimgrey", "lightgrey"],
        "tv_stand": wood,
        "wardrobe": wood,
        "window": glass,
    }

    # Floor hatch
    ax.add_patch(
        patches.Rectangle(
            (0, 0),
            width,
            depth,
            fill=True,
            color=color_dict["floor"][FILL_COLOR],
            hatch="///",
        )
    )

    # Room outline
    draw_room_shape(ax, objects, width, depth, wall_flags)

    # First pass: rug, then small furniture underlays, then everything else
    def draw_object(obj):
        y1, x1, y2, x2 = obj["bbox"][:4]

        obj_width = abs(x2 - x1)
        obj_depth = abs(y2 - y1)

        horizontal = obj_width >= obj_depth
        rotation = 0 if horizontal else 90

        colours = color_dict.get(obj["label"], ["black", "white"])

        # Handle zero-dimension objects (doors, windows)
        min_obj_width = 0.1
        if obj_width == 0:
            if x2 == width:
                x1 = width - min_obj_width
            if x1 == 0:
                x2 = min_obj_width
        if obj_depth == 0:
            if y2 == depth:
                y1 = depth - min_obj_width
            if y1 == 0:
                y2 = min_obj_width

        obj_width = abs(x2 - x1)
        obj_depth = abs(y2 - y1)

        rect = patches.Rectangle(
            (x1, y1),
            obj_width,
            obj_depth,
            fill=True,
            linewidth=0.5,
            edgecolor=colours[EDGES_COLOR],
            facecolor=colours[FILL_COLOR],
            alpha=1.0,
        )
        ax.add_patch(rect)

        if obj["label"] not in {"rug", "side_table", "tv_stand"}:
            font_size = 2 if obj["label"] == "table_lamp" else 4
            ax.text(
                0.5 * (x1 + x2),
                0.5 * (y1 + y2),
                obj["label"],
                color=colours[EDGES_COLOR],
                fontsize=font_size,
                ha="center",
                va="center",
                rotation=rotation,
            )

    for obj in objects:
        if obj["label"] == "rug":
            draw_object(obj)

    for obj in objects:
        if obj["label"] in {"tv_stand", "side_table"}:
            draw_object(obj)

    for obj in objects:
        if obj["label"] not in {"cutout_area", "rug", "tv_stand", "side_table"}:
            draw_object(obj)

    style_suffix = (
        " " + room_info["layout_style"]["chosen_layout_style"] + ", "
        if "layout_style" in room_info and "chosen_layout_style" in room_info["layout_style"]
        else " "
    )

    ax.set_title(
        f"ID: {room_info['layout_id']},{style_suffix}{room_type}, {width * depth:.2f} sqm"
    )
    ax.set_aspect("equal")
    plt.grid(False)

    if png_file_path:
        fig.savefig(png_file_path, bbox_inches="tight", dpi=300)

    if print_to_screen:
        plt.show()

    plt.close(fig)


def generate(prompt: str, model_id: str):
    """Helper wrapper around Gemini generate-content streaming API."""
    load_dotenv()  # Uses .env in the current working directory
    api_key = os.getenv("GEMINI_API_KEY")

    client = genai.Client(api_key=api_key)

    contents = [
        types.Content(
            role="user",
            parts=[types.Part.from_text(text=prompt)],
        )
    ]
    generate_content_config = types.GenerateContentConfig(
        response_mime_type="text/plain",
    )

    output_chunks = client.models.generate_content_stream(
        model=model_id,
        contents=contents,
        config=generate_content_config,
    )

    return "".join(chunk.text for chunk in output_chunks if chunk.text)
