from langchain_community.retrievers import ArxivRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langchain_community.tools.tavily_search import TavilySearchResults
from typing import List
from typing_extensions import TypedDict
from langchain.schema import Document
from pprint import pprint
from langchain.schema import Document
from langgraph.graph import END, StateGraph, START
import re, os
from os.path import isdir

# API_URL = 'http://localhost:8001/v1'
# model_name_r = "qwen2.5:32b-instruct-q8_0"
# llm = ChatOpenAI(model_name=model_name_r, base_url=API_URL, api_key='EMPTY')
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""
    binary_score: str = Field("Documents are relevant to the question, 'yes' or 'no'")

def init_crag_llm_state(api_url, model_name):
    global llm
    global structured_llm_grader
    global API_URL, model_name_r
    API_URL = api_url
    model_name_r = model_name
    llm = ChatOpenAI(model_name=model_name_r, base_url=API_URL, api_key='EMPTY')
    #Data Model
    structured_llm_grader = llm.with_structured_output(GradeDocuments)
    return llm

def set_crag_llm(llm_model):
    global llm
    global structured_llm_grader
    llm = llm_model
    #Data Model
    structured_llm_grader = llm.with_structured_output(GradeDocuments)
    return llm

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """

    question: str
    generation: str
    web_search: bool
    documents: List[str]

class InformationRetrieverAgent():
    def __init__(self, topic:str):
        self.topic = topic
        self.web_search_tool = TavilySearchResults(max_results=5)
        self.retriever = self._setup_arxiv_ret(self.topic)
        self.retrieval_grader = self._create_grader()
        self.rag_chain = self._create_rag_chain()
        self.question_rewriter = self._create_question_rewriter()
        self.graph = self.init_langraph()

    def _setup_arxiv_ret(self, inp):
        vectorstore_path = '/home/{}/chroma_db'.format(os.getenv("USER"))
        if not isdir(vectorstore_path):
            os.makedirs(vectorstore_path, exist_ok=True)
        print(inp)
        if not os.path.exists(vectorstore_path) or not os.listdir(vectorstore_path):
            print("Vectorstore not found. Building vectorstore...")
            arxiv_retriever = ArxivRetriever(
            load_max_docs=10,
            get_full_documents=True
            )
            arxiv_docs = arxiv_retriever.invoke(inp)
            local_pdf_dir = "/home/{}/Knowledge".format(os.getenv("USER"))
            if not isdir(local_pdf_dir):
                os.makedirs(local_pdf_dir, exist_ok=True)
            local_pdf_paths = [
                os.path.join(local_pdf_dir, file)
                for file in os.listdir(local_pdf_dir)
                if file.endswith(".pdf")
            ]
            local_docs = []
            for pdf_path in local_pdf_paths:
                print(f"Loading and processing local PDF file: {pdf_path}")
                loader = PyPDFLoader(pdf_path)
                local_docs.extend(loader.load())

        # Combine Arxiv and local documents
            all_docs = arxiv_docs + local_docs
            references_pattern = r"(References|REFERENCES|Bibliography|BIBLIOGRAPHY)[\s\S]*"
            for doc in all_docs:
                doc.page_content = re.sub(references_pattern, "", doc.page_content, flags=re.IGNORECASE)
        
            text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
            doc_splits = text_splitter.split_documents(all_docs)

        # Create vector store
            hf_embedding_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code": True})
            vectorstore = Chroma.from_documents(
                documents=doc_splits,
                collection_name="rag_chroma",
                embedding=hf_embedding_model,
                persist_directory='/home/{}/chroma_db'.format(os.getenv("USER"))
            )
            vectorstore.persist()    
        else:
            print("Vectorstore found. Loading existing vectorstore...")
            vectorstore = Chroma(
                collection_name="rag_chroma",
                persist_directory=vectorstore_path,
                embedding_function=HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code": True})
            )
        return vectorstore.as_retriever(search_kwargs={"k": 8})

    
    def _create_grader(self):
        system_prompt = (
            "You are a grader assessing relevance of a retrieved document to a user question. \n"
            "If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n"
            "Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."
        )
        grade_prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("human", "Retrieved document: \n\n {document} \n\n Question: {question}"),
        ])
        return grade_prompt | structured_llm_grader
    
    def _create_rag_chain(self):
        prompt = ChatPromptTemplate.from_messages([
            ("system", "You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question."),
            ("human", "Context: {context} \n\n Question: {question}" )
        ])
        return prompt | llm | StrOutputParser()
    
    def _create_question_rewriter(self):
        sys = """You a question re-writer that converts an input question to a better version that is optimized \n 
     for web search. Look at the input and try to reason about the underlying semantic intent / meaning. ONLY RESPOND WITH THE REWRITED QUESTION NOTHING ELSE"""

        rewriter_prompt = ChatPromptTemplate.from_messages([
            ("system", sys),
            ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question"),
        ])

        return rewriter_prompt | llm | StrOutputParser()
    
    def retrieve(self, state: GraphState):
        """
        Retrieve documents

        Args:
            state (dict): The current graph state

        Returns:
            state (dict): New key added to state, documents, that contains retrieved documents
        """
        print("---RETRIEVE---")
        question = state["question"]

        # Retrieval
        documents = self.retriever.invoke(question)
        return {"documents": documents, "question": question}
    
    def grade_documents(self, state: GraphState):
        """
        Determines whether the retrieved documents are relevant to the question.

        Args:
            state (dict): The current graph state

        Returns:
            state (dict): Updates documents key with only filtered relevant documents
        """

        print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
        question = state["question"]
        documents = state["documents"]

        # Score each doc
        filtered_docs = []
        web_search = "No"
        for d in documents:
            try:
                score = self.retrieval_grader.invoke(
                    {"question": question, "document": d.page_content}
                )
                grade = score.binary_score
                if grade == "yes":
                    print("---GRADE: DOCUMENT RELEVANT---")
                    filtered_docs.append(d)
                else:
                    print("---GRADE: DOCUMENT NOT RELEVANT---")
                    web_search = "Yes"
                    continue
            except Exception as e:
                print(f"Document failed with error: {e}")
                continue
        web_search = "Yes" if len(filtered_docs)== 0 else "No"
        return {"documents": filtered_docs, "question": question, "web_search": web_search}
    
    def transform_query(self,state: GraphState):
        """
        Transform the query to produce a better question.

        Args:
            state (dict): The current graph state

        Returns:
            state (dict): Updates question key with a re-phrased question
        """

        print("---TRANSFORM QUERY---")
        question = state["question"]
        documents = state["documents"]

        # Re-write question
        better_question = self.question_rewriter.invoke({"question": question})
        return {"documents": documents, "question": better_question}
      
    def web_search(self, state: GraphState):
        """
        Web search based on the re-phrased question.

        Args:
            state (dict): The current graph state

        Returns:
            state (dict): Updates documents key with appended web results
        """

        print("---WEB SEARCH---")
        question = state["question"]
        documents = state["documents"]

        # Web search
        print(question)
        do_cs = self.web_search_tool.invoke({"query": question})
        print(do_cs)
        web_results = "\n".join([d.get('content', '') for d in do_cs])
        web_results = Document(page_content=web_results)
        documents.append(web_results)

        return {"documents": documents, "question": question}
    
    def generate(self, state:GraphState):
        """
        Generate answer

        Args:
            state (dict): The current graph state

        Returns:
            state (dict): New key added to state, generation, that contains LLM generation
        """
        print("---GENERATE---")
        question = state["question"]
        documents = state["documents"]

        # RAG generation
        
        generation = self.rag_chain.invoke({"context": documents, "question": question})
        return {"documents": documents, "question": question, "generation": generation}

    def decide_to_generate(self, state: GraphState):
        """
        Determines whether to generate an answer, or re-generate a question.

        Args:
            state (dict): The current graph state

        Returns:
            str: Binary decision for next node to call
        """

        print("---ASSESS GRADED DOCUMENTS---")
        state["question"]
        web_search = state["web_search"]
        state["documents"]

        if web_search == "Yes":
            # All documents have been filtered check_relevance
            # We will re-generate a new query
            print(
                "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
            )
            return "transform_query"
        else:
            # We have relevant documents, so generate answer
            print("---DECISION: GENERATE---")
            return "generate"

    def init_langraph(self):
        # Initialize the LangGraph agent.
        agent = StateGraph(GraphState)
        agent.add_node("retrieve", self.retrieve)
        agent.add_node("grade_documents", self.grade_documents)
        agent.add_node("generate", self.generate)
        agent.add_node("transform_query", self.transform_query)
        agent.add_node("web_search_node", self.web_search)

        # Define graph edges.
        agent.add_edge(START, "retrieve")
        agent.add_edge("retrieve", "grade_documents")
        agent.add_conditional_edges("grade_documents", self.decide_to_generate, {
            "transform_query": "transform_query",
            "generate": "generate"
        })
        agent.add_edge("transform_query", "web_search_node")
        agent.add_edge("web_search_node", "generate")
        agent.add_edge("generate", END)

        return agent.compile()
    
    def run(self, input: str):
        # Invoke the agent with the input.

        inputs = {
            "question": input,
            "generation": "",
            "web_search": "Yes",
            "documents": []  # Start with an empty list; will be populated by 'retrieve'
        }
        print(f"\n--- RUNNING AGENT FOR QUESTION: '{input}' ---")
        final_state = self.graph.invoke(inputs)
        print("--- AGENT RUN COMPLETE ---")
        # Final generation output.
        #final_generation = output.get("generation")
        return final_state

''' 
if __name__ == "__main__":
    search_topic = "Correct"
    agent = InformationRetrieverAgent(search_topic)
    input_question = "What is the best way to determine the state entities for kripke structures?"
    final_output = agent.run(input_question)
    print("Final Generation Output:")
    pprint(final_output)
    '''