import os

from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.tools import tool
from langgraph.graph import START, END, StateGraph

from utils.llm_utils import LLM_Factory

from rag_assistant.rag_utils import deduplicate_and_format_sources, tavily_search, format_sources, perplexity_search, duckduckgo_search, finalize_summary, route_research
from rag_assistant.rag_state import SummaryState, SummaryStateInput, SummaryStateOutput
from rag_assistant.rag_prompts import query_writer_instructions, summarizer_instructions, reflection_instructions
from dotenv import load_dotenv
from utils.utils import HiddenPrints

# Nodes
def generate_query(state: SummaryState):
    """ Generate a query for web search """

    # Format the prompt
    query_writer_instructions_formatted = query_writer_instructions.format(research_topic=state.research_topic)

    # Instantiate agent
    llm = LLM_Factory.initialize_llm()

    json_schema = {
        "title": "search_query",
        "description": "Schema for generating a search query for web research.",
        "type": "object",
        "properties": {
            "query": {
                "type": "string",
                "description": "The search query to be used for web research.",
            },
            "aspect": {
                "type": "string",
                "description": "The specific aspect or focus of the query.",
            },
            "rationale": {
                "type": "string",
                "description": "The reason or rationale behind the query.",
            },
        },
        "required": ["query", "aspect", "rationale"],
    }
    structured_llm = llm.with_structured_output(schema=json_schema, include_raw=True)

    result = structured_llm.invoke(
        [SystemMessage(content=query_writer_instructions_formatted),
        HumanMessage(content=f"Generate a query for web search:")]
    )

    # print(f"Query agent: \n    Query: {result['query']} ,\n    Aspect: {result['aspect']},\n    Rationale: {result['rationale']}\n")

    input_token_count = state.input_token_count + result["raw"].usage_metadata["input_tokens"]
    output_token_count = state.output_token_count + result["raw"].usage_metadata["output_tokens"]

    return {"search_query": result["parsed"]['query'], "input_token_count": input_token_count, "output_token_count": output_token_count}

def web_research(state: SummaryState):
    """ Gather information from the web """

    # Load environment variables from the .env file in the parent directory
    load_dotenv(os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env'))

    # Get the search API from the environment variables
    search_api = os.getenv("SEARCH_API", "duckduckgo")

    # Search the web
    with HiddenPrints():
        if search_api == "tavily":
            search_results = tavily_search(state.search_query, include_raw_content=True, max_results=1)
            search_str = deduplicate_and_format_sources(search_results, max_tokens_per_source=1000, include_raw_content=True)
        elif search_api == "perplexity":
            search_results = perplexity_search(state.search_query, state.research_loop_count)
            search_str = deduplicate_and_format_sources(search_results, max_tokens_per_source=1000, include_raw_content=False)
        elif search_api == "duckduckgo":
            search_results = duckduckgo_search(state.search_query, max_results=3, fetch_full_page=True) #os.getenv("FETCH_FULL_PAGE", "True"))
            search_str = deduplicate_and_format_sources(search_results, max_tokens_per_source=1000, include_raw_content=False)
            #print("Search results: ", search_results)
            #print(search_str)
        else:
            raise ValueError(f"Unsupported search API")

    return {"sources_gathered": [format_sources(search_results)], "research_loop_count": state.research_loop_count + 1, "web_research_results": [search_str]}

def summarize_sources(state: SummaryState):
    """ Summarize the gathered sources """

    # Existing summary
    existing_summary = state.running_summary

    # Most recent web research
    most_recent_web_research = state.web_research_results[-1]

    # Build the human message
    if existing_summary:
        human_message_content = (
            f"<User Input> \n {state.research_topic} \n <User Input>\n\n"
            f"<Existing Summary> \n {existing_summary} \n <Existing Summary>\n\n"
            f"<New Search Results> \n {most_recent_web_research} \n <New Search Results>"
        )
    else:
        human_message_content = (
            f"<User Input> \n {state.research_topic} \n <User Input>\n\n"
            f"<Search Results> \n {most_recent_web_research} \n <Search Results>"
        )

    # Run the LLM
    llm = LLM_Factory.initialize_llm()

    #print("Human message content: ", human_message_content)
    result = llm.invoke(
        [SystemMessage(content=summarizer_instructions),
        HumanMessage(content=human_message_content)]
    )

    # print(f"Summarizing agent: \n    ", result.content, "\n")

    input_token_count = state.input_token_count + result.usage_metadata["input_tokens"]
    output_token_count = state.output_token_count + result.usage_metadata["output_tokens"]

    return {"running_summary": result.content, "input_token_count": input_token_count, "output_token_count": output_token_count}

def reflect_on_summary(state: SummaryState):
    """ Reflect on the summary and generate a follow-up query """
    
    llm = LLM_Factory.initialize_llm()

    json_schema = {
        "title": "follow_up_query",
        "description": "Schema for identifying knowledge gaps and generating a follow-up query.",
        "type": "object",
        "properties": {
            "knowledge_gap": {
                "type": "string",
                "description": "The identified gap in the current knowledge or summary.",
            },
            "follow_up_query": {
                "type": "string",
                "description": "A follow-up query to address the identified knowledge gap.",
            },
        },
        "required": ["knowledge_gap", "follow_up_query"],
    }

    structured_llm = llm.with_structured_output(schema=json_schema, include_raw=True)

    result = structured_llm.invoke(
        [SystemMessage(content=reflection_instructions.format(research_topic=state.research_topic)),
        HumanMessage(content=f"Identify a knowledge gap and generate a follow-up web search query based on our existing knowledge: {state.running_summary}")]
    )

    # print(f"Reflection agent: \n    Knowledge gap: {result['knowledge_gap']},\n    Follow-up query: {result['follow_up_query']}\n")
    input_token_count = state.input_token_count + result["raw"].usage_metadata["input_tokens"]
    output_token_count = state.output_token_count + result["raw"].usage_metadata["output_tokens"]

    # Get the follow-up query
    follow_up_query = result["parsed"]['follow_up_query']

    # Update search query with follow-up query
    return {"search_query": follow_up_query, "input_token_count": input_token_count, "output_token_count": output_token_count}

# We define a function and a langchain decorator for the web search agentic rag system
# It is a callable lanchain StructuredTool object, which incorporates the docstring as a description
# response_format="content_and_artifact" allows to return the rag summary directly to the model as a string, but still allows to access token count metadata
@tool(response_format="content_and_artifact")
def rag_assistant(research_topic: str) -> str:
    """Perform an in-depth web search on a given topic and return a summarized result.

    Args:
        research_topic (str): The topic to research using web search APIs like DuckDuckGo. The web search has trouble with acronyms, write this query in clear natural language.

    This function invokes a team of agents that iteratively:
    - Generate a search query.
    - Perform web research.
    - Summarize the gathered sources.
    - Reflect on the summary for improvements.
    - Finalize the summary.

    Returns:
        str: A summary of the research findings along with a list of sources.
    """

    # Add nodes and edges
    builder = StateGraph(SummaryState, input=SummaryStateInput, output=SummaryStateOutput)
    builder.add_node("generate_query", generate_query)
    builder.add_node("web_research", web_research)
    builder.add_node("summarize_sources", summarize_sources)
    builder.add_node("reflect_on_summary", reflect_on_summary)
    builder.add_node("finalize_summary", finalize_summary)

    # Add edges
    builder.add_edge(START, "generate_query")
    builder.add_edge("generate_query", "web_research")
    builder.add_edge("web_research", "summarize_sources")
    builder.add_edge("summarize_sources", "reflect_on_summary")
    builder.add_conditional_edges("reflect_on_summary", route_research)
    builder.add_edge("finalize_summary", END)

    graph = builder.compile()

    output = graph.invoke({"research_topic": research_topic})

    return output["running_summary"], {"input_token_count": output["input_token_count"], "output_token_count": output["output_token_count"]}