# sam2_api.py

import io
import base64
from fastapi import FastAPI, Request
from pydantic import BaseModel
from PIL import Image
import numpy as np
import torch
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from typing import List

app = FastAPI()

device = "cuda" if torch.cuda.is_available() else "cpu"
model = build_sam3_image_model(checkpoint_path="facebook/sam3/sam3.pt", device=device)
segmentation_model = Sam3Processor(model)

class PredictRequest(BaseModel):
    image_base64: str
    bboxes: List[List[float]]           # List of [x1, y1, x2, y2]
    nouns: List[str]                    # List of noun phrases

class PredictResponse(BaseModel):
    mask_base64: str  

def decode_image(base64_str: str) -> Image.Image:
    image_data = base64.b64decode(base64_str)
    return Image.open(io.BytesIO(image_data)).convert("RGB")

def encode_mask_to_base64(mask: np.ndarray) -> str:
    pil_img = Image.fromarray((mask * 255).astype(np.uint8)) 
    buffer = io.BytesIO()
    pil_img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest):
    image = decode_image(request.image_base64)
    nouns = request.nouns
    bboxes = request.bboxes
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        mask_all = np.zeros((image.height, image.width), dtype=bool)
        inference_state = segmentation_model.set_image(image)
        # Prompt the model with text
        for noun in nouns:
            output = segmentation_model.set_text_prompt(state=inference_state, prompt=noun)
            # Get the masks, bounding boxes, and scores
            masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
            for mask in zip(masks):
                mask = mask[0][0].cpu().numpy().astype(bool)
                mask_all = np.logical_or(mask_all, mask)
        mask_base64 = encode_mask_to_base64(mask_all)
    return PredictResponse(mask_base64=mask_base64)