import os
from PIL import Image
import torch
import json
from transformers import CLIPProcessor, CLIPModel
import torch
import numpy as np
import pandas as pd
from utils import CLIPTextWrapper


im_dir = '' #relative path to fairface dataset
train_dir = im_dir + 'train/' 
val_dir = im_dir + 'val/'

val_df = pd.read_csv(im_dir+'fairface_label_val.csv')
image_pths = list(val_df['file'])


for model_ID in ["openai/clip-vit-large-patch14", 
     "openai/clip-vit-base-patch32", 
     "openai/clip-vit-large-patch14-336",
     "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"]:

    val_df = pd.read_csv(im_dir+'fairface_label_val.csv')
    image_pths = list(val_df['file'])

    model = CLIPModel.from_pretrained(model_ID)
    processor = CLIPProcessor.from_pretrained(model_ID)

    wrapped_clip = CLIPTextWrapper(model_ID)

    bs = 256
    current_im_indx = 0 
    img_embeds = []
    while current_im_indx < len(image_pths):
        image_files = []
        for im_indx in range(current_im_indx, np.min([current_im_indx+bs, len(image_pths)])):
            img = Image.open(im_dir + image_pths[im_indx])
            img.load()
            image_files.append(img)
        current_im_indx += bs

        img_embeds.append(wrapped_clip.get_joint_image_embed(image_files))

        print(f"{current_im_indx}/{len(image_pths)}")
        
    all_img_embeds = torch.concatenate(img_embeds)

    val_df['img_embed'] = list(all_img_embeds[i].numpy() for i in range(all_img_embeds.shape[0]))
    val_df.to_csv(f"ff_embed_csvs/{model_ID.split('/')[-1]}_val.csv") 
