from transformers import AutoTokenizer, CLIPModel, AutoProcessor, AutoModel
from fashion_clip.fashion_clip import FashionCLIP
from torchvision import transforms
from PIL import Image
import torch

class CLIPModelExtractor:

    def __init__(self):
        super().__init__()
        # You can change the foundation model to other models here.
        file_path = "./clip-model/"
        # file_path = "./fashion-clip/"
        self.model =  CLIPModel.from_pretrained(file_path, local_files_only=True)
        self.tokenizer =  AutoTokenizer.from_pretrained(file_path, local_files_only=True)
        self.processor =  AutoProcessor.from_pretrained(file_path, local_files_only=True)


    @torch.no_grad()
    def encode_label(self, label_texts):
        inputs = self.tokenizer(label_texts, padding=True, return_tensors="pt")
        text_features = self.model.get_text_features(**inputs)
        return text_features


    @torch.no_grad()
    def encode_image(self, images):
        images = [transforms.ToPILImage()(image) for image in images]
        images[0].save('./test/test.png')
        inputs = self.processor(images=images, return_tensors="pt")
        return self.model.get_image_features(**inputs)