from qwen_vl_utils import process_vision_info
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from transformers.generation import GenerationConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from transformers.generation import GenerationConfig
import torch
import json
import tqdm
import random
from PIL import Image
torch.manual_seed(1234)

min_pixels = 256 * 28 * 28
max_pixels = 256 * 28 * 28
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
).eval()

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)


# Function to generate caption with grounding
# def call_model(image_path, text_prompt):
#
#     query = tokenizer.from_list_format([
#         {'image': image_path},
#         {'text': text_prompt},
#     ])
#     response, history = model.chat(tokenizer, query=query, history=None)
#     return response
#


def call_model(image_path, text_prompt):
    conversation = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": Image.open(image_path),
                },
                {"type": "text", "text": text_prompt},
            ],
        }
    ]
    prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    prompt += "\nassistant:"
    real_images, real_videos = process_vision_info(conversation)
    inputs = processor(text=prompt, images=real_images, return_tensors="pt").to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
        pad_token_id=processor.tokenizer.pad_token_id,
    )

    full_output = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    print(full_output)
    # Extract just the assistant's response
    keywords = ["ASSISTANT:", "ANSWER_BEGINS_HERE:", "assistant:", "assistant"]
    for key in keywords:
        if key in full_output:
            final_output = full_output.split(key, 1)[-1].strip()
            break
    else:
        final_output = full_output

    print("\nFinal Output:", final_output)

    return final_output