# import torch
# from transformers import AutoModel, AutoTokenizer, AutoConfig

# def get_tokenizer(model_name: str):
#     """
#     Get the tokenizer for a given transformer model.
    
#     Args:
#         model_name: Name of the transformer model (e.g., 'bert-base-uncased')
        
#     Returns:
#         A tokenizer instance
#     """
#     tokenizer = AutoTokenizer.from_pretrained(model_name)
    
#     # Add sentence separator token if not already in tokenizer
#     if not hasattr(tokenizer, 'sentence_separator_id'):
#         # Use SEP token for BERT-like models, or EOS for other models
#         if hasattr(tokenizer, 'sep_token_id') and tokenizer.sep_token_id is not None:
#             tokenizer.sentence_separator_id = tokenizer.sep_token_id
#         elif hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
#             tokenizer.sentence_separator_id = tokenizer.eos_token_id
#         else:
#             # Add a custom separator token if necessary
#             tokenizer.add_special_tokens({'additional_special_tokens': ['<sep>']})
#             tokenizer.sentence_separator_id = tokenizer.convert_tokens_to_ids('<sep>')
    
#     return tokenizer

# def get_model(model_name: str):
#     """
#     Get a transformer model with custom getter for the embedding layer.
    
#     Args:
#         model_name: Name of the transformer model (e.g., 'bert-base-uncased')
        
#     Returns:
#         A transformer model with added get_embedding_layer method
#     """
#     model = AutoModel.from_pretrained(model_name)
    
#     # Add a method to get the embedding layer
#     # Different models have different attribute names for embeddings
#     def get_embedding_layer(self):
#         if hasattr(self, 'embeddings'):
#             return self.embeddings.word_embeddings
#         elif hasattr(self, 'wte'):
#             return self.wte
#         elif hasattr(self, 'model') and hasattr(self.model, 'embeddings'):
#             return self.model.embeddings.word_embeddings
#         else:
#             # Fall back to the first layer of the model
#             for module in self.modules():
#                 if isinstance(module, torch.nn.Embedding):
#                     if module.weight.shape[0] >= 1000:  # Likely the token embedding layer
#                         return module
#             raise AttributeError(f"Could not find embedding layer for model {model_name}")
    
#     # Add the method to the model
#     model.get_embedding_layer = get_embedding_layer.__get__(model, type(model))
    
#     return model

# def get_model_config(model_name: str):
#     """
#     Get the configuration of a transformer model.
    
#     Args:
#         model_name: Name of the transformer model (e.g., 'bert-base-uncased')
        
#     Returns:
#         Model configuration
#     """
#     config = AutoConfig.from_pretrained(model_name)
#     return config

# def get_model_hidden_size(model_name: str):
#     """
#     Get the hidden size dimension of a transformer model.
    
#     Args:
#         model_name: Name of the transformer model (e.g., 'bert-base-uncased')
        
#     Returns:
#         Hidden size dimension
#     """
#     config = get_model_config(model_name)
#     if hasattr(config, 'hidden_size'):
#         return config.hidden_size
#     elif hasattr(config, 'd_model'):
#         return config.d_model
#     else:
#         raise AttributeError(f"Could not determine hidden size for model {model_name}")

# def create_attention_mask(input_ids):
#     """
#     Create an attention mask from input IDs.
    
#     Args:
#         input_ids: Tensor of token IDs
        
#     Returns:
#         Attention mask tensor (1 for real tokens, 0 for padding)
#     """
#     return (input_ids != 0).float()

# def pad_sequence(sequences, max_len=None, padding_value=0):
#     """
#     Pad a list of sequences to the same length.
    
#     Args:
#         sequences: List of tensors to pad
#         max_len: Maximum length to pad to (if None, use length of longest sequence)
#         padding_value: Value to use for padding
        
#     Returns:
#         Padded tensor
#     """
#     if max_len is None:
#         max_len = max(seq.size(0) for seq in sequences)
    
#     padded_sequences = []
#     for seq in sequences:
#         if seq.size(0) < max_len:
#             padding = torch.full((max_len - seq.size(0), *seq.size()[1:]), 
#                                  padding_value, 
#                                  dtype=seq.dtype,
#                                  device=seq.device)
#             padded_seq = torch.cat([seq, padding], dim=0)
#         else:
#             padded_seq = seq[:max_len]
#         padded_sequences.append(padded_seq)
    
#     return torch.stack(padded_sequences)
import logging
import torch
from transformers import AutoTokenizer, AutoModel

logger = logging.getLogger(__name__)

def get_tokenizer(model_name):
    """
    Get a tokenizer from a model name with better error handling
    """
    try:
        # Check if model_name is a valid HuggingFace model name
        # Default to bert-base-uncased if it's not a valid model name
        if model_name in ['universal-inference-machine', 'uim'] or not model_name:
            logger.warning(f"Invalid model name for tokenizer: {model_name}. Using bert-base-uncased instead.")
            model_name = "bert-base-uncased"
            
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Add sentence separator token if it doesn't exist
        if not hasattr(tokenizer, 'sentence_separator_id'):
            tokenizer.sentence_separator_id = tokenizer.sep_token_id or tokenizer.eos_token_id
            
        return tokenizer
    except Exception as e:
        logger.error(f"Error loading tokenizer for {model_name}: {e}")
        # Fallback to bert-base-uncased
        logger.info("Falling back to bert-base-uncased tokenizer")
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        tokenizer.sentence_separator_id = tokenizer.sep_token_id
        return tokenizer

def get_model(model_name):
    """
    Get a model from a model name with better error handling
    """
    try:
        # Check if model_name is a valid HuggingFace model name
        # Default to bert-base-uncased if it's not a valid model name
        if model_name in ['universal-inference-machine', 'uim'] or not model_name:
            logger.warning(f"Invalid model name for model: {model_name}. Using bert-base-uncased instead.")
            model_name = "bert-base-uncased"
            
        model = AutoModel.from_pretrained(model_name)
        
        # Patch the model to add get_embedding_layer method if it doesn't exist
        if not hasattr(model, 'get_embedding_layer'):
            def get_embedding_layer(self):
                # Most common embedding layer names in transformers
                for attr_name in ['embeddings', 'shared', 'embed_tokens', 'embedding']:
                    if hasattr(self, attr_name):
                        layer = getattr(self, attr_name)
                        if isinstance(layer, torch.nn.Module):
                            # If layer is a module, check if it has a word_embeddings attribute
                            if hasattr(layer, 'word_embeddings'):
                                return layer.word_embeddings
                            return layer
                
                # Fallback: return a function that returns the embedding
                return lambda x, **kwargs: self.forward(input_ids=x, **kwargs).embeddings
                
            model.get_embedding_layer = get_embedding_layer.__get__(model)
            
        return model
    except Exception as e:
        logger.error(f"Error loading model for {model_name}: {e}")
        # Fallback to bert-base-uncased
        logger.info("Falling back to bert-base-uncased model")
        model = AutoModel.from_pretrained("bert-base-uncased")
        
        def get_embedding_layer(self):
            return self.embeddings.word_embeddings
            
        model.get_embedding_layer = get_embedding_layer.__get__(model)
        return model