from abc import abstractmethod
from typing import List
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image

from ..model import ScoreModel
from ...constants import HF_CACHE_DIR

PICKSCORE_MODELS = ['pickscore-v1']

class PickScoreModel(ScoreModel):
    "A wrapper for PickScore models"
    def __init__(self,
                 model_name='pickscore-v1',
                 device='cuda',
                 cache_dir=HF_CACHE_DIR):
        assert model_name in PICKSCORE_MODELS
        super().__init__(model_name=model_name,
                         device=device,
                         cache_dir=cache_dir)
    
    def load_model(self):
        """Load the model, tokenizer, image transform
        """
        assert self.model_name == 'pickscore-v1'
        processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
        model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"

        self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
        self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)

    def load_images(self,
                    image: List[str]) -> torch.Tensor:
        """Load the image(s), and return a tensor (no preprocessing!!) put on self.device
        """
        image = [self.image_loader(x) for x in image]
        image = self.processor(images=image, padding=True, truncation=True, max_length=77, return_tensors="pt").to(self.device)
        # image = torch.stack(image, dim=0).to(self.device)
        return image
    
    @torch.no_grad()
    def forward(self,
                images: List[str],
                texts: List[str]) -> torch.Tensor:
        """Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
        """
        assert len(images) == len(texts)
        image = self.load_images(images)
        text_inputs = self.processor(
            text=texts,
            padding=True,
            truncation=True,
            max_length=77,
            return_tensors="pt",
        ).to(self.device)
        
        # embed
        image_embs = self.model.get_image_features(**image)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
    
        text_embs = self.model.get_text_features(**text_inputs)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
    
        # score
        scores = (image_embs * text_embs).sum(dim=-1)
        return scores
