"""
Unified Google Gemini API handler
"""

import time
import logging
from typing import Dict, Any, Tuple, Optional
import google.generativeai as genai
from .base_provider import BaseProvider
from .utils import TokenCounter, CostCalculator, RequestLogger


class GoogleHandler(BaseProvider):
    """Handler for all Google Gemini models"""
    
    
    # Safety settings for Gemini
    SAFETY_SETTINGS = [
        {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_NONE"
        },
        {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_NONE"
        },
        {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_NONE"
        },
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_NONE"
        }
    ]
    
    def __init__(self, config: Dict[str, Any]):
        """Initialize Google handler with configuration"""
        super().__init__(config)
        
        # Configure the API key
        genai.configure(api_key=self.api_key)
        
        # Initialize the model
        self.model = genai.GenerativeModel(
            model_name=self.model_name,
            generation_config={
                'temperature': self.temperature,
                'top_p': config.get('top_p', 1.0),
                'top_k': config.get('top_k', 1),
                'max_output_tokens': self.max_tokens,
            },
            safety_settings=self.SAFETY_SETTINGS
        )
        
        # Initialize utilities
        self.token_counter = TokenCounter()
        self.request_logger = RequestLogger('Google')
    
    def invoke(self, prompt: str, system_prompt: str = None, **kwargs) -> Tuple[str, Dict]:
        """
        Invoke Google Gemini model
        
        Args:
            prompt: User prompt
            system_prompt: Optional system prompt
            **kwargs: Additional parameters
            
        Returns:
            Tuple of (response_text, metadata)
        """
        start_time = time.time()
        
        # Combine system prompt and user prompt
        if system_prompt:
            full_prompt = f"{system_prompt}\n\n{prompt}"
        else:
            full_prompt = prompt
        
        # Override generation config if needed
        generation_config = {
            'temperature': kwargs.get('temperature', self.temperature),
            'top_p': kwargs.get('top_p', 1.0),
            'top_k': kwargs.get('top_k', 1),
            'max_output_tokens': kwargs.get('max_tokens', self.max_tokens),
        }
        
        try:
            # Make API call
            response = self.model.generate_content(
                full_prompt,
                generation_config=generation_config,
                safety_settings=self.SAFETY_SETTINGS
            )
            
            # Check if response was blocked by safety filters or token limit
            if not response.candidates or not response.candidates[0].content:
                # Handle safety filtering or other blocking reasons
                finish_reason = response.candidates[0].finish_reason if response.candidates else "UNKNOWN"
                safety_ratings = response.candidates[0].safety_ratings if response.candidates else []
                
                # Check for specific finish reasons
                if finish_reason == 2 or str(finish_reason) == "STOP_REASON_MAX_TOKENS":
                    # Token limit reached - try to return partial response
                    if response.candidates and response.candidates[0].content and response.candidates[0].content.parts:
                        partial_text = response.candidates[0].content.parts[0].text
                        self.logger.warning(f"Gemini hit max token limit. Returning partial response (length: {len(partial_text)} chars)")
                        # Return partial response
                        response_text = partial_text
                    else:
                        error_msg = (f"Response exceeded max token limit (finish_reason={finish_reason}). "
                                   f"Current max_tokens: {generation_config.get('max_output_tokens', self.max_tokens)}. "
                                   f"Consider increasing max_tokens in config.yaml")
                        self.logger.error(error_msg)
                        raise ValueError(error_msg)
                elif finish_reason in [3, 4, 5] or "SAFETY" in str(finish_reason):
                    error_msg = f"Response blocked by Gemini safety filters (finish_reason={finish_reason})"
                    if safety_ratings:
                        blocked_categories = [r.category for r in safety_ratings if r.probability != "NEGLIGIBLE"]
                        if blocked_categories:
                            error_msg += f" - Blocked categories: {blocked_categories}"
                    self.logger.warning(error_msg)
                    raise ValueError(error_msg)
                else:
                    error_msg = f"Gemini response error (finish_reason={finish_reason})"
                    self.logger.warning(error_msg)
                    raise ValueError(error_msg)
            else:
                # Extract response text normally
                response_text = response.text
            
            # Calculate tokens (Gemini provides usage metadata)
            if hasattr(response, 'usage_metadata'):
                input_tokens = response.usage_metadata.prompt_token_count
                output_tokens = response.usage_metadata.candidates_token_count
                total_tokens = response.usage_metadata.total_token_count
            else:
                # Fallback to approximation
                input_tokens = self.token_counter.count_tokens_google(full_prompt, self.model)
                output_tokens = self.token_counter.count_tokens_google(response_text, self.model)
                total_tokens = input_tokens + output_tokens
            
            # Calculate cost
            cost = CostCalculator.calculate_google_cost(self.model_name, input_tokens, output_tokens)
            
            # Build metadata
            metadata = {
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'total_tokens': total_tokens,
                'cost': cost,
                'model': self.model_name,
                'endpoint': 'generateContent',
                'time_taken': time.time() - start_time
            }
            
            # Log the request
            self.request_logger.log_request(full_prompt, response_text, metadata)
            
            return response_text, metadata
            
        except Exception as e:
            self.logger.error(f"Gemini API call failed: {e}")
            raise
    
    def count_tokens(self, text: str) -> int:
        """Count tokens for this provider"""
        return self.token_counter.count_tokens_google(text, self.model)