"""
Unified OpenAI API handler supporting multiple models and endpoints
"""

import time
import json
import logging
from typing import Dict, Any, Tuple, Optional
from openai import OpenAI
from .base_provider import BaseProvider
from .utils import TokenCounter, CostCalculator, RequestLogger


class OpenAIHandler(BaseProvider):
    """Handler for all OpenAI models with automatic endpoint detection"""
    
    def _get_model_type(self) -> str:
        """
        Determine model type for appropriate handling
        
        Returns:
            'standard' for GPT-4o series
            'reasoning' for O1/O3 series
            'gpt5' for GPT-5 series
        """
        model_lower = self.model_name.lower()
        
        # GPT-5 series
        if model_lower.startswith(('gpt-5', 'gpt5')):
            return 'gpt5'
        
        # Reasoning models (O1/O3 series)
        if model_lower.startswith(('o1', 'o3')):
            return 'reasoning'
            
        # Standard models (GPT-4o and others)
        return 'standard'
    
    def __init__(self, config: Dict[str, Any]):
        """Initialize OpenAI handler with configuration"""
        super().__init__(config)
        
        # Initialize OpenAI client for standard calls
        self.client = OpenAI(api_key=self.api_key)
        
        # Respect configured endpoint if provided (e.g., 'responses' for o3/o3-pro)
        self.configured_endpoint = self.config.get('endpoint')
        
        # Initialize utilities
        self.token_counter = TokenCounter()
        self.request_logger = RequestLogger('OpenAI')
    
    def _invoke_standard_model(self, prompt: str, system_prompt: str = None, **kwargs) -> Tuple[str, Dict]:
        """
        Handle standard models (GPT-4o series) using chat/completions endpoint
        
        Args:
            prompt: User prompt
            system_prompt: Optional system prompt
            **kwargs: Additional parameters
            
        Returns:
            Tuple of (response_text, metadata)
        """
        # Build messages with system message support
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
            messages.append({"role": "user", "content": prompt})
        else:
            messages.append({"role": "user", "content": prompt})
        
        # Prepare parameters with full support
        params = {
            'model': self.model_name,
            'messages': messages,
            'temperature': 1.0,
            'top_p': kwargs.get('top_p', 1.0),
            'frequency_penalty': kwargs.get('frequency_penalty', 0),
            'presence_penalty': kwargs.get('presence_penalty', 0),
            'max_completion_tokens': kwargs.get('max_tokens', self.max_tokens)
        }
        
        # Make API call
        try:
            response = self.client.chat.completions.create(**params)
            response_text = response.choices[0].message.content
            
            # Calculate tokens and cost
            input_tokens = response.usage.prompt_tokens if response.usage else self.token_counter.count_tokens_openai(prompt, self.model_name)
            output_tokens = response.usage.completion_tokens if response.usage else self.token_counter.count_tokens_openai(response_text, self.model_name)
            cost = CostCalculator.calculate_openai_cost(self.model_name, input_tokens, output_tokens)
            
            metadata = {
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'total_tokens': input_tokens + output_tokens,
                'cost': cost,
                'model': self.model_name,
                'endpoint': 'chat/completions'
            }
            
            return response_text, metadata
            
        except Exception as e:
            self.logger.error(f"OpenAI API call failed: {e}")
            raise
    
    def _invoke_reasoning_model(self, prompt: str, system_prompt: str = None, **kwargs) -> Tuple[str, Dict]:
        """
        Handle reasoning models (O1/O3 series) using responses endpoint
        
        Args:
            prompt: User prompt
            system_prompt: Optional system prompt (will be combined with prompt)
            **kwargs: Additional parameters
            
        Returns:
            Tuple of (response_text, metadata)
        """
        # Combine prompts (reasoning models don't support system messages)
        if system_prompt:
            full_prompt = f"{system_prompt}\n\n{prompt}"
        else:
            full_prompt = prompt
        
        # Make API call using responses endpoint
        try:
            sdk_response = self.client.responses.create(
                model=self.model_name,
                input=full_prompt,
                temperature=1.0,  # Fixed for reasoning models
                max_output_tokens=kwargs.get('max_tokens', self.max_tokens)
            )
            
            # Extract text from response
            response_text = self._extract_text_from_responses(sdk_response)
            if not response_text:
                raise ValueError("Could not extract text from responses API")
            
            # Calculate tokens and cost
            usage = getattr(sdk_response, 'usage', None)
            input_tokens = getattr(usage, 'input_tokens', None) if usage else None
            output_tokens = getattr(usage, 'output_tokens', None) if usage else None
            if input_tokens is None:
                input_tokens = self.token_counter.count_tokens_openai(full_prompt, self.model_name)
            if output_tokens is None:
                output_tokens = self.token_counter.count_tokens_openai(response_text, self.model_name)
            
            cost = CostCalculator.calculate_openai_cost(self.model_name, input_tokens, output_tokens)
            
            metadata = {
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'total_tokens': input_tokens + output_tokens,
                'cost': cost,
                'model': self.model_name,
                'endpoint': 'responses'
            }
            
            return response_text, metadata
            
        except Exception as e:
            self.logger.error(f"Error processing response: {e}")
            raise
    
    def _invoke_gpt5_model(self, prompt: str, system_prompt: str = None, **kwargs) -> Tuple[str, Dict]:
        """
        Handle GPT-5 series using responses endpoint with special parameters
        
        Args:
            prompt: User prompt
            system_prompt: Optional system prompt (will be combined with prompt)
            **kwargs: Additional parameters including reasoning_effort and verbosity
            
        Returns:
            Tuple of (response_text, metadata)
        """
        # Combine prompts
        if system_prompt:
            full_prompt = f"{system_prompt}\n\n{prompt}"
        else:
            full_prompt = prompt
        
        # Make API call with GPT-5 specific format
        try:
            sdk_response = self.client.responses.create(
                model=self.model_name,
                input=full_prompt,
                reasoning={
                    "effort": kwargs.get('reasoning_effort', self.config.get('reasoning_effort', 'medium'))
                },
                text={
                    "verbosity": kwargs.get('verbosity', self.config.get('verbosity', 'medium'))
                },
                max_output_tokens=kwargs.get('max_tokens', self.max_tokens)
            )
            
            # Extract text from response
            response_text = self._extract_text_from_responses(sdk_response)
            if not response_text:
                raise ValueError("Could not extract text from GPT-5 responses API")
            
            # Calculate tokens and cost
            usage = getattr(sdk_response, 'usage', None)
            input_tokens = getattr(usage, 'input_tokens', None) if usage else None
            output_tokens = getattr(usage, 'output_tokens', None) if usage else None
            if input_tokens is None:
                input_tokens = self.token_counter.count_tokens_openai(full_prompt, self.model_name)
            if output_tokens is None:
                output_tokens = self.token_counter.count_tokens_openai(response_text, self.model_name)
            
            cost = CostCalculator.calculate_openai_cost(self.model_name, input_tokens, output_tokens)
            
            metadata = {
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'total_tokens': input_tokens + output_tokens,
                'cost': cost,
                'model': self.model_name,
                'endpoint': 'responses'
            }
            
            return response_text, metadata
            
        except Exception as e:
            self.logger.error(f"Error processing GPT-5 response: {e}")
            raise
    
    def invoke(self, prompt: str, system_prompt: str = None, **kwargs) -> Tuple[str, Dict]:
        """
        Invoke OpenAI model with appropriate handler based on model type
        
        Args:
            prompt: User prompt
            system_prompt: Optional system prompt
            **kwargs: Additional parameters
            
        Returns:
            Tuple of (response_text, metadata)
        """
        start_time = time.time()
        
        # Route to appropriate handler based on model type
        model_type = self._get_model_type()
        
        if model_type == 'standard':
            response_text, metadata = self._invoke_standard_model(prompt, system_prompt, **kwargs)
        elif model_type == 'reasoning':
            response_text, metadata = self._invoke_reasoning_model(prompt, system_prompt, **kwargs)
        elif model_type == 'gpt5':
            response_text, metadata = self._invoke_gpt5_model(prompt, system_prompt, **kwargs)
        else:
            # Fallback to standard handler
            response_text, metadata = self._invoke_standard_model(prompt, system_prompt, **kwargs)
        
        # Add timing information
        metadata['time_taken'] = time.time() - start_time
        
        # Log the request
        self.request_logger.log_request(prompt, response_text, metadata)
        
        return response_text, metadata
    
    def _extract_text_from_responses(self, sdk_response) -> Optional[str]:
        """Extract plain text from Responses API SDK response."""
        # Try direct text extraction
        if hasattr(sdk_response, 'output_text') and sdk_response.output_text:
            return sdk_response.output_text.strip()
        
        # Try structured output blocks
        if hasattr(sdk_response, 'output') and isinstance(sdk_response.output, list):
            for block in sdk_response.output:
                if hasattr(block, 'content') and isinstance(block.content, list):
                    for content in block.content:
                        if (hasattr(content, 'type') and content.type in ('output_text', 'text') and
                            hasattr(content, 'text') and content.text):
                            return content.text.strip()
        
        # Try serialization as fallback
        try:
            if hasattr(sdk_response, 'model_dump'):
                result = sdk_response.model_dump()
            elif hasattr(sdk_response, 'dict'):
                result = sdk_response.dict()
            else:
                return None
                
            # Extract from serialized structure
            output = result.get('output', [])
            for item in output:
                content_list = item.get('content', [])
                for content in content_list:
                    if content.get('type') in ('output_text', 'text') and content.get('text'):
                        return content['text'].strip()
        except Exception:
            pass
        
        return None
    
    def count_tokens(self, text: str) -> int:
        """Count tokens for this provider"""
        return self.token_counter.count_tokens_openai(text, self.model_name)