"""
Token tracking utilities for comprehensive monitoring of LLM token usage.

This module provides utilities to track and analyze token usage from LiteLLM responses,
including thinking/reasoning tokens from reasoning models like GPT-5 and o1.
"""

from typing import Dict, Any, List, Optional


class TokenTracker:
    """
    Tracks and analyzes token usage from LiteLLM completion responses.
    
    Handles both thinking models (GPT-5, o1) that produce reasoning tokens
    and standard models (GPT-4o) that only produce visible completion tokens.
    """
    
    def __init__(self):
        self.turn_metrics: List[Dict[str, Any]] = []
        self.cumulative_tokens = {
            "total_prompt_tokens": 0,
            "total_completion_tokens": 0,
            "total_reasoning_tokens": 0,
            "total_visible_tokens": 0,
            "total_tokens": 0,
            "turns_count": 0
        }
    
    def extract_token_metrics(self, response, turn_number: int, model_name: str) -> Dict[str, Any]:
        """
        Extract comprehensive token metrics from a LiteLLM response.
        
        Args:
            response: LiteLLM completion response object
            turn_number: Current conversation turn number (1-indexed)
            model_name: Name of the model used
            
        Returns:
            Dictionary containing token metrics for this response
        """
        if not response or not hasattr(response, 'usage'):
            return self._empty_metrics(turn_number, model_name)
        
        usage = response.usage
        
        # Base token counts
        prompt_tokens = getattr(usage, 'prompt_tokens', 0)
        completion_tokens = getattr(usage, 'completion_tokens', 0)
        total_tokens = getattr(usage, 'total_tokens', prompt_tokens + completion_tokens)
        
        # Extract reasoning tokens (thinking models only)
        reasoning_tokens = 0
        if hasattr(usage, 'completion_tokens_details') and usage.completion_tokens_details:
            reasoning_tokens = getattr(usage.completion_tokens_details, 'reasoning_tokens', 0) or 0
        
        # Calculate visible tokens (what user actually sees)
        visible_tokens = completion_tokens - reasoning_tokens
        
        # Calculate test-time compute ratio
        test_time_compute_ratio = (reasoning_tokens / completion_tokens) if completion_tokens > 0 else 0.0
        
        # Extract cached tokens (for long-context analysis)
        cached_tokens = 0
        if hasattr(usage, 'prompt_tokens_details') and usage.prompt_tokens_details:
            cached_tokens = getattr(usage.prompt_tokens_details, 'cached_tokens', 0) or 0
        
        # Extract cost information
        cost = 0.0
        if hasattr(response, '_hidden_params') and 'response_cost' in response._hidden_params:
            cost = response._hidden_params['response_cost']
        
        return {
            "turn_number": turn_number,
            "model_name": model_name,
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "reasoning_tokens": reasoning_tokens,
            "visible_tokens": visible_tokens,
            "total_tokens": total_tokens,
            "test_time_compute_ratio": test_time_compute_ratio,
            "cached_tokens": cached_tokens,
            "cost": cost,
            "is_thinking_model": reasoning_tokens > 0
        }
    
    def add_turn_metrics(self, response, turn_number: int, model_name: str) -> Dict[str, Any]:
        """
        Add token metrics for a conversation turn and update cumulative totals.
        
        Args:
            response: LiteLLM completion response object
            turn_number: Current conversation turn number (1-indexed)
            model_name: Name of the model used
            
        Returns:
            Token metrics for this turn
        """
        turn_metrics = self.extract_token_metrics(response, turn_number, model_name)
        
        # Store turn metrics
        self.turn_metrics.append(turn_metrics)
        
        # Update cumulative totals
        self.cumulative_tokens["total_prompt_tokens"] += turn_metrics["prompt_tokens"]
        self.cumulative_tokens["total_completion_tokens"] += turn_metrics["completion_tokens"]
        self.cumulative_tokens["total_reasoning_tokens"] += turn_metrics["reasoning_tokens"]
        self.cumulative_tokens["total_visible_tokens"] += turn_metrics["visible_tokens"]
        self.cumulative_tokens["total_tokens"] += turn_metrics["total_tokens"]
        self.cumulative_tokens["turns_count"] = len(self.turn_metrics)
        
        return turn_metrics
    
    def get_conversation_summary(self) -> Dict[str, Any]:
        """
        Get comprehensive token usage summary for the entire conversation.
        
        Returns:
            Dictionary containing conversation-level token analytics
        """
        if not self.turn_metrics:
            return self._empty_conversation_summary()
        
        # Calculate average reasoning ratio across all turns
        total_completion = self.cumulative_tokens["total_completion_tokens"]
        total_reasoning = self.cumulative_tokens["total_reasoning_tokens"]
        avg_reasoning_ratio = (total_reasoning / total_completion) if total_completion > 0 else 0.0
        
        # Calculate total cost
        total_cost = sum(turn.get("cost", 0) for turn in self.turn_metrics)
        
        # Analyze thinking model usage
        thinking_turns = [turn for turn in self.turn_metrics if turn.get("is_thinking_model", False)]
        thinking_model_usage = len(thinking_turns) / len(self.turn_metrics) if self.turn_metrics else 0.0
        
        return {
            **self.cumulative_tokens,
            "avg_reasoning_ratio": avg_reasoning_ratio,
            "total_cost": total_cost,
            "thinking_model_usage": thinking_model_usage,
            "per_turn_metrics": self.turn_metrics,
            "conversation_efficiency": {
                "tokens_per_turn": self.cumulative_tokens["total_tokens"] / self.cumulative_tokens["turns_count"] if self.cumulative_tokens["turns_count"] > 0 else 0,
                "reasoning_tokens_per_turn": total_reasoning / self.cumulative_tokens["turns_count"] if self.cumulative_tokens["turns_count"] > 0 else 0,
                "visible_tokens_per_turn": self.cumulative_tokens["total_visible_tokens"] / self.cumulative_tokens["turns_count"] if self.cumulative_tokens["turns_count"] > 0 else 0
            }
        }
    
    def get_turn_summary(self, turn_number: int) -> Optional[Dict[str, Any]]:
        """
        Get token metrics for a specific turn.
        
        Args:
            turn_number: Turn number to retrieve (1-indexed)
            
        Returns:
            Token metrics for the specified turn, or None if not found
        """
        for turn in self.turn_metrics:
            if turn["turn_number"] == turn_number:
                return turn
        return None
    
    def _empty_metrics(self, turn_number: int, model_name: str) -> Dict[str, Any]:
        """Return empty metrics structure for failed responses."""
        return {
            "turn_number": turn_number,
            "model_name": model_name,
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "reasoning_tokens": 0,
            "visible_tokens": 0,
            "total_tokens": 0,
            "test_time_compute_ratio": 0.0,
            "cached_tokens": 0,
            "cost": 0.0,
            "is_thinking_model": False
        }
    
    def _empty_conversation_summary(self) -> Dict[str, Any]:
        """Return empty conversation summary structure."""
        return {
            "total_prompt_tokens": 0,
            "total_completion_tokens": 0,
            "total_reasoning_tokens": 0,
            "total_visible_tokens": 0,
            "total_tokens": 0,
            "turns_count": 0,
            "avg_reasoning_ratio": 0.0,
            "total_cost": 0.0,
            "thinking_model_usage": 0.0,
            "per_turn_metrics": [],
            "conversation_efficiency": {
                "tokens_per_turn": 0,
                "reasoning_tokens_per_turn": 0,
                "visible_tokens_per_turn": 0
            }
        }


def extract_token_metrics_from_response(response) -> Dict[str, Any]:
    """
    Standalone function to extract token metrics from a single LiteLLM response.
    
    This is a convenience function for one-off token extraction without
    maintaining conversation state.
    
    Args:
        response: LiteLLM completion response object
        
    Returns:
        Dictionary containing token metrics
    """
    tracker = TokenTracker()
    return tracker.extract_token_metrics(response, turn_number=1, model_name="unknown")