from transformers import AutoTokenizer, AutoModel
import torch
 
class EmbeddingModel:
    def __init__(self, model_name) -> None:
        self.model = model_name
        self.device = torch.device("cuda:1")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
    
    def encode(self, input):
        if len(input) == 0:
            return torch.zeros(768, device=self.device)
        with torch.no_grad():
            inputs = self.tokenizer(input, return_tensors='pt', truncation=True, padding=True).to(self.device)
            outputs = self.model(**inputs)
        return self.mean_pooling(outputs, inputs['attention_mask'])
    
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)