from typing import Any


import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Load LLaVA model
model_id = "llava-hf/llava-1.5-7b-hf"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
).to(device)

processor = AutoProcessor.from_pretrained(model_id)

def generate_caption(image_tensor) -> Any:
    """
    Generate caption for an image using LLaVA.
    Args:
        image_tensor: A single image tensor (C, H, W) in PyTorch format.
    Returns:
        str: Generated caption.
    """
    image_pil = Image.fromarray((image_tensor.cpu().numpy().transpose(1, 2, 0) * 255).astype('uint8'))
    conversation = [
        {"role": "user", "content": [{"type": "text", "text": "Concise description, in under 20 words, highlighting its unique visual characteristic."}, {"type": "image"}]},
    ]
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    inputs = processor(images=image_pil, text=prompt, return_tensors="pt").to(device, torch.float16)
    output = model.generate(**inputs, max_new_tokens=77, do_sample=False)
    caption = processor.decode(output[0][2:], skip_special_tokens=True)

    return caption.strip()  # Return clean caption
