import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

class BERTEmbedding:
    def __init__(self, model_name="bert-base-uncased"):
        """
        Initialize a BERT embedding model
        
        Args:
            model_name: Name of BERT model to use (default: bert-base-uncased)
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        try:
            self.model = AutoModel.from_pretrained(model_name).to('cuda').eval()
        except RuntimeError as e:
            print(f"[WARNING] Failed to load BERT on CUDA: {e}")
            self.model = AutoModel.from_pretrained(model_name).to('cpu').eval()

        self.max_length = 512  # Standard BERT sequence length
        
    def __call__(self, input_texts: list[str]):
        """
        Generate embeddings for a list of texts
        
        Args:
            input_texts: List of strings to embed
            
        Returns:
            Tensor of embeddings, shape [batch_size, embedding_dim]
        """
        batch_dict = self.tokenizer(
            input_texts, 
            max_length=self.max_length, 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        )
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        
        with torch.no_grad():
            outputs = self.model(**batch_dict)
        
        # Use CLS token embedding (first token)
        embeddings = outputs.last_hidden_state[:, 0]
        
        # Normalize embeddings to unit length
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings.data.cpu()

class SFR2Embedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-2_R')
        self.model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-2_R', device_map="auto", torch_dtype=torch.bfloat16).eval()
        self.max_length = 4096

    def __call__(self, input_texts: list[str]):
        batch_dict = self.tokenizer(input_texts, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)

        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()

class SFREmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
        self.model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral', device_map="auto", torch_dtype=torch.bfloat16).eval()
        self.max_length = 4096

    def __call__(self, input_texts: list[str]):
        batch_dict = self.tokenizer(input_texts, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)

        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
    
class BGELargeEmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5')
        self.model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5').to('cuda').eval()
        
    def __call__(self, input_texts: list[str]):
        batch_dict = self.tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)
        
        embeddings = outputs[0][:, 0]
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
    
class BGESmallEmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5')
        self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5').to('cuda').eval()
        
    def __call__(self, input_texts: list[str]):
        batch_dict = self.tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)
        
        embeddings = outputs[0][:, 0]
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
        
class E5Embedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
        self.model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct').to('cuda').eval()
    
    def __call__(self, input_texts: list[str]):
        batch_dict = self.tokenizer(input_texts, max_length=4096, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
        outputs = self.model(**batch_dict)
        
        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings.data.cpu()
    
def get_embedding_model(model_name):
    """
    Factory function to get embedding model by name
    
    Args:
        model_name: Name of embedding model to use
        
    Returns:
        Embedding model instance
    """
    if model_name == "sfr2":
        return SFR2Embedding()
    elif model_name == "sfr":
        return SFREmbedding()
    elif model_name == "bge-large":
        return BGELargeEmbedding()
    elif model_name == "bge-small":
        return BGESmallEmbedding()
    elif model_name == "e5":
        return E5Embedding()
    elif model_name.startswith("bert"):
        return BERTEmbedding(model_name)
    else:
        available_models = ["sfr2", "sfr", "bge-large", "bge-small", "e5", "bert-*"]
        raise ValueError(f"Unknown embedding model: {model_name}. Available models: {available_models}")