# import
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
from torchvision.transforms import ToTensor, Resize, CenterCrop, Normalize, Compose
import pdb

IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]

# load model
device = "cuda"
processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"

processor = AutoProcessor.from_pretrained(processor_name_or_path)
model = AutoModel.from_pretrained(model_pretrained_name_or_path).train().to(device)
custom_image_processor = [
    Resize((224, 224)),
    CenterCrop((224, 224)),
    Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
]
custom_image_processor = Compose(custom_image_processor)

def calc_probs(prompt, images):
    
    # preprocess
    image_inputs = processor(
        images=images,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)
    # image_inputs = custom_image_processor(images[0]).unsqueeze(0).to(device)
    
    text_inputs = processor(
        text=prompt,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)

    # embed
    test_placeholder = image_inputs["pixel_values"].detach().requires_grad_(True)
    # test_placeholder = image_inputs.detach().requires_grad_(True)
    print(test_placeholder.shape)

    image_embs = model.get_image_features(pixel_values=test_placeholder)
    image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)

    text_embs = model.get_text_features(**text_inputs)
    text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)

    # score
    scores = model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
    scores.backward()
    # print(test_placeholder.grad, test_placeholder.grad.shape)
    print(scores)

    return scores, image_embs, text_embs

with torch.enable_grad():
    pil_images = [ToTensor()(Image.open("test.jpg"))]
    pil_images[0] = pil_images[0].requires_grad_(True)
    prompt = "66666666666666"
    scores, image_embs, text_embs = calc_probs(prompt, pil_images)