import base64
import io
import os
from typing import Optional, Union

import anthropic
import instructor
import litellm
import numpy as np
import openai
from PIL import Image
from instructor import Mode
from langsmith import traceable
from openai.types.chat.chat_completion import ChatCompletion
from pydantic import BaseModel

from autorpa.utils.llm_logger import log_llm_call


def _array_to_jpeg_bytes(image: np.ndarray) -> bytes:
    """Converts a numpy array into a byte string for a JPEG image."""
    image = Image.fromarray(image)
    in_mem_file = io.BytesIO()
    image.save(in_mem_file, format='JPEG')
    # Reset file pointer to start
    in_mem_file.seek(0)
    img_bytes = in_mem_file.read()
    return img_bytes


class OpenAIWrapper:
    def __init__(self, model_name, max_retry: int = 5, temperature: float = 0.0, reasoning_effort: str = None, enable_logging: bool = True):
        # Metdata
        self.model_name = model_name
        self.max_retry = max_retry
        self.temperature = temperature
        self.reasoning_effort = reasoning_effort
        self.enable_logging = enable_logging
        
        # Llm client
        self.base_url = os.environ.get('OPENAI_URL', '')
        self.client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url=self.base_url)
        self.client = instructor.from_openai(self.client, mode=Mode.MD_JSON)
        
        # Set global configurations for litellm
        litellm.logging = True
        litellm.set_verbose = True
    
    @classmethod
    def encode_image(cls, image: np.ndarray) -> str:
        return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')
    
    @traceable(run_type="chain", name="agent_run")
    def predict_mm(
        self,
        user_prompt: str,
        images: list[np.ndarray] = [],
        system_prompt: Optional[str] = None,
        output_format: Optional[BaseModel] = None) -> Union[
        tuple[BaseModel, ChatCompletion], tuple[str, ChatCompletion]]:

        # Build user message content
        user_content = [{'type': 'text', 'text': user_prompt}]
        
        if images:
            # Filter out None values from images
            none_count = sum(1 for img in images if img is None)
            if none_count > 0:
                print(f"⚠️  Warning: {none_count}/{len(images)} images are None and will be skipped")
            
            valid_images = [img for img in images if img is not None]
            if valid_images:
                image_blocks = [{
                    'type': 'image_url',
                    'image_url': {
                        'url': f'data:image/jpeg;base64,{self.encode_image(image)}',
                    },
                } for image in valid_images]
                user_content.extend(image_blocks)
            else:
                print("⚠️  Warning: All images are None, sending text-only prompt")

        final_messages = []
        if system_prompt and system_prompt.strip():
            final_messages.append({'role': 'developer', 'content': system_prompt})
        final_messages.append({'role': 'user', 'content': user_content})
        common_kwargs = {
            'model': self.model_name,
            'messages': final_messages,
            'temperature': self.temperature
        }

        # Log LLM call input to local log
        if self.enable_logging:
            log_llm_call(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                images=images,
                model_name=self.model_name,
                call_type='predict_mm',
                additional_info={
                    'temperature': self.temperature,
                    'reasoning_effort': self.reasoning_effort,
                    'output_format': output_format.__name__ if output_format else None,
                    'max_retry': self.max_retry
                }
            )
        
        # Register hooks to capture raw response
        self.response_before_parse = None
        
        def log_completion_response(response):
            self.response_before_parse = response.choices[0].message.content
        
        def log_completion_error(error):
            print(f"Error occurred during completion: {error}")
            print(f"Model's raw response: {self.response_before_parse}")
        
        self.client.on("completion:response", log_completion_response)
        self.client.on("completion:error", log_completion_error)
        
        # Add reasoning_effort if specified (for GPT-5)
        if self.reasoning_effort:
            common_kwargs['reasoning_effort'] = self.reasoning_effort
        
        if output_format is not None:
            response, completion = self.client.create_with_completion(
                response_model=output_format,
                max_retries=self.max_retry,
                **common_kwargs
            )  
        else:
            response, completion = self.client.create_with_completion(
                max_retries=1,
                **common_kwargs
            )
        return response, completion


class AnthropicWrapper:
    def __init__(self, model_name, max_retry: int = 5, temperature: float = 0.0, enable_logging: bool = True):
        # Metdata
        self.model_name = model_name
        self.max_retry = max_retry
        self.temperature = temperature
        self.enable_logging = enable_logging
        
        # Llm client
        import re
        base_url = os.environ.get('OPENAI_URL', '')
        self.base_url = re.sub(r'/v1/?$', '', base_url)
        self.raw_client = anthropic.Anthropic(api_key=os.environ['CLAUDE_API_KEY'], base_url=self.base_url)
        self.client = instructor.from_anthropic(self.raw_client, mode=Mode.ANTHROPIC_JSON)
        
        # Set global configurations for litellm
        litellm.logging = True
        litellm.set_verbose = True
    
    @classmethod
    def encode_image(cls, image: np.ndarray) -> str:
        return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')

    @traceable(run_type="chain", name="agent_run")
    def predict_mm(
        self,
        user_prompt: str,
        images: list[np.ndarray] = [],
        system_prompt: Optional[str] = None,
        output_format: Optional[BaseModel] = None) -> Union[
        tuple[BaseModel, ChatCompletion], tuple[str, ChatCompletion]]:
        
        # Build user message content
        user_content = [{'type': 'text', 'text': user_prompt}]
        
        if images:
            # Filter out None values from images
            none_count = sum(1 for img in images if img is None)
            if none_count > 0:
                print(f"⚠️  Warning: {none_count}/{len(images)} images are None and will be skipped")
            
            valid_images = [img for img in images if img is not None]
            if valid_images:
                image_blocks = [{
                    'type': 'image',
                    'source': {
                        "type": "base64",
                        "media_type": "image/jpeg",     # or image/png
                        "data": self.encode_image(image)  # Raw base64, no data:image/... prefix
                    },
                } for image in valid_images]
                user_content.extend(image_blocks)
            else:
                print("⚠️  Warning: All images are None, sending text-only prompt")
        
        # Log LLM call input to local log
        if self.enable_logging:
            log_llm_call(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                images=images,
                model_name=self.model_name,
                call_type='predict_mm',
                additional_info={
                    'temperature': self.temperature,
                    'output_format': output_format.__name__ if output_format else None,
                    'max_retry': self.max_retry
                }
            )
        
        common_kwargs = {
            'model': self.model_name,
            'messages': [{'role': 'user', 'content': user_content}],
            'temperature': self.temperature,
            'max_tokens': 8000,
        }
        
        # Add system prompt if provided
        if system_prompt and system_prompt.strip():
            common_kwargs['system'] = system_prompt

        if output_format is not None:
            response, completion = self.client.create_with_completion(
                response_model=output_format,
                max_retries=self.max_retry,
                **common_kwargs
            )
        else:
            response, completion = self.client.responses.create_with_completion(
                max_retries=1,
                **common_kwargs
            )
        return response, completion


class DoubaoWrapper:
    def __init__(self, model_name, max_retry: int = 5, temperature: float = 0.0, enable_logging: bool = True):
        # Metdata
        self.model_name = model_name
        self.max_retry = max_retry
        self.temperature = temperature
        self.enable_logging = enable_logging
        
        # Llm client
        self.base_url = 'https://ark.cn-beijing.volces.com/api/v3'
        self.client = openai.OpenAI(base_url=self.base_url,
                                    api_key='098be844-30f7-4f06-83d6-25d3d8a2ca25')
        self.client = instructor.from_openai(self.client, mode=Mode.JSON)
        
        # Doubao API typically supports system role
        self.supports_system_role = True
        
        # Set global configurations for litellm
        litellm.logging = True
        litellm.set_verbose = True
    
    @classmethod
    def encode_image(cls, image: np.ndarray) -> str:
        return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')
    
    @traceable(run_type="chain", name="agent_run")
    def predict_mm(
        self,
        user_prompt: str,
        images: list[np.ndarray] = [],
        system_prompt: Optional[str] = None,
        output_format: Optional[BaseModel] = None) -> Union[
        tuple[BaseModel, ChatCompletion], tuple[str, ChatCompletion]]:
        
        # Handle system prompt based on API support
        # If API doesn't support system role, merge system prompt into user prompt
        if system_prompt and system_prompt.strip() and not self.supports_system_role:
            # Merge system prompt into user prompt for APIs that don't support system role
            user_prompt = f"{system_prompt}\n\n{user_prompt}"
            system_prompt = None  # Clear system_prompt so it's not added to messages
        
        # Build user message content
        user_content = [{'type': 'text', 'text': user_prompt}]
        
        if images:
            # Filter out None values from images
            none_count = sum(1 for img in images if img is None)
            if none_count > 0:
                print(f"⚠️  Warning: {none_count}/{len(images)} images are None and will be skipped")
            
            valid_images = [img for img in images if img is not None]
            if valid_images:
                image_blocks = [{
                    'type': 'image_url',
                    'image_url': {
                        'url': f'data:image/jpeg;base64,{self.encode_image(image)}',
                    },
                } for image in valid_images]
                user_content.extend(image_blocks)
            else:
                print("⚠️  Warning: All images are None, sending text-only prompt")
        
        # Build messages
        final_messages = []
        
        # Add system message if API supports it and system_prompt is provided
        if system_prompt and system_prompt.strip() and self.supports_system_role:
            final_messages.append({'role': 'system', 'content': system_prompt})
        
        # Add user message
        final_messages.append({'role': 'user', 'content': user_content})
        
        # Log LLM call input to local log
        if self.enable_logging:
            log_llm_call(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                images=images,
                model_name=self.model_name,
                call_type='predict_mm',
                additional_info={
                    'temperature': self.temperature,
                    'output_format': output_format.__name__ if output_format else None,
                    'max_retry': self.max_retry
                }
            )
        
        # Prepare common kwargs
        common_kwargs = {
            'model': self.model_name,
            'messages': final_messages,
            'temperature': self.temperature
        }
        
        if output_format is not None:
            response, completion = self.client.chat.completions.create_with_completion(
                response_model=output_format,
                max_retries=self.max_retry,
                **common_kwargs
            )
            assert isinstance(response, output_format)
        else:
            completion = self.client.chat.completions.create(
                response_model=output_format,
                max_retries=1,
                **common_kwargs
            )
            response = completion.choices[0].message.content
        return response, completion


def get_llm_wrapper(model_name: str, enable_logging: bool = True) -> OpenAIWrapper | DoubaoWrapper:
    # Handle GPT-5 with reasoning effort suffix
    if model_name.startswith('gpt-5-') and model_name not in ['gpt-5-preview', 'gpt-5-turbo']:
        parts = model_name.split('-')
        if len(parts) >= 3 and parts[2] in ['low', 'medium', 'high']:
            reasoning_effort = parts[2]
            return OpenAIWrapper('gpt-5', reasoning_effort=reasoning_effort, enable_logging=enable_logging)
    
    if model_name == 'gpt-4.1':
        return OpenAIWrapper('gpt-4.1-2025-04-14', enable_logging=enable_logging)
    
    if model_name == 'gpt-4o-1120':
        return OpenAIWrapper('gpt-4o-2024-11-20', enable_logging=enable_logging)
    
    if model_name == 'claude-sonnet-4-5':
        return AnthropicWrapper('claude-sonnet-4-5-20250929', enable_logging=enable_logging)  # proxy
    
    if model_name == 'gemini-2.5-pro':
        return OpenAIWrapper(model_name='gemini-2.5-pro-thinking', enable_logging=enable_logging)  # proxy
    
    if model_name.startswith('doubao'):
        return DoubaoWrapper(model_name, enable_logging=enable_logging)
    
    return OpenAIWrapper(model_name, enable_logging=enable_logging)