import torch
import clip
from PIL import Image
import pandas as pd 
from pathlib import Path



device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device, jit=False)
model.eval()


all_image_divs = []

model_name = 'v1.4'

ids_csv = 'prompt_ids.csv' # list of prompt ids
file = pd.read_csv(ids_csv)

all_ids = file["id"].tolist()



sd_model ="v1.5"
steps = 10



all_image_path = f'CLIP/{sd_model}_{steps}/'
k = 4


for img_id in all_ids:
    print(f'processing id: {img_id}')
    image_paths = [f'{all_image_path}{img_id}_{i}.png' for i in range(k)] #load all generated images for prompt
    check_path = Path(image_paths[0])
    assert check_path.exists()

    image_features = []

    with torch.no_grad():
        for path in image_paths:
            image = preprocess(Image.open(path).convert("RGB"))
            image = image.unsqueeze(0).to(device)

            feat = model.encode_image(image)
            feat = feat / feat.norm(dim=1, keepdim=True)  # L2 normalize (required for clip according to what i read)

            image_features.append(feat)

    image_features = torch.cat(image_features, dim=0)  


    # set it to float so that it doesn't complain 
    image_features = image_features.float()   

    # get pairwise distances between features
    dist_sq = torch.cdist(image_features, image_features, p=2) ** 2

    # sum over i < j
    clip_div = dist_sq.triu(diagonal=1).sum()
    # normalize over batch size (see equation 9 in the paper )
    clip_div = (2.0 / (k* (k - 1))) * clip_div


    all_image_divs.append(clip_div)



avg = sum(all_image_divs) / len(all_image_divs)
print(f'average over ImageReward: {avg}')