import pandas as pd
import torch 
from pathlib import Path
from PIL import Image
import clip


output_csv = "same_budget_clip.csv"
rows = []
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device, jit=False)
model.eval()



output_csv = "same_budget_clip.csv"

all_image_divs = []

model_name = 'v1.4'

ids_csv = 'complete_prompts.csv' # list of prompt ids
file = pd.read_csv(ids_csv)

all_ids = file["id"].tolist()

all_image_path = 'budget_4650_v1.4/'
k = 4

for img_id in all_ids:
    print(f'processing id: {img_id}')
    image_paths = [f'{all_image_path}{img_id}_img_{i}_k_4_chains_3.jpg' for i in range(k)]
    check_path = Path(image_paths[0])

    if not check_path.exists():
        image_paths = [f'{all_image_path}{img_id}_img_{i}_k_4_chains_3.jpg' for i in range(k)]

    image_features = []

    with torch.no_grad():
        for path in image_paths:
            image = preprocess(Image.open(path).convert("RGB"))
            image = image.unsqueeze(0).to(device)

            feat = model.encode_image(image)
            feat = feat / feat.norm(dim=1, keepdim=True)
            image_features.append(feat)

    image_features = torch.cat(image_features, dim=0).float()

    dist_sq = torch.cdist(image_features, image_features, p=2) ** 2
    clip_div = dist_sq.triu(diagonal=1).sum()
    clip_div = (2.0 / (k * (k - 1))) * clip_div

    rows.append({
        "img_id": img_id,
        "clip_div": clip_div.item()
    })

# write once at the end
df = pd.DataFrame(rows)
df.to_csv(output_csv, index=False)


# import torch
# import clip
# from PIL import Image
# import pandas as pd 
# from pathlib import Path



#


# for img_id in all_ids:
#     print(f'processing id: {img_id}')
#     image_paths = [f'{all_image_path}{img_id}_img_{i}_k_4_chains_3.jpg' for i in range(k)] #load all generated images for prompt
#     check_path = Path(image_paths[0])
    
#     if check_path.exists() == False: 
#         image_paths = [f'{all_image_path}{img_id}_img_{i}_k_4_chains_3.jpg' for i in range(k)] #load all generated images for prompt

#     image_features = []

#     with torch.no_grad():
#         for path in image_paths:
#             image = preprocess(Image.open(path).convert("RGB"))
#             image = image.unsqueeze(0).to(device)

#             feat = model.encode_image(image)
#             feat = feat / feat.norm(dim=1, keepdim=True)  # L2 normalize (required for clip according to what i read)

#             image_features.append(feat)

#     image_features = torch.cat(image_features, dim=0)  


#     # set it to float so that it doesn't complain 
#     image_features = image_features.float()   

#     # get pairwise distances between features
#     dist_sq = torch.cdist(image_features, image_features, p=2) ** 2

#     # sum over i < j
#     clip_div = dist_sq.triu(diagonal=1).sum()
#     # normalize over batch size (see equation 9 in the paper )
#     clip_div = (2.0 / (k* (k - 1))) * clip_div


    


avg = sum(all_image_divs) / len(all_image_divs)
print(f'average over ImageReward: {avg}')