import os
from typing import Optional, Dict, Any, List, Union
from dataclasses import dataclass
from omegaconf import DictConfig
import logging
import asyncio
import time
from omegaconf import OmegaConf

# Import timeout exception
from asyncio import TimeoutError as AsyncTimeoutError

# Default timeout for LLM API calls in seconds
DEFAULT_LLM_API_TIMEOUT = 60.0

@dataclass
class Message:
    role: str
    content: str

class Conversation:
    def __init__(self):
        self.messages: List[Message] = []
    
    def add_message(self, role: str, content: str):
        self.messages.append(Message(role=role, content=content))
    
    def get_messages(self) -> List[Message]:
        return self.messages
    
    def clear(self):
        self.messages = []

class LLMCaller:
    def __init__(self, model_config: Optional[DictConfig] = None, secrets_dir: str = "secrets", logger=None):
        self.logger = logger or logging.getLogger(__name__)
        # Get the directory containing this file
        current_file_dir = os.path.dirname(os.path.abspath(__file__))
        # Go up one level to project root
        project_root = os.path.dirname(current_file_dir)
        # Use project root for secrets directory
        self.secrets_dir = os.path.join(os.path.dirname(project_root), secrets_dir)
        self.model_config = model_config
        
        # Get timeout from config or use default
        # Access timeout from model_config if it exists
        self.request_timeout = DEFAULT_LLM_API_TIMEOUT # Default
        if self.model_config and hasattr(self.model_config, 'request_timeout'):
            config_timeout = self.model_config.request_timeout
            if isinstance(config_timeout, (int, float)) and config_timeout > 0:
                self.request_timeout = float(config_timeout)
            else:
                self.logger.warning(f"Invalid request_timeout in model_config: {config_timeout}. Using default {self.request_timeout}s.")
        
        self.logger.info(f"LLMCaller initialized with request timeout: {self.request_timeout} seconds")

        # Initialize API keys from secrets directory
        self._init_api_keys()
        
        # Initialize conversations
        self.conversations: Dict[str, Conversation] = {}
    
    def _init_api_keys(self):
        """Initialize OpenAI API key from secrets directory"""
        try:
            with open(os.path.join(self.secrets_dir, "openai.key")) as f:
                self.openai_key = f.read().strip()
        except FileNotFoundError as e:
            if self.logger:
                self.logger.error(f"API key file not found: {e}")
                self.logger.info(f"Please add your OpenAI API key to: {os.path.join(self.secrets_dir, 'openai.key')}")
            raise ValueError(f"Missing OpenAI API key file (expected at {os.path.join(self.secrets_dir, 'openai.key')})")
    
    def get_or_create_conversation(self, conversation_id: str) -> Conversation:
        if conversation_id not in self.conversations:
            self.conversations[conversation_id] = Conversation()
        return self.conversations[conversation_id]
    
    async def call_openai(
        self,
        conversation_id: str,
        prompt: str,
        model: str = "gpt-4o",
        temperature: float = 1.0,
        max_tokens: Optional[int] = None,
        system_prompt: Optional[str] = None,
        n: int = 1,
        request_timeout: Optional[float] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """Call OpenAI API with conversation history and timeout"""
        try:
            from openai import AsyncOpenAI
            
            # Determine the timeout for this specific call
            # Priority: passed argument > instance default
            timeout_duration = request_timeout if request_timeout is not None else self.request_timeout
            if self.logger:
                self.logger.debug(f"Using timeout duration: {timeout_duration}s for API call to model {model}")

            # Create client instance with timeout
            client = AsyncOpenAI(
                api_key=self.openai_key,
                timeout=timeout_duration # Pass timeout to the client
            )
            
            conv = self.get_or_create_conversation(conversation_id)
            conv.add_message("user", prompt)
            
            messages = [{"role": msg.role, "content": msg.content} for msg in conv.get_messages()]
            
            # Add system message if provided
            if system_prompt is not None:
                if self.logger:
                    self.logger.info(f"Using system prompt for {conversation_id}")
                # Insert system message at the beginning
                messages.insert(0, {"role": "system", "content": system_prompt})
            
            # Prepare API parameters
            api_params = {
                "model": model,
                "messages": messages,
                "temperature": temperature,
                "n": n,
                **kwargs
            }
            
            # Add max_tokens if provided
            if max_tokens is not None:
                api_params["max_tokens"] = max_tokens
            
            response = await client.chat.completions.create(**api_params)
            
            # Add only the first response to conversation history if n > 1
            if response.choices:
                 conv.add_message("assistant", response.choices[0].message.content)
            
            if self.logger:
                self.logger.info(f"Successfully called OpenAI API with model {model}")
            
            return response
        except AsyncTimeoutError: # Catch the specific timeout error from asyncio/openai
            if self.logger:
                self.logger.error(f"OpenAI API call timed out after {timeout_duration} seconds (model: {model}).")
            # Re-raise or return a specific indicator? Re-raising for now.
            raise # Allows callers to handle the timeout specifically
        except Exception as e:
            if self.logger:
                self.logger.error(f"Error calling OpenAI API: {e}")
            raise

    async def call_model(
        self,
        conversation_id: str,
        prompt: str,
        model: Optional[str] = None,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        system_prompt: Optional[str] = None,
        model_config_override: Optional[Union[DictConfig, Dict]] = None,
        n: int = 1,
        request_timeout: Optional[float] = None
    ) -> Dict[str, Any]:
        """Generic method to call OpenAI model, handling config and overrides."""
        # Determine final parameters, layering: defaults -> self.model_config -> model_config_override
        final_params = {}

        # Start with base defaults or passed arguments
        final_params['model'] = model if model is not None else "gpt-4o" # Default model
        final_params['temperature'] = temperature if temperature is not None else 1.0 # Default temp
        final_params['max_tokens'] = max_tokens # Can be None
        final_params['top_p'] = top_p # Can be None
        final_system_prompt = system_prompt # Can be None
        final_params['n'] = n # Add n to final params

        # Layer 1: Apply self.model_config if present
        base_config = getattr(self, 'model_config', None)
        if base_config:
            final_params['model'] = getattr(base_config, 'name', final_params['model']) # Note: 'name' in config
            final_params['temperature'] = getattr(base_config, 'temperature', final_params['temperature'])
            final_params['max_tokens'] = getattr(base_config, 'max_tokens', final_params['max_tokens'])
            final_params['top_p'] = getattr(base_config, 'top_p', final_params['top_p'])
            final_system_prompt = getattr(base_config, 'system_prompt', final_system_prompt)
            final_params['n'] = getattr(base_config, 'n', final_params['n']) # Apply n from base config
            # Note: Timeout is handled via self.request_timeout by default, but can be overridden below
            
        # Layer 2: Apply model_config_override if present
        if model_config_override:
            # Convert to dict if it's a DictConfig for easier access
            override_dict = model_config_override
            if isinstance(model_config_override, DictConfig):
                 override_dict = OmegaConf.to_container(model_config_override, resolve=True)
            
            # Ensure override_dict is actually a dictionary before proceeding
            if isinstance(override_dict, dict):
                # Override specific parameters if they exist in the override dict
                final_params['model'] = override_dict.get('name', override_dict.get('model', final_params['model'])) # Allow 'name' or 'model' key
                final_params['temperature'] = override_dict.get('temperature', final_params['temperature'])
                final_params['max_tokens'] = override_dict.get('max_tokens', final_params['max_tokens'])
                final_params['top_p'] = override_dict.get('top_p', final_params['top_p'])
                # Note: system_prompt override is usually handled by passing it directly, but could be added here if needed
                # final_system_prompt = override_dict.get('system_prompt', final_system_prompt)
                final_params['n'] = override_dict.get('n', final_params['n']) # Apply n from override
                # Allow overriding the request timeout via override config
                request_timeout = override_dict.get('request_timeout', request_timeout)
            else:
                self.logger.warning(f"Could not convert model_config_override to dict. Type: {type(override_dict)}. Ignoring override.")

        # Filter out None values for optional parameters before passing to call_openai
        call_openai_args = {k: v for k, v in final_params.items() if v is not None}
        # Ensure essential keys like 'model' are present
        if 'model' not in call_openai_args:
             raise ValueError("LLM model name could not be determined.")
         
        # Call OpenAI with resolved parameters
        # Pass positional args first, then unpack the keywords dict
        response = await self.call_openai(
            conversation_id=conversation_id,
            prompt=prompt,
            system_prompt=final_system_prompt, 
            request_timeout=request_timeout, # Pass the final timeout value
            **call_openai_args 
        )
         
        return response

async def test_llm_api():
    """Test the LLM API to ensure it's working correctly"""
    # Configure logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("llm_api_test")
    
    try:
        # Initialize LLMCaller
        logger.info("Initializing LLM caller...")
        caller = LLMCaller(logger=logger)
        
        # Create a simple test prompt
        test_prompt = "Respond with a very brief greeting, just 2-3 words."
        
        # Call the OpenAI API
        logger.info("Testing API connection with a simple prompt...")
        start_time = time.time()
        
        response = await caller.call_model(
            conversation_id="api_test",
            prompt=test_prompt,
            model="gpt-3.5-turbo",  # Using a cheaper model for testing
            temperature=0.7,
            system_prompt="You are a helpful assistant that keeps responses extremely brief."
        )
        
        elapsed = time.time() - start_time
        
        # Check response
        content = response.choices[0].message.content
        logger.info(f"Received response in {elapsed:.2f} seconds: '{content}'")
        logger.info(f"Model: {response.model}")
        logger.info("API test successful!")
        
        return True, content
    except Exception as e:
        logger.error(f"API test failed: {e}")
        return False, str(e)

if __name__ == "__main__":
    # Run the API test
    asyncio.run(test_llm_api())