                                                     
import sys               
import os               
import logging
import torch               
               
               
               
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
               
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, Gemma3ForConditionalGeneration                 
from fortress.config import get_config
from typing import Optional, List             
import traceback               


logger = logging.getLogger(__name__)

class EmbeddingModel:
    """
    A class to generate embeddings using a pre-trained transformer model.
    It supports various configurations for model selection, hidden state extraction,
    and pooling strategies. Now supports both Gemma and Qwen models.
    """

    def __init__(self):
        """
        Initializes the EmbeddingModel by loading the model and tokenizer
        based on the configuration in settings.yaml.
        """
        config_data = get_config()
        config = config_data.get('embedding_model', {})
        self.model_name = config.get('model_name', "google/gemma-3-1b-it")
        self.input_prompt_template = config.get('input_prompt_template', "{prompt_text}")
        self.pooling_strategy = config.get('pooling_strategy', 'mean').lower()
                                                                                                                        

        self.tokenizer = None
        self.embedding_model_instance = None                     
        self.generative_model_instance = None                            
        self.device = None
        self.processor = None                    

                                                                      
        self.is_large_gemma3_model = any(variant in self.model_name for variant in [
            "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"
        ])
        
                                                
        self.is_qwen_model = "qwen" in self.model_name.lower()
        
                                                                  
        self.qwen_enable_thinking = config.get('qwen_enable_thinking', True)

        try:
            if self.is_large_gemma3_model:
                logger.info(f"Initializing {self.model_name} using Gemma3ForConditionalGeneration pipeline.")
                self.processor = AutoProcessor.from_pretrained(self.model_name)
                if hasattr(self.processor, 'tokenizer'):
                    self.tokenizer = self.processor.tokenizer
                else:
                    logger.warning(f"AutoProcessor for {self.model_name} does not have a .tokenizer attribute. Attempting to use processor directly for tokenization.")
                    self.tokenizer = self.processor

                _gemma3_conditional_model = Gemma3ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.bfloat16,
                    device_map="auto"
                )
                self.generative_model_instance = _gemma3_conditional_model

                                                               
                if hasattr(_gemma3_conditional_model, 'model'):
                    self.embedding_model_instance = _gemma3_conditional_model.model
                else:
                    logger.error(f"The model {self.model_name} of type {type(_gemma3_conditional_model)} does not have a '.model' attribute. Using full model for embeddings, which may cause issues.")
                    self.embedding_model_instance = _gemma3_conditional_model           
                
                self.device = self.generative_model_instance.device                               
            
            elif self.is_qwen_model:
                                           
                logger.info(f"Initializing {self.model_name} using Qwen-compatible pipeline.")
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                
                                                     
                if "0.6B" in self.model_name or "0.5B" in self.model_name:
                    dtype = torch.float32
                else:
                    dtype = torch.bfloat16
                
                _full_causal_lm = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=dtype,
                    device_map="auto"
                )
                
                                                                                       
                if hasattr(_full_causal_lm, 'model'):
                    self.embedding_model_instance = _full_causal_lm.model
                elif hasattr(_full_causal_lm, 'transformer'):
                                                                                 
                    self.embedding_model_instance = _full_causal_lm.transformer
                else:
                    logger.warning(f"The Qwen model {self.model_name} doesn't have expected attributes. Using full model for embeddings.")
                    self.embedding_model_instance = _full_causal_lm
                
                self.generative_model_instance = _full_causal_lm
                self.device = self.generative_model_instance.device
                
            else:
                                                                        
                logger.info(f"Initializing {self.model_name} using AutoModelForCausalLM pipeline.")
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                _full_causal_lm = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float32                        
                )
                
                                                   
                if hasattr(_full_causal_lm, 'model'):
                    self.embedding_model_instance = _full_causal_lm.model
                else:
                    logger.error(f"The model {self.model_name} of type {type(_full_causal_lm)} does not have a '.model' attribute. Using full model for embeddings, which may cause issues.")
                    self.embedding_model_instance = _full_causal_lm           

                self.generative_model_instance = _full_causal_lm                                                   
                
                self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                                                                    
                if self.embedding_model_instance:                                 
                    self.embedding_model_instance.to(self.device)
                                                                                                                      
                if self.generative_model_instance and self.embedding_model_instance is not self.generative_model_instance:
                    self.generative_model_instance.to(self.device)
                elif self.generative_model_instance and not self.embedding_model_instance:                                                  
                    self.generative_model_instance.to(self.device)


                                           
            if self.embedding_model_instance:
                self.embedding_model_instance.eval()
            if self.generative_model_instance and self.generative_model_instance is not self.embedding_model_instance:
                self.generative_model_instance.eval()
            
            logger.info(f"EmbeddingModel initialized with {self.model_name} on {self.device}")

        except Exception as e:
            logger.error(f"Error initializing EmbeddingModel ({self.model_name}): {e}", exc_info=True)
            raise                                                  


    def get_embedding(self, texts: List[str]) -> Optional[torch.Tensor]:                                
        if not self.embedding_model_instance or not self.tokenizer:
            logger.error("Embedding model or tokenizer not initialized properly.")
            return None

        if not texts:            
            logger.warning("get_embedding called with an empty list of texts.")
                                                                            
            hidden_size = 0
            if hasattr(self.embedding_model_instance, 'config') and hasattr(self.embedding_model_instance.config, 'hidden_size'):
                hidden_size = self.embedding_model_instance.config.hidden_size
            else:
                logger.warning(f"Cannot determine hidden_size for {self.model_name} to return empty tensor.")
            return torch.empty(0, hidden_size, device=self.device)

        try:
                                                                                  
            if self.is_qwen_model and hasattr(self.tokenizer, 'apply_chat_template'):
                                                                                  
                all_input_ids = []
                all_attention_masks = []
                
                for text in texts:
                    messages = [{"role": "user", "content": text}]
                    templated_text = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                        enable_thinking=False                                             
                    )
                    inputs = self.tokenizer(templated_text, return_tensors="pt", truncation=True, padding=False, max_length=512)
                    all_input_ids.append(inputs.input_ids)
                    all_attention_masks.append(inputs.attention_mask)
                
                                              
                max_length = max(ids.shape[1] for ids in all_input_ids)
                padded_input_ids = []
                padded_attention_masks = []
                
                for ids, mask in zip(all_input_ids, all_attention_masks):
                    pad_length = max_length - ids.shape[1]
                    if pad_length > 0:
                        pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
                        padded_ids = torch.cat([ids, torch.full((1, pad_length), pad_token_id)], dim=1)
                        padded_mask = torch.cat([mask, torch.zeros((1, pad_length), dtype=torch.long)], dim=1)
                    else:
                        padded_ids = ids
                        padded_mask = mask
                    padded_input_ids.append(padded_ids)
                    padded_attention_masks.append(padded_mask)
                
                inputs = {
                    'input_ids': torch.cat(padded_input_ids, dim=0).to(self.device if not self.is_large_gemma3_model else self.embedding_model_instance.device),
                    'attention_mask': torch.cat(padded_attention_masks, dim=0).to(self.device if not self.is_large_gemma3_model else self.embedding_model_instance.device)
                }
            else:
                                                                                                
                inputs = self.tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
                inputs = {k: v.to(self.device if not self.is_large_gemma3_model else self.embedding_model_instance.device) for k, v in inputs.items()}

            with torch.no_grad():
                                                                   
                outputs = self.embedding_model_instance(**inputs, output_hidden_states=True)
            
                                              
            last_hidden_states = None
            if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
                last_hidden_states = outputs.hidden_states[-1]                                      
            elif hasattr(outputs, 'last_hidden_state'):
                last_hidden_states = outputs.last_hidden_state
            elif isinstance(outputs, torch.Tensor):                                                
                last_hidden_states = outputs
            else:
                logger.error(f"Unexpected output type or structure from model: {type(outputs)}. Cannot extract hidden states.")
                                                                    
                try:
                    logger.error(f"Available attributes in outputs: {dir(outputs)}")
                except:
                    pass                                                           
                return None

            if last_hidden_states is None:                                                       
                logger.error(f"Failed to extract hidden states from model output of type: {type(outputs)}.")
                return None

            batch_embeddings = None
            if self.pooling_strategy == 'mean':
                attention_mask = inputs['attention_mask']
                                                                                                 
                mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
                                                                
                sum_embeddings = torch.sum(last_hidden_states * mask_expanded, 1)
                                                                                              
                sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
                batch_embeddings = sum_embeddings / sum_mask
            elif self.pooling_strategy == 'cls':
                                                                                           
                batch_embeddings = last_hidden_states[:, 0]                                         
            else:                
                logger.error(f"Unsupported pooling strategy: {self.pooling_strategy}")
                return None
            return batch_embeddings

        except Exception as e:
            logger.error(f"Error during embedding generation for texts '{texts[0][:50]}...': {e}", exc_info=True)
            return None

    def get_perplexity(self, text: str) -> Optional[float]:
        """
        Calculates the perplexity of a given text using the configured model.
        Perplexity is a measure of how well a probability model predicts a sample.
        Lower perplexity indicates the model is less surprised by the text.
        """
        if not self.generative_model_instance or not self.tokenizer:
            logger.error("Model or tokenizer for perplexity not initialized properly.")
            return None
        if not text:
            logger.warning("get_perplexity called with empty text.")
            return None                                             

        try:
            if self.generative_model_instance.device.type == 'cuda':
                torch.cuda.empty_cache()

                                                                                  
            if self.is_qwen_model and hasattr(self.tokenizer, 'apply_chat_template'):
                messages = [{"role": "user", "content": text}]
                templated_text = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=False                                               
                )
                inputs = self.tokenizer(templated_text, return_tensors="pt", truncation=True, max_length=512).to(self.generative_model_instance.device)
            else:
                inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.generative_model_instance.device)
            
            input_ids = inputs.input_ids
            
            with torch.no_grad():
                outputs = self.generative_model_instance(input_ids, labels=input_ids)
                loss = outputs.loss                                                           
            
            if loss is not None:
                perplexity = torch.exp(loss).item()
                logger.debug(f"Perplexity for text '{text[:50]}...': {perplexity}")
                return perplexity
            else:
                logger.error(f"Could not calculate loss for perplexity for text '{text[:50]}...'.")
                return None

        except Exception as e:
            logger.error(f"Error calculating perplexity for text '{text[:50]}...': {e}", exc_info=True)
            return None

    def get_token_source_log_probabilities(self, text: str, model_name_override: Optional[str] = None) -> List[float]:
        """
        Calculates the log probability of each token in the input text given the preceding tokens,
        using the full CausalLM or ConditionalGeneration model.

        Args:
            text: The input string.
            model_name_override: Optional Hugging Face model identifier to override the one from config.

        Returns:
            A list of float values, where each float is the log probability of the
            corresponding token in the input text. The list length matches the number of tokens.
        """
        if not text:
            logger.warning("get_token_source_log_probabilities called with empty text.")
            return []

        target_model_name = model_name_override if model_name_override else self.model_name
        
        model_to_use = None
        tokenizer_to_use = None
        is_target_qwen = False

                                                                                   
        if not model_name_override or target_model_name == self.model_name:
            if not self.generative_model_instance or not self.tokenizer:
                logger.error("Model for log probabilities not initialized properly.")
                return []
            model_to_use = self.generative_model_instance
            tokenizer_to_use = self.tokenizer
            is_target_qwen = self.is_qwen_model
        else:
                                                 
            logger.info(f"Loading {target_model_name} for get_token_source_log_probabilities due to override.")
            try:
                is_target_large_gemma3 = any(variant in target_model_name for variant in [
                    "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"
                ])
                is_target_qwen = "qwen" in target_model_name.lower()
                
                _device_for_override = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                if _device_for_override.type == 'cuda':
                    torch.cuda.empty_cache()                                       

                _processor_for_override = None                                      
                if is_target_large_gemma3:
                    _processor_for_override = AutoProcessor.from_pretrained(target_model_name)
                    if hasattr(_processor_for_override, 'tokenizer'):
                        tokenizer_to_use = _processor_for_override.tokenizer
                    else:
                        tokenizer_to_use = _processor_for_override           
                    
                    model_to_use = Gemma3ForConditionalGeneration.from_pretrained(
                        target_model_name,
                        torch_dtype=torch.bfloat16,
                        device_map="auto" 
                    )
                else:
                    tokenizer_to_use = AutoTokenizer.from_pretrained(target_model_name)
                    
                                                    
                    if is_target_qwen and ("0.6B" in target_model_name or "0.5B" in target_model_name):
                        dtype = torch.float32
                    elif is_target_qwen:
                        dtype = torch.bfloat16
                    else:
                        dtype = torch.float32
                    
                    model_to_use = AutoModelForCausalLM.from_pretrained(
                        target_model_name,
                        torch_dtype=dtype,
                        device_map="auto" if is_target_qwen else None
                    )
                                                            
                    if not is_target_qwen:
                        model_to_use.to(_device_for_override)
                
                model_to_use.eval()

            except Exception as e:
                logger.error(f"Error loading override model {target_model_name} for log probabilities: {e}", exc_info=True)
                return []

        if not model_to_use or not tokenizer_to_use:
             logger.error("Failed to obtain model or tokenizer for log probabilities.")
             return []

        if tokenizer_to_use.pad_token is None and hasattr(tokenizer_to_use, 'eos_token'):
            tokenizer_to_use.pad_token = tokenizer_to_use.eos_token
            logger.info(f"Set pad_token to eos_token for {target_model_name} for log probability calculation.")
        elif tokenizer_to_use.pad_token is None:
            logger.warning(f"pad_token is None and eos_token is not available for {target_model_name}. Padding issues may occur.")

        log_probabilities = []
        try:
            if model_to_use.device.type == 'cuda':
                torch.cuda.empty_cache()

                                                                                  
            if is_target_qwen and hasattr(tokenizer_to_use, 'apply_chat_template'):
                messages = [{"role": "user", "content": text}]
                templated_text = tokenizer_to_use.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=False                                                    
                )
                encoded_input = tokenizer_to_use(templated_text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(model_to_use.device)
            else:
                                       
                encoded_input = tokenizer_to_use(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(model_to_use.device)
            
            input_ids = encoded_input.input_ids
            attention_mask = encoded_input.attention_mask

            with torch.no_grad():
                outputs = model_to_use(**encoded_input, labels=input_ids)
                logits = outputs.logits

                                                                             
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            shift_attention_mask = attention_mask[..., 1:].contiguous()

                                                          
            log_softmax_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
            
                                                             
            token_log_probs_tensor = log_softmax_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)

                                                                                  
            if token_log_probs_tensor.ndim > 1 and token_log_probs_tensor.shape[0] == 1:
                token_log_probs_tensor = token_log_probs_tensor.squeeze(0)                   
                shift_labels = shift_labels.squeeze(0)
                shift_attention_mask = shift_attention_mask.squeeze(0)

            for i in range(token_log_probs_tensor.shape[0]):                               
                if shift_attention_mask[i].item() == 1:                                   
                    log_probabilities.append(token_log_probs_tensor[i].item())
            
        except Exception as e:
            logger.error(f"Error calculating token source log probabilities for '{text[:50]}...': {e}", exc_info=True)
            return []                             
        
        logger.debug(f"Calculated {len(log_probabilities)} log probabilities for text '{text[:50]}'.")
        return log_probabilities

    def get_token_source_log_probabilities_batch(self, texts: List[str], model_name_override: Optional[str] = None) -> List[List[float]]:
        """
        Calculates the log probability of each token for a batch of texts.
        Args:
            texts: List of input strings.
            model_name_override: Optional Hugging Face model identifier to override the one from config.
        Returns:
            List[List[float]]: Each inner list is the log probabilities for one input text.
        """
        if not texts:
            logger.warning("get_token_source_log_probabilities_batch called with empty list of texts.")
            return []

        target_model_name = model_name_override if model_name_override else self.model_name
        model_to_use = None
        tokenizer_to_use = None
        is_target_qwen = False

        if not model_name_override or target_model_name == self.model_name:
            if not self.generative_model_instance or not self.tokenizer:
                logger.error("Model for log probabilities not initialized properly.")
                return [[] for _ in texts]
            model_to_use = self.generative_model_instance
            tokenizer_to_use = self.tokenizer
            is_target_qwen = self.is_qwen_model
        else:
            logger.info(f"Loading {target_model_name} for get_token_source_log_probabilities_batch due to override.")
            try:
                is_target_large_gemma3 = any(variant in target_model_name for variant in [
                    "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"
                ])
                is_target_qwen = "qwen" in target_model_name.lower()
                _device_for_override = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                if _device_for_override.type == 'cuda':
                    torch.cuda.empty_cache()
                _processor_for_override = None
                if is_target_large_gemma3:
                    _processor_for_override = AutoProcessor.from_pretrained(target_model_name)
                    if hasattr(_processor_for_override, 'tokenizer'):
                        tokenizer_to_use = _processor_for_override.tokenizer
                    else:
                        tokenizer_to_use = _processor_for_override
                    model_to_use = Gemma3ForConditionalGeneration.from_pretrained(
                        target_model_name,
                        torch_dtype=torch.bfloat16,
                        device_map="auto"
                    )
                else:
                    tokenizer_to_use = AutoTokenizer.from_pretrained(target_model_name)
                    
                                                    
                    if is_target_qwen and ("0.6B" in target_model_name or "0.5B" in target_model_name):
                        dtype = torch.float32
                    elif is_target_qwen:
                        dtype = torch.bfloat16
                    else:
                        dtype = torch.float32
                    
                    model_to_use = AutoModelForCausalLM.from_pretrained(
                        target_model_name,
                        torch_dtype=dtype,
                        device_map="auto" if is_target_qwen else None
                    )
                    if not is_target_qwen:
                        model_to_use.to(_device_for_override)
                model_to_use.eval()
            except Exception as e:
                logger.error(f"Error loading override model {target_model_name} for batched log probabilities: {e}", exc_info=True)
                return [[] for _ in texts]

        if not model_to_use or not tokenizer_to_use:
            logger.error("Failed to obtain model or tokenizer for batched log probabilities.")
            return [[] for _ in texts]

        if tokenizer_to_use.pad_token is None:
            if hasattr(tokenizer_to_use, 'eos_token') and tokenizer_to_use.eos_token:
                tokenizer_to_use.pad_token = tokenizer_to_use.eos_token
            else:
                logger.warning(f"Tokenizer for {target_model_name} has no pad_token or eos_token. Adding a default pad_token '[PAD]'.")
                tokenizer_to_use.add_special_tokens({'pad_token': '[PAD]'})
                model_to_use.resize_token_embeddings(len(tokenizer_to_use))

        batch_log_probabilities = []
        try:
            if model_to_use.device.type == 'cuda':
                torch.cuda.empty_cache()

                                                    
            if is_target_qwen and hasattr(tokenizer_to_use, 'apply_chat_template'):
                processed_texts = []
                for text in texts:
                    messages = [{"role": "user", "content": text}]
                    templated_text = tokenizer_to_use.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                        enable_thinking=False                                                    
                    )
                    processed_texts.append(templated_text)
                encoded_inputs = tokenizer_to_use(processed_texts, return_tensors="pt", truncation=True, padding=True, max_length=512).to(model_to_use.device)
            else:
                encoded_inputs = tokenizer_to_use(texts, return_tensors="pt", truncation=True, padding=True, max_length=512).to(model_to_use.device)
            
            input_ids_batch = encoded_inputs.input_ids
            attention_mask_batch = encoded_inputs.attention_mask

            with torch.no_grad():
                outputs = model_to_use(**encoded_inputs, labels=input_ids_batch)
                logits_batch = outputs.logits

            for i in range(logits_batch.shape[0]):
                item_logits = logits_batch[i]
                item_input_ids = input_ids_batch[i]
                item_attention_mask = attention_mask_batch[i]

                shift_logits = item_logits[:-1, :].contiguous()
                shift_labels = item_input_ids[1:].contiguous()
                shift_attention_mask = item_attention_mask[1:].contiguous()

                log_softmax_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
                token_log_probs_tensor = log_softmax_probs.gather(1, shift_labels.unsqueeze(-1)).squeeze(-1)

                item_log_probs = []
                for j in range(token_log_probs_tensor.shape[0]):
                    if shift_attention_mask[j].item() == 1:
                        item_log_probs.append(token_log_probs_tensor[j].item())
                batch_log_probabilities.append(item_log_probs)
        except Exception as e:
            logger.error(f"Error calculating batched token source log probabilities: {e}", exc_info=True)
            return [[] for _ in texts]
        logger.debug(f"Calculated batched log probabilities for {len(texts)} texts.")
        return batch_log_probabilities


if __name__ == '__main__':
                                  
                                                                       
    from rich.console import Console
    console = Console()
    console.print("[bold green]Starting EmbeddingModel basic execution test...[/bold green]")


    try:
                                                                                             
        console.print(f"[cyan]Testing with default model from settings.yaml...[/cyan]")
        model_default = EmbeddingModel()                                     
        
        test_texts = [
            "This is a test sentence for embedding.",
            "Another sentence to be embedded.",
            "Exploring batch processing capabilities."
        ]
        console.print(f"Test texts (batch of {len(test_texts)}):")
        for i, text_item in enumerate(test_texts):                          
            console.print(f"  {i+1}. {text_item}")                          
            
        embeddings = model_default.get_embedding(test_texts) 

        if embeddings is not None:
            console.print(f"Generated embeddings shape: {embeddings.shape}")
            assert embeddings.shape[0] == len(test_texts), "Batch size mismatch"
                                                                                    
            console.print(f"[green]Default model embedding test successful.[/green]")
        else:
            console.print(f"[bold red]Failed to get embeddings with default model.[/bold red]")
            
                                                        
        test_perplexity_text = "This is a sample sentence for perplexity calculation."
        perplexity = model_default.get_perplexity(test_perplexity_text)
        if perplexity is not None:
            console.print(f"Perplexity for '{test_perplexity_text}': {perplexity:.4f}")
            console.print(f"[green]Default model perplexity test successful.[/green]")
        else:
            console.print(f"[bold red]Failed to get perplexity with default model.[/bold red]")

                                                                
        test_log_prob_text = "This is a test for log probabilities."
        log_probabilities = model_default.get_token_source_log_probabilities(test_log_prob_text)
        if log_probabilities:
            console.print(f"Log probabilities for '{test_log_prob_text}' (first 5): {log_probabilities[:5]}")
            console.print(f"Number of log probabilities: {len(log_probabilities)}")
            console.print(f"[green]Default model log probabilities test successful.[/green]")
        else:
            console.print(f"[bold red]Failed to get log probabilities with default model.[/bold red]")

                                                         
        console.print(f"[cyan]Testing with Qwen model (if different from default)...[/cyan]")
        if "qwen" not in model_default.model_name.lower():
                                                                      
            original_model_name = get_config()['embedding_model']['model_name']
            get_config()['embedding_model']['model_name'] = "Qwen/Qwen3-0.6B"
            if hasattr(get_config, "config"):
                delattr(get_config, "config")                         

            model_qwen = EmbeddingModel()
            embeddings_qwen = model_qwen.get_embedding(["Test for Qwen model."])
            if embeddings_qwen is not None:
                console.print(f"Qwen model embeddings shape: {embeddings_qwen.shape}")
                console.print(f"[green]Qwen model test successful.[/green]")
            else:
                console.print(f"[bold red]Failed to get embeddings with Qwen model.[/bold red]")
            
                                                                        
            get_config()['embedding_model']['model_name'] = original_model_name
            if hasattr(get_config, "config"):
                delattr(get_config, "config")
        else:
            console.print(f"[yellow]Default model is already a Qwen model, skipping separate Qwen test.[/yellow]")


    except Exception as e:
        console.print(f"[bold red]Error during EmbeddingModel basic test: {e}[/bold red]")
        console.print(traceback.format_exc())

    console.print("[bold green]EmbeddingModel basic execution test finished.[/bold green]")
