from PIL import Image
from typing import Optional
import torch
import timm
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms

class ModelZoo:
    def transform(self, image: Image):
        pass

    def transform_tensor(self, image_tensor: torch.tensor):
        pass

    def calculate_loss(self, output: torch.tensor, target_images: Optional[torch.tensor]):
        pass

    def get_probability(self, output: torch.tensor, target_images: Optional[torch.tensor]):
        pass


class CLIPImageSimilarity(ModelZoo):
    def __init__(self):
        # initialize classifier
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    def transform(self, image: Image):
        images_processed = self.clip_processor(images=image, return_tensors="pt")['pixel_values'].cuda()
        return images_processed

    def transform_tensor(self, image_tensor: torch.tensor):
        image_tensor = torch.nn.functional.interpolate(image_tensor, size=(224, 224), mode='bicubic',
                                                           align_corners=False)
        normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                         std=[0.26862954, 0.26130258, 0.27577711])
        image_tensor = normalize(image_tensor)
        return image_tensor

    def calculate_loss(self, output: torch.tensor, target_images: Optional[torch.tensor]):
        # calculate CLIP loss
        output = self.clip_model.get_image_features(output)
        # loss = -torch.cosine_similarity(output, input_clip_embedding, axis=1)

        mean_target_image = target_images.mean(dim=0).reshape(1, -1)
        loss = torch.mean(torch.cosine_similarity(output[None], mean_target_image[:, None], axis=2),
                          axis=1)
        loss = 1 - loss.mean()
        return loss

    def get_probability(self, output: torch.tensor, target_images: Optional[torch.tensor]):
        output = self.clip_model.get_image_features(output)
        mean_target_image = target_images.mean(dim=0).reshape(1, -1)
        loss = torch.mean(torch.cosine_similarity(output[None], mean_target_image, axis=2),
                          axis=1)
        return loss.mean()

