"""
Gemini Token Tracking Callback
Tracks token usage from Gemini API responses
"""

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult


class GeminiTokenTrackingCallback(BaseCallbackHandler):
    """Callback to track token usage from Gemini API responses"""
    
    def __init__(self):
        super().__init__()
        self.input_tokens = 0
        self.output_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.reasoning_tokens = 0  # For Gemini: reasoning = total - (prompt + candidates)
        self.total_tokens = 0
    
    def on_llm_end(self, *args, **kwargs):
        """Track token usage when LLM finishes"""
        # Handle different call signatures from LangChain
        # LangChain may call: on_llm_end(response) or on_llm_end(response=...)
        response = None
        if args:
            response = args[0]
        elif 'response' in kwargs:
            response = kwargs['response']
        
        if not response:
            return
        
        # Handle LLMResult from LangChain
        if isinstance(response, LLMResult):
            # Check if response has metadata with Gemini response
            if hasattr(response, 'llm_output') and response.llm_output:
                # Check if we have a custom response object with gemini_response in metadata
                if 'response' in response.llm_output:
                    custom_response = response.llm_output['response']
                    if hasattr(custom_response, 'response_metadata'):
                        gemini_response = custom_response.response_metadata.get('gemini_response')
                        if gemini_response:
                            self._extract_gemini_tokens(gemini_response)
            
            # Also check response_metadata directly if available
            # This handles cases where the response object itself has metadata
            if hasattr(response, 'generations') and response.generations:
                for gen_list in response.generations:
                    for gen in gen_list:
                        if hasattr(gen, 'message') and hasattr(gen.message, 'response_metadata'):
                            gemini_response = gen.message.response_metadata.get('gemini_response')
                            if gemini_response:
                                self._extract_gemini_tokens(gemini_response)
        
        # Handle direct invocation result (our custom LangChainCompatibleResponse)
        # The response might be our custom wrapper's response
        elif hasattr(response, 'response_metadata'):
            gemini_response = response.response_metadata.get('gemini_response')
            if gemini_response:
                self._extract_gemini_tokens(gemini_response)
        
        # Also check if response itself has usage_metadata (direct access)
        if hasattr(response, 'usage_metadata') and response.usage_metadata:
            # Try to extract directly if available
            usage = response.usage_metadata
            if hasattr(usage, 'prompt_token_count') or (isinstance(usage, dict) and 'prompt_token_count' in usage):
                # This is likely a Gemini usage_metadata object
                if hasattr(usage, 'prompt_token_count'):
                    prompt_tokens = usage.prompt_token_count
                    candidates_tokens = usage.candidates_token_count
                    total_tokens = usage.total_token_count
                elif isinstance(usage, dict):
                    prompt_tokens = usage.get('prompt_token_count', 0)
                    candidates_tokens = usage.get('candidates_token_count', 0)
                    total_tokens = usage.get('total_token_count', 0)
                else:
                    prompt_tokens = getattr(usage, 'prompt_token_count', 0)
                    candidates_tokens = getattr(usage, 'candidates_token_count', 0)
                    total_tokens = getattr(usage, 'total_token_count', 0)
                
                reasoning_tokens = total_tokens - (prompt_tokens + candidates_tokens)
                if reasoning_tokens < 0:
                    reasoning_tokens = 0
                
                self.prompt_tokens += prompt_tokens
                self.completion_tokens += candidates_tokens
                self.input_tokens += prompt_tokens
                self.output_tokens += candidates_tokens
                self.reasoning_tokens += reasoning_tokens
                self.total_tokens += total_tokens
    
    def _extract_gemini_tokens(self, gemini_response):
        """
        Extract token counts from Gemini response
        
        Gemini structure:
        - usage_metadata.prompt_token_count: Input tokens
        - usage_metadata.candidates_token_count: Output tokens
        - usage_metadata.total_token_count: Total tokens
        - reasoning_tokens = total - (prompt + candidates)
        """
        # Handle both object and dict representations
        usage_metadata = None
        if hasattr(gemini_response, 'usage_metadata'):
            usage_metadata = gemini_response.usage_metadata
        elif isinstance(gemini_response, dict) and 'usage_metadata' in gemini_response:
            usage_metadata = gemini_response['usage_metadata']
        
        if usage_metadata:
            # Handle both object and dict
            if hasattr(usage_metadata, 'prompt_token_count'):
                prompt_tokens = usage_metadata.prompt_token_count
                candidates_tokens = usage_metadata.candidates_token_count
                total_tokens = usage_metadata.total_token_count
            elif isinstance(usage_metadata, dict):
                prompt_tokens = usage_metadata.get('prompt_token_count', 0)
                candidates_tokens = usage_metadata.get('candidates_token_count', 0)
                total_tokens = usage_metadata.get('total_token_count', 0)
            else:
                prompt_tokens = getattr(usage_metadata, 'prompt_token_count', 0)
                candidates_tokens = getattr(usage_metadata, 'candidates_token_count', 0)
                total_tokens = getattr(usage_metadata, 'total_token_count', 0)
            
            # Calculate reasoning tokens: total - (prompt + candidates)
            reasoning_tokens = total_tokens - (prompt_tokens + candidates_tokens)
            if reasoning_tokens < 0:
                reasoning_tokens = 0  # Ensure non-negative
            
            # Accumulate tokens
            self.prompt_tokens += prompt_tokens
            self.completion_tokens += candidates_tokens
            self.input_tokens += prompt_tokens
            self.output_tokens += candidates_tokens
            self.reasoning_tokens += reasoning_tokens
            self.total_tokens += total_tokens
    
    def get_stats(self):
        """Get token statistics"""
        return {
            "input_tokens": self.input_tokens,
            "output_tokens": self.output_tokens,
            "prompt_tokens": self.prompt_tokens,
            "completion_tokens": self.completion_tokens,
            "reasoning_tokens": self.reasoning_tokens,
            "total_tokens": self.input_tokens + self.output_tokens + self.reasoning_tokens
        }
    
    def reset(self):
        """Reset counters"""
        self.input_tokens = 0
        self.output_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.reasoning_tokens = 0
        self.total_tokens = 0

