"""
Model handling utilities for different LLM providers.
Supports OpenAI, Anthropic, and Google Gemini models with both LiteLLM and direct API access.
"""

import os
import json
import requests
import litellm
from anthropic import Anthropic
from openai import OpenAI
from typing import List, Dict, Any, Optional, Tuple


class ModelConfig:
    """Handles model configuration and provider detection for LiteLLM and direct clients."""
    
    @staticmethod
    def detect_provider(model_name: str) -> str:
        """Detect the provider based on model name."""
        model_lower = model_name.lower()
        
        # Handle explicit provider prefixes first
        if "/" in model_name:
            provider_prefix = model_name.split("/")[0].lower()
            # Map common provider prefixes
            provider_map = {
                "openai": "openai",
                "anthropic": "anthropic",
                "google": "google"
            }
            mapped_provider = provider_map.get(provider_prefix, None)
            if mapped_provider:
                return mapped_provider
            # If no explicit mapping found, fall through to auto-detection
        
        # Auto-detect based on model name patterns
        if model_lower.startswith(("gpt-")):
            return "openai"
        elif model_lower.startswith(("claude", "anthropic")):
            return "anthropic"
        elif model_lower.startswith(("gemini", "google")):
            return "google"
        elif model_lower.startswith(("vllm", "qwen", "llama", "hermes", "mistral")) or "/qwen" in model_lower or "qwen/" in model_lower:
            return "vllm"
        else:
            # Default to OpenAI for backward compatibility
            return "openai"
    
    @staticmethod
    def requires_direct_client(model_name: str) -> bool:
        """Determine if a model requires direct client instead of LiteLLM."""
        # Check for special Anthropic internal models that need auth token
        if "anthropic.claude" in model_name.lower() and "floodgate" in os.environ.get("ANTHROPIC_BASE_URL", ""):
            return True
        # Check for Gemini models via floodgate
        if model_name.lower().startswith("gemini") and "floodgate" in os.environ.get("GOOGLE_BASE_URL", ""):
            return True
        # Add other conditions here as needed
        return False
    
    @staticmethod
    def validate_provider_keys(provider: str, model_name: str = None) -> None:
        """Validate that required API keys are available for the provider."""
        required_keys = {
            "openai": ["OPENAI_API_KEY"],
            "anthropic": ["ANTHROPIC_API_KEY"],
            "google": ["GOOGLE_API_KEY"],
            "vllm": [],  # No API key required for local vLLM
        }
        
        # Special case for Anthropic with floodgate - skip API key check
        if provider == "anthropic" and model_name:
            base_url = os.environ.get("ANTHROPIC_BASE_URL", "")
            if "floodgate" in base_url and ModelConfig.requires_direct_client(model_name):
                return  # Skip validation for floodgate setup
        
        # Special case for Google with floodgate - skip API key check
        if provider == "google" and model_name:
            base_url = os.environ.get("GOOGLE_BASE_URL", "")
            if "floodgate" in base_url and ModelConfig.requires_direct_client(model_name):
                return  # Skip validation for floodgate setup
        
        if provider in required_keys:
            missing_keys = [key for key in required_keys[provider] if key not in os.environ]
            if missing_keys:
                raise ValueError(f"Missing required environment variables for {provider}: {missing_keys}")
    
    @staticmethod
    def get_model_params(model_name: str, provider: str) -> Dict[str, Any]:
        """Get provider-specific model parameters."""
        if model_name in ['gpt-5', 'openai/gpt-5']:
            base_params = {
                    "model": model_name,
                    "temperature": 1,
            }
        elif provider == "vllm":
            # Use slightly higher temperature for Qwen models for more natural responses
            temperature = 0.3 if "qwen" in model_name.lower() else 0.0
            
            base_params = {
                "model": model_name,
                "temperature": temperature,
                "top_p": 0.8,
                "max_tokens": 8192,
                "extra_body": {
                    "repetition_penalty": 1.05,
                    "chat_template_kwargs": {"enable_thinking": True}
                }
            }
        else:
            base_params = {
                "model": model_name,
                "temperature": 0.0,
            }
        
        return base_params
    
    @staticmethod
    def create_anthropic_client() -> Anthropic:
        """Create an Anthropic client, prioritizing API key over base_url setup."""
        # First check if we have an API key
        api_key = os.environ.get("ANTHROPIC_API_KEY")
        if api_key:
            return Anthropic(api_key=api_key)
        
        # If no API key, check if we have a floodgate setup
        base_url = os.environ.get("ANTHROPIC_BASE_URL")
        if base_url and "floodgate" in base_url:
            print("Warning: No ANTHROPIC_API_KEY found, defaulting to base_url setup")
            # Use internal authentication method (anonymized)
            try:
                # Internal authentication command removed for anonymization
                # auth_token = get_internal_auth_token()  # Placeholder
                raise Exception("Internal auth method removed for anonymization")
            except Exception as e:
                print(f"Warning: Internal auth not available, falling back to API key: {e}")
        
        # Final fallback to API key (will be None if not set)
        return Anthropic(api_key=api_key)


class GeminiClient:
    """Direct client for Google Gemini models via floodgate or standard API."""
    
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.base_url = os.environ.get("GOOGLE_BASE_URL", "")
        
        if "floodgate" in self.base_url:
            # Use floodgate setup
            self.token = self._get_floodgate_token()
            self.url = f"{self.base_url}/v1/publishers/google/models/{model_name}:generateContent"
        else:
            # Standard Google API setup
            self.token = os.environ.get("GOOGLE_API_KEY")
            self.url = 'https://floodgate.g.apple.com/api/gemini/v1/publishers/google/models/{model_name}:generateContent'
    
    def _get_floodgate_token(self) -> str:
        """Get floodgate authentication token."""
        try:
            # Internal authentication command removed for anonymization
            # return get_internal_floodgate_token()  # Placeholder
            raise Exception("Internal auth method removed for anonymization")
        except Exception as e:
            raise ValueError(f"Failed to get floodgate auth token: {e}")
    
    def _convert_openai_tools_to_gemini(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Convert OpenAI tool format to Gemini function declarations format."""
        if not tools:
            return []
        
        gemini_tools = []
        function_declarations = []
        
        for tool in tools:
            if tool.get("type") == "function":
                function = tool["function"]
                function_declarations.append({
                    "name": function["name"],
                    "description": function.get("description", ""),
                    "parameters": function.get("parameters", {})
                })
        
        if function_declarations:
            gemini_tools.append({
                "function_declarations": function_declarations
            })
        
        return gemini_tools
    
    def _convert_messages_to_gemini(self, messages: List[Dict[str, Any]]) -> Tuple[str, List[Dict[str, Any]]]:
        """Convert standard message format to Gemini format."""
        system_message = ""
        gemini_contents = []
        
        i = 0
        while i < len(messages):
            msg = messages[i]
            role = msg.get("role")
            content = msg.get("content", "")
            
            if role == "system":
                system_message = content
            elif role == "user":
                gemini_contents.append({
                    "role": "user",
                    "parts": [{"text": content}]
                })
            elif role == "assistant":
                # Handle assistant messages with tool calls
                if msg.get("tool_calls"):
                    parts = []
                    if content:
                        parts.append({"text": content})
                    
                    for tool_call in msg["tool_calls"]:
                        parts.append({
                            "functionCall": {
                                "name": tool_call["function"]["name"],
                                "args": json.loads(tool_call["function"]["arguments"])
                            }
                        })
                    
                    gemini_contents.append({
                        "role": "model",
                        "parts": parts
                    })
                else:
                    gemini_contents.append({
                        "role": "model", 
                        "parts": [{"text": content}]
                    })
            elif role == "tool":
                # Group consecutive tool responses together
                function_response_parts = []
                
                # Collect all consecutive tool responses
                while i < len(messages) and messages[i].get("role") == "tool":
                    tool_msg = messages[i]
                    function_response_parts.append({
                        "functionResponse": {
                            "name": tool_msg.get("name", ""),
                            "response": {"result": tool_msg.get("content", "")}
                        }
                    })
                    i += 1
                
                # Add grouped function responses as a single message
                gemini_contents.append({
                    "role": "function",
                    "parts": function_response_parts
                })
                
                # Decrement i since we'll increment it at the end of the loop
                i -= 1
            
            i += 1
        
        return system_message, gemini_contents
    
    def _convert_gemini_response_to_openai_format(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
        """Convert Gemini response to OpenAI-compatible format."""
        content = ""
        tool_calls = []
        
        candidates = response_data.get("candidates", [])
        if not candidates:
            return {"role": "assistant", "content": "No response generated", "tool_calls": None}
        
        candidate = candidates[0]
        parts = candidate.get("content", {}).get("parts", [])
        
        for part in parts:
            if "text" in part:
                content += part["text"]
            elif "functionCall" in part:
                func_call = part["functionCall"]
                tool_calls.append({
                    "id": f"call_{len(tool_calls)}",
                    "type": "function",
                    "function": {
                        "name": func_call["name"],
                        "arguments": json.dumps(func_call.get("args", {}))
                    }
                })
        
        return {
            "role": "assistant",
            "content": content,  # Keep empty string instead of converting to None
            "tool_calls": tool_calls if tool_calls else None
        }
    
    def create_completion(
        self,
        messages: List[Dict[str, Any]],
        tools: Optional[List[Dict[str, Any]]] = None,
        response_format: Optional[Dict[str, Any]] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """Create a completion using Gemini API."""
        system_message, gemini_contents = self._convert_messages_to_gemini(messages)
        gemini_tools = self._convert_openai_tools_to_gemini(tools) if tools else []
        
        # Prepare request body
        request_body = {
            "contents": gemini_contents,
            "generationConfig": {
                "temperature": kwargs.get("temperature", 0.0),
                "maxOutputTokens": kwargs.get("max_tokens", 4096),  # Increased limit for longer responses
                "stopSequences": ["\n\n\n\n"]  # Stop excessive newlines
            }
        }
        
        # Add system instruction if present
        if system_message:
            request_body["systemInstruction"] = {
                "parts": [{"text": system_message}]
            }
        
        # Add tools if present
        if gemini_tools:
            request_body["tools"] = gemini_tools
        
        # Skip JSON response format for Gemini - rely on prompts like Claude
        # This avoids conflicts between tool calling and JSON formatting
        # The prompts already instruct the model to respond in JSON format
        
        # Make API request
        headers = {
            'Content-Type': 'application/json',
            'User-Agent': 'API-Example'
        }
        
        if "floodgate" in self.base_url:
            headers['Authorization'] = f'Bearer {self.token}'
        else:
            # For standard Google API, append key to URL
            self.url += f"?key={self.token}"
        
        response = requests.post(self.url, headers=headers, json=request_body)
        
        if response.status_code != 200:
            raise Exception(f"Gemini API error: {response.status_code} - {response.text}")
        
        response_data = response.json()
        
        # Convert to OpenAI format and return
        converted_response = self._convert_gemini_response_to_openai_format(response_data)
        
        # Create mock response object for compatibility
        usage_metadata = response_data.get("usageMetadata", {})
        mock_response = type('MockResponse', (), {
            'choices': [type('Choice', (), {
                'message': type('Message', (), converted_response)()
            })()],
            'usage': type('Usage', (), {
                'prompt_tokens': usage_metadata.get('promptTokenCount', 0),
                'completion_tokens': usage_metadata.get('candidatesTokenCount', 0),
                'total_tokens': usage_metadata.get('totalTokenCount', 0)
            })()
        })()
        
        return mock_response


class ModelHandler:
    """Unified model handler that can work with different providers."""
    
    def __init__(self, model_name: str, **model_kwargs):
        self.model_name = model_name
        self.provider = ModelConfig.detect_provider(model_name)
        self.use_direct_client = ModelConfig.requires_direct_client(model_name)
        
        # Validate provider keys
        ModelConfig.validate_provider_keys(self.provider, model_name)
        
        # Get model parameters
        self.model_params = ModelConfig.get_model_params(model_name, self.provider)
        self.model_params.update(model_kwargs)
        
        # Initialize clients
        self.anthropic_client = None
        self.gemini_client = None
        self.vllm_client = None
        
        if self.use_direct_client:
            if self.provider == "anthropic":
                self.anthropic_client = ModelConfig.create_anthropic_client()
            elif self.provider == "google":
                self.gemini_client = GeminiClient(model_name)
        
        # Always use direct OpenAI client for vLLM models
        if self.provider == "vllm":
            self.vllm_client = OpenAI(
                api_key="EMPTY",
                base_url="http://localhost:8000/v1"
            )
    
    def create_completion(
        self,
        messages: List[Dict[str, Any]],
        tools: Optional[List[Dict[str, Any]]] = None,
        **kwargs
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """Create a completion and return (response_message, usage_stats)."""
        
        if self.provider == "vllm":
            # Use direct OpenAI client for vLLM models
            completion_kwargs = {
                "model": self.model_name,
                "messages": messages,
                **{k: v for k, v in self.model_params.items() if k not in ["model", "extra_body"]},
                **kwargs
            }
            
            if tools:
                completion_kwargs["tools"] = tools
                completion_kwargs["tool_choice"] = "auto"
            
            # Handle extra_body parameters
            extra_body = self.model_params.get("extra_body", {})
            if extra_body:
                completion_kwargs["extra_body"] = extra_body
            
            # Note: Some vLLM servers may not support response_format, so we rely on prompt instructions
            
            response = self.vllm_client.chat.completions.create(**completion_kwargs)
            
            # Extract response message
            # Handle reasoning models (like Qwen) that put content in reasoning_content
            message = response.choices[0].message
            content = message.content
            
            # Clean up thinking tags for Qwen models to improve JSON parsing
            # Note: We still count all tokens (including thinking) in usage stats
            if "qwen" in self.model_name.lower() and content:
                import re
                # Remove <think>...</think> tags but keep the actual response
                cleaned_content = re.sub(r'<think>.*?</think>\s*', '', content, flags=re.DOTALL).strip()
                if cleaned_content:  # Only use cleaned content if it's not empty
                    content = cleaned_content
              
            response_message = {
                "role": message.role,
                "content": content,
                "tool_calls": getattr(message, 'tool_calls', None)
            }
            
            # Extract usage stats
            usage_stats = {
                "prompt_tokens": getattr(response.usage, 'prompt_tokens', 0) if response.usage else 0,
                "completion_tokens": getattr(response.usage, 'completion_tokens', 0) if response.usage else 0,
                "total_tokens": getattr(response.usage, 'total_tokens', 0) if response.usage else 0
            }
            
            return response_message, usage_stats
            
        elif self.use_direct_client and self.gemini_client:
            # Use direct Gemini client
            response = self.gemini_client.create_completion(
                messages=messages,
                tools=tools,
                **{k: v for k, v in self.model_params.items() if k != "model"},
                **kwargs
            )
            
            # Extract response message
            if hasattr(response, 'choices') and response.choices:
                response_message = {
                    "role": response.choices[0].message.role,
                    "content": response.choices[0].message.content,
                    "tool_calls": getattr(response.choices[0].message, 'tool_calls', None)
                }
            else:
                response_message = {"role": "assistant", "content": "No response", "tool_calls": None}
            
            # Extract usage stats
            usage_stats = {
                "prompt_tokens": getattr(response.usage, 'prompt_tokens', 0) if hasattr(response, 'usage') else 0,
                "completion_tokens": getattr(response.usage, 'completion_tokens', 0) if hasattr(response, 'usage') else 0,
                "total_tokens": getattr(response.usage, 'total_tokens', 0) if hasattr(response, 'usage') else 0
            }
            
            return response_message, usage_stats
            
        elif self.use_direct_client and self.anthropic_client:
            # Use direct Anthropic client - existing code
            from .tool_calling_agents import ToolAgent  # Import here to avoid circular imports
            # This would need to be refactored to extract the Anthropic handling
            raise NotImplementedError("Direct Anthropic handling needs to be extracted from ToolAgent")
            
        else:
            # Use LiteLLM - but never for vLLM models
            if self.provider == "vllm":
                raise RuntimeError(f"vLLM model {self.model_name} should use direct OpenAI client, not LiteLLM")
                
            completion_kwargs = {
                "model": self.model_name,
                "messages": messages,
                "tools": tools,
                "tool_choice": "auto" if tools else None,
                **{k: v for k, v in self.model_params.items() if k != "model"},
                **kwargs
            }
            
            response = litellm.completion(**completion_kwargs)
            
            response_message = response.choices[0].message.model_dump()
            usage_stats = {
                "prompt_tokens": getattr(response.usage, 'prompt_tokens', 0),
                "completion_tokens": getattr(response.usage, 'completion_tokens', 0), 
                "total_tokens": getattr(response.usage, 'total_tokens', 0)
            }
            
            return response_message, usage_stats
