import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from typing import List
from torch import Tensor
import itertools

def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
    
class SFR2Embedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-2_R')
        self.model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-2_R', device_map="auto", torch_dtype=torch.bfloat16).eval()
        self.max_length = 4096
        
    def check_truncation(self, input_text: str):
        return len(self.tokenizer.tokenize(input_text)) > self.max_length

    def __call__(self, input_texts: List[str]):
        batch_dict = self.tokenizer(input_texts, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)

        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
    
class SFREmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
        self.model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral', device_map="auto", torch_dtype=torch.bfloat16).eval()
        self.max_length = 4096

    def __call__(self, input_texts: List[str]):
        batch_dict = self.tokenizer(input_texts, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)

        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
    
class BGELargeEmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5')
        self.model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5', device_map="auto", torch_dtype=torch.bfloat16).eval()
        
    def __call__(self, input_texts: List[str]):
        batch_dict = self.tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)
        
        embeddings = outputs[0][:, 0]
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
    
class BGESmallEmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5')
        self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', device_map="auto", torch_dtype=torch.bfloat16).eval()
        
    def __call__(self, input_texts: List[str]):
        batch_dict = self.tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)
        
        embeddings = outputs[0][:, 0]
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
        
class E5Embedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
        self.model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct', device_map="auto", torch_dtype=torch.bfloat16).eval()
    
    def __call__(self, input_texts: List[str]):
        batch_dict = self.tokenizer(input_texts, max_length=4096, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)
        
        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()

class BERTEmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.model = AutoModel.from_pretrained("bert-base-uncased", device_map="auto", torch_dtype=torch.bfloat16).eval()
        
    def __call__(self, input_texts: List[str]):
        batch_dict = self.tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)
        
        embeddings = outputs.last_hidden_state[:, 0]
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
        
def batched(iterable, n):
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(itertools.islice(it, n)):
        yield batch

MODELS = {
    "sfr2": SFR2Embedding,
    "sfr": SFREmbedding,
    "bge-large": BGELargeEmbedding,
    "bge-small": BGESmallEmbedding,
    "e5": E5Embedding,
    "bert": BERTEmbedding,
}
    
def get_embeddings(model, texts: List[str], batch_size):
    embeddings = []
    for batch_text in batched(texts, batch_size):
        embeddings.append(model(batch_text))
    
    return torch.cat(embeddings, dim=0)