import os
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_groq import ChatGroq
from typing import Optional

from ..utils.settings import settings
from ..utils.log import logger

def get_llm(model_id: str) -> BaseChatModel:
    if model_id not in settings.supported_models.llm:
        logger.debug(f"Supported models: {', '.join(settings.supported_models.llm)}")
        raise ValueError(f"Model {model_id} is not supported. Supported models are: {', '.join(settings.supported_models.llm)}")
    
    platform, model_name = model_id.split(":", 1)
    
    if platform == "gateway":
        env_api_key = os.getenv("LLM_GATEWAY_API_KEY")
        if not env_api_key:
            raise ValueError("LLM_GATEWAY_API_KEY environment variable is required for gateway models")
        # api_key = env_api_key if env_api_key.startswith("Bearer ") else ("Bearer " + env_api_key)
        llm = ChatOpenAI(
            model=model_name,
            temperature=settings.llm_settings['gateway']['temperature'],
            api_key=env_api_key,
            base_url=settings.llm_settings['gateway']['base_url'],
        )
        return llm
    elif platform == "openai":
        # o3 models only support temperature=1.0 (default)
        if model_name in ["o3", "o3-mini", "gpt-5"]:
            temperature = 1.0
        else:
            temperature = settings.llm_settings['openai']['temperature']
        
        llm = ChatOpenAI(
            model=model_name,
            temperature=temperature,
        )
        return llm
    elif platform == "groq":
        llm = ChatGroq(
            model=model_name,
            temperature=settings.llm_settings['groq']['temperature'],
        )
        return llm
    else:
        raise ValueError(f"Unsupported platform: {platform}")
    
def get_embedding(model_id: str):
    if model_id not in settings.supported_models.embedding:
        logger.debug(f"Supported embedding models: {', '.join(settings.supported_models.embedding)}")
        return None
    
    platform, model_name = model_id.split(":", 1)
    
    if platform == "openai":
        embedding = OpenAIEmbeddings(
            model=model_name,
        )
        return embedding
    else:
        raise ValueError(f"Unsupported embedding platform: {platform}")