#%%
import torch
import open_clip
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as TF
from transformers import BlipProcessor, BlipForImageTextRetrieval

# %%
model_path = "pretrained_frameworks/blip-itm-base-coco"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = BlipProcessor.from_pretrained(model_path)
model = BlipForImageTextRetrieval.from_pretrained(model_path)

save_path = "finetuned_models/blip_finetuned/blip_finetuned_20.pth"
model.load_state_dict(torch.load(save_path, map_location=device))
model.to(device)

model = model.to(device)
model.eval()

#%%

images = [Image.open("test_imgs/image_org.jpg").convert("RGB"),
          Image.open("test_imgs/image_org.jpg").convert("RGB")]
image_inputs = processor(images=images, return_tensors="pt").to(device)
with torch.no_grad():
    image_features = model.vision_model(**image_inputs).last_hidden_state[:, 0, :]
    
texts = ["The image depicts a serene outdoor setting with a person sitting alone on a wooden bench facing a calm body of water. The individual has long, reddish-brown hair and is wearing a light-colored top and dark pants. The bench is positioned on a concrete slab, surrounded by patches of dry grass and a pebbled shoreline. The water appears tranquil, reflecting the clear sky above. In the distance, a seagull is visible, adding a touch of wildlife to the scene. The overall atmosphere is peaceful and contemplative, with no other people or animals in sight.",
         "The image depicts a close-up view of a computer screen displaying a 3D model or rendering. A person's hand is pointing at the screen, which shows a geometric shape with orange and blue hues. The background is blurred, but it appears to be an indoor setting with a desk or table. The screen displays a list of items on the left side, possibly related to the 3D model, and a 3D view of the model in the center. The model features a triangular shape with a white base and orange edges, suggesting a simplified representation of a structure or object. The overall mood is focused and analytical, as if someone is examining or working with the 3D model."]
text_inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
    text_features = model.text_encoder(**text_inputs).last_hidden_state[:, 0, :]

# %%
cosine_similarities = torch.nn.functional.cosine_similarity(image_features, text_features, dim=1)
scaled_similarities = (cosine_similarities + 1) / 2  # scale to [0, 1]

# Print results
for i, sim in enumerate(scaled_similarities):
    print(f"[{i}] Scaled similarity (0-1): {sim.item():.4f}")

# %%
