# this file implements a simple inference class for pick score

from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
from torchvision.transforms import ToTensor, Resize, CenterCrop, Normalize, Compose
import pdb

class PickScore:
    def __init__(
        self,
        processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
        model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1",
        device = "cuda",
    ):
        self.device = device
        self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
        self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).train().to(device)
        self.custom_image_processor = [
            Resize((224, 224)),
            CenterCrop((224, 224)),
            Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
        ]
        self.custom_image_processor = Compose(self.custom_image_processor)

    def score(self, img, prompt):
        if isinstance(img, str):
            img = Image.open(img).convert("RGB")
            img = ToTensor()(img)
        elif isinstance(img, torch.Tensor):
            pass
        else:
            raise ValueError("img should be str or torch.Tensor")

        # preprocess
        image_inputs = self.custom_image_processor(img).unsqueeze(0).to(self.device)
        if len(image_inputs.shape) == 5:
            image_inputs = image_inputs.squeeze(0)
        text_inputs = self.processor(
            text=prompt,
            padding=True,
            truncation=True,
            max_length=77,
            return_tensors="pt",
        ).to(self.device)

        # embed
        image_embs = self.model.get_image_features(pixel_values=image_inputs)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)

        text_embs = self.model.get_text_features(**text_inputs)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)

        # score
        scores = self.model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
        return scores