"""
chat.py - A client implementation for various AI chat models using the litellm library.

This module provides a simplified ChatClient that offers a unified interface
for interacting with a wide range of Large Language Models (LLMs),
such as those from OpenAI, Anthropic, Google, and more. It handles API key
management by loading environment variables from a .env file and provides
cost tracking for API usage.

This implementation aims for simplicity, elegance, and modularity, leveraging
litellm to abstract away the complexities of individual API providers.
"""

import os
from typing import List, Dict, Optional, Any, Union
import asyncio
import time

import litellm
from dotenv import load_dotenv
from openai import OpenAI

def load_environment_variables():
    """
    Load environment variables from a .env file located in the parent directory of this Python file.
    
    This function searches for a .env file in the parent directory of the current Python file
    and loads the environment variables from it.
    """
    try:
        # Get the parent directory of the directory where this Python file is located
        current_dir = os.path.dirname(os.path.abspath(__file__))
        parent_dir = os.path.dirname(current_dir)
        dotenv_path = os.path.join(parent_dir, '.env')
        
        # Load environment variables from the .env file
        load_dotenv(dotenv_path)
    except Exception:
        # Fallback for environments where __file__ is not defined (e.g., interactive sessions)
        load_dotenv()

# --- Environment Variable Loading ---
# Load environment variables from a .env file located in the parent directory of this file.
load_environment_variables()

# --- LiteLLM Configuration ---
# Set to True to see detailed logs from litellm, useful for debugging.
litellm.set_verbose = False  
litellm.drop_params = True
# litellm._turn_on_debug()

class ChatResponse:
    """Response object that mimics OpenAI's chat completion response format."""
    
    def __init__(self, content: str = None, tool_calls: List = None):
        self.content = content
        self.tool_calls = tool_calls or []
        
    def get(self, key: str, default=None):
        """Allow dict-like access for backward compatibility."""
        return getattr(self, key, default)

class CloudChatClient:
    """
    A client for interacting with cloud-based AI chat models via litellm.
    
    Supports function calling and async operations.
    """

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        max_tokens: Optional[int] = None,
        max_completion_tokens: Optional[int] = None,
        api_key: Optional[str] = None,
        **kwargs: Any,
    ):
        """
        Initializes the CloudChatClient with specified model and parameters.

        Args:
            model (str): The name of the model to use (e.g., 'openai/gpt-4o', 'claude-3-sonnet-20240229').
                         litellm will automatically route the request.
            temperature (float): The sampling temperature for generation (0.0 to 2.0).
            max_tokens (Optional[int]): The maximum number of tokens for the model's response.
            max_completion_tokens (Optional[int]): Alternative to max_tokens for certain providers.
            api_key (Optional[str]): Explicitly provide an API key. If None, litellm
                                     searches for the relevant environment variable.
            **kwargs (Any): Additional keyword arguments to be passed directly to
                            the `litellm.completion` call.
        """
        self.model = model
        self.temperature = temperature
        
        # Token limit handling: if both are set, prioritize max_tokens
        if max_tokens is not None:
            self.max_tokens = max_tokens
            self.max_completion_tokens = None
        elif max_completion_tokens is not None:
            self.max_tokens = None
            self.max_completion_tokens = max_completion_tokens
        else:
            self.max_tokens = None
            self.max_completion_tokens = None
            
        self.api_key = api_key
        self.extra_params = kwargs
        self.total_cost = 0.0

    def chat(
        self, 
        messages: List[Dict[str, str]], 
        tools: Optional[List[Dict]] = None,
        tool_choice: Optional[Union[str, Dict]] = None,
        **kwargs: Any
    ) -> Union[str, ChatResponse]:
        """
        Sends a list of messages to the configured LLM and returns the response.
        
        Args:
            messages: The conversation history
            tools: List of tools/functions available for the model to call
            tool_choice: How the model should choose tools ("auto", "none", or specific tool)
            **kwargs: Additional parameters
            
        Returns:
            ChatResponse object if tools are provided, otherwise string content
        """
        max_retries = 3
        for attempt in range(max_retries):
            try:
                return self._chat_with_litellm(messages, tools=tools, tool_choice=tool_choice, **kwargs)
            except Exception as e:
                print(f"Attempt {attempt + 1}/{max_retries} failed: {e}")
                if attempt < max_retries - 1:
                    wait_time = 2 ** attempt  # Exponential backoff: 1s, 2s, 4s
                    print(f"Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                else:
                    print("All retry attempts failed.")
                    return None

    async def chat_async(
        self, 
        messages: List[Dict[str, str]], 
        tools: Optional[List[Dict]] = None,
        tool_choice: Optional[Union[str, Dict]] = None,
        **kwargs: Any
    ) -> Union[str, ChatResponse]:
        """Async version of chat method."""
        # Run the synchronous chat method in a thread pool
        return await asyncio.get_event_loop().run_in_executor(
            None, 
            lambda: self.chat(messages, tools=tools, tool_choice=tool_choice, **kwargs)
        )

    def _chat_with_litellm(
        self, 
        messages: List[Dict[str, str]], 
        tools: Optional[List[Dict]] = None,
        tool_choice: Optional[Union[str, Dict]] = None,
        **kwargs: Any
    ) -> Union[str, ChatResponse]:
        """Sends messages using litellm for cloud providers."""
        # Combine instance-level parameters with call-specific overrides
        params = {
            "model": self.model,
            "messages": messages,
            "temperature": self.temperature,
            "logprobs": False,
            **self.extra_params,
            **kwargs,
        }
        
        # Add API key if provided
        if self.api_key:
            params["api_key"] = self.api_key
        
        # Add token limit parameters only if set
        if self.max_tokens is not None:
            params["max_tokens"] = self.max_tokens
        elif self.max_completion_tokens is not None:
            params["max_completion_tokens"] = self.max_completion_tokens

        # Add tools if provided
        if tools:
            params["tools"] = tools
            if tool_choice:
                params["tool_choice"] = tool_choice

        print(f"\n---> Calling model: {self.model}")
        response = litellm.completion(**params)

        # Calculate and accumulate the cost of the call
        cost = litellm.completion_cost(completion_response=response)
        if cost is not None:
            self.total_cost += cost
            print(f"Cost for this call: ${cost:.6f}")
            print(f"Total cost: ${self.total_cost:.6f}")

        print("<--- Model response received successfully.")
        
        # Extract response data
        message = response.choices[0].message
        
        # Handle all tool call scenarios correctly
        if tools:
            # When tools are provided, always return ChatResponse
            if hasattr(message, 'tool_calls') and message.tool_calls:
                # Model made tool calls
                return ChatResponse(
                    content=message.content,
                    tool_calls=message.tool_calls
                )
            else:
                # Model didn't make tool calls but tools were available
                return ChatResponse(content=message.content)
        else:
            # No tools provided, return string content directly
            return message.content

class LocalChatClient:
    """
    A client for interacting with local AI models using OpenAI client (e.g., VLLM servers).
    """

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        max_tokens: Optional[int] = None,
        api_key: Optional[str] = None,
        base_url: str = "http://localhost:8000/v1",
        **kwargs: Any,
    ):
        """
        Initializes the LocalChatClient with specified model and parameters.

        Args:
            model (str): The name of the model to use on the local server.
            temperature (float): The sampling temperature for generation (0.0 to 2.0).
            max_tokens (Optional[int]): The maximum number of tokens for the model's response.
            api_key (Optional[str]): API key (VLLM typically doesn't require a real one).
            base_url (str): The base URL of the local server.
            **kwargs (Any): Additional keyword arguments.
        """
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.base_url = base_url
        self.extra_params = kwargs
        
        # Initialize OpenAI client for local deployment
        self.openai_client = OpenAI(
            api_key=api_key or "EMPTY",  # Local servers typically don't require a real API key
            base_url=base_url
        )
        print(f"Initialized OpenAI client for local server with base_url: {base_url}")

    def chat(
        self, 
        messages: List[Dict[str, str]], 
        tools: Optional[List[Dict]] = None,
        tool_choice: Optional[Union[str, Dict]] = None,
        **kwargs: Any
    ) -> Union[str, ChatResponse]:
        """
        Sends a list of messages to the local model and returns the response.
        
        Args:
            messages: The conversation history
            tools: List of tools/functions available for the model to call
            tool_choice: How the model should choose tools ("auto", "none", or specific tool)
            **kwargs: Additional parameters
            
        Returns:
            ChatResponse object if tools are provided, otherwise string content
        """
        max_retries = 3
        for attempt in range(max_retries):
            try:
                return self._chat_with_openai(messages, tools=tools, tool_choice=tool_choice, **kwargs)
            except Exception as e:
                print(f"Attempt {attempt + 1}/{max_retries} failed: {e}")
                if attempt < max_retries - 1:
                    wait_time = 2 ** attempt  # Exponential backoff: 1s, 2s, 4s
                    print(f"Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                else:
                    print("All retry attempts failed.")
                    return None

    async def chat_async(
        self, 
        messages: List[Dict[str, str]], 
        tools: Optional[List[Dict]] = None,
        tool_choice: Optional[Union[str, Dict]] = None,
        **kwargs: Any
    ) -> Union[str, ChatResponse]:
        """Async version of chat method."""
        # Run the synchronous chat method in a thread pool
        return await asyncio.get_event_loop().run_in_executor(
            None, 
            lambda: self.chat(messages, tools=tools, tool_choice=tool_choice, **kwargs)
        )

    def _chat_with_openai(
        self, 
        messages: List[Dict[str, str]], 
        tools: Optional[List[Dict]] = None,
        tool_choice: Optional[Union[str, Dict]] = None,
        **kwargs: Any
    ) -> Union[str, ChatResponse]:
        """Sends messages using OpenAI client for local server."""
        print(f"\n---> Calling local model: {self.model}")
        
        # Prepare parameters for OpenAI client
        params = {
            "model": self.model,
            "messages": messages,
            "temperature": self.temperature,
            **self.extra_params,
            **kwargs,
        }
        
        # Add token limit parameter if set
        if self.max_tokens is not None:
            params["max_tokens"] = self.max_tokens
        
        # Add tools if provided
        if tools:
            params["tools"] = tools
            if tool_choice:
                params["tool_choice"] = tool_choice
        
        # Remove None values
        params = {k: v for k, v in params.items() if v is not None}
        
        response = self.openai_client.chat.completions.create(**params)
        
        print("<--- Local server response received successfully.")
        
        # Extract response data
        message = response.choices[0].message
        
        # Handle all tool call scenarios correctly
        if tools:
            # When tools are provided, always return ChatResponse
            if hasattr(message, 'tool_calls') and message.tool_calls:
                # Model made tool calls
                return ChatResponse(
                    content=message.content,
                    tool_calls=message.tool_calls
                )
            else:
                # Model didn't make tool calls but tools were available
                return ChatResponse(content=message.content)
        else:
            # No tools provided, return string content directly
            return message.content

# Backward compatibility alias
ChatClient = CloudChatClient
