import torch
from typing import Union, List, Optional, Dict
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig


class Phi4MultimodalWrapper:
    def __init__(
        self,
        model_name: str = "microsoft/Phi-4-multimodal-instruct",
        torch_dtype=torch.bfloat16,
        device: str = "cuda",
        attn_implementation: str = "flash_attention_2",
        system_prompt: str = "<|system|>You are a helpful assistant.<|end|>",
    ):
        """
        Initialize the Phi-4 multimodal instruct model, processor, and default system prompt.
        """
        self.model_name = model_name
        self.max_new_tokens = 1536
        self.device = device
        self.system_prompt = system_prompt

        # Load processor and model with remote code
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            trust_remote_code=True,
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch_dtype,
            trust_remote_code=True,
            _attn_implementation=attn_implementation,
        ).to(device)

        # Load generation config
        self.generation_config = GenerationConfig.from_pretrained(
            model_name,
            trust_remote_code=True,
        )

    def get_prediction(
        self,
        images: Union[str, List[str]],
        prompt: str,
        passages: Optional[List[Dict[str, str]]] = None,
        passage_prompt: Optional[str] = None,
    ) -> str:
        """
        Generate a response conditioned on a system prompt, optional passages (image + caption),
        an optional passages_prompt, one or more images, and a user text prompt.
        """
        # Normalize images list
        if not isinstance(images, list):
            images = [images]

        # Special tags
        system_tag = self.system_prompt
        user_tag = "<|user|>"
        assistant_tag = "<|assistant|>"
        end_tag = "<|end|>"

        # Build prompt parts
        counter = 1
        prompt_parts: List[str] = [system_tag, user_tag]

        # Add passages first
        if passages:
            for passage in passages:
                img_path = passage.get("image_path")
                caption = passage.get("caption")
                if img_path:
                    prompt_parts.append(f"<|image_{counter}|>")
                    counter += 1
                if caption:
                    prompt_parts.append(caption)

        # Add passage_prompt if provided
        if passage_prompt:
            prompt_parts.append(passage_prompt)

        # Then main images
        for img in images:
            prompt_parts.append(f"<|image_{counter}|>")
            counter += 1

        # Add the user prompt and tags
        prompt_parts.append(prompt)
        prompt_parts.append(end_tag)
        prompt_parts.append(assistant_tag)
        prompt_text = "".join(prompt_parts)

        # Load images into PIL in same order
        from PIL import Image

        pil_images = []
        if passages:
            for passage in passages:
                img_path = passage.get("image_path")
                if img_path:
                    if img_path.startswith("http"):
                        import requests

                        pil_images.append(
                            Image.open(requests.get(img_path, stream=True).raw)
                        )
                    else:
                        pil_images.append(Image.open(img_path))
        for img in images:
            if img.startswith("http"):
                import requests
                pil_images.append(Image.open(requests.get(img, stream=True).raw))
            else:
                pil_images.append(Image.open(img))

        inputs = self.processor(
            text=prompt_text,
            images=pil_images,
            return_tensors="pt",
        ).to(self.device)

        gen_kwargs = dict(
            max_new_tokens=self.max_new_tokens,
            do_sample=False,
            generation_config=self.generation_config,
        )

        with torch.inference_mode():
            outputs = self.model.generate(
                **inputs,
                **gen_kwargs,
            )
        tokens = outputs[:, inputs["input_ids"].shape[-1] :]
        response = self.processor.batch_decode(
            tokens,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )[0]
        return response


# Example usage
if __name__ == "__main__":
    wrapper = Phi4MultimodalWrapper()
    passages = [
        {"image_path": "https://example.com/img1.jpg", "caption": "First scene"},
        {"image_path": "/local/path/img2.png", "caption": "Second scene"},
    ]
    image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    prompt = "Please summarize all visuals."
    passage_prompt = "Here are additional context captions."
    result = wrapper.get_prediction(
        image_url, prompt, passages=passages, passage_prompt=passage_prompt
    )
    print("Response:", result)
