import logging
from . import backend_deepseek, backend_openai
from .utils import FunctionSpec, OutputType, PromptType, compile_prompt_to_md

logger = logging.getLogger("comind")


def determine_provider(model: str) -> str:
    if model.startswith("gpt-") or model.startswith("o1-") or model.startswith("o4-"):
        return "openai"
    elif model.startswith("claude-"):
        return "anthropic"
    elif model.startswith("gemini-"):
        return "gdm"
    elif model.startswith("deepseek-"):
        return "deepseek"
    # all other models are handle by openrouter
    else:
        return "openrouter"


provider_to_query_func = {
    "deepseek": backend_deepseek.query,
    "openai": backend_openai.query,
}

provider_to_raw_query_func = {
    "deepseek": backend_deepseek.raw_query,
    "openai": backend_openai.raw_query,
}

def query(
    system_message: PromptType | None,
    user_message: PromptType | None,
    model: str,
    temperature: float | None = None,
    max_tokens: int | None = None,
    func_spec: FunctionSpec | None = None,
    convert_system_to_user: bool = False,
    **model_kwargs,
) -> OutputType:
    """
    General LLM query for various backends with a single system and user message.
    Supports function calling for some backends.

    Args:
        system_message (PromptType | None): Uncompiled system message (will generate a message following the OpenAI/Anthropic format)
        user_message (PromptType | None): Uncompiled user message (will generate a message following the OpenAI/Anthropic format)
        model (str): string identifier for the model to use (e.g. "gpt-4-turbo")
        temperature (float | None, optional): Temperature to sample at. Defaults to the model-specific default.
        max_tokens (int | None, optional): Maximum number of tokens to generate. Defaults to the model-specific max tokens.
        func_spec (FunctionSpec | None, optional): Optional FunctionSpec object defining a function call. If given, the return value will be a dict.

    Returns:
        OutputType: A string completion if func_spec is None, otherwise a dict with the function call details.
    """

    model_kwargs = model_kwargs | {
        "model": model,
        "temperature": temperature,
        "max_tokens": max_tokens,
    }

    logger.info("---Querying model---", extra={"verbose": True})
    system_message = compile_prompt_to_md(system_message) if system_message else None
    if system_message:
        logger.info(f"system: {system_message}", extra={"verbose": True})
    user_message = compile_prompt_to_md(user_message) if user_message else None
    if user_message:
        logger.info(f"user: {user_message}", extra={"verbose": True})
    if func_spec:
        logger.info(f"function spec: {func_spec.to_dict()}", extra={"verbose": True})

    provider = determine_provider(model)
    query_func = provider_to_query_func[provider]
    output, req_time, in_tok_count, out_tok_count, info = query_func(
        system_message=system_message,
        user_message=user_message,
        func_spec=func_spec,
        convert_system_to_user=convert_system_to_user,
        **model_kwargs,
    )
    logger.info(f"response: {output}", extra={"verbose": True})
    logger.info(f"---Query complete---", extra={"verbose": True})

    return output

class Conversation:
    def __init__(
        self, 
        model: str,
        temperature: float | None = None,
        max_tokens: int | None = None,
        **model_kwargs,
    ):
        self.messages = []
        self.model = model

        self.model_kwargs = model_kwargs | {
            "model": model,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }

        provider = determine_provider(model)
        self.query_func = provider_to_raw_query_func[provider]
    
    def add_message(
        self, 
        system_message: PromptType | None = None, 
        user_message: PromptType | None = None,
        convert_system_to_user: bool = False
    ):
        if system_message:
            system_message = compile_prompt_to_md(system_message)
            if convert_system_to_user:
                self.messages.append({"role": "user", "content": system_message})
            else:
                self.messages.append({"role": "system", "content": system_message})
        if user_message:
            user_message = compile_prompt_to_md(user_message)
            self.messages.append({"role": "user", "content": user_message})
    
    def pop_message(self):
        if len(self.messages) > 0:
            self.messages.pop()
    
    def query(
        self,
        func_spec: FunctionSpec | None = None,
    ):
        output, req_time, in_tok_count, out_tok_count, info = self.query_func(
            messages=self.messages,
            func_spec=func_spec,
            **self.model_kwargs,
        )

        self.messages.append({"role": "assistant", "content": output})
        return output