import asyncio
import json
import math
import time
from typing import Dict, List, Optional, Union, Any
import random

import tiktoken
from json_repair import repair_json
from openai import (
    APIError,
    AsyncOpenAI,
    AuthenticationError,
    OpenAIError,
    RateLimitError,
    APITimeoutError,
    APIConnectionError,
)
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_random_exponential,
    before_sleep_log,
    after_log,
)
from src.utils.logsetup import logger

from src.config import config
from src.schema.message import ROLE_VALUES, Message
from src.utils.exceptions import TokenLimitExceeded
from src.utils.message_validator import message_validator

REASONING_MODELS = ["o1", "o3-mini", "o3", "o4-mini"]


class TokenCounter:
    # Token constants
    BASE_MESSAGE_TOKENS = 4
    FORMAT_TOKENS = 2
    LOW_DETAIL_IMAGE_TOKENS = 85
    HIGH_DETAIL_TILE_TOKENS = 170

    # Image processing constants
    MAX_SIZE = 2048
    HIGH_DETAIL_TARGET_SHORT_SIDE = 768
    TILE_SIZE = 512

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def count_text(self, text: str) -> int:
        """Calculate tokens for a text string"""
        return 0 if not text else len(self.tokenizer.encode(text))

    def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
        """Calculate tokens for high detail images based on dimensions"""
        # Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
        if width > self.MAX_SIZE or height > self.MAX_SIZE:
            scale = self.MAX_SIZE / max(width, height)
            width = int(width * scale)
            height = int(height * scale)

        # Step 2: Scale so the shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
        scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
        scaled_width = int(width * scale)
        scaled_height = int(height * scale)

        # Step 3: Count number of 512px tiles
        tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
        tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
        total_tiles = tiles_x * tiles_y

        # Step 4: Calculate final token count
        return total_tiles * self.HIGH_DETAIL_TILE_TOKENS + self.LOW_DETAIL_IMAGE_TOKENS

    def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
        """Calculate tokens for message content"""
        if not content:
            return 0

        if isinstance(content, str):
            return self.count_text(content)

        token_count = 0
        for item in content:
            if isinstance(item, str):
                token_count += self.count_text(item)
            elif isinstance(item, dict):
                if "text" in item:
                    token_count += self.count_text(item["text"])
                elif "image_url" in item:
                    token_count += self.count_image(item)

        return token_count

    def count_tool_calls(self, tool_calls: List[dict]) -> int:
        """Calculate tokens for tool calls"""
        token_count = 0
        for tool_call in tool_calls:
            if "function" in tool_call:
                function = tool_call["function"]
                token_count += self.count_text(function.get("name", ""))
                token_count += self.count_text(function.get("arguments", ""))
        return token_count

    def count_message_tokens(self, messages: List[dict]) -> int:
        """Calculate the total number of tokens in a message list"""
        total_tokens = self.FORMAT_TOKENS  # Base format tokens

        for message in messages:
            tokens = self.BASE_MESSAGE_TOKENS  # Base tokens per message

            # Add role tokens
            tokens += self.count_text(message.get("role", ""))

            # Add content tokens
            if "content" in message:
                tokens += self.count_content(message["content"])

            # Add tool calls tokens
            if "tool_calls" in message:
                tokens += self.count_tool_calls(message["tool_calls"])

            # Add name and tool_call_id tokens
            tokens += self.count_text(message.get("name", ""))
            tokens += self.count_text(message.get("tool_call_id", ""))

            total_tokens += tokens

        return total_tokens


class LLM:
    _instances: Dict[str, "LLM"] = {}

    def __new__(
            cls, config_name: str = "default"
    ):
        if config_name not in cls._instances:
            instance = super().__new__(cls)
            instance.__init__(config_name)
            cls._instances[config_name] = instance
        return cls._instances[config_name]

    def __init__(
            self, config_name: str = "default"
    ):
        if not hasattr(self, "client"):  # Only initialize if not already initialized
            llm_config = config.llm
            llm_config = llm_config.get(config_name, llm_config["default"])

            self.model = llm_config.model
            self.max_tokens = llm_config.max_tokens
            self.temperature = llm_config.temperature
            self.api_key = llm_config.api_key
            self.api_version = llm_config.api_version
            self.base_url = llm_config.base_url
            self.thinking_enabled = llm_config.thinking_enabled
            self.venus_platform_token = config.get_config(category="llm", key="venus_platform_token")
            
            # Enhanced timeout and retry settings
            self.default_timeout = getattr(llm_config, 'default_timeout', 300)  # 5 minutes
            self.max_timeout = getattr(llm_config, 'max_timeout', 900)  # 15 minutes
            self.max_retries = getattr(llm_config, 'max_retries', 6)
            self.base_retry_delay = getattr(llm_config, 'base_retry_delay', 1.0)
            self.max_retry_delay = getattr(llm_config, 'max_retry_delay', 60.0)

            # Add token counting related attributes
            self.total_input_tokens = 0
            self.total_output_tokens = 0
            self.max_input_tokens = (
                llm_config.max_input_tokens
                if hasattr(llm_config, "max_input_tokens")
                else None
            )

            logger.info(f"Initializing LLM, config_name: {config_name}, model: {self.model}")

            # Initialize tokenizer
            try:
                self.tokenizer = tiktoken.encoding_for_model(self.model)
            except KeyError:
                # If the model is not in tiktoken's presets, use cl100k_base as default
                self.tokenizer = tiktoken.get_encoding("cl100k_base")

            # Enhanced client with timeout settings
            self.client = AsyncOpenAI(
                api_key=self.api_key, 
                base_url=self.base_url,
                timeout=self.default_timeout
            )
            self.token_counter = TokenCounter(self.tokenizer)
            
            # Enhanced statistics
            self.request_stats = {
                'total_requests': 0,
                'successful_requests': 0,
                'failed_requests': 0,
                'timeout_errors': 0,
                'rate_limit_errors': 0,
                'api_errors': 0,
                'total_retry_attempts': 0,
                'avg_response_time': 0.0,
                'total_response_time': 0.0
            }

            # Add timing statistics structure
            self.timing_stats = {
                "ask": {
                    "total_time": 0.0,  # Total time
                    "call_count": 0,  # Number of calls
                    # Detailed statistics
                    "msg_processing_time": 0.0,  # Message processing time
                    "api_call_time": 0.0,  # API call time
                    "prefix_filling_time": 0.0,  # Prefix filling processing time (streaming requests)
                    "output_phase_time": 0.0  # Output phase time (streaming requests)
                },
                "ask_tool": {
                    "total_time": 0.0,
                    "call_count": 0,
                    "msg_processing_time": 0.0,
                    "api_call_time": 0.0,
                    "tool_processing_time": 0.0  # Tool-related processing time
                },
                "ask_json": {
                    "total_time": 0.0,
                    "call_count": 0,
                    "json_parsing_time": 0.0  # JSON parsing time
                },
                "ask_with_files": {
                    "total_time": 0.0,
                    "call_count": 0,
                    "msg_processing_time": 0.0,
                    "file_processing_time": 0.0,  # File processing time
                    "api_call_time": 0.0
                }
            }

            # Add tool timing statistics
            self.tool_timing_stats = {}  # Format: {tool_name: {"total_time": 0.0, "call_count": 0}}

    def count_tokens(self, text: str) -> int:
        """Calculate the number of tokens in a text"""
        if not text:
            return 0
        return len(self.tokenizer.encode(text))

    def count_message_tokens(self, messages: List[dict]) -> int:
        return self.token_counter.count_message_tokens(messages)

    def update_token_count(self, input_tokens: int, output_tokens: int) -> None:
        """Update token counts"""
        # Only track tokens if max_input_tokens is set
        self.total_input_tokens += input_tokens
        self.total_output_tokens += output_tokens
        logger.info(
            f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}\
 Output={output_tokens}, Cumulative Output={self.total_output_tokens}"
        )

    def get_total_input_tokens(self) -> int:
        """Get total input tokens"""
        return self.total_input_tokens
    
    def get_total_output_tokens(self) -> int:
        """Get total output tokens"""
        return self.total_output_tokens
    
    def get_total_tokens(self) -> int:
        """Get total tokens (input + output)"""
        return self.total_input_tokens + self.total_output_tokens
    
    def get_token_usage_summary(self) -> Dict[str, int]:
        """Get token usage summary"""
        return {
            "total_input_tokens": self.total_input_tokens,
            "total_output_tokens": self.total_output_tokens,
            "total_tokens": self.total_input_tokens + self.total_output_tokens,
            "max_input_tokens": self.max_input_tokens,
            "remaining_input_tokens": (
                self.max_input_tokens - self.total_input_tokens 
                if self.max_input_tokens is not None else None
            )
        }
    
    def reset_token_counts(self) -> None:
        """Reset token counts"""
        logger.info(
            f"Resetting token counts. Previous totals - Input: {self.total_input_tokens}, "
            f"Output: {self.total_output_tokens}, Total: {self.get_total_tokens()}"
        )
        self.total_input_tokens = 0
        self.total_output_tokens = 0
    
    def print_token_usage(self) -> None:
        """Print token usage statistics"""
        usage = self.get_token_usage_summary()
        print("\n📊 Token Usage Statistics:")
        print(f"   Input Tokens:  {usage['total_input_tokens']:,}")
        print(f"   Output Tokens: {usage['total_output_tokens']:,}")
        print(f"   Total Tokens:  {usage['total_tokens']:,}")
        
        if usage['max_input_tokens'] is not None:
            print(f"   Input Limit:   {usage['max_input_tokens']:,}")
            if usage['remaining_input_tokens'] is not None:
                print(f"   Remaining:     {usage['remaining_input_tokens']:,}")
                usage_percent = (usage['total_input_tokens'] / usage['max_input_tokens']) * 100
                print(f"   Usage Rate:    {usage_percent:.1f}%")
    
    def get_task_token_summary(self) -> str:
        """Get current task token usage summary (simple version)"""
        total_input = self.total_input_tokens
        total_output = self.total_output_tokens
        total_tokens = total_input + total_output
        
        return f"📊 Task Token Summary: Input={total_input:,} | Output={total_output:,} | Total={total_tokens:,}"
    
    def _calculate_timeout(self, attempt: int, base_timeout: Optional[float] = None) -> float:
        """Calculate timeout with exponential increase for retries"""
        if base_timeout is None:
            base_timeout = self.default_timeout
            
        # Exponential timeout increase: base * (1.5 ^ attempt)
        timeout = base_timeout * (1.5 ** attempt)
        return min(timeout, self.max_timeout)

    def _calculate_retry_delay(self, attempt: int) -> float:
        """Calculate retry delay with exponential backoff and jitter"""
        delay = self.base_retry_delay * (2 ** attempt)
        delay = min(delay, self.max_retry_delay)
        # Add jitter to prevent thundering herd
        jitter = random.uniform(0, delay * 0.1)
        return delay + jitter
    
    def _validate_tool_conversation(self, messages: List[dict]) -> List[dict]:
        """Validate and clean tool conversation history to ensure tool_use and tool_result match"""
        cleaned_messages = []
        tool_calls_in_progress = {}
        
        for msg in messages:
            msg_copy = msg.copy()
            
            # Handle tool_calls in assistant messages
            if msg.get('role') == 'assistant' and 'tool_calls' in msg:
                # Record all tool call IDs
                for tool_call in msg['tool_calls']:
                    if 'id' in tool_call:
                        tool_calls_in_progress[tool_call['id']] = True
                cleaned_messages.append(msg_copy)
            
            # Handle tool messages, validate if tool_call_id is valid
            elif msg.get('role') == 'tool':
                tool_call_id = msg.get('tool_call_id')
                if tool_call_id and tool_call_id in tool_calls_in_progress:
                    # Valid tool result, add to messages and remove record
                    cleaned_messages.append(msg_copy)
                    del tool_calls_in_progress[tool_call_id]
                else:
                    # Invalid tool result, log warning but don't add to messages
                    logger.warning(f"Found orphaned tool result, tool_call_id: {tool_call_id}, removed")
            
            # Handle other messages, especially check for tool_use_id issues in content
            else:
                # Clean possible tool_use_id references in message content
                if 'content' in msg_copy:
                    content = msg_copy['content']
                    if isinstance(content, list):
                        # Check each element in content array
                        cleaned_content = []
                        for item in content:
                            if isinstance(item, dict):
                                # Remove tool_use_id field that might cause ValidationException
                                item_copy = item.copy()
                                if 'tool_use_id' in item_copy:
                                    logger.warning(f"Removed tool_use_id from content: {item_copy.get('tool_use_id')}")
                                    del item_copy['tool_use_id']
                                cleaned_content.append(item_copy)
                            else:
                                cleaned_content.append(item)
                        msg_copy['content'] = cleaned_content
                
                cleaned_messages.append(msg_copy)
        
        # Check for incomplete tool calls
        if tool_calls_in_progress:
            logger.warning(f"Found incomplete tool calls: {list(tool_calls_in_progress.keys())}")
        
        return cleaned_messages

    def check_token_limit(self, input_tokens: int) -> bool:
        """Check if token limits are exceeded"""
        if self.max_input_tokens is not None:
            return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
        # If max_input_tokens is not set, always return True
        return True

    def get_limit_error_message(self, input_tokens: int) -> str:
        """Generate error message for token limit exceeded"""
        if (
                self.max_input_tokens is not None
                and (self.total_input_tokens + input_tokens) > self.max_input_tokens
        ):
            return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"

        return "Token limit exceeded"

    @staticmethod
    def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
        """
        Format messages for LLM by converting them to OpenAI message format.

        Args:
            messages: List of messages that can be either dict or Message objects

        Returns:
            List[dict]: List of formatted messages in OpenAI format

        Raises:
            ValueError: If messages are invalid or missing required fields
            TypeError: If unsupported message types are provided

        Examples:
            >>> msgs = [
            ...     Message.system_message("You are a helpful assistant"),
            ...     {"role": "user", "content": "Hello"},
            ...     Message.user_message("How are you?")
            ... ]
            >>> formatted = LLM.format_messages(msgs)
        """
        formatted_messages = []

        for message in messages:
            # Convert Message objects to dictionaries
            if isinstance(message, Message):
                message = message.to_dict()

            if not isinstance(message, dict):
                raise TypeError(f"Unsupported message type: {type(message)}")

            # Validate required fields
            if "role" not in message:
                raise ValueError("Message dict must contain 'role' field")
            if message.get('cache', False):
                if isinstance(message["content"], str):
                    message["content"] = [
                        {"type": "text", "text": message["content"]}]
                elif isinstance(message["content"], list):
                    # Convert string items to proper text objects
                    message["content"] = [
                        (
                            {"type": "text", "text": item}
                            if isinstance(item, str) else item
                        )
                        for item in message["content"]
                    ]
                message['content'][-1]['cache_control'] = {"type": "ephemeral"}
                message.pop('cache')

            # Only include messages with content or tool_calls
            if "content" in message or "tool_calls" in message:
                formatted_messages.append(message)

        # Validate all roles
        invalid_roles = [
            msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
        ]
        if invalid_roles:
            raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")

        return formatted_messages

    async def create_completions(self, params):
        if self.thinking_enabled is not None:
            extra_body = params.get("extra_body", {})
            extra_body.update({'thinking_enabled': self.thinking_enabled})
            params['extra_body'] = extra_body

        # Get or initialize extra_headers
        extra_headers = params.get("extra_headers", {})

        # Update extra_headers in params
        params['extra_headers'] = extra_headers

        return await self.client.chat.completions.create(**params)

    @retry(
        wait=wait_random_exponential(min=6, max=60),
        stop=stop_after_attempt(6),
        retry=retry_if_exception_type(
            (OpenAIError, Exception, ValueError)
        ),  # Don't retry TokenLimitExceeded
    )
    async def ask(
            self,
            messages: List[Union[dict, Message]],
            system_msgs: Optional[List[Union[dict, Message]]] = None,
            temperature: Optional[float] = None,
    ):
        """
        Ask LLM a question and get the response.

        Args:
            messages: List of conversation messages
            system_msgs: Optional system messages to prepend
            temperature (float): Sampling temperature for the response

        Returns:
            str: The model's response

        Raises:
            TokenLimitExceeded: If token limits are exceeded
            ValueError: If messages are invalid or response is empty
            OpenAIError: If API call fails after retries
            Exception: For unexpected errors
        """
        # Start timing
        start_time = time.time()
        msg_processing_time = 0.0
        api_call_time = 0.0

        try:
            # Format messages
            msg_processing_start = time.time()
            if system_msgs:
                system_msgs = self.format_messages(system_msgs)
                messages = system_msgs + self.format_messages(messages)
            else:
                messages = self.format_messages(messages)

            # Calculate tokens and check limits
            input_tokens = self.count_message_tokens(messages)
            if not self.check_token_limit(input_tokens):
                raise TokenLimitExceeded(
                    self.get_limit_error_message(input_tokens))
            msg_processing_time = time.time() - msg_processing_start

            # Non-streaming request
            api_call_start = time.time()

            # Set request parameters
            params = {
                "model": self.model,
                "messages": messages,
                "stream": False
            }

            # Set different parameters based on model type
            if self.model in REASONING_MODELS:
                params["max_completion_tokens"] = self.max_tokens
            else:
                params["max_tokens"] = self.max_tokens
                params["temperature"] = (
                    temperature if temperature is not None else self.temperature
                )

            # Enhanced message validation - includes tool conversation validation
            messages = self._validate_tool_conversation(messages)
            
            validation_result = message_validator.validate_messages_before_api_call(messages)
            if not validation_result.is_valid:
                logger.error(f"🚫 Message validation failed: {validation_result.error_details}")
                # Try to create fallback message
                fallback_msg = message_validator.create_fallback_message("general")
                messages = [fallback_msg]
                # Update params with the fixed messages
                params["messages"] = messages
                logger.warning(f"🔄 Using fallback message to replace invalid content")
            elif validation_result.fixes_applied:
                logger.info(f"🔧 Messages cleaned, applied fixes: {validation_result.fixes_applied}")
                messages = validation_result.cleaned_content
                # Update params with the cleaned messages
                params["messages"] = messages

            response = await self.create_completions(params)
            api_call_time = time.time() - api_call_start

            # Handle response
            if not response.choices or not response.choices[0].message.content:
                raise ValueError("Empty or invalid response from LLM")

            # Update token statistics using usage field
            if hasattr(response, "usage") and response.usage:
                self.update_token_count(
                    response.usage.prompt_tokens,
                    response.usage.completion_tokens
                )
            else:
                logger.warning("Response does not contain usage information")

            return response.choices[0].message.content

        except ValueError as ve:
            logger.error(f"Validation error: {ve}")
            raise
        except OpenAIError as oe:
            logger.error(f"OpenAI API error: {oe}")
            logger.error(f"Request: {params}")
            
            # Update error statistics
            if isinstance(oe, RateLimitError):
                self.request_stats['rate_limit_errors'] += 1
            elif isinstance(oe, (APITimeoutError, APIConnectionError)):
                self.request_stats['timeout_errors'] += 1
            elif isinstance(oe, APIError):
                self.request_stats['api_errors'] += 1
            
            if isinstance(oe, AuthenticationError):
                logger.error("Authentication failed. Check API key.")
            elif isinstance(oe, RateLimitError):
                logger.error(
                    "Rate limit exceeded. Consider increasing retry attempts.")
            elif isinstance(oe, APIError):
                logger.error(f"API error: {oe}")
            raise
        except Exception as e:
            logger.error(f"Unexpected error in ask: {e}")
            raise
        finally:
            # Calculate total elapsed time
            elapsed_time = time.time() - start_time

            # Update timing statistics
            self.timing_stats["ask"]["total_time"] += elapsed_time
            self.timing_stats["ask"]["call_count"] += 1
            self.timing_stats["ask"]["msg_processing_time"] += msg_processing_time
            self.timing_stats["ask"]["api_call_time"] += api_call_time

            # Log detailed time distribution
            log_message = f"⏱️ LLM ask() execution time: {elapsed_time:.2f} seconds"
            if msg_processing_time > 0:
                log_message += f" (message processing: {msg_processing_time:.2f}s"
            if api_call_time > 0:
                if msg_processing_time > 0:
                    log_message += f", API call: {api_call_time:.2f}s"
                else:
                    log_message += f" (API call: {api_call_time:.2f}s"
            log_message += ")"
            logger.info(log_message)
    
    def get_stats(self) -> Dict[str, Any]:
        """Get comprehensive statistics"""
        stats = self.request_stats.copy()
        stats.update({
            'token_usage': {
                'total_input_tokens': self.total_input_tokens,
                'total_output_tokens': self.total_output_tokens,
                'total_tokens': self.total_input_tokens + self.total_output_tokens,
            },
            'success_rate': (
                stats['successful_requests'] / stats['total_requests'] * 100
                if stats['total_requests'] > 0 else 0
            ),
            'avg_retries_per_request': (
                stats['total_retry_attempts'] / stats['total_requests']
                if stats['total_requests'] > 0 else 0
            )
        })
        return stats

    def print_stats(self):
        """Print comprehensive statistics"""
        stats = self.get_stats()
        print("\n" + "="*50)
        print("📊 Enhanced LLM Statistics")
        print("="*50)
        print(f"Total Requests: {stats['total_requests']}")
        print(f"Successful: {stats['successful_requests']}")
        print(f"Failed: {stats['failed_requests']}")
        print(f"Success Rate: {stats['success_rate']:.2f}%")
        print(f"Avg Response Time: {stats['avg_response_time']:.2f}s")
        print(f"Avg Retries/Request: {stats['avg_retries_per_request']:.2f}")
        print(f"\nError Breakdown:")
        print(f"  Timeouts: {stats['timeout_errors']}")
        print(f"  Rate Limits: {stats['rate_limit_errors']}")
        print(f"  API Errors: {stats['api_errors']}")
        print(f"\nToken Usage:")
        print(f"  Input: {stats['token_usage']['total_input_tokens']:,}")
        print(f"  Output: {stats['token_usage']['total_output_tokens']:,}")
        print(f"  Total: {stats['token_usage']['total_tokens']:,}")
        print("="*50)

    async def ask_json(
            self,
            messages: List[Union[dict, Message]],
            system_msgs: Optional[List[Union[dict, Message]]] = None,
            temperature: Optional[float] = None,
            return_str: bool = False,
    ):
        # Start timing
        start_time = time.time()
        ask_elapsed_time = 0
        try:
            # Call ask method
            ask_start_time = time.time()
            response_text = await self.ask(messages, system_msgs, temperature)
            ask_elapsed_time = time.time() - ask_start_time

            # Try to parse JSON response
            try:
                # Clean response text, remove possible non-JSON prefixes or suffixes
                cleaned_text = response_text.strip()

                # If response contains ```json and ``` markers, extract the content
                if "```json" in cleaned_text and "```" in cleaned_text:
                    start = cleaned_text.find("```json") + 7
                    end = cleaned_text.find("```", start)
                    if end > start:
                        cleaned_text = cleaned_text[start:end].strip()
                # If only ``` markers (no language specified), also extract the content
                elif cleaned_text.startswith("```") and cleaned_text.endswith("```"):
                    cleaned_text = cleaned_text[3:-3].strip()

                # Parse JSON
                json_response = repair_json(cleaned_text)

                if return_str:
                    return json_response
                
                json_response = json.loads(json_response)
                return json_response
            except json.JSONDecodeError as e:
                logger.error(f"JSON parsing error: {e}")
                logger.error(f"Original response: {response_text}")

                return {}
        finally:
            # Calculate elapsed time and update statistics (excluding ask method time)
            elapsed_time = time.time() - start_time
            # Only calculate time for additional operations like JSON parsing
            json_only_time = elapsed_time - ask_elapsed_time
            self.timing_stats["ask_json"]["total_time"] += json_only_time
            self.timing_stats["ask_json"]["call_count"] += 1
            logger.info(
                f"⏱️ LLM ask_json() execution time: {json_only_time:.2f} seconds (excluding ask call)")

    @retry(
        wait=wait_random_exponential(min=6, max=60),
        stop=stop_after_attempt(6),
        retry=retry_if_exception_type(
            (OpenAIError, Exception, ValueError)
        ),  # Don't retry TokenLimitExceeded
    )
    async def ask_tool(
            self,
            messages: List[Union[dict, Message]],
            system_msgs: Optional[List[Union[dict, Message]]] = None,
            timeout: int = 300,
            tools: Optional[List[dict]] = None,
            temperature: Optional[float] = None,
            **kwargs,
    ):
        # Start timing
        start_time = time.time()
        tool_names = []  # Record tool names for logging
        msg_processing_time = 0.0
        api_call_time = 0.0
        tool_processing_time = 0.0

        try:
            # Validate tool_choice
            tool_processing_start = time.time()
            tool_processing_time += time.time() - tool_processing_start

            # Format messages
            msg_processing_start = time.time()
            if system_msgs:
                system_msgs = self.format_messages(system_msgs)
                messages = system_msgs + self.format_messages(messages)
            else:
                messages = self.format_messages(messages)

            # Calculate input token count
            input_tokens = self.count_message_tokens(messages)

            # If there are tools, calculate token count for tool descriptions
            tools_tokens = 0
            tool_processing_start = time.time()
            if tools:
                for tool in tools:
                    tools_tokens += self.count_tokens(str(tool))
                    # Extract tool name (if available)
                    if "function" in tool and "name" in tool["function"]:
                        tool_names.append(tool["function"]["name"])

            input_tokens += tools_tokens
            tool_processing_time += time.time() - tool_processing_start
            msg_processing_time = time.time() - msg_processing_start

            # Check if token limits are exceeded
            if not self.check_token_limit(input_tokens):
                error_message = self.get_limit_error_message(input_tokens)
                # Raise a special exception that won't be retried
                raise TokenLimitExceeded(error_message)

            # Validate tools if provided
            tool_processing_start = time.time()
            if tools:
                for tool in tools:
                    if not isinstance(tool, dict) or "type" not in tool:
                        raise ValueError(
                            "Each tool must be a dict with 'type' field")
            tool_processing_time += time.time() - tool_processing_start

            # Set up the completion request
            params = {
                "model": self.model,
                "messages": messages,
                "tools": tools,
                "timeout": timeout,
                **kwargs,
            }

            # DEBUG: Print tools content
            if tools:
                logger.info(f"🔧 DEBUG: Number of tools passed to API: {len(tools)}")

            if self.model in REASONING_MODELS:
                params["max_completion_tokens"] = self.max_tokens
            else:
                params["max_tokens"] = self.max_tokens
                params["temperature"] = (
                    temperature if temperature is not None else self.temperature
                )

            # Execute API call
            # Enhanced message validation - includes tool conversation validation
            messages = self._validate_tool_conversation(messages)
            
            # Update params with the tool-validated messages first
            params["messages"] = messages
            
            validation_result = message_validator.validate_messages_before_api_call(messages)
            if not validation_result.is_valid:
                logger.error(f"🚫 Message validation failed: {validation_result.error_details}")
                # Try to create fallback message
                fallback_msg = message_validator.create_fallback_message("tool call")
                messages = [fallback_msg]
                # Update params with the fixed messages
                params["messages"] = messages
                logger.warning(f"🔄 Using fallback message to replace invalid content")
            elif validation_result.fixes_applied:
                logger.info(f"🔧 Messages cleaned, applied fixes: {validation_result.fixes_applied}")
                messages = validation_result.cleaned_content
                # Update params with the cleaned messages
                params["messages"] = messages
            
            api_call_start = time.time()
            response = await self.create_completions(params)
            api_call_time = time.time() - api_call_start

            # Check if response is valid
            if not response.choices or not response.choices[0].message:
                print(response)
                raise ValueError("Invalid or empty response from LLM")

            # Update token counts
            self.update_token_count(
                response.usage.prompt_tokens, response.usage.completion_tokens)

            # Analyze and count tool calls
            tool_processing_start = time.time()
            response_message = response.choices[0].message
            used_tools = []
            
            if not response_message.content:
                response_message.content = ''

            if hasattr(response_message, "tool_calls") and response_message.tool_calls:
                for tool_call in response_message.tool_calls:
                    if hasattr(tool_call, "function") and tool_call.function.name:
                        tool_name = tool_call.function.name
                        used_tools.append(tool_name)

                        # Update tool call statistics
                        if tool_name not in self.tool_timing_stats:
                            self.tool_timing_stats[tool_name] = {
                                "total_time": 0.0, "call_count": 0}

                        self.tool_timing_stats[tool_name]["call_count"] += 1
                        # Since we cannot accurately measure actual consumption time for individual tools, allocate proportional time here
                        # Distribute API call time evenly among called tools
                        if len(used_tools) > 0:
                            tool_time = api_call_time / len(used_tools)
                            self.tool_timing_stats[tool_name]["total_time"] += tool_time
            tool_processing_time += time.time() - tool_processing_start

            return response_message

        except TokenLimitExceeded:
            # Re-raise token limit errors without logging
            raise
        except ValueError as ve:
            logger.error(f"Validation error in ask_tool: {ve}")
            raise
        except OpenAIError as oe:
            logger.error(f"OpenAI API error: {oe}")
            logger.error(f"Request: {params}")
            
            # Update error statistics
            if isinstance(oe, RateLimitError):
                self.request_stats['rate_limit_errors'] += 1
            elif isinstance(oe, (APITimeoutError, APIConnectionError)):
                self.request_stats['timeout_errors'] += 1
            elif isinstance(oe, APIError):
                self.request_stats['api_errors'] += 1
                
            if isinstance(oe, AuthenticationError):
                logger.error("Authentication failed. Check API key.")
            elif isinstance(oe, RateLimitError):
                logger.error(
                    "Rate limit exceeded. Consider increasing retry attempts.")
            elif isinstance(oe, APIError):
                logger.error(f"API error: {oe}")
            raise
        except Exception as e:
            logger.error(f"Unexpected error in ask_tool: {e}")
            raise
        finally:
            # Calculate total elapsed time
            elapsed_time = time.time() - start_time

            # Update timing statistics
            self.timing_stats["ask_tool"]["total_time"] += elapsed_time
            self.timing_stats["ask_tool"]["call_count"] += 1
            self.timing_stats["ask_tool"]["msg_processing_time"] += msg_processing_time
            self.timing_stats["ask_tool"]["api_call_time"] += api_call_time
            self.timing_stats["ask_tool"]["tool_processing_time"] += tool_processing_time

            # Log detailed time distribution
            log_message = f"⏱️ LLM ask_tool() execution time: {elapsed_time:.2f} seconds"

            if msg_processing_time > 0:
                log_message += f" (message processing: {msg_processing_time:.2f}s"

            if api_call_time > 0:
                if msg_processing_time > 0:
                    log_message += f", API call: {api_call_time:.2f}s"
                else:
                    log_message += f" (API call: {api_call_time:.2f}s"

            if tool_processing_time > 0:
                log_message += f", tool processing: {tool_processing_time:.2f}s"

            # Add tool usage information
            if 'used_tools' in locals() and used_tools:
                tools_str = ", ".join(used_tools)
                log_message += f", used tools: {tools_str}"
            elif tool_names:
                available_tools_str = ", ".join(tool_names)
                log_message += f", available tools: {available_tools_str} (unused)"

            log_message += ")"
            logger.info(log_message)
