import torch
from PIL import Image
from transformers import AutoModelForCausalLM


class OvisWrapper:
    def __init__(self, model_path="AIDC-AI/Ovis2-1B", device="cuda", verbose=False):
        """
        Initialize the OvisWrapper with the Ovis model and tokenizers.

        Args:
            model_path (str): Path to the pretrained Ovis model (default: "AIDC-AI/Ovis2-1B").
            device (str): Device to run the model on (default: "cuda").
            verbose (bool): Whether to print error messages (default: False).
        """
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            multimodal_max_length=32768,
            trust_remote_code=True,
        ).to(device)
        self.text_tokenizer = self.model.get_text_tokenizer()
        self.visual_tokenizer = self.model.get_visual_tokenizer()
        self.device = device
        self.verbose = verbose
        self.fail_msg = "Failed to obtain answer via API."

    def get_prediction(
        self,
        image_path=None,
        prompt=None,
        max_partition=None,
        passages=None,
        passage_prompt=None,
    ):
        """
        Generate a prediction based on the provided image(s) and prompt.

        Args:
            image_path (str or list): Path to a single image or list of image paths.
            prompt (str): Text prompt to accompany the image(s) or standalone query.
            max_partition (int, optional): Parameter for image processing; defaults based on input type.

        Returns:
            str: The model's generated response.
        """
        # Determine max_partition if not provided
        if isinstance(image_path, str):
            images = [image_path]
        else:
            images = image_path

        if max_partition is None:
            if image_path and isinstance(image_path, list) and len(image_path) > 1:
                max_partition = 4  # Default for multiple images
            else:
                max_partition = 9  # Default for single image or text-only

        # Process passages first
        pil_images = []
        image_cnt = 0
        passage_query = ""
        if passages:
            for passage in passages:
                if isinstance(passage, str):
                    passage_query += passage
                if isinstance(passage, dict):
                    if "image_path" in passage:
                        pil_image = Image.open(passage["image_path"])
                        pil_images.append(pil_image)
                        passage_query += f"Image {image_cnt+1}: <image>\n"
                    if "caption" in passage:
                        passage_query += (
                            f"Caption {image_cnt+1}: {passage['caption']}\n"
                        )
                    image_cnt += 1
                if isinstance(passage, tuple) and len(passage) == 2:
                    passage_image_path, passage_text = passage
                    pil_image = Image.open(passage_image_path)
                    pil_images.append(pil_image)
                    passage_query += "Image {image_cnt+1}: <image>\n}{passage_text}"
                    image_cnt += 1

        # Handle image_path and construct query
        if len(images) == 1:
            query = f"<image>\n{prompt}" if prompt else "<image>"
        else:
            query = (
                "\n".join([f"Image {i+1}: <image>" for i in range(len(images))])
                + "\n"
                + prompt
            )

        if passage_prompt:
            passage_query += passage_prompt

        # Load images
        for img_path in images:
            try:
                pil_image = Image.open(img_path)
                pil_images.append(pil_image)
            except FileNotFoundError:
                if self.verbose:
                    print(f"Image file not found: {img_path}")
                return self.fail_msg

        # Preprocess inputs
        prompt, input_ids, pixel_values = self.model.preprocess_inputs(
            query, pil_images, max_partition=max_partition
        )

        # Prepare inputs for the model
        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
        input_ids = input_ids.unsqueeze(0).to(device=self.device)
        attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
        if pixel_values is not None:
            pixel_values = pixel_values.to(
                dtype=self.visual_tokenizer.dtype, device=self.visual_tokenizer.device
            )
            pixel_values = [pixel_values]

        # Generate output
        with torch.inference_mode():
            gen_kwargs = dict(
                max_new_tokens=1024,
                do_sample=False,
                top_p=None,
                top_k=None,
                temperature=None,
                repetition_penalty=None,
                eos_token_id=self.model.generation_config.eos_token_id,
                pad_token_id=self.text_tokenizer.pad_token_id,
                use_cache=True,
            )
            output_ids = self.model.generate(
                input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                **gen_kwargs,
            )[0]
            output = self.text_tokenizer.decode(output_ids, skip_special_tokens=True)
        return output


if __name__ == "__main__":
    # Initialize the wrapper (placeholders for model, tokenizers)
    ovis_wrapper = OvisWrapper(model_path="AIDC-AI/Ovis2-8B")

    # Example inputs
    image_path = "/drl_nas1/ckddls1321/data/coco/val2014/COCO_val2014_000000192168.jpg"
    prompt = "What is the outfit this man is wearing called?"

    # Passages as list of dicts with 'image_path' and 'caption'
    passages = [
        {"image_path": image_path, "caption": "The monk is wearing robe."},
    ]
    passage_prompt = "Here is relevant image and their corresponding description"

    # Get prediction
    result = ovis_wrapper.get_prediction(
        image_path=image_path,
        prompt=prompt,
        passages=passages,
        passage_prompt=passage_prompt,
    )
    print("Prediction:", result)
