import torch
from transformers.utils.import_utils import is_flash_attn_2_available

from colpali_engine.models import ColModernVBert, ColModernVBertProcessor

from datasets import load_dataset

MODEL_NAME = "SmolVEncoder/colvbert-modernbert_base-vidore"

model = ColModernVBert.from_pretrained(
        MODEL_NAME,
        device_map="cuda:0",
        attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
    ).to(torch.float32).eval()
processor = ColModernVBertProcessor.from_pretrained(MODEL_NAME)

# Your inputs
samples = load_dataset("vidore/arxivqa_test_subsampled", split="test").take(5)

# Process the inputs
batch_images = processor.process_images(samples["image"]).to(model.device)
batch_queries = processor.process_queries(samples["query"]).to(model.device)

# Forward pass
with torch.no_grad():
    image_embeddings = model(**batch_images)
    query_embeddings = model(**batch_queries)

scores = processor.score_multi_vector(query_embeddings, image_embeddings)

# show the results
for i, sample in enumerate(samples):
    print("-" * 50)
    print(f"Query: {sample['query']}")
    print(f"Scores: {scores[i]}")
    print()
print("Inference completed successfully.")
