import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel
from typing import Union


class AbstractEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def encode(self, *args, **kwargs):
        raise NotImplementedError


class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self,
                 tokenizer_version="openai/clip-vit-large-patch14",
                 text_model_version="openai/clip-vit-large-patch14",
                 device="cuda",
                 max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version)
        self.transformer = CLIPTextModel.from_pretrained(text_model_version)
        self.device = device
        self.max_length = max_length
        self.freeze()
        self.transformer.to(device)

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)


class CLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.transformer.eval().to(device)

    def tokenize(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        return tokens

    def forward(self, tokens):
        outputs = self.transformer(input_ids=tokens)
        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)
    
    
class CLIPEmbedderWithProject(AbstractEncoder):
    def __init__(self,
                 clip_embedder: Union[FrozenCLIPEmbedder, CLIPEmbedder],
                 embed_dim=768,
                 project_dim=1280):
        super().__init__()
        self.clip_embedder = clip_embedder
        self.project = nn.Linear(embed_dim, project_dim)
        self.project.to(clip_embedder.device)
        
    def forward(self, text):
        z = self.clip_embedder(text)
        z_project = self.project(z)
        return z_project

    def encode(self, text):
        return self(text)
