from PIL import Image
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, CLIPTextModel
import torch
import numpy as np
import pandas as pd
from PIL import Image
from utils import CLIPTextWrapper
import os

spawrious_path = 'spawrious_2/spawrious224/1/'
num_per_background = 1000

for model_ID in [#"openai/clip-vit-base-patch16", 
     "openai/clip-vit-large-patch14", 
     "openai/clip-vit-base-patch32", 
     "openai/clip-vit-large-patch14-336",
     "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"]:

    model = CLIPModel.from_pretrained(model_ID)
    processor = CLIPProcessor.from_pretrained(model_ID)
    wrapped_clip = CLIPTextWrapper(model_ID)

    img_embeds = []
    background_label = []
    breed_label = []
    file_name = []
    for background in ["jungle", "mountain", "snow", "desert"]:
        for breed in ["bulldog", "corgi", "dachshund", "labrador"]:
            try:
                print(f"{background} + {breed}")
                print()
                bs = 250
                current_im_indx = 0 
                
                while current_im_indx < num_per_background:
                    image_files = []
                    for im_indx in range(current_im_indx, np.min([current_im_indx+bs, num_per_background])):
                        img = Image.open(spawrious_path + background + "/" + breed + "/" + f"{background}_{breed}_{im_indx}.png")
                        img.load()
                        image_files.append(img)
                        background_label.append(background)
                        breed_label.append(breed)
                        file_name.append(f"{background}_{breed}_{im_indx}.png")
                    current_im_indx += bs

                    img_embeds.append(wrapped_clip.get_joint_image_embed(image_files))


                    print(f"{current_im_indx}/{num_per_background}")
                print()
            except:
                print(f"{background} + {breed} combo doesn't exist")
                
    all_img_embeds = torch.concatenate(img_embeds)
    embed_arrays = list(all_img_embeds[i].numpy() for i in range(all_img_embeds.shape[0]))
    df = pd.DataFrame({"img_embed": embed_arrays, "background": background_label, 'breed': breed_label, "filename": file_name})
    # if not os.path.exists(model_ID.split("/")[-1]):
    #     os.makedirs(model_ID.split("/")[-1])
    df.to_csv(f"spawrious_embed_csvs/{model_ID.split('/')[-1]}.csv") #write embeddings to disk