from pprint import pprint
import argparse
import base64
import json
import os
import re
import time
import uuid
from contextlib import asynccontextmanager
from io import BytesIO
from threading import Thread
from typing import List, Literal, Optional, Union, get_args
import PIL
import requests
import torch
from PIL import Image as PILImage
from pydantic import BaseModel
from llava.entry import load
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
    KeywordsStoppingCriteria,
    get_model_name_from_path,
    process_images,
    tokenizer_image_token,
)
from llava.utils import disable_torch_init
from llava import media


class TextContent(BaseModel):
    type: Literal["text"]
    text: str


class MediaURL(BaseModel):
    url: str


class ImageContent(BaseModel):
    type: Literal["image_url"]
    image_url: MediaURL


class ChatMessage(BaseModel):
    role: Literal["user", "assistant"]
    content: Union[str, List[Union[TextContent, ImageContent]]]


IMAGE_CONTENT_BASE64_REGEX = re.compile(r"^data:image/(png|jpe?g);base64,(.*)$")
VIDEO_CONTENT_BASE64_REGEX = re.compile(r"^data:video/(mp4);base64,(.*)$")


def load_image(image_url: str) -> PILImage:
    if image_url.startswith("http") or image_url.startswith("https"):
        response = requests.get(image_url)
        image = PILImage.open(BytesIO(response.content)).convert("RGB")
    elif image_url.startswith("/"):
        image = PILImage.open(image_url).convert("RGB")
    else:
        match_results = IMAGE_CONTENT_BASE64_REGEX.match(image_url)
        if match_results is None:
            raise ValueError(f"Invalid image url: {image_url[:64]}")
        image_base64 = match_results.groups()[1]
        image = PILImage.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
    return image


class NVILAWrapper:
    def __init__(
        self,
        model_path: str = "Efficient-Large-Model/NVILA-Lite-15B",
        max_new_tokens: int = 1536,
        torch_dtype=torch.bfloat16,
        device: str = "cuda",
    ):
        self.model_path = model_path
        self.max_new_tokens = max_new_tokens
        self.device = device
        self.model = load(model_path)

    def get_prediction(
        self,
        image_path: Union[str, List[str]],
        prompt: str,
        passages: Optional[Union[List[str], List[tuple], List[dict]]] = None,
        passage_prompt: Optional[str] = None,
    ) -> str:
        combined_prompt = []
        if passages:
            for passage in passages:
                if isinstance(passage, dict):
                    if "image_path" in passage:
                        combined_prompt.append(load_image(passage["image_path"]))
                    combined_prompt.append(passage["caption"])
                elif isinstance(passage, tuple) and len(passage) == 2:
                    passage_image, passage_text = passage
                    if passage_image:
                        combined_prompt.append(load_image(passage_image))
                    combined_prompt.append(passage_text)
                elif isinstance(passage, str):
                    combined_prompt.append(passage)

        # Append an optional passage prompt.
        if passage_prompt:
            combined_prompt.append(passage_prompt)

        # Process image inputs.
        if not isinstance(image_path, list):
            image_path = [image_path]
        for img in image_path:
            combined_prompt.append(load_image(img))

        combined_prompt.append(prompt)
        outputs = self.model.generate_content(combined_prompt)
        return outputs


if __name__ == "__main__":
    # Initialize the NVILA wrapper with the specified model path.
    nvila_wrapper = NVILAWrapper(model_path="Efficient-Large-Model/NVILA-Lite-15B")

    # Example input: a single image path and a prompt.
    image_paths = "/drl_nas1/ckddls1321/data/coco/val2014/COCO_val2014_000000192168.jpg"
    prompt = "What is the outfit this man is wearing called?"

    # Get prediction.
    result = nvila_wrapper.get_prediction(image_paths, prompt)
    print("Prediction:", result)
