from transformers import CLIPProcessor, CLIPModel  
import torch  
from PIL import Image  
import numpy as np

class CLIPColorDetector:
    def __init__(self, clip_path, device):
        self.model = CLIPModel.from_pretrained(clip_path).to(device)
        self.model.eval()
        self.processor = CLIPProcessor.from_pretrained(clip_path)
        self.device = device

    def set_image_path(self, image_path):
        self.image_path = image_path
        self.image = Image.open(self.image_path).convert('RGB')
        self.image_blank = Image.new("RGB", self.image.size, color='#999')
    
    def check_color_list(self, bounding_box, object_item, color_list, seg_mask):
        image = np.array(self.image)
        shape = image.shape
        seg_mask = seg_mask.reshape(shape[0], shape[1], 1)
        image = np.where(seg_mask, image, 255)
        image = Image.fromarray(image)

        image = image.crop(bounding_box[:4])
        texts = []
        for color in color_list:
            texts.extend([f"The color of {object_item} in this photo is {color}.", 
                    f"The {object_item} in this photo is {color}-colored."])
        inputs = self.processor(text=texts, images=image, return_tensors="pt", padding=True)  
        for k in inputs:
            inputs[k] = inputs[k].to(self.device)
        with torch.no_grad():  
            outputs = self.model(**inputs)  
        logits_per_image = outputs.logits_per_image  # 图像相对于文本的相似度  
        return [color_list[logits_per_image.argmax().item() // 2]]
