# imports
import random
import time
import openai
import os
import re
from copy import deepcopy
import tiktoken
import asyncio

# Tokenizer
CL100K_ENCODER = tiktoken.get_encoding("cl100k_base")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# openai.api_key = ""

OPENAI_MODEL_DETAILS_MAP = {
    # Default parameters for gpt-4.1 (based on typical current limits, adjust as needed)
    'gpt-4.1': {
        # Tokens per minute (TPM) - placeholder, check your actual limit
        'tpm': 200000,
        # Requests per minute (RPM) - placeholder, check your actual limit
        'rpm': 2000,
        # Max context length for gpt-4.1
        'max_tokens': 128000
    },
    'gpt-3.5-turbo': {
        'tpm': 90000,
        'rpm': 3500,
        'max_tokens': 4097
    },
    'gpt-3.5-turbo-16k': {
        'tpm': 180000,
        'rpm': 3500,
        'max_tokens': 16385
    },
    'gpt-4': {
        'tpm': 40000,
        'rpm': 200,
        'max_tokens': 8192
    },
    'text-embedding-ada-002': {
        # Embedding models often have very high limits, adjust as needed
        'tpm': 1000000,
        'rpm': 3000,
        'max_tokens': 8191
    }
}

# Define the model to be used globally as requested
DEFAULT_MODEL = "gpt-4.1"
EMBEDDING_MODEL = "text-embedding-ada-002"

def get_llm_config(config, logger, name, rate_limiter):
    """
    Builds the LLM configuration dictionary.
    Provides default retry parameters if missing from config.setup.
    """

    # safe defaults for retry/backoff parameters
    retry_initial_delay = getattr(config.setup, "api_retry_with_exponential_backoff__initial_delay", 1)
    retry_exponential_base = getattr(config.setup, "api_retry_with_exponential_backoff__exponential_base", 2)
    retry_jitter = getattr(config.setup, "api_retry_with_exponential_backoff__jitter", True)
    retry_max_retries = getattr(config.setup, "api_retry_with_exponential_backoff__max_retries", 10)

    return {
        "model": getattr(config.run, "model", "gpt-4.1"),
        "temperature": getattr(config.run, "temperature", 0.7),
        "top_p": getattr(config.run, "top_p", 1.0),
        "frequency_penalty": getattr(config.run, "frequency_penalty", 0.0),
        "presence_penalty": getattr(config.run, "presence_penalty", 0.0),
        "stop": getattr(config.run, "stop", None),
        "request_timeout": getattr(config.setup, "api_request_timeout", 60),
        "stream": getattr(config.setup, "api_stream", False),
        "_open_ai_rate_limit_requests_per_minute": getattr(config.setup, "open_ai_rate_limit_requests_per_minute", 3000),
        "_use_azure_api": False,
        "_logger": logger,
        "_name": name,
        "_rate_limiter": rate_limiter,
        "_retry_with_exponential_backoff__initial_delay": retry_initial_delay,
        "_retry_with_exponential_backoff__exponential_base": retry_exponential_base,
        "_retry_with_exponential_backoff__jitter": retry_jitter,
        "_retry_with_exponential_backoff__max_retries": retry_max_retries,
    }


def get_llm_config_prev(config, logger, name, rate_limiter):
    # Simplified config since we only use native OpenAI API
    # Assumes config object has 'run' and 'setup' attributes as in the original code
    return deepcopy({
        'model': DEFAULT_MODEL, # Hardcode model to gpt-4.1
        'temperature': config.run.temperature,
        'top_p': config.run.top_p,
        'frequency_penalty': config.run.frequency_penalty,
        'presence_penalty': config.run.presence_penalty,
        'stop': config.run.stop,
        'request_timeout': config.setup.api_request_timeout,
        'stream': config.setup.api_stream,
        '_open_ai_rate_limit_requests_per_minute': config.setup.open_ai_rate_limit_requests_per_minute,
        '_logger': logger,
        '_name': name,
        '_rate_limiter': rate_limiter,
        '_retry_with_exponential_backoff__initial_delay': config.setup.api_retry_with_exponential_backoff__initial_delay,
        '_retry_with_exponential_backoff__exponential_base': config.setup.api_retry_with_exponential_backbackoff__exponential_base,
        '_retry_with_exponential_backoff__jitter': config.setup.api_retry_with_exponential_backoff__jitter,
        '_retry_with_exponential_backoff__max_retries': config.setup.api_retry_with_exponential_backoff__max_retries
    })

def setup_chat_rate_limiter(config: dict):
    # Only using native OpenAI API, no Azure logic needed
    model = DEFAULT_MODEL
    model_details = OPENAI_MODEL_DETAILS_MAP.get(model, OPENAI_MODEL_DETAILS_MAP['gpt-4.1'])
    request_limit = model_details['rpm']
    token_limit = model_details['tpm']
    return request_limit, token_limit

def get_model_max_tokens(model_name=DEFAULT_MODEL):
    # Only using native OpenAI API, no Azure logic needed
    model_details = OPENAI_MODEL_DETAILS_MAP.get(model_name, OPENAI_MODEL_DETAILS_MAP['gpt-4.1'])
    max_tokens = model_details.get('max_tokens', 128000) # Default to 128k for gpt-4.1 if not found
    return max_tokens

def pretty_print_chat_messages(
    messages,
    num_tokens=None,
    max_tokens=None,
    logger=None,
    response_msg=False,
    step_idx=None,
    total_steps=None,
    max_re_tries=None,
    re_tries=None
):
    COLORS = {
        "system": "\033[95m",      # Light Magenta
        "user": "\033[94m",        # Light Blue
        "assistant": "\033[92m",   # Light Green
        "tokens": "\033[91m"       # Light Red
    }

    if response_msg:
        print("[LLM RESPONSE MESSAGE]")
        if logger:
            logger.info("[LLM RESPONSE MESSAGE]")

    for msg in messages:
        role = msg.get("role", "system")
        color = COLORS.get(role, COLORS["system"])
        formatted_role = role.capitalize()

        content = msg.get("content", "")
        function_call = msg.get("function_call", None)

        if function_call:
            formatted_role = "Function Call"
            name = function_call.get("name", "<no name>")
            arguments = function_call.get("arguments", "<no arguments>")
            print(f"{color}[{formatted_role}] [{name}] {arguments}\033[0m")
            if logger:
                logger.info(f"[{formatted_role}] [{name}] {arguments}")
        else:
            print(f"{color}[{formatted_role}] {content}\033[0m")
            if logger:
                logger.info(f"[{formatted_role}] {content}")

    # ------------------------------------------------------------------
    # Progress / token usage printout
    # ------------------------------------------------------------------
    if not response_msg:
        if step_idx is not None and total_steps is not None:
            if num_tokens and max_tokens:
                token_capacity_used_percent = (num_tokens / max_tokens) * 100.0
                tokens_remaining = max_tokens - num_tokens

                if max_re_tries is not None and re_tries is not None:
                    progress_message = (
                        f"[Progress: Step {step_idx + 1}/{total_steps} | "
                        f"Retries: {re_tries}/{max_re_tries} | "
                        f"Token Capacity Used: {token_capacity_used_percent:.2f}% | "
                        f"Tokens remaining {tokens_remaining}]"
                    )
                else:
                    progress_message = (
                        f"[Progress: Step {step_idx + 1}/{total_steps} | "
                        f"Token Capacity Used: {token_capacity_used_percent:.2f}% | "
                        f"Tokens remaining {tokens_remaining}]"
                    )

                print(f"{COLORS['tokens']}{progress_message}\033[0m")
                if logger:
                    logger.info(progress_message)

        elif num_tokens and max_tokens:
            token_capacity_used_percent = (num_tokens / max_tokens) * 100.0
            tokens_remaining = max_tokens - num_tokens
            message = (
                f"[Token Capacity Used: {token_capacity_used_percent:.2f}% | "
                f"Tokens remaining {tokens_remaining}]"
            )
            print(f"{COLORS['tokens']}{message}\033[0m")
            if logger:
                logger.info(message)


def pretty_print_chat_messages_prev(messages, num_tokens=None, max_tokens=None, logger=None, response_msg=False, step_idx=None, total_steps=None, max_re_tries=None, re_tries=None):
    COLORS = {
        "system": "\033[95m",      # Light Magenta
        "user": "\033[94m",        # Light Blue
        "assistant": "\033[92m",   # Light Green
        "tokens": "\033[91m"    # Light Red
    }

    if response_msg:
        print("[LLM RESPONSE MESSAGE]")  # Reset color at the end
        if logger:
            logger.info("[LLM RESPONSE MESSAGE]")
    
    for msg in messages:
        role = msg['role']
        color = COLORS.get(role, COLORS["system"])  # Default to system color if role not found
        formatted_role = role.capitalize()
        content = msg.get('content')
        function_call = msg.get('function_call')

        if function_call:
            formatted_role = "Function Call"
            print(f"{color}[{formatted_role}] [{function_call['name']}] {function_call['arguments']}\033[0m")  # Reset color at the end
            if logger:
                logger.info(f"[{formatted_role}] [{function_call['name']}] {function_call['arguments']}")
        else:
            print(f"{color}[{formatted_role}] {content}\033[0m")  # Reset color at the end
            if logger:
                logger.info(f"[{formatted_role}] {content}")

    if not response_msg:
        if step_idx is not None and total_steps is not None:
            if num_tokens and max_tokens:
                token_capacity_used_percent = ((num_tokens / max_tokens) * 100.0)
                tokens_remaining = max_tokens - num_tokens

                progress_message = f"[Progress: Step {step_idx + 1}/{total_steps} | Token Capacity Used: {token_capacity_used_percent:.2f}% | Tokens remaining {tokens_remaining}]"
                if max_re_tries is not None and re_tries is not None:
                    progress_message = f"[Progress: Step {step_idx + 1}/{total_steps} | Retries: {re_tries}/{max_re_tries} | Token Capacity Used: {token_capacity_used_percent:.2f}% | Tokens remaining {tokens_remaining}]"

                print(f"{COLORS['tokens']}{progress_message}\033[0m")
                if logger:
                    logger.info(progress_message)
                
        else:
            if num_tokens and max_tokens:
                token_capacity_used_percent = ((num_tokens / max_tokens) * 100.0)
                tokens_remaining = max_tokens - num_tokens
                message = f"[Token Capacity Used: {token_capacity_used_percent:.2f}% | Tokens remaining {tokens_remaining}]"
                print(f"{COLORS['tokens']}{message}\033[0m")
                if logger:
                    logger.info(message)


def chat_completion_rl(**kwargs):
    # Implements retry_with_exponential_backoff
    initial_delay = kwargs.pop('_retry_with_exponential_backoff__initial_delay', 1)
    exponential_base = kwargs.pop('_retry_with_exponential_backoff__exponential_base', 2)
    jitter = kwargs.pop('_retry_with_exponential_backoff__jitter', True)
    max_retries = kwargs.pop('_retry_with_exponential_backoff__max_retries', 10)
    # Removing _use_azure_api since it's hardcoded out
    kwargs.pop('_use_azure_api', None) 
    
    stream = kwargs.get('stream', False)

    # Hardcode model to DEFAULT_MODEL
    kwargs['model'] = DEFAULT_MODEL

    logger = kwargs.get('_logger', None)
    name = kwargs.get('_name', None)

    # Update errors to include common non-rate-limit errors
    errors: tuple = (openai.error.RateLimitError, openai.error.APIError, openai.error.Timeout, openai.error.ServiceUnavailableError)

    # Initialize variables
    num_retries = 0
    delay = initial_delay

    # Loop until a successful response or max_retries is hit or an exception is raised
    while True:
        try:
            if stream:
                return asyncio.run(async_chat_completion_rl_inner(**kwargs))
            else:
                return chat_completion_rl_inner(**kwargs)

        # Retry on specified errors
        except errors as e:
            num_retries += 1
            log_or_print = logger.info if logger else print
            
            log_or_print(f"[{name}][OpenAI API Request Error] {type(e)} {e.args} | num_retries: {num_retries} / {max_retries}")

            # Check if max retries has been reached
            if num_retries > max_retries:
                log_or_print(f"[{name}][OpenAI API Request Error] Exception Maximum number of retries ({max_retries}) exceeded.")
                raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")

            # Increment the delay
            delay *= exponential_base * (1 + jitter * random.random())
            
            # OpenAI's rate limit error message may contain a specific retry-after time
            match = re.search(r'Please retry after (\d+(?:\.\d+)?)', str(e))
            if match:
                delay = float(match.group(1)) + 0.5 # Add a small buffer

            if delay > 60:
                delay = 60 # Cap the delay
                
            log_or_print(f"[{name}][OpenAI API Request Error] {type(e)} {e.args} | num_retries: {num_retries} / {max_retries} | Now sleeping for {delay:.2f} seconds")
            time.sleep(delay)

        # Raise exceptions for any errors not specified
        except Exception as e:
            raise e

async def async_chat_completion_rl_inner(**kwargs):
    logger = kwargs.pop('_logger', None)
    name = kwargs.pop('_name', None)
    kwargs.pop('_open_ai_rate_limit_requests_per_minute', None)
    rate_limiter = kwargs.pop('_rate_limiter', None)
    
    # Ensure model is set to DEFAULT_MODEL
    kwargs['model'] = DEFAULT_MODEL
    
    # Remove all Azure-specific keys
    for k in list(kwargs.keys()):
        if k.startswith('api_'):
            kwargs.pop(k, None)
            
    t0 = time.perf_counter()

    if rate_limiter:
        rate_limiter.consume(**kwargs)
        responses = await openai.ChatCompletion.acreate(**kwargs)
    else:
        responses = await openai.ChatCompletion.acreate(**kwargs)
        
    response = {}
    chunks = []
    # Logic to reconstruct the message from stream chunks
    # Note: For simplicity and brevity, I've kept the original chunk processing logic 
    # but streaming print is commented out to avoid excessive output during testing.
    # In a real app, you would yield/print the chunks as they arrive.
    async for chunk in responses:
        # print(chunk) # Commented out for cleaner output
        if 'choices' not in chunk or len(chunk['choices']) == 0:
            continue
        chunk_message = chunk['choices'][0]['delta'].to_dict_recursive()  # extract the message
        chunks.append(chunk_message)
        # print(chunk_message) # Commented out for cleaner output
        for k, v in chunk_message.items():
            if k in response:
                if isinstance(response[k], dict):
                    for k2, v2 in v.items():
                        if k2 in response[k]:
                            response[k][k2] += v2
                        else:
                            response[k][k2] = v2
                else:
                    response[k] += v
            else:
                response[k] = v
                
    # print(response) # Commented out for cleaner output
    return_response = {"choices": [{"message": response}]}
    return return_response

def chat_completion_rl_inner(**kwargs):
    logger = kwargs.pop('_logger', None)
    name = kwargs.pop('_name', None)
    kwargs.pop('_open_ai_rate_limit_requests_per_minute', None)
    rate_limiter = kwargs.pop('_rate_limiter', None)
    
    # Ensure model is set to DEFAULT_MODEL
    kwargs['model'] = DEFAULT_MODEL
    
    # Remove all Azure-specific keys
    for k in list(kwargs.keys()):
        if k.startswith('api_') or k == 'engine':
            kwargs.pop(k, None)
            
    kwargs.pop('stream', None)
    t0 = time.perf_counter()

    if rate_limiter:
        rate_limiter.consume(**kwargs)
        response = openai.ChatCompletion.create(**kwargs)
    else:
        response = openai.ChatCompletion.create(**kwargs)

    return response

# -------------------------------------------------------------------------------------------------
# Embedding

def replace_newlines(input_data):
    if isinstance(input_data, str):
        return input_data.replace("\n", " ")
    elif isinstance(input_data, list):
        return [item.replace("\n", " ") for item in input_data]
    else:
        raise ValueError("Input should be either a string or a list of strings")

def embedding_rl(input, **kwargs):
    # Implements retry_with_exponential_backoff
    initial_delay = kwargs.pop('_retry_with_exponential_backoff__initial_delay', 1)
    exponential_base = kwargs.pop('_retry_with_exponential_backoff__exponential_base', 2)
    jitter = kwargs.pop('_retry_with_exponential_backoff__jitter', True)
    max_retries = kwargs.pop('_retry_with_exponential_backoff__max_retries', 10)
    # Removing _use_azure_api since it's hardcoded out
    kwargs.pop('_use_azure_api', None) 

    # Hardcode model to the standard embedding model
    kwargs['model'] = EMBEDDING_MODEL

    logger = kwargs.get('_logger', None)
    name = kwargs.get('_name', None)

    errors: tuple = (openai.error.RateLimitError, openai.error.APIError, openai.error.Timeout, openai.error.ServiceUnavailableError)

    # Initialize variables
    num_retries = 0
    delay = initial_delay

    # Loop until a successful response or max_retries is hit or an exception is raised
    while True:
        try:
            return embedding_rl_inner(input, **kwargs)

        # Retry on specified errors
        except errors as e:
            num_retries += 1
            log_or_print = logger.info if logger else print
            
            log_or_print(f"[{name}][OpenAI API Request Error] {type(e)} {e.args} | num_retries: {num_retries} / {max_retries}")

            # Check if max retries has been reached
            if num_retries > max_retries:
                log_or_print(f"[{name}][OpenAI API Request Error] Exception Maximum number of retries ({max_retries}) exceeded.")
                raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")

            # Increment the delay
            delay *= exponential_base * (1 + jitter * random.random())
            
            match = re.search(r'Please retry after (\d+(?:\.\d+)?)', str(e))
            if match:
                delay = float(match.group(1)) + 0.5

            if delay > 60:
                delay = 60
                
            log_or_print(f"[{name}][OpenAI API Request Error] {type(e)} {e.args} | num_retries: {num_retries} / {max_retries} | Now sleeping for {delay:.2f} seconds")
            time.sleep(delay)

        # Raise exceptions for any errors not specified
        except Exception as e:
            raise e

def embedding_rl_inner(input, **kwargs):
    logger = kwargs.pop('_logger', None)
    name = kwargs.pop('_name', None)
    kwargs.pop('_open_ai_rate_limit_requests_per_minute', None)
    rate_limiter = kwargs.pop('_rate_limiter', None)
    
    # Ensure model is set to the standard embedding model
    kwargs['model'] = EMBEDDING_MODEL
    
    # Remove all Azure-specific keys
    for k in list(kwargs.keys()):
        if k.startswith('api_') or k == 'engine':
            kwargs.pop(k, None)
            
    kwargs['input'] = replace_newlines(input)

    t0 = time.perf_counter()
    embeddings = []
    
    if rate_limiter:
        # rate_limiter.consume(**kwargs) # You would need to implement token consumption logic for embedding
        response = openai.Embedding.create(**kwargs)
    else:
        response = openai.Embedding.create(**kwargs)
        
    for i, be in enumerate(response["data"]):
        assert i == be["index"]  # double check embeddings are in same order as input
    batch_embeddings = [e["embedding"] for e in response["data"]]
    embeddings.extend(batch_embeddings)

    return embeddings

def num_tokens_consumed_by_chat_request(messages, max_tokens=15, n=1, functions='', **kwargs):
    # Pass the actual model name to the more accurate token counter
    model = kwargs.get('model', DEFAULT_MODEL) 
    num_tokens = num_tokens_from_messages(messages, model=model)

    if functions:
        function_tokens = num_tokens_from_functions(functions, model=model)
        num_tokens += function_tokens

    # This is an approximation for the model's max response tokens.
    # In practice, it's better to calculate based on the actual model max tokens minus the prompt tokens.
    num_tokens += n * max_tokens

    return num_tokens

def num_tokens_from_messages(messages, model=DEFAULT_MODEL):
    """Return the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using cl100k_base encoding.")
        encoding = CL100K_ENCODER
        
    # Standard values for chat models including gpt-4.1
    tokens_per_message = 3
    tokens_per_name = 1
    
    if model in {"gpt-3.5-turbo-0301"}:
        tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
        tokens_per_name = -1  # if there's a name, the role is omitted
    
    # Fallback/warning for models not explicitly in the tiktoken map
    elif "gpt-3.5-turbo" in model:
        print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
        # Use gpt-3.5-turbo-0613 rules as a safe default for gpt-3.5 series
        tokens_per_message = 3
        tokens_per_name = 1
    elif "gpt-4" in model or "gpt-4.1" in model:
        # Use gpt-4-0613 rules as a safe default for gpt-4/gpt-4.1 series
        tokens_per_message = 3
        tokens_per_name = 1
    else:
        # Raise an error for truly unsupported models
        pass 

    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            try:
                num_tokens += len(encoding.encode(value))
            except TypeError:
                num_tokens += len(encoding.encode(str(value)))
            if key == "name":
                num_tokens += tokens_per_name
    
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens


def num_tokens_from_functions(functions, model=DEFAULT_MODEL):
    """Return the number of tokens used by a list of functions."""
    # This function is a complex approximation based on how OpenAI tokenizes function calls.
    # It's an approximation because the exact tokenization of JSON schema is implementation-specific.
    num_tokens = 0
    encoding = CL100K_ENCODER # Functions use the cl100k_base encoding
    
    for function in functions:
        function_tokens = len(encoding.encode(function['name']))
        function_tokens += len(encoding.encode(function.get('description', '')))
        
        if 'parameters' in function:
            parameters = function['parameters']
            if 'properties' in parameters:
                for propertiesKey in parameters['properties']:
                    function_tokens += len(encoding.encode(propertiesKey))
                    v = parameters['properties'][propertiesKey]
                    
                    # Approximate token count for property details
                    if 'type' in v:
                        function_tokens += 2 + len(encoding.encode(v['type']))
                    if 'description' in v:
                        function_tokens += 2 + len(encoding.encode(v['description']))
                    if 'enum' in v:
                        function_tokens -= 3 # Adjust for list start/end characters
                        for o in v['enum']:
                            function_tokens += 3 + len(encoding.encode(o))
                    
                    # Simple recursive check for nested array items
                    if 'items' in v and isinstance(v['items'], dict):
                        function_tokens += len(encoding.encode(v['items'].get('type', '')))
                        if 'properties' in v['items']:
                            NestedParameters = v['items']
                            for NestedpropertiesKey in NestedParameters['properties']:
                                function_tokens += len(encoding.encode(NestedpropertiesKey))
                                Nestedv = NestedParameters['properties'][NestedpropertiesKey]
                                if 'type' in Nestedv:
                                     function_tokens += len(encoding.encode(Nestedv['type']))
                                if 'description' in Nestedv:
                                     function_tokens += len(encoding.encode(Nestedv['description']))
                                if 'enum' in Nestedv:
                                    for Nestedo in Nestedv['enum']:
                                        function_tokens += len(encoding.encode(Nestedo))
                                        
                function_tokens += 11 # Approximation for surrounding JSON structure
        
        num_tokens += function_tokens

    num_tokens += 12 # Approximation for the surrounding call message
    return num_tokens

if __name__ == "__main__":
    # Ensure a model is set for the test run, overriding the default in kwargs
    test_model = DEFAULT_MODEL 

    print(f"--- Testing chat_completion_rl with model: {test_model} ---")
    
    # This will fail unless OPENAI_API_KEY is actually set in the environment.
    # Since I cannot set an environment variable, this part is for your execution.
    try:
        response = chat_completion_rl(
                model=test_model,
                messages=[
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": "Who won the world series in 2020?"},
                        {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
                        {"role": "user", "content": "Where was it played?"}
                    ],
                max_tokens=5,
                temperature=0.9,
                top_p=1,
                frequency_penalty=0,
                presence_penalty=0,
                stop=["\n", "Human:", "AI:"],
                # Pass dummy retry params, as actual config is not available
                _retry_with_exponential_backoff__initial_delay=1,
                _retry_with_exponential_backoff__exponential_base=2,
                _retry_with_exponential_backoff__jitter=True,
                _retry_with_exponential_backoff__max_retries=3,
                _use_azure_api=False # Explicitly marking as non-Azure for clarity
                )
        print("Test Chat Completion Response:")
        print(response)
    except Exception as e:
        print(f"Chat Completion Test Failed (This is expected if OPENAI_API_KEY is not set): {e}")

    print("\n--- Testing replace_newlines utility ---")
    def test_replace_newlines():
        assert replace_newlines("Hello\nWorld") == "Hello World"
        assert replace_newlines(["Hello\nWorld", "Python\nRocks"]) == ["Hello World", "Python Rocks"]
        assert replace_newlines("No newline here") == "No newline here"
        assert replace_newlines(["No newline here", "Neither here"]) == ["No newline here", "Neither here"]
        assert replace_newlines("\n\n") == "  "
        assert replace_newlines(["\n", "\n\n"]) == [" ", "  "]
        print("All replace_newlines tests passed! ✅")
        
    test_replace_newlines()
    
    print("\n--- Testing num_tokens_from_messages utility ---")
    messages_test = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Who won the world series in 2020?"},
        {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
        {"role": "user", "content": "Where was it played?"}
    ]
    tokens = num_tokens_from_messages(messages_test, model=test_model)
    print(f"Token count for test messages ({test_model}): {tokens}")