from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.messages import HumanMessage
from ..llms import get_llm
from ...utils.settings import settings
from .traffic_rule_checker import traffic_rule_agent
from .traffic_accident_retriever import traffic_accident_agent

llm = get_llm(settings.app.llm['main'])

class AgentState(MessagesState):
    next: str

def rule_node(state: AgentState):
    """
    Process the traffic rule checking agent.
    """
    query = state['messages'][-1].content
    response = traffic_rule_agent.invoke({"query": query})
    final_response = [HumanMessage(content=f"{response['result']}", name="rule_agent")]
    return {"messages": final_response}

def accident_node(state: AgentState):
    """
    Process the traffic accident retrieval agent.
    """
    query = state['messages'][-1].content
    response = traffic_accident_agent.invoke({"scene": query})
    final_response = [HumanMessage(content=f"{response['consequences']}", name="accident_agent")]
    return {"messages": final_response}


def analyze_node(state: AgentState):
    """
    Process the accident analysis agent.
    """
    response = llm.invoke(state['messages'])
    final_response = [HumanMessage(content=response.content, name="analyzer_agent")]
    return {"messages": final_response}


members = ["rule_agent", "accident_agent", "analyzer_agent"]
options = members + ["FINISH"]

from typing import Literal
from typing_extensions import TypedDict

class Router(TypedDict):
    next: Literal[*options]


def supervisor(state: AgentState):
    system_prompt = {
        "You are a supervisor agent tasked with managing a conversation between the following workers: {members}. \n\n"
        "Each worker has a specific role:\n"
        "- rule_agent: Checks if the current traffic situation violates any rules.\n"
        "- accident_agent: Retrieves information about potential accidents in the current scene.\n"
        "- analyzer_agent: Give driving suggestions for ego vehicle based on the output of rule agent and accident agent.\n\n"
         "Given the following user request, respond with the worker to act next."
        " Each worker will perform a task and respond with their results and status."
        " When finished, respond with FINISH."
    }

    messages = [{"role": "system", "content": system_prompt},] + state["messages"]
    response = llm.with_structured_output(Router).invoke(messages)
    next_ = response["next"]
    
    if next_ == "FINISH":
        next_ = END
    
    return {"next": next_}


builder = StateGraph(AgentState) 

builder.add_node("supervisor", supervisor)
builder.add_node("rule_agent", rule_node)
builder.add_node("accident_agent", accident_node)
builder.add_node("analyzer_agent", analyze_node)
for member in members:
    builder.add_edge(member, "supervisor")
builder.add_conditional_edges("supervisor", lambda state: state['next'])
builder.add_edge(START, "supervisor")

graph = builder.compile()