"""
Gemini LangChain-Compatible Wrapper
Wraps Gemini API to work with LangChain's message format
"""

from typing import List, Dict, Any, Optional
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage

import google.generativeai as genai
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold


class GeminiLangChainWrapper:
    """Wrapper for Gemini API that is compatible with LangChain's interface"""
    
    def __init__(self, api_key: str, model_name: str = "gemini-2.5-pro"):
        """
        Initialize Gemini wrapper
        
        Args:
            api_key: Gemini API key
            model_name: Model name (default: "gemini-2.5-pro")
        """
        self.api_key = api_key
        self.model_name = model_name
        genai.configure(api_key=api_key)
        
        # Initialize models
        self.text_model = genai.GenerativeModel(model_name)
        self.vision_model = genai.GenerativeModel(model_name)  # Same model for vision
    
    def invoke(self, messages: List[BaseMessage], config: Optional[Dict] = None, **kwargs):
        """
        Invoke Gemini API with LangChain messages
        
        Args:
            messages: List of LangChain messages (HumanMessage, AIMessage, SystemMessage)
            config: Optional config dict (may contain callbacks)
            **kwargs: Additional arguments
            
        Returns:
            LangChain-compatible response object with .content and .response_metadata
        """
        # Convert LangChain messages to Gemini format
        prompt_text, system_instruction, has_images = self._convert_messages_to_gemini(messages)
        
        # Get temperature from config or kwargs
        temperature = kwargs.get("temperature", 0.7)
        if config and "temperature" in config:
            temperature = config["temperature"]
        
        # Safety settings
        safety_settings = {
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        }
        
        # Generation config
        generation_config = GenerationConfig(temperature=temperature)
        
        # Create model with system instruction
        model = genai.GenerativeModel(
            self.model_name,
            system_instruction=system_instruction if system_instruction else None,
            generation_config=generation_config,
            safety_settings=safety_settings,
        )
        
        # Prepare content (text + images if any)
        if has_images:
            content = self._prepare_multimodal_content(messages)
        else:
            content = prompt_text
        
        # Generate response
        gemini_response = model.generate_content(content)
        
        # Convert Gemini response to LangChain format
        langchain_response = self._convert_gemini_response_to_langchain(gemini_response)
        
        # Invoke callbacks if provided in config
        if config and 'callbacks' in config:
            from langchain_core.outputs import LLMResult, Generation
            callbacks = config['callbacks']
            
            # Create LLMResult for callbacks
            llm_result = LLMResult(
                generations=[[Generation(text=langchain_response.content, message=langchain_response)]],
                llm_output={'response': langchain_response}
            )
            
            # Call on_llm_end for each callback
            for callback in callbacks:
                if hasattr(callback, 'on_llm_end'):
                    try:
                        # Pass both LLMResult and direct response
                        callback.on_llm_end(llm_result, response=langchain_response)
                    except Exception as e:
                        print(f"⚠️ Error in callback: {e}")
        
        return langchain_response
    
    def _convert_messages_to_gemini(self, messages: List[BaseMessage]) -> tuple:
        """
        Convert LangChain messages to Gemini format
        
        Returns:
            (prompt_text, system_instruction, has_images)
        """
        prompt_parts = []
        system_instruction = None
        has_images = False
        
        for message in messages:
            if isinstance(message, SystemMessage):
                # System messages become system_instruction
                system_instruction = message.content
            elif isinstance(message, HumanMessage):
                # Human messages become prompt
                if isinstance(message.content, str):
                    prompt_parts.append(message.content)
                elif isinstance(message.content, list):
                    # Multimodal content
                    for item in message.content:
                        if isinstance(item, dict):
                            if item.get("type") == "text":
                                prompt_parts.append(item.get("text", ""))
                            elif item.get("type") == "image_url":
                                has_images = True
                                # Store image URL for later processing
                                prompt_parts.append(item)
                        else:
                            prompt_parts.append(str(item))
            elif isinstance(message, AIMessage):
                # AI messages are typically not sent to Gemini in this context
                # But we can include them if needed
                if message.content:
                    prompt_parts.append(f"[Assistant]: {message.content}")
        
        prompt_text = "\n\n".join([p for p in prompt_parts if isinstance(p, str)])
        
        return prompt_text, system_instruction, has_images
    
    def _prepare_multimodal_content(self, messages: List[BaseMessage]) -> List:
        """
        Prepare multimodal content for Gemini API
        
        Returns:
            List of text strings and image data
        """
        content = []
        
        for message in messages:
            if isinstance(message, HumanMessage):
                if isinstance(message.content, list):
                    for item in message.content:
                        if isinstance(item, dict):
                            if item.get("type") == "text":
                                content.append(item.get("text", ""))
                            elif item.get("type") == "image_url":
                                # Extract base64 data from data URI
                                image_url = item.get("image_url", {})
                                if isinstance(image_url, dict):
                                    url = image_url.get("url", "")
                                else:
                                    url = image_url
                                
                                # Parse data URI: data:image/png;base64,<data>
                                if url.startswith("data:image"):
                                    # Extract mime type and base64 data
                                    header, data = url.split(",", 1)
                                    mime_type = header.split(";")[0].split(":")[1]
                                    
                                    import base64
                                    try:
                                        image_data = base64.b64decode(data)
                                        content.append({
                                            "mime_type": mime_type,
                                            "data": image_data
                                        })
                                    except Exception as e:
                                        print(f"⚠️ Error decoding base64 image: {e}")
                                else:
                                    # Regular URL - would need to fetch
                                    print(f"⚠️ Non-data URI image URL not supported: {url}")
                elif isinstance(message.content, str):
                    content.append(message.content)
        
        return content
    
    def _convert_gemini_response_to_langchain(self, gemini_response) -> Any:
        """
        Convert Gemini response to LangChain-compatible format
        
        Returns:
            Object with .content (str) and .response_metadata (dict)
        """
        # Extract text from Gemini response
        text = ""
        if gemini_response.candidates and len(gemini_response.candidates) > 0:
            cand = gemini_response.candidates[0]
            
            if hasattr(cand, 'content') and cand.content:
                if hasattr(cand.content, 'parts') and cand.content.parts:
                    for part in cand.content.parts:
                        if hasattr(part, 'text'):
                            text += part.text
                elif hasattr(cand.content, 'text'):
                    text = cand.content.text
            
            # Fallback to response.text if available
            if not text:
                try:
                    if hasattr(gemini_response, 'text'):
                        text = gemini_response.text
                except (ValueError, AttributeError):
                    pass
        
        # Create LangChain-compatible response object
        class LangChainCompatibleResponse:
            def __init__(self, content: str, gemini_response):
                self.content = content
                self.response_metadata = {
                    "gemini_response": gemini_response
                }
        
        return LangChainCompatibleResponse(text, gemini_response)

