#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
OpenAI Embedding Model Wrapper for the Implicit Embeddings Benchmark.
This wrapper provides a compatible interface with SentenceTransformer for OpenAI embedding models.
"""

import os
import time
import logging
import numpy as np
from typing import List, Dict, Any, Union, Optional
import openai
from tqdm import tqdm
import tiktoken

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("OpenAIModelWrapper")

# List of valid OpenAI embedding models
VALID_OPENAI_MODELS = [
    "text-embedding-3-large",
    "text-embedding-3-small",
    "text-embedding-ada-002"
]

# Initialize tiktoken encoder
ENCODER = tiktoken.get_encoding("cl100k_base")  # Using OpenAI's encoding
MAX_TOKENS = 8000  # Leave some margin for safety

def truncate_text(text: str, max_tokens: int = MAX_TOKENS) -> str:
    """
    Truncate text to a maximum number of tokens using tiktoken.
    
    Args:
        text: Input text to truncate
        max_tokens: Maximum number of tokens to keep
        
    Returns:
        Truncated text
    """
    tokens = ENCODER.encode(text)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
        text = ENCODER.decode(tokens)
    return text

class OpenAIEmbeddingModel:
    """
    Wrapper for OpenAI embedding models to be compatible with the benchmark infrastructure.
    Provides a similar interface to SentenceTransformer for encoding texts.
    """
    
    def __init__(self, model_name: str, batch_size: int = 128):
        """
        Initialize the OpenAI embedding model wrapper.
        
        Args:
            model_name: Name of the OpenAI embedding model (e.g., "text-embedding-3-large")
            batch_size: Batch size for API calls (default: 128, to avoid token limits)
        """
        # Clean model name - remove any comments if present
        self.model_name = model_name.split('#')[0].strip()
        self.batch_size = min(batch_size, 256)  # Ensure batch size doesn't exceed OpenAI's limit
        
        # Validate model name
        if self.model_name not in VALID_OPENAI_MODELS:
            logger.warning(f"Model name '{self.model_name}' is not in the list of known OpenAI embedding models: {VALID_OPENAI_MODELS}")
            logger.warning("This may cause errors if the model name is invalid.")
        
        # Check if API key exists in environment
        if not os.environ.get('OPENAI_API_KEY'):
            logger.warning("OPENAI_API_KEY environment variable not found. You will need to set this before using the model.")
        
        # Configure OpenAI client with API key from environment
        self.client = openai.OpenAI()
        
        logger.info(f"Initialized OpenAI embedding model: {self.model_name} with batch size {self.batch_size}")
    
    def encode(self, sentences: Union[str, List[str]], show_progress_bar: bool = True, batch_size: Optional[int] = None) -> np.ndarray:
        """
        Encode sentences into embeddings using the OpenAI API.
        
        Args:
            sentences: Single sentence or list of sentences to encode
            show_progress_bar: Whether to show progress bar during encoding
            batch_size: Optional batch size override (default: use self.batch_size)
            
        Returns:
            NumPy array of embeddings
        """
        # Convert single sentence to list
        if isinstance(sentences, str):
            sentences = [sentences]
        
        # Truncate each sentence to max tokens
        sentences = [truncate_text(s) for s in sentences]
        
        # Use provided batch_size or fall back to instance batch_size
        batch_size = batch_size if batch_size is not None else self.batch_size
        
        # Prepare batches
        batches = [sentences[i:i+batch_size] for i in range(0, len(sentences), batch_size)]
        
        all_embeddings = []
        
        # Set up progress bar if requested
        batches_iter = tqdm(batches, desc="Encoding with OpenAI API") if show_progress_bar else batches
        
        # Process each batch
        for batch in batches_iter:
            try:
                # Make API call with retries for rate limiting
                embeddings = self._encode_batch_with_retry(batch)
                all_embeddings.extend(embeddings)
                
            except Exception as e:
                logger.error(f"Error encoding batch: {e}")
                raise
        
        # Convert to numpy array and return
        return np.array(all_embeddings)
    
    def _encode_batch_with_retry(self, batch: List[str], max_retries: int = 5) -> List[List[float]]:
        """
        Encode a batch of texts with retry logic for rate limiting.
        
        Args:
            batch: List of texts to encode
            max_retries: Maximum number of retries on rate limiting errors
            
        Returns:
            List of embeddings
        """
        retry_count = 0
        backoff_time = 1  # Initial backoff time in seconds
        
        while retry_count < max_retries:
            try:
                response = self.client.embeddings.create(
                    model=self.model_name,
                    input=batch
                )
                
                # Extract embeddings from response
                embeddings = [item.embedding for item in response.data]
                return embeddings
                
            except openai.RateLimitError:
                retry_count += 1
                if retry_count < max_retries:
                    logger.warning(f"Rate limit exceeded. Retrying in {backoff_time} seconds... (Attempt {retry_count}/{max_retries})")
                    time.sleep(backoff_time)
                    backoff_time *= 2  # Exponential backoff
                else:
                    logger.error(f"Rate limit exceeded after {max_retries} retries.")
                    raise
            
            except openai.BadRequestError as e:
                # Handle invalid model ID errors specifically
                if "invalid model ID" in str(e):
                    logger.error(f"Invalid model ID: '{self.model_name}'. Please check that the model name is correct.")
                    logger.error(f"Valid OpenAI embedding models are: {VALID_OPENAI_MODELS}")
                # Handle token limit errors
                elif "max_tokens_per_request" in str(e):
                    logger.error(f"Token limit exceeded. Please try with a smaller batch size. Current batch size: {self.batch_size}")
                    logger.error(f"Error details: {e}")
                raise
                
            except Exception as e:
                logger.error(f"Error in OpenAI API call: {e}")
                raise
                
    def get_sentence_embedding_dimension(self) -> int:
        """
        Return the embedding dimension for the model.
        
        Returns:
            Dimension of embeddings (e.g., 1536 for text-embedding-3-small)
        """
        # Map of known OpenAI model dimensions
        model_dimensions = {
            "text-embedding-3-small": 1536,
            "text-embedding-3-large": 3072,
            "text-embedding-ada-002": 1536
        }
        
        # Return known dimension or make a test API call to determine it
        if self.model_name in model_dimensions:
            return model_dimensions[self.model_name]
        else:
            # Make a test API call to determine the dimension
            logger.info(f"Determining embedding dimension for unknown model {self.model_name}")
            try:
                test_embedding = self.encode("This is a test sentence", show_progress_bar=False)
                return test_embedding.shape[1]
            except Exception as e:
                logger.error(f"Error determining embedding dimension: {e}")
                # Default to 1536 if unable to determine
                logger.warning("Defaulting to 1536 dimensions")
                return 1536

# Factory function to create an instance of the wrapper with the same interface as SentenceTransformer
def OpenAIModel(model_name: str, batch_size: int = 128, **kwargs) -> OpenAIEmbeddingModel:
    """
    Factory function to create an OpenAI embedding model instance.
    This mimics the SentenceTransformer constructor for compatibility.
    
    Args:
        model_name: Name of the OpenAI embedding model
        batch_size: Batch size for API calls (default: 128)
        **kwargs: Additional arguments (ignored, for compatibility)
        
    Returns:
        OpenAIEmbeddingModel instance
    """
    return OpenAIEmbeddingModel(model_name, batch_size) 