"""
Common utilities for API providers - Token counting, cost calculation, and request logging.

This module provides shared utilities used by all LLM providers (OpenAI, Google) to ensure
consistent behavior across different APIs while reducing code duplication.

Key Components:
- TokenCounter: Accurate token counting for different model types
- CostCalculator: Precise cost estimation based on current pricing
- RequestLogger: Standardized logging for API requests and responses

Benefits:
- Unified token counting logic across providers
- Accurate cost tracking for budget management
- Consistent logging format for debugging
- Easy maintenance of pricing information
- Reduced code duplication between providers

Example Usage:
    counter = TokenCounter()
    tokens = counter.count_tokens_openai("Hello world", "gpt-4o")
    cost = CostCalculator.calculate_openai_cost("gpt-4o", 10, 20)
    
    logger = RequestLogger('OpenAI')
    logger.log_request(prompt, response, metadata)
"""

import logging
import time
import tiktoken
from typing import Dict, Any, Optional


class TokenCounter:
    """Utility class for counting tokens across different providers"""
    
    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)
        self._tiktoken_cache = {}
    
    def count_tokens_openai(self, text: str, model_name: str = None) -> int:
        """
        Count tokens for OpenAI models using tiktoken
        
        Args:
            text: Text to count tokens for
            model_name: Optional model name for specific encoding
            
        Returns:
            Number of tokens
        """
        try:
            # Get encoding based on model
            encoding = self._get_tiktoken_encoding(model_name)
            return len(encoding.encode(text))
        except Exception as e:
            self.logger.warning(f"Tiktoken counting failed, using approximation: {e}")
            return self._approximate_tokens(text)
    
    def count_tokens_google(self, text: str, gemini_model=None) -> int:
        """
        Count tokens for Google Gemini models
        
        Args:
            text: Text to count tokens for
            gemini_model: Gemini model instance for accurate counting
            
        Returns:
            Number of tokens
        """
        if gemini_model:
            try:
                return gemini_model.count_tokens(text).total_tokens
            except Exception as e:
                self.logger.warning(f"Gemini token counting failed, using approximation: {e}")
        
        return self._approximate_tokens(text)
    
    def _get_tiktoken_encoding(self, model_name: str = None):
        """Get tiktoken encoding with caching"""
        cache_key = model_name or 'default'
        
        if cache_key not in self._tiktoken_cache:
            try:
                if model_name:
                    encoding = tiktoken.encoding_for_model(model_name)
                else:
                    encoding = tiktoken.get_encoding("cl100k_base")
                self._tiktoken_cache[cache_key] = encoding
            except Exception:
                # Fallback to default encoding
                encoding = tiktoken.get_encoding("cl100k_base")
                self._tiktoken_cache[cache_key] = encoding
        
        return self._tiktoken_cache[cache_key]
    
    def _approximate_tokens(self, text: str) -> int:
        """Approximate token count (1 token ≈ 4 characters for most models)"""
        return len(text) // 4


class CostCalculator:
    """Utility class for calculating API costs across providers"""
    
    OPENAI_PRICING = {
        'gpt-5': {'input': 0.00125, 'output': 0.01},  # $1.25 / 1M tokens input, $10.00 / 1M tokens output
        'gpt-5-mini': {'input': 0.00025, 'output': 0.002},  # $0.25 / 1M tokens input, $2.00 / 1M tokens output
        'gpt-4o': {'input': 0.0025, 'output': 0.01},  # $2.50 / 1M tokens input, $10.00 / 1M tokens output
        'gpt-4o-mini': {'input': 0.00015, 'output': 0.0006},
        'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
        'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
        'o3': {'input': 0.002, 'output': 0.008},  # $2.00 / 1M tokens input, $8.00 / 1M tokens output
        'o3-pro': {'input': 0.02, 'output': 0.08},  # $20.00 / 1M tokens input, $80.00 / 1M tokens output
        'o3-mini': {'input': 0.003, 'output': 0.012},
        'o1': {'input': 0.015, 'output': 0.06},  # $15.00 / 1M tokens input, $60.00 / 1M tokens output
        'o1-pro': {'input': 0.15, 'output': 0.6},  # $150.00 / 1M tokens input, $600.00 / 1M tokens output
        'o1-mini': {'input': 0.003, 'output': 0.012},
        'default': {'input': 0.01, 'output': 0.03}
    }
    
    GOOGLE_PRICING = {
        'gemini-2.5-pro': {'input': 0.00125, 'output': 0.01},  # $1.25 / 1M tokens input, $10.00 / 1M tokens output
        'gemini-2.5-flash': {'input': 0.0003, 'output': 0.0025},  # $0.30 / 1M tokens input, $2.50 / 1M tokens output
        'gemini-1.5-pro': {'input': 0.00125, 'output': 0.005},
        'gemini-1.5-flash': {'input': 0.00025, 'output': 0.001},
        'default': {'input': 0.001, 'output': 0.003}
    }
    
    ANTHROPIC_PRICING = {
        'claude-opus-4-1': {'input': 0.015, 'output': 0.075},  # $15.00 / 1M tokens input, $75.00 / 1M tokens output
        'claude-opus-4': {'input': 0.015, 'output': 0.075},
        'claude-sonnet-4': {'input': 0.003, 'output': 0.015},  # $3.00 / 1M tokens input, $15.00 / 1M tokens output
        'claude-sonnet-3.7': {'input': 0.003, 'output': 0.015},
        'claude-sonnet-3.5': {'input': 0.003, 'output': 0.015},
        'claude-haiku-3': {'input': 0.00025, 'output': 0.00125},  # $0.25 / 1M tokens input, $1.25 / 1M tokens output
        'claude-3-5-haiku': {'input': 0.0008, 'output': 0.004},
        'claude-haiku-3.5': {'input': 0.0008, 'output': 0.004},
        'default': {'input': 0.003, 'output': 0.015}
    }
    
    @classmethod
    def calculate_openai_cost(cls, model_name: str, input_tokens: int, output_tokens: int) -> float:
        """Calculate cost for OpenAI models"""
        return cls._calculate_cost(cls.OPENAI_PRICING, model_name, input_tokens, output_tokens)
    
    @classmethod
    def calculate_google_cost(cls, model_name: str, input_tokens: int, output_tokens: int) -> float:
        """Calculate cost for Google models"""
        return cls._calculate_cost(cls.GOOGLE_PRICING, model_name, input_tokens, output_tokens)
    
    @classmethod
    def calculate_anthropic_cost(cls, model_name: str, input_tokens: int, output_tokens: int) -> float:
        """Calculate cost for Anthropic Claude models"""
        return cls._calculate_cost(cls.ANTHROPIC_PRICING, model_name, input_tokens, output_tokens)
    
    @classmethod
    def _calculate_cost(cls, pricing_table: Dict, model_name: str, input_tokens: int, output_tokens: int) -> float:
        """Generic cost calculation"""
        pricing = cls._get_pricing_for_model(pricing_table, model_name)
        input_cost = (input_tokens / 1000) * pricing['input']
        output_cost = (output_tokens / 1000) * pricing['output']
        return input_cost + output_cost
    
    @classmethod
    def _get_pricing_for_model(cls, pricing_table: Dict, model_name: str) -> Dict[str, float]:
        """Get pricing for a specific model with fallback logic"""
        # Exact match
        if model_name in pricing_table:
            return pricing_table[model_name]
        
        # Prefix match
        model_lower = model_name.lower()
        for model_key, pricing in pricing_table.items():
            if model_key != 'default' and model_lower.startswith(model_key):
                return pricing
        
        # Default pricing
        return pricing_table['default']


class RequestLogger:
    """Utility class for logging API requests consistently"""
    
    def __init__(self, provider_name: str):
        self.logger = logging.getLogger(f"{provider_name}RequestLogger")
    
    def log_request(self, prompt: str, response: str, metadata: Dict[str, Any]):
        """Log API request details"""
        self.logger.debug(f"Model: {metadata.get('model', 'unknown')}")
        self.logger.debug(f"Endpoint: {metadata.get('endpoint', 'unknown')}")
        self.logger.debug(f"Input tokens: {metadata.get('input_tokens', 0)}")
        self.logger.debug(f"Output tokens: {metadata.get('output_tokens', 0)}")
        self.logger.debug(f"Total tokens: {metadata.get('total_tokens', 0)}")
        self.logger.debug(f"Cost: ${metadata.get('cost', 0):.4f}")
        self.logger.debug(f"Time taken: {metadata.get('time_taken', 0):.2f}s")
        
        if self.logger.isEnabledFor(logging.DEBUG):
            self.logger.debug(f"Prompt length: {len(prompt)} chars")
            self.logger.debug(f"Response length: {len(response)} chars")