import os
import json
import fire
from dotenv import load_dotenv
from google import genai
from google.genai import types
from google.genai.types import GenerateContentConfig

from prompts import (
    counter_type_choice, layout_generation_prompt,
    l_shaped_kitchen_spec, u_shaped_kitchen_spec, g_shaped_kitchen_spec,
    galley_kitchen_spec, island_kitchen_spec, one_row_kitchen_spec,
    open_room, l_shaped_room, rectangular_room
) # add prompts to the path


def load_api_key() -> str:
    load_dotenv()
    api_key = os.environ.get("GEMINI_API_KEY")
    if not api_key:
        raise ValueError("GEMINI_API_KEY not found in environment variables.")
    return api_key


def generate(prompt: str, model_id: str, api_key: str) -> str:
    client = genai.Client(api_key=api_key)
    contents = [types.Content(role="user", parts=[types.Part.from_text(text=prompt)])]
    config = GenerateContentConfig(response_mime_type="text/plain")

    output_text = []
    for chunk in client.models.generate_content_stream(model=model_id, contents=contents, config=config):
        if chunk.text:
            output_text.append(chunk.text)
    return "".join(output_text)


def pick_right_spec(layout_style: str) -> str:
    mapping = {
        'l-shaped': l_shaped_kitchen_spec,
        'u-shaped': u_shaped_kitchen_spec,
        'galley': galley_kitchen_spec,
        'island': island_kitchen_spec,
        'one row': one_row_kitchen_spec,
        'g-shaped': g_shaped_kitchen_spec
    }
    for key, spec in mapping.items():
        if key in layout_style.lower():
            return spec
    return ""


def pick_one_shot_example(shape: str) -> str:
    mapping = {
        'open': open_room,
        'l-shape': l_shaped_room,
        'rectangular': rectangular_room
    }
    for key, example in mapping.items():
        if key in shape.lower():
            return example
    return ""


def process_room_layouts(
    model_id: str = "gemini-2.5-pro-preview-05-06",
    room_batch_dir: str = "generated_results/base_rooms",
    output_dir: str = "generated_results/kitchens",
    room_number: int = 600
):
    api_key = load_api_key()
    os.makedirs(output_dir, exist_ok=True)

    for i in range(room_number // 20):
        batch_path = os.path.join(room_batch_dir, f"generated_kitchens_{i}.json")
        if not os.path.isfile(batch_path):
            print(f"Skipped missing batch file: {batch_path}")
            continue

        with open(batch_path, "r") as file:
            processed_rooms = json.load(file)

        for idx, room in enumerate(processed_rooms):
            room['layout_id'] = room['layout_id'] + i * 20
            print(f"Processing room {room['layout_id']}")

            try:
                # Step 1: Counter type selection
                counter_type_prompt = counter_type_choice.format(room_input=room)
                counter_response = generate(counter_type_prompt, model_id, api_key)
                chosen_counter_style = json.loads(counter_response.strip().removeprefix('```json').removesuffix('```'))
                print(f"Chosen Counter Style: {chosen_counter_style['chosen_counter_style']}")

                # Step 2: Layout generation
                layout_prompt = layout_generation_prompt.format(
                    room_details=room,
                    layout_style=chosen_counter_style,
                    kitchen_spec=pick_right_spec(chosen_counter_style['chosen_counter_style']),
                    layout_room=pick_one_shot_example(room['room']['shape'])
                )

                layout_response = generate(layout_prompt, model_id, api_key)
                generated_room = json.loads(layout_response.strip().removeprefix('```json').removesuffix('```'))

                # Save generated layout
                output_path = os.path.join(output_dir, f"kitchen_{room['layout_id']}.json")
                with open(output_path, "w") as out_file:
                    json.dump(generated_room, out_file, indent=4)

                print(f"Saved layout to {output_path}")

            except Exception as e:
                print(f"Error processing room {room['layout_id']}: {e}")


if __name__ == "__main__":
    fire.Fire(process_room_layouts)
