import os
import json
import fire
from dotenv import load_dotenv
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from google import genai
from google.genai import types
from google.genai.types import GenerateContentConfig

from prompts import room_generation_prompt


def load_api_key() -> str:
    load_dotenv()
    api_key = os.environ.get("GEMINI_API_KEY")
    if not api_key:
        raise ValueError("GEMINI_API_KEY not found in environment variables.")
    return api_key


def generate(prompt: str, model_id: str, api_key: str) -> str:
    client = genai.Client(api_key=api_key)
    contents = [types.Content(role="user", parts=[types.Part.from_text(text=prompt)])]
    config = GenerateContentConfig(response_mime_type="text/plain")

    output_text = []
    for chunk in client.models.generate_content_stream(model=model_id, contents=contents, config=config):
        if chunk.text:
            output_text.append(chunk.text)
    return "".join(output_text)


def save_json(data, output_path: str):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w") as file:
        json.dump(data, file, indent=4)


def run_generation(
    model_id: str = "gemini-2.5-pro-preview-05-06",
    room_number: int = 600,
    output_dir: str = "generated_results/base_room",
):
    api_key = load_api_key()
    prompt = room_generation_prompt

    for i in range(room_number // 20):
        print(f"Generating batch {i + 1} of {room_number // 20}...")
        raw_output = generate(prompt, model_id, api_key)
        cleaned_json = raw_output.strip().removeprefix('```json').removesuffix('```')
        try:
            processed_rooms = json.loads(cleaned_json)
        except json.JSONDecodeError as e:
            print(f"Failed to parse JSON on iteration {i}: {e}")
            continue

        output_path = os.path.join(output_dir, f"generated_kitchens_{i}.json")
        save_json(processed_rooms, output_path)
        print(f"Saved to {output_path}")


if __name__ == "__main__":
    fire.Fire(run_generation)
