
from fastapi import FastAPI, File, UploadFile,Form
import os
from datetime import datetime
from fastapi.responses import JSONResponse
from utils import detect_text_save
import sys
from fastapi.responses import FileResponse

from openai import OpenAI
from google import genai
import random
from generate_agents import generate_single_step
from functools import partial

async def to_thread(func, *args, **kwargs):
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(None, partial(func, *args, **kwargs))

client = genai.Client(api_key="")
# client = OpenAI(api_key='')

model_path = ''

app = FastAPI()
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)


output_dir = 'output'
os.makedirs(output_dir,exist_ok=True)

GPU_IDS = [0,1] 
import asyncio
gpu_queue = asyncio.Queue()
for gpu_id in GPU_IDS:
    gpu_queue.put_nowait(gpu_id)

async def process_with_model(file, action: str, gpu_id: int):
    random_suffix = random.randint(10, 99) 

    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    filename = f"{timestamp}_{random_suffix}_{file.filename}"
    ui_image_path = os.path.join(UPLOAD_DIR, filename)
    layout_image_path  = os.path.join(UPLOAD_DIR, filename[:-4]+'_layout.png')
    with open(ui_image_path, "wb") as f:
        content = await file.read()
        f.write(content)

    detect_text_save(ui_image_path,layout_image_path,min_width=10,min_height=10,output_path_id = None)
    new_ui_path, new_layout_path = await to_thread(
    generate_single_step,
    ui_image_path,
    layout_image_path,
    action,
    client,
    output_dir,
    str(timestamp)+"_"+str(random_suffix),
     gpu_id,
    model_path
)
    return new_ui_path

@app.post("/predict/")
async def upload_image(file: UploadFile = File(...),action: str = Form(...)):
    gpu_id = await gpu_queue.get()

    try:
        new_ui_path = await process_with_model(file, action, gpu_id)
        return FileResponse(
        path=new_ui_path,
        media_type="image/jpeg",  
        filename='world_model.png'
            )

    finally:
        await gpu_queue.put(gpu_id)

