import os
import io
import base64
from copy import deepcopy
from typing import Optional, Union

import numpy as np
from PIL import Image
import openai
import litellm
import instructor
from instructor import Mode
from langsmith import traceable
from openai.types.chat.chat_completion import ChatCompletion
from pydantic import BaseModel


def _array_to_jpeg_bytes(image: np.ndarray) -> bytes:
    """Converts a numpy array into a byte string for a JPEG image."""
    image = Image.fromarray(image)
    in_mem_file = io.BytesIO()
    image.save(in_mem_file, format='JPEG')
    # Reset file pointer to start
    in_mem_file.seek(0)
    img_bytes = in_mem_file.read()
    return img_bytes


class OpenAIWrapper:
    def __init__(self, model_name, max_retry: int = 5, temperature: float = 0.0, reasoning_effort: str = None):
        # Metdata
        self.model_name = model_name
        self.max_retry = max_retry
        self.temperature = temperature
        self.reasoning_effort = reasoning_effort
        
        # Llm client
        self.client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url=os.environ['OPENAI_URL'])
        self.client = instructor.from_openai(self.client, mode=Mode.JSON)
        
        # Set global configurations for litellm
        litellm.logging = True
        litellm.set_verbose = True
    
    @classmethod
    def encode_image(cls, image: np.ndarray) -> str:
        return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')
    
    @traceable(run_type="chain", name="agent_run")
    def predict_mm(
        self,
        text_prompt: str = None,
        images: list[np.ndarray] = [],
        output_format: Optional[BaseModel] = None) -> Union[
        tuple[BaseModel, ChatCompletion], tuple[str, ChatCompletion]]:
        
        final_messages = [{'role': 'user', 'content': [{'type': 'text', 'text': text_prompt}]}]
        
        if images:
            image_blocks = [{
                'type': 'image_url',
                'image_url': {
                    'url': f'data:image/jpeg;base64,{self.encode_image(image)}',
                },
            } for image in images]
            final_messages[-1]['content'].extend(image_blocks)

        # Register hooks to capture raw response
        self.response_before_parse = None
        def log_completion_response(response):
            self.response_before_parse = response.choices[0].message.content
        
        def log_completion_error(error):  
            print(f"Error on completion: {error}")
            print(f"Model raw response: {self.response_before_parse}")

        self.client.on("completion:response", log_completion_response)  
        self.client.on("completion:error", log_completion_error)
        
        # Prepare common kwargs
        common_kwargs = {
            'model': self.model_name,
            'messages': final_messages,
            'temperature': self.temperature
        }
        
        # Add reasoning_effort if specified (for GPT-5)
        if self.reasoning_effort:
            common_kwargs['reasoning_effort'] = self.reasoning_effort
        
        if output_format is not None:
            response, completion = self.client.chat.completions.create_with_completion(
                response_model=output_format,
                max_retries=self.max_retry,
                **common_kwargs
            )
            assert isinstance(response, output_format)
        else:
            completion = self.client.chat.completions.create(
                response_model=output_format,
                max_retries=1,
                **common_kwargs
            )
            response = completion.choices[0].message.content
        return response, completion


class DoubaoWrapper:
    def __init__(self, model_name, max_retry: int = 5, temperature: float = 0.0):
        # Metdata
        self.model_name = model_name
        self.max_retry = max_retry
        self.temperature = temperature
        
        # Llm client
        self.client = openai.OpenAI(base_url='https://ark.cn-beijing.volces.com/api/v3',
                                    api_key='098be844-30f7-4f06-83d6-25d3d8a2ca25')
        self.client = instructor.from_openai(self.client, mode=Mode.JSON)
        
        # Set global configurations for litellm
        litellm.logging = True
        litellm.set_verbose = True
    
    @classmethod
    def encode_image(cls, image: np.ndarray) -> str:
        return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')
    
    @traceable(run_type="chain", name="agent_run")
    def predict_mm(
        self,
        text_prompt: str = None,
        images: list[np.ndarray] = [],
        output_format: Optional[BaseModel] = None) -> Union[
        tuple[BaseModel, ChatCompletion], tuple[str, ChatCompletion]]:
        
        final_messages = [{'role': 'user', 'content': [{'type': 'text', 'text': text_prompt}]}]
        
        if images:
            image_blocks = [{
                'type': 'image_url',
                'image_url': {
                    'url': f'data:image/jpeg;base64,{self.encode_image(image)}',
                },
            } for image in images]
            final_messages[-1]['content'].extend(image_blocks)
        
        if output_format is not None:
            response, completion = self.client.chat.completions.create_with_completion(
                model=self.model_name,
                messages=final_messages,
                response_model=output_format,
                max_retries=self.max_retry,
                temperature=self.temperature
            )
            assert isinstance(response, output_format)
        else:
            completion = self.client.chat.completions.create(
                model=self.model_name,
                messages=final_messages,
                response_model=output_format,
                max_retries=1,
                temperature=self.temperature
            )
            response = completion.choices[0].message.content
        return response, completion


def get_llm_wrapper(model_name: str) -> Union[OpenAIWrapper, DoubaoWrapper]:
    # Handle GPT-5 with reasoning effort suffix
    if model_name.startswith('gpt-5-') and model_name not in ['gpt-5-preview', 'gpt-5-turbo']:
        parts = model_name.split('-')
        if len(parts) >= 3 and parts[2] in ['low', 'medium', 'high']:
            reasoning_effort = parts[2]
            return OpenAIWrapper('gpt-5', reasoning_effort=reasoning_effort)
    
    if model_name == 'gpt-4.1':
        return OpenAIWrapper('gpt-4.1-2025-04-14')
    
    if model_name == 'gpt-4o-1120':
        return OpenAIWrapper('gpt-4o-2024-11-20')
    
    if model_name == 'claude-sonnet-4-5':
        return OpenAIWrapper('claude-sonnet-4-5-20250929-thinking')  # proxy

    if model_name == 'gemini-2.5-pro':
        return OpenAIWrapper('gemini-2.5-pro-thinking')  # proxy
    
    if model_name.startswith('doubao'):
        return DoubaoWrapper(model_name)
    
    return OpenAIWrapper(model_name)