import argparse
import io
import re
import warnings
from typing import Any, List, Optional

import torch
from fastapi import FastAPI, File, Form, UploadFile
from PIL import Image
from transformers import Blip2ForConditionalGeneration, Blip2Processor

warnings.filterwarnings("ignore")

# Parse command-line arguments
parser = argparse.ArgumentParser(description="Host a model with FastAPI.")
parser.add_argument("--model_name", type=str, default="Salesforce/blip2-flan-t5-xl", help="Name of the model to use")
parser.add_argument("--port", type=int, default=9555, help="Port to run the server on")
parser.add_argument("--device", type=str, default="cuda", help='Device to run the model on (e.g., "cpu", "cuda")')
args = parser.parse_args()

app = FastAPI()

dtype = "bfloat16" if "cuda" in args.device else "auto"


# Different versions of BLIP2 in HF are showing sensitivity to image resolution / fp precision
def try_load_model_high_precision(model_id: str, original_dtype: str) -> Any:
    if original_dtype == "bfloat16" or original_dtype == "float16" or original_dtype == "auto":
        try:
            model = Blip2ForConditionalGeneration.from_pretrained(model_id, torch_dtype="float32")
            return model
        except Exception as e:
            model = Blip2ForConditionalGeneration.from_pretrained(model_id, torch_dtype=original_dtype)
            return model
    else:
        model = Blip2ForConditionalGeneration.from_pretrained(model_id, torch_dtype=original_dtype)
        return model


# server-cuda:0 -> cuda:0 // server-cuda -> cuda // cuda -> cuda // cuda:0 -> cuda:0
args.device = re.sub(r"^server-(cuda(:\d+)?)$", r"\1", args.device)
torch.cuda.empty_cache()
if "blip2" in args.model_name.lower():
    processor = Blip2Processor.from_pretrained(args.model_name)
    model = try_load_model_high_precision(args.model_name, dtype)
    model.to(args.device)
else:
    raise ValueError(f"Model {args.model_name} is not supported. Currently, only BLIP-2 models are supported.")


@app.post("/")
async def root() -> dict:
    return {"message": "Hello World"}


@app.post("/caption/")
async def caption_images(
    files: List[UploadFile] = File(...), prompt: Optional[List[str]] = Form(None), max_new_tokens: int = Form(32)
) -> dict:
    images = [Image.open(io.BytesIO(await file.read())) for file in files]

    if not prompt:  # Obs: new versions of Blip2 requires empty prompts as list of empty strings
        prompt = [""] * len(images)

    # Try to caption full batch
    try:
        inputs = processor(images=images, text=prompt, return_tensors="pt").to(args.device)
        generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
        captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

        # # Remove unicode, spaces, and newlines
        # captions = [re.sub(r'[^\x00-\x7F]+', '', caption).strip() for caption in captions]

        for caption in captions:
            if "taiwan " in caption:
                print("stop")

        return {"captions": captions}

    except Exception as e:
        # try to caption one by one
        captions = []
        errors = []
        for i, (img, p) in enumerate(zip(images, prompt)):
            try:
                inputs = processor(images=[img], text=[p], return_tensors="pt").to(args.device, dtype)
                generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
                captions.append(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])

            except Exception as e:
                print(f"Error: {e}")
                errors.append(e)
                captions.append("")

        for caption in captions:
            if "tv" in caption:
                print("stop")

        return {"captions": captions, "errors": errors}


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=args.port)
