import os
from langchain_neo4j import Neo4jGraph
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.language_models.chat_models import BaseChatModel

from ...llm.agent.agent_prompt import traffic_scene_prompt
from ...utils.settings import settings
from .constants import chat_mode_settings
from .QA_integration import get_neo4j_retriever, create_document_retriever_chain, retrieve_documents


graph = Neo4jGraph(
    url=settings.database['neo4j']['uri'],
    database=settings.database['neo4j']['database'],
    username=settings.database['neo4j']['username'],
    password=os.getenv('NEO4J_PASSWORD'),
    refresh_schema=True,
    sanitize=True,
)

class Neo4jRetriever(BaseRetriever):
    llm: BaseChatModel 

    def _get_relevant_documents(self, query: str) -> list[Document]:
        message = traffic_scene_prompt.invoke({"traffic_scene": query}).to_messages()
        retriever = get_neo4j_retriever(
            graph=graph,
            chat_mode_settings=chat_mode_settings,
            document_names=[]
        )
        doc_retriever = create_document_retriever_chain(self.llm, retriever)
        docs,transformed_question = retrieve_documents(doc_retriever, message) 

        sorted_docs = sorted(docs, key=lambda doc: doc.state.get("query_similarity_score", 0), reverse=True)
        documents = []
        for doc in sorted_docs:
            metadata = {}
            metadata['similarity'] = doc.state.get("query_similarity_score", 0)
            metadata.update(doc.metadata)
            document = Document(
                page_content=doc.page_content,
                metadata=metadata
            )
            documents.append(document)
        return documents
           
