import os
from typing import Optional, Union, List, Tuple
import torch
from PIL import Image
from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoTokenizer,
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
)
from qwen_vl_utils import process_vision_info  # your custom vision processing utility


class QWenWrapper:
    def __init__(
        self,
        model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct",
        max_new_tokens: int = 1536,
        torch_dtype=torch.bfloat16,
        device: str = "cuda",
    ):
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens
        self.device = device
        self.temperature = 0.9  # kept for API compatibility
        if "AWQ" in model_name:
            torch_dtype = torch.float16

        # Load the model and processor using transformers.
        if "Qwen2.5-VL" in model_name:
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map="auto",
                attn_implementation="flash_attention_2",
            )
        if "Qwen2-VL" in model_name:
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map="auto",
                attn_implementation="flash_attention_2",
            )

        self.processor = AutoProcessor.from_pretrained(model_name)

    def get_prediction(
        self,
        image_path: Union[str, List[str]],
        prompt: str,
        passages: Optional[Union[List[str], List[Tuple[str, str]]]] = None,
        passage_prompt: Optional[str] = None,
    ) -> str:
        """
        Get predictions from the model given the image(s), prompt, and optional passages.
        This method matches the API of your original vLLM-based implementation.
        """

        # Construct a chat-style message.
        message = {"role": "user", "content": []}

        if not isinstance(image_path, list):
            image_path = [image_path]

        if passages is not None:
            for passage in passages:  # List of dictionaries
                if isinstance(passage, dict):
                    if "image_path" in passage:
                        message["content"].append(
                            {
                                "type": "image",
                                "image": "file://" + passage["image_path"],
                                "max_pixels": 512 * 512,
                            }
                        )
                    message["content"].append(
                        {"type": "text", "text": passage["caption"]}
                    )
                if isinstance(passage, tuple) and len(passage) == 2:
                    passage_image, passage_text = passage
                    # If an image is provided in the tuple, add it.
                    if passage_image:
                        message["content"].append(
                            {"type": "image", "image": "file://" + passage_image}
                        )
                    # Append the associated text.
                    message["content"].append({"type": "text", "text": passage_text})
                elif isinstance(passage, str):
                    message["content"].append({"type": "text", "text": passage})

        if passage_prompt:
            message["content"].append({"type": "text", "text": passage_prompt})

        for img in image_path:
            message["content"].append(
                {"type": "image", "image": "file://" + img, "max_pixels": 512 * 512}
            )
        message["content"].append({"type": "text", "text": prompt})

        messages = [message]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")

        # Inference
        generated_ids = self.model.generate(
            **inputs, max_new_tokens=self.max_new_tokens, do_sample=False
        )
        generated_ids_trimmed = [
            out_ids[len(in_ids) :]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )
        return output_text[0] if output_text else ""


# Example usage:

if __name__ == "__main__":
    # Initialize the wrapper
    qwen_wrapper = QWenWrapper()

    # Example inputs
    image_paths = "/drl_nas1/ckddls1321/data/coco/val2014/COCO_val2014_000000192168.jpg"
    prompt = "What is the outfit this man is wearing called?"

    # Get prediction
    result = qwen_wrapper.get_prediction(image_paths, prompt)
    print("Prediction:", result)
