from src.llm.llms import get_llm, get_embedding
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

messages = ChatPromptTemplate([
    ("system", "You are a helpful assistant."),
    ("user", "Knock knock."),
])

def test_get_llm_default():
    import pytest
    with pytest.raises(ValueError, match="Model random_model is not supported"):
        get_llm("random_model")
    
def llm_knock_test(llm):
    chain = messages | llm | StrOutputParser()
    res = chain.invoke({})
    assert 'there' in res.lower(), "Knock knock response not found in the output"
    
def langchain_embedding_test(embedding):
    text = "This is a test sentence."
    embedding_vector = embedding.embed_query(text)
    assert isinstance(embedding_vector, list), "Expected embedding vector to be a list"
    assert len(embedding_vector) > 0, "Expected non-empty embedding vector"
    
def test_get_llm_openai_gpt_4o():
    llm = get_llm("openai:gpt-4o")
    assert llm is not None, "Expected LLM instance for supported model"
    llm_knock_test(llm)
    
def test_get_embedding_openai_text_embedding_3_large():
    embedding = get_embedding("openai:text-embedding-3-large")
    langchain_embedding_test(embedding)
    
def test_get_llm_groq_llama_3_3():
    llm = get_llm("groq:llama-3.3-70b-versatile")
    assert llm is not None, "Expected LLM instance for supported model"
    llm_knock_test(llm)

def test_get_llm_gateway_gpt_5():
    llm = get_llm("gateway:routeway-discount/gpt-5")
    assert llm is not None, "Expected LLM instance for supported model"
    llm_knock_test(llm)

def test_get_llm_gateway_openai_gpt_4_1():
    llm = get_llm("gateway:openai/gpt-4.1")
    assert llm is not None, "Expected LLM instance for supported model"
    llm_knock_test(llm)

def test_get_llm_openai_o3():
    llm = get_llm("openai:o3")
    assert llm is not None, "Expected LLM instance for supported model"
    assert llm.temperature == 1.0, "o3 model should use temperature=1.0"
    llm_knock_test(llm)

def test_temperature_settings():
    # Test that o3 uses temperature=1.0
    o3_llm = get_llm("openai:o3")
    assert o3_llm.temperature == 1.0, "o3 should use temperature=1.0"
    
    # Test that other models use temperature=0.0
    gpt4o_llm = get_llm("openai:gpt-4o")
    assert gpt4o_llm.temperature == 0.0, "gpt-4o should use temperature=0.0"