from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np 


class SBERTSelector:
    def __init__(self) -> None:
        # load model from HF Hub
        self.tokenizer = AutoTokenizer.from_pretrained(
            'sentence-transformers/bert-base-nli-mean-tokens',
            model_max_length=128, 
            padding=True, 
            return_tensors='pt',
            truncation=True,
            do_lower_case=True
            )
        self.model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
    
    @staticmethod
    def mean_pooling(model_output, attention_mask):
        " Mean Pooling that takes the attention mask into account for correct averaging "
        token_embeddings = model_output[0] # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def choose_best_caption(self, sentences):

        # # sentences we want sentence embeddings for
        # sentences = [
        #     "A restaurant has modern wooden tables and chairs.", 
        #     "A long restaurant table with rattan rounded back chairs.", 
        #     "A long table with a plant on top of it surrounded with wooden chairs ", 
        #     "A long table with a flower arrangement in the middle for meetings", 
        #     "A table is adorned with wooden chairs with blue accents."
        # ]

        # tokenize sentences
        encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

        # compute token embeddings
        with torch.no_grad():
            model_output = self.model(**encoded_input)

        # perform max. pooling
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = np.array([x.detach().cpu().numpy() for x in sentence_embeddings])

        # find sentence with highest average pair-wise cosine similarity 
        cos_sims = cosine_similarity(sentence_embeddings)
        best_sent_idx = np.argmax(np.mean(cos_sims, axis=0))
        # print(best_sent_idx, sentences[best_sent_idx])
        
        return best_sent_idx, sentences[best_sent_idx]
        