import torch
from PIL import Image
from transformers import (
    LlavaNextProcessor, LlavaNextForConditionalGeneration,
    AutoProcessor, AutoModelForVision2Seq,
    Idefics2Processor, Idefics2ForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, GenerationConfig,
    InstructBlipProcessor, InstructBlipForConditionalGeneration, IdeficsForVisionText2Text,
    Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
)

def load_model(args):
    model_type = args.model_type.lower()
    model_path = args.model_path
    dtype = torch.float16

    if model_type == "llava":
        print(f"Loading LLaVA model from {model_path} ...")
        model = LlavaNextForConditionalGeneration.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=dtype,
        )
        processor = LlavaNextProcessor.from_pretrained(model_path)
        tokenizer = processor.tokenizer  # explicit access
        return model, processor, tokenizer

    elif model_type == "qwen":
        print(f"Loading Qwen-VL model from {model_path} ...")
        min_pixels = 256 * 28 * 28
        max_pixels = 256 * 28 * 28
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        ).eval()

        tokenizer = AutoTokenizer.from_pretrained(model_path)


        processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
        return model, processor, tokenizer

    elif model_type == "instructionblip":
        processor = InstructBlipProcessor.from_pretrained(model_path)
        model = InstructBlipForConditionalGeneration.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=dtype,  # Add this to ensure consistent dtype
        ).eval()
        tokenizer = processor.tokenizer
        return model, processor, tokenizer

    elif model_type == "idefics":
        print(f"Loading IDEFICS model from {model_path} ...")
        processor = AutoProcessor.from_pretrained(model_path)
        model = IdeficsForVisionText2Text.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        ).eval()

        model.generation_config = GenerationConfig.from_pretrained(model_path)
        tokenizer = processor.tokenizer
        return model, processor, tokenizer

    else:
        raise ValueError(f"Unsupported model type: {model_type}. Supported: llava, idefics2, qwen")


def format_conversation(user_prompt, processor, tokenizer, image_path, model_type="llava"):
    """
    Generate a conversation dictionary suitable for the given model type.

    Parameters:
        user_prompt (str): The user query.
        model_type (str): One of ["llava", "idefics2", "qwen"]

    Returns:
        list: A list of message dicts in the conversation format expected by the model.
    """
    model_type = model_type.lower()

    if model_type == "llava":
        return [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_prompt,
                    },
                    {"type": "image"},
                ],
            }
        ]
    elif model_type == "qwen":
        conversation = tokenizer.from_list_format([
            {'image': image_path},
            {'text': user_prompt},
        ])
        return conversation
    elif model_type == "deepseek":
        conversation = [
            {
                "role": "User",
                "content": f"<image_placeholder>{user_prompt}",
                "images": [image_path],
            },
            {"role": "Assistant", "content": ""},
        ]
        return conversation
    elif model_type == "idefics2":
        return [
            {"role": "user", "content": "<image>" + "\n" + user_prompt}
        ]

    else:
        raise ValueError(f"Unsupported model type: {model_type}")


def call_model_wrapper(model, processor, tokenizer, model_type):
    def call_model(image_path, user_prompt):
        # Load and preprocess image
        image = Image.open(image_path)

        # Get formatted conversation
        conversation = format_conversation(user_prompt, model_type=model_type)

        if model_type == "llava":
            # Extract the text from conversation
            prompt = conversation[0]["content"][0]["text"]
            inputs = processor(prompt, image, return_tensors="pt").to(model.device, torch.float16)
            output = model.generate(**inputs, max_new_tokens=512)
            return tokenizer.decode(output[0], skip_special_tokens=True).strip()

        elif model_type == "idefics2":
            prompt = conversation[0]["content"]  # e.g., "<image>\nUser question"
            inputs = processor(prompt=prompt, images=image, return_tensors="pt").to(model.device, torch.float16)
            output = model.generate(**inputs, max_new_tokens=512)
            return tokenizer.decode(output[0], skip_special_tokens=True).strip()

        elif model_type == "qwen":
            # Qwen accepts structured `content`
            inputs = processor(conversation=conversation, images=[image], return_tensors="pt").to(model.device, torch.float16)
            output = model.generate(**inputs, max_new_tokens=512)
            return tokenizer.decode(output[0], skip_special_tokens=True).strip()

        else:
            raise ValueError(f"Unsupported model type: {model_type}")

    return call_model
