import torch
from typing import Union, List, Optional, Tuple
from transformers import AutoProcessor, Gemma3ForConditionalGeneration


class GemmaWrapper:
    def __init__(
        self,
        model_name: str = "google/gemma-3-4b-it",
        max_new_tokens: int = 1024,
        torch_dtype=torch.bfloat16,
        device: str = "cuda",
    ):
        """
        Initialize the GemmaWrapper with the Gemma-3 model and processor.
        """
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens
        self.device = device
        self.temperature = 0.9  # kept for API compatibility if needed

        # Load the model and processor
        self.model = Gemma3ForConditionalGeneration.from_pretrained(
            model_name,
            device_map="auto",
            attn_implementation="flash_attention_2",
            torch_dtype=torch_dtype,
        ).eval()
        self.processor = AutoProcessor.from_pretrained(model_name)

    def get_prediction(
        self,
        image_path: Union[str, List[str]],
        prompt: str,
        passages=None,
        passage_prompt=None,
    ) -> str:
        """
        Get a prediction from the model given image(s), a prompt, and optional passages.

        Parameters:
            image_path (str or List[str]): Local file path(s) or URL(s) of the image(s).
            prompt (str): The main prompt text.
            passages (optional): Additional context as strings, tuples, or dictionaries.
            passage_prompt (optional): An extra prompt to append after the passages.

        Returns:
            str: The decoded prediction from the model.
        """
        # Construct a chat-style message.
        message = {"role": "user", "content": []}

        # Append any passages if provided.
        if passages is not None:
            for passage in passages:
                if isinstance(passage, dict):
                    if "image_path" in passage:
                        img_str = passage["image_path"]
                        message["content"].append({"type": "image", "image": img_str})
                    if "caption" in passage:
                        message["content"].append(
                            {"type": "text", "text": passage["caption"]}
                        )
                elif isinstance(passage, tuple) and len(passage) == 2:
                    passage_image, passage_text = passage
                    if passage_image:
                        img_str = passage_image
                        if not passage_image.startswith("http"):
                            img_str = "file://" + passage_image
                        message["content"].append({"type": "image", "image": img_str})
                    message["content"].append({"type": "text", "text": passage_text})
                elif isinstance(passage, str):
                    message["content"].append({"type": "text", "text": passage})

        if passage_prompt:
            message["content"].append({"type": "text", "text": passage_prompt})

        # Add the image(s)
        if not isinstance(image_path, list):
            image_path = [image_path]
        for img in image_path:
            message["content"].append({"type": "image", "image": img})
        # message["content"].append(
        #     {"type": "text", "text": "Here is query image related to question."}
        # )
        message["content"].append({"type": "text", "text": prompt})

        messages = [message]

        # Process the messages to create inputs for the model.
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device, dtype=torch.bfloat16)

        input_len = inputs["input_ids"].shape[-1]

        # Generate the output using inference mode.
        with torch.inference_mode():
            generated = self.model.generate(
                **inputs, max_new_tokens=self.max_new_tokens, do_sample=False
            )
            # Trim the inputs from the generated tokens.
            generated = generated[0][input_len:]

        # Decode the output tokens.
        output_text = self.processor.decode(generated, skip_special_tokens=True)
        return output_text


# Example usage:
if __name__ == "__main__":
    # Initialize the GemmaWrapper.
    gemma_wrapper = GemmaWrapper()

    # Example: Using an image URL and a prompt.
    image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
    prompt = "Describe this image in detail."

    # Get and print the prediction.
    result = gemma_wrapper.get_prediction(image_url, prompt)
    print("Prediction:", result)
