from google import genai
from google.genai import types
from google.genai.types import GenerateContentConfig, GoogleSearch, Tool

import os

from dotenv import load_dotenv

import matplotlib as mpl
import matplotlib.lines as lines
import matplotlib.patches as patches
import matplotlib.pyplot as plt

import numpy as np

GEMINI_2DOT5_PRO_PREVIEW_05_06 = "gemini-2.5-pro-preview-05-06"

TOP = 0
BOTTOM = 1
LEFT = 2
RIGHT = 3

EDGES_COLOR = 0
FILL_COLOR = 1

WALL_FLAGS = [TOP, BOTTOM, LEFT, RIGHT]
WALL_NAMES = ["top", "bottom", "left", "right"]


def draw_room_shape(ax, objects, width, depth, wall_flags):
    """Draw the outline of a room and any L‑shaped cut‑outs."""

    for i, wall_name in enumerate(WALL_NAMES):
        line_style = "solid" if wall_flags[i] else "dashed"
        line_width = 1 if wall_flags[i] else 0.5

        if i == TOP:
            ax.add_line(
                lines.Line2D([0, width], [0, 0], linewidth=line_width, linestyle=line_style, color="black")
            )
        elif i == BOTTOM:
            ax.add_line(
                lines.Line2D(
                    [0, width], [depth, depth], linewidth=line_width, linestyle=line_style, color="black"
                )
            )
        elif i == LEFT:
            ax.add_line(
                lines.Line2D([0, 0], [0, depth], linewidth=line_width, linestyle=line_style, color="black")
            )
        elif i == RIGHT:
            ax.add_line(
                lines.Line2D(
                    [width, width], [0, depth], linewidth=line_width, linestyle=line_style, color="black"
                )
            )

    # Handle cut‑outs for L‑shaped layouts
    for obj in objects:
        if obj["label"] == "cutout_area":
            y1, x1, y2, x2 = obj["bbox"]
            cutout = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, hatch="///")
            ax.add_patch(cutout)


def draw_living_room_layout(room_info):
    """Quick (experimental) top‑left‑origin plan view for a single living room."""

    room_width = room_info["room"]["width"]
    room_depth = room_info["room"]["depth"]
    objects = room_info["objects"]

    fig, ax = plt.subplots(figsize=(10, 8))

    # Outer room rectangle
    ax.add_patch(
        patches.Rectangle((0, 0), room_width, room_depth, linewidth=2, edgecolor="black", facecolor="none")
    )

    for obj in objects:
        if "bbox" in obj and obj["bbox"]:
            y1, x1, y2, x2 = obj["bbox"]

            # Nudge zero‑depth/width windows a little so the rectangle is visible
            if obj["label"] == "window":
                if y1 == y2:
                    y1 += 0.1
                    y2 -= 0.1
                elif x1 == x2:
                    x1 += 0.1
                    x2 -= 0.1

            rect_width = x2 - x1
            rect_height = y2 - y1

            ax.add_patch(
                patches.Rectangle(
                    (x1, y1),
                    rect_width,
                    rect_height,
                    linewidth=1,
                    edgecolor="blue",
                    facecolor="lightblue",
                    alpha=0.5,
                )
            )

            ax.text(
                x1 + rect_width / 2,
                y1 + rect_height / 2,
                obj["label"],
                ha="center",
                va="center",
                fontsize=8,
            )

    ax.set_xlim(0, room_width)
    ax.set_ylim(0, room_depth)
    ax.set_xlabel("Width (m)")
    ax.set_ylabel("Depth (m)")
    ax.set_title("Living Room Layout (Top‑Left Origin)")
    ax.set_aspect("equal")
    ax.invert_yaxis()
    plt.grid(True)
    plt.show()


def draw_room(room_info, png_file_path: str | None = None, print_to_screen: bool = True):
    """General‑purpose plan‑view renderer for a (possibly L‑shaped) room."""

    width = room_info["room"]["width"]
    depth = room_info["room"]["depth"]
    room_type = room_info["room"]["shape"]
    objects = room_info["objects"]

    wall_flags = np.ones(len(WALL_FLAGS), dtype=bool)

    if room_type == "open":
        shape_desc = room_info["room"]["shape_description"].lower()
        for name in WALL_NAMES:
            if name in shape_desc:
                wall_flags[WALL_NAMES.index(name)] = False

    # Matplotlib tweaks
    mpl.rcParams["figure.dpi"] = 300
    mpl.rcParams["hatch.linewidth"] = 0.25

    fig, ax = plt.subplots()
    ax.set_xlim(-0.5, width + 0.5)
    ax.set_ylim(depth + 0.5, -0.5)

    # Basic colour palettes
    glass = ["royalblue", "aliceblue"]
    lamp = ["darkgoldenrod", "papayawhip"]
    wood = ["brown", "linen"]
    sofa = ["chocolate", "seashell"]

    color_dict = {
        "armchair": sofa,
        "bed": sofa,
        "bench": sofa,
        "bookshelf": wood,
        "chair": sofa,
        "closet_alcove": wood,
        "coffee_table": wood,
        "cutout_area": ["black", "white"],
        "desk": wood,
        "door": wood,
        "dresser": wood,
        "fireplace": ["maroon", "mistyrose"],
        "floor": ["goldenrod", "cornsilk"],
        "floor_lamp": lamp,
        "loveseat": sofa,
        "mirror": glass,
        "nightstand": wood,
        "ottoman": sofa,
        "plant": ["darkgreen", "mintcream"],
        "rug": ["navy", "lavender"],
        "side_table": wood,
        "sofa": sofa,
        "table_lamp": lamp,
        "television": ["dimgrey", "lightgrey"],
        "tv_stand": wood,
        "wardrobe": wood,
        "window": glass,
    }

    # Floor hatch
    ax.add_patch(
        patches.Rectangle(
            (0, 0),
            width,
            depth,
            fill=True,
            color=color_dict["floor"][FILL_COLOR],
            hatch="///",
        )
    )

    # Room outline
    draw_room_shape(ax, objects, width, depth, wall_flags)

    # First pass: rug, then small furniture underlays, then everything else
    def draw_object(obj):
        y1, x1, y2, x2 = obj["bbox"][:4]

        obj_width = abs(x2 - x1)
        obj_depth = abs(y2 - y1)

        horizontal = obj_width >= obj_depth
        rotation = 0 if horizontal else 90

        colours = color_dict.get(obj["label"], ["black", "white"])

        # Handle zero‑dimension objects (doors, windows)
        min_obj_width = 0.1
        if obj_width == 0:
            if x2 == width:
                x1 = width - min_obj_width
            if x1 == 0:
                x2 = min_obj_width
        if obj_depth == 0:
            if y2 == depth:
                y1 = depth - min_obj_width
            if y1 == 0:
                y2 = min_obj_width

        obj_width = abs(x2 - x1)
        obj_depth = abs(y2 - y1)

        rect = patches.Rectangle(
            (x1, y1),
            obj_width,
            obj_depth,
            fill=True,
            linewidth=0.5,
            edgecolor=colours[EDGES_COLOR],
            facecolor=colours[FILL_COLOR],
            alpha=1.0,
        )
        ax.add_patch(rect)

        if obj["label"] not in {"rug", "side_table", "tv_stand"}:
            font_size = 2 if obj["label"] == "table_lamp" else 4
            ax.text(
                0.5 * (x1 + x2),
                0.5 * (y1 + y2),
                obj["label"],
                color=colours[EDGES_COLOR],
                fontsize=font_size,
                ha="center",
                va="center",
                rotation=rotation,
            )

    for obj in objects:
        if obj["label"] == "rug":
            draw_object(obj)

    for obj in objects:
        if obj["label"] in {"tv_stand", "side_table"}:
            draw_object(obj)

    for obj in objects:
        if obj["label"] not in {"cutout_area", "rug", "tv_stand", "side_table"}:
            draw_object(obj)

    style_suffix = (
        " " + room_info["layout_style"]["chosen_layout_style"] + ", "
        if "layout_style" in room_info and "chosen_layout_style" in room_info["layout_style"]
        else " "
    )

    ax.set_title(
        f"ID: {room_info['layout_id']},{style_suffix}{room_type}, {width * depth:.2f} sqm"
    )
    ax.set_aspect("equal")
    plt.grid(False)

    if png_file_path:
        fig.savefig(png_file_path, bbox_inches="tight", dpi=300)

    if print_to_screen:
        plt.show()

    plt.close(fig)


def generate(prompt: str, model_id: str):
    """Helper wrapper around Gemini generate‑content streaming API."""

    load_dotenv()  # Uses .env in the current working directory
    api_key = os.getenv("GEMINI_API_KEY")

    client = genai.Client(api_key=api_key)

    contents = [
        types.Content(
            role="user",
            parts=[types.Part.from_text(text=prompt)],
        )
    ]
    generate_content_config = types.GenerateContentConfig(
        response_mime_type="text/plain",
    )

    output_chunks = client.models.generate_content_stream(
        model=model_id,
        contents=contents,
        config=generate_content_config,
    )

    return "".join(chunk.text for chunk in output_chunks if chunk.text)
