import os
import json
import base64
import platformdirs
from typing import List, Union, Dict, Any, TypeVar
from openai import AzureOpenAI
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)

from epc_aw.models.formatters import QueryAnalysis

from .base import EngineLM, CachedEngine
from .engine_utils import get_image_type_from_bytes

T = TypeVar('T', bound='BaseModel')

def validate_structured_output_model(model_string: str) -> bool:
    """Check if the model supports structured outputs."""
    # Azure OpenAI models that support structured outputs
    return any(x in model_string.lower() for x in ["gpt-4"])

def validate_chat_model(model_string: str) -> bool:
    """Check if the model is a chat model."""
    return any(x in model_string.lower() for x in ["gpt"])

def validate_reasoning_model(model_string: str) -> bool:
    """Check if the model is a reasoning model."""
    # Azure OpenAI doesn't have specific reasoning models like OpenAI
    return False

def validate_pro_reasoning_model(model_string: str) -> bool:
    """Check if the model is a pro reasoning model."""
    # Azure OpenAI doesn't have pro reasoning models
    return False

class ChatAzureOpenAI(EngineLM, CachedEngine):
    """
    Azure OpenAI API implementation of the EngineLM interface.
    """
    DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

    def __init__(
        self,
        model_string: str = "gpt-4",
        use_cache: bool = False,
        system_prompt: str = DEFAULT_SYSTEM_PROMPT,
        is_multimodal: bool = False,
        **kwargs
    ):
        """
        Initialize the Azure OpenAI engine.
        
        Args:
            model_string: The name of the Azure OpenAI deployment
            use_cache: Whether to use caching
            system_prompt: The system prompt to use
            is_multimodal: Whether to enable multimodal capabilities
            **kwargs: Additional arguments to pass to the AzureOpenAI client
        """
        self.model_string = model_string
        self.use_cache = use_cache
        self.system_prompt = system_prompt
        self.is_multimodal = is_multimodal

        # Set model capabilities
        self.support_structured_output = validate_structured_output_model(self.model_string)
        self.is_chat_model = validate_chat_model(self.model_string)
        self.is_reasoning_model = validate_reasoning_model(self.model_string)
        self.is_pro_reasoning_model = validate_pro_reasoning_model(self.model_string)

        # Set up caching if enabled
        if self.use_cache:
            root = platformdirs.user_cache_dir("agentflow")
            cache_path = os.path.join(root, f"cache_azure_openai_{model_string}.db")
            self.image_cache_dir = os.path.join(root, "image_cache")
            os.makedirs(self.image_cache_dir, exist_ok=True)
            super().__init__(cache_path=cache_path)

        # Validate required environment variables
        if not os.getenv("AZURE_OPENAI_API_KEY"):
            raise ValueError("Please set the AZURE_OPENAI_API_KEY environment variable.")
        if not os.getenv("AZURE_OPENAI_ENDPOINT"):
            raise ValueError("Please set the AZURE_OPENAI_ENDPOINT environment variable.")

        # Initialize Azure OpenAI client
        self.client = AzureOpenAI(
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview"),
            azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
        )
        
        # Set default kwargs
        self.default_kwargs = kwargs

    def __call__(self, prompt, **kwargs):
        """
        Handle direct calls to the instance (e.g., model(prompt)).
        Forwards the call to the generate method.
        """
        return self.generate(prompt, **kwargs)

    def _format_content(self, content: List[Union[str, bytes]]) -> List[Dict[str, Any]]:
        """Format content for the OpenAI API."""
        formatted_content = []
        for item in content:
            if isinstance(item, str):
                formatted_content.append({"type": "text", "text": item})
            elif isinstance(item, bytes):
                # For images, encode as base64
                image_type = get_image_type_from_bytes(item)
                if image_type:
                    base64_image = base64.b64encode(item).decode('utf-8')
                    formatted_content.append({
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/{image_type};base64,{base64_image}",
                            "detail": "auto"
                        }
                    })
            elif isinstance(item, dict) and "type" in item:
                # Already formatted content
                formatted_content.append(item)
        return formatted_content

    @retry(
        wait=wait_random_exponential(min=1, max=5),
        stop=stop_after_attempt(5),
    )
    def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs):
        try:
            if isinstance(content, str):
                return self._generate_text(content, system_prompt=system_prompt, **kwargs)
            elif isinstance(content, list):
                if not self.is_multimodal:
                    raise NotImplementedError(f"Multimodal generation is not supported for {self.model_string}.")
                return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs)
        except Exception as e:
            print(f"Error in generate method: {str(e)}")
            print(f"Error type: {type(e).__name__}")
            print(f"Error details: {e.args}")
            return {
                "error": type(e).__name__,
                "message": str(e),
                "details": getattr(e, 'args', None)
            }

    def _generate_text(
        self,
        prompt: str,
        system_prompt: str = None,
        temperature: float = 0,
        max_tokens: int = 4000,
        top_p: float = 0.99,
        response_format: dict = None,
        **kwargs,
    ) -> str:
        """
        Generate a response from the Azure OpenAI API.
        """
        sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

        if self.use_cache:
            cache_key = sys_prompt_arg + prompt
            cache_or_none = self._check_cache(cache_key)
            if cache_or_none is not None:
                return cache_or_none

        
        # Chat models with structured output format
        if self.is_chat_model and self.support_structured_output and response_format is not None:
            response = self.client.beta.chat.completions.parse(
                model=self.model_string,
                messages=[
                    {"role": "system", "content": sys_prompt_arg},
                    {"role": "user", "content": prompt},
                ],
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                response_format=response_format,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            response = response.choices[0].message.parsed

        # Chat models without structured outputs
        elif self.is_chat_model and (not self.support_structured_output or response_format is None):
            response = self.client.chat.completions.create(
                model=self.model_string,
                messages=[
                    {"role": "system", "content": sys_prompt_arg},
                    {"role": "user", "content": prompt},
                ],
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            response = response.choices[0].message.content

        # Reasoning models: currently only supports base response
        elif self.is_reasoning_model:
            print(f"Using reasoning model: {self.model_string}")
            response = self.client.chat.completions.create(
                model=self.model_string,
                messages=[
                    {"role": "user", "content": prompt},
                ],
                max_completion_tokens=max_tokens,
                reasoning_effort="medium",
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            # Workaround for handling length finish reason
            if hasattr(response.choices[0], 'finish_reason') and response.choices[0].finish_reason == "length":
                response = "Token limit exceeded"
            else:
                response = response.choices[0].message.content
                
        # Fallback for other model types
        else:
            response = self.client.chat.completions.create(
                model=self.model_string,
                messages=[
                    {"role": "system", "content": sys_prompt_arg},
                    {"role": "user", "content": prompt},
                ],
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            response = response.choices[0].message.content
        
        # Cache the response if caching is enabled
        if self.use_cache:
            self._add_to_cache(cache_key, response)
            
        return response

    def _generate_multimodal(
        self,
        content: List[Union[str, bytes]],
        system_prompt: str = None,
        temperature: float = 0,
        max_tokens: int = 4000,
        top_p: float = 0.99,
        response_format: dict = None,
        **kwargs,
    ) -> str:
        """
        Generate a response from multiple input types (text and images).
        """
        if not self.is_multimodal:
            raise ValueError("Multimodal input is not supported by this model.")
            
        sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
        formatted_content = self._format_content(content)

        if self.use_cache:
            cache_key = sys_prompt_arg + json.dumps(formatted_content)
            cache_or_none = self._check_cache(cache_key)
            if cache_or_none is not None:
                return cache_or_none
        
        
        messages = [
            {"role": "system", "content": sys_prompt_arg},
            {"role": "user", "content": formatted_content},
        ]
        
        # Chat models with structured output format
        if self.is_chat_model and self.support_structured_output and response_format is not None:
            response = self.client.beta.chat.completions.parse(
                model=self.model_string,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                response_format=response_format,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            response_content = response.choices[0].message.parsed
        
        # Standard chat completion
        elif self.is_chat_model and (not self.support_structured_output or response_format is None):
            response = self.client.chat.completions.create(
                model=self.model_string,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            response_content = response.choices[0].message.content
            
        # Reasoning models: currently only supports base response
        elif self.is_reasoning_model:
            print(f"Using reasoning model: {self.model_string}")
            response = self.client.chat.completions.create(
                model=self.model_string,
                messages=[
                    {"role": "user", "content": formatted_content},
                ],
                max_completion_tokens=max_tokens,
                reasoning_effort="medium",
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            # Workaround for handling length finish reason
            if hasattr(response.choices[0], 'finish_reason') and response.choices[0].finish_reason == "length":
                response_content = "Token limit exceeded"
            else:
                response_content = response.choices[0].message.content
                
        # Fallback for other model types
        else:
            response = self.client.chat.completions.create(
                model=self.model_string,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=0,
                presence_penalty=0,
                stop=None
            )
            response_content = response.choices[0].message.content
        
        # Cache the response if caching is enabled
        if self.use_cache:
            self._add_to_cache(cache_key, response_content)
            
        return response_content
