#%%
import torch
from transformers import CLIPModel, CLIPProcessor
from torch.utils.data import Dataset
from datasets import load_dataset
from PIL import Image
import numpy as np
from tqdm import tqdm
from region_clip_fusion import RegionCLIPDataset, GatedCrossAttentionFusion
import torch.nn as nn

#%%
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Load checkpoint
ckpt = torch.load("finetuned_models/finetuned_clip_region/clip_fusion_joint_10fusion_15joint.pth", map_location=device)

# Rebuild CLIP model
clip_model = CLIPModel.from_pretrained("pretrained_frameworks/clip-vit-base-patch16")
clip_model.load_state_dict(ckpt["clip_model_state_dict"])
clip_model.to(device).eval()

# Rebuild Fusion module
fusion_config = ckpt.get("fusion_config", {"embed_dim": clip_model.config.projection_dim, "num_heads": 8})
fusion = GatedCrossAttentionFusion(
    embed_dim=fusion_config["embed_dim"],
    num_heads=fusion_config["num_heads"]
) 
fusion.load_state_dict(ckpt["fusion_state_dict"])
fusion.to(device).eval()

# clip_model = CLIPModel.from_pretrained("pretrained_frameworks/clip-vit-base-patch16")
# clip_model.to(device).eval()

# fusion = GatedCrossAttentionFusion(
#     embed_dim=clip_model.config.projection_dim,
#     num_heads=8
# )
# fusion.load_state_dict(torch.load("fusion_phase1.pth", map_location=device))
# fusion.to(device).eval()

# # Load processor
processor = CLIPProcessor.from_pretrained("pretrained_frameworks/clip-vit-base-patch16")

#%%
hf_dataset = load_dataset("downloaded_datatset/HumanEdit", split="train")
test_dataset = RegionCLIPDataset(hf_dataset)

# ----- TEST FIRST 10 -----
with torch.no_grad():
    acc = 0     
    for i in tqdm(range(1000)):
        full_img, region, text = test_dataset[i]

        inputs_region = processor(images=region, return_tensors="pt", padding=True).to(device)
        inputs_full   = processor(images=full_img, return_tensors="pt", padding=True).to(device)
        inputs_text   = processor(text=[text], return_tensors="pt", padding=True).to(device)

        E_r = clip_model.get_image_features(**inputs_region)
        E_f = clip_model.get_image_features(**inputs_full)
        E_t = clip_model.get_text_features(**inputs_text)

        E_rf = fusion(E_r, E_f)
        E_rf = nn.functional.normalize(E_rf, dim=-1)
        E_t = nn.functional.normalize(E_t, dim=-1)

        # cosine similarity
        sim = torch.nn.functional.cosine_similarity(E_rf, E_t)
        sim_01 = (sim + 1) / 2
        acc += sim_01.item()
    print(f"Average cosine similarity (region-text): {acc/1000}")

# %%
