from transformers import AutoProcessor, InstructBlipProcessor, InstructBlipForConditionalGeneration
from transformers.generation import GenerationConfig
import torch
import random
from PIL import Image

torch.manual_seed(1234)

# Load processor, model and tokenizer correctly
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
model = InstructBlipForConditionalGeneration.from_pretrained(
    "Salesforce/instructblip-vicuna-7b",
    device_map="cuda",
    torch_dtype=torch.float16  # Add this to ensure consistent dtype
).eval()


# The generation config (if needed)
model.generation_config = GenerationConfig.from_pretrained("Salesforce/instructblip-vicuna-7b")


# Function to generate caption
# def call_model(image_path, text_prompt):
#     # Load the image
#     image = Image.open(image_path).convert('RGB')
#
#     # Extract the actual query
#     if "The user query is:" in text_prompt:
#         query = text_prompt.split("The user query is:")[1].strip()
#     else:
#         query = text_prompt
#
#     # Create a more direct prompt
#     simplified_prompt = f"Look at the image and answer: {query}"
#
#     # Process inputs
#     inputs = processor(images=image, text=simplified_prompt, return_tensors="pt").to("cuda")
#
#     # Generate
#     outputs = model.generate(
#         **inputs,
#         do_sample=True,  # Try enabling sampling for more varied responses
#         num_beams=5,  # Increase beam search
#         min_length=30,  # Encourage longer responses
#         top_p=0.9,
#         repetition_penalty=1.5,
#         length_penalty=1.0,
#         max_new_tokens=512,
#         temperature=0.7,  # Lower temperature for more focused responses
#     )
#
#     # Decode the response
#     response = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
#     print(response)
#     hi
#     return response

def call_model(image_path, text_prompt):
    # Load the image
    image = Image.open(image_path).convert('RGB')

    # Use a unique marker that won't appear in natural text
    answer_marker = "ANSWER_BEGINS_HERE:"

    # Format the prompt
    formatted_prompt = f"{text_prompt} {answer_marker}"
    # print("formatted_prompt", formatted_prompt)

    # Process inputs
    inputs = processor(images=image, text=formatted_prompt, return_tensors="pt").to("cuda")

    # Generate
    outputs = model.generate(
        **inputs,
        do_sample=True,
        num_beams=1,
        max_new_tokens=512,
        use_cache=True,
    )

    # Decode the response
    full_response = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()

    # Extract just the answer part
    if answer_marker in full_response:
        answer_only = full_response.split(answer_marker)[1].strip()
        # print("answer_only: ", answer_only)
        # hi
        return answer_only
    else:
        # Fallback method
        # print(full_response)
        # hi
        return full_response