from typing import Any, Literal
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from recipe.langgraph_agent.example.codebase_tool_node import StatefulToolNode
from data.repo_server.server import get_repo_simple
from recipe.langgraph_agent.chat_model import (
    ChatModel,
    MaxTokenExceededError,
    convert_to_agent_output,
)
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
async def call_model(state: MessagesState, config: RunnableConfig):
    model = config["configurable"]["model"]
    sampling_params = config["configurable"]["sampling_params"]
    try:
        message = await model.ainvoke(state["messages"], sampling_params=sampling_params)
        return {"messages": [message]}
    except MaxTokenExceededError:
        return {"messages": []}
def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]:
    max_assistant_turns = config["configurable"]["max_assistant_turns"]
    num_assistant_turns = 0
    for message in state["messages"]:
        if message.type == "ai":
            num_assistant_turns += 1
    last_message = state["messages"][-1]
    if last_message.type == "tool":
        return END
    if max_assistant_turns and num_assistant_turns >= max_assistant_turns:
        return END
    if not last_message.tool_calls:
        return END
    return "tools"
class ReactAgentLoop(AgentLoopBase):
    @classmethod
    def init_class(cls, config, tokenizer, **kwargs):
        if cls._class_initialized:
            return
        cls._class_initialized = True
        print("Performing class-level ReactAgentLoop initialization")
        cls.graph = cls.build_graph()
    @classmethod
    def build_graph(cls) -> StateGraph:
        workflow = StateGraph(MessagesState)
        workflow.add_node("agent", call_model)
        workflow.add_node("tools", StatefulToolNode(cls.tools))
        workflow.set_entry_point("agent")
        workflow.add_conditional_edges(
            "agent",
            should_continue,
            {
                "tools": "tools",
                END: END,
            },
        )
        workflow.add_edge("tools", "agent")
        graph = workflow.compile()
        return graph
    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        messages = list(kwargs["raw_prompt"])
        model_path = self.config.actor_rollout_ref.model.path
        model_name = "/".join(model_path.split("/")[-2:])
        rollout = self.config.actor_rollout_ref.rollout
        try:
            teacher = self.config.trainer.teacher
        except:
            teacher = None
        model = ChatModel(
            model=model_name,
            client=self.server_manager,
            tokenizer=self.tokenizer,
            max_tokens=rollout.response_length, 
            max_parallel_calls=rollout.multi_turn.max_parallel_calls,
            tool_parser=rollout.multi_turn.format,
        )
        model = model.bind_tools(self.tools, tool_choice="any")
        def get_repo_id(messages):
            prompt = messages[0]['content']
            repo_id = prompt.split("-+-+-+-+-+-+-+-+-+-+")[-1]
            return repo_id
        def get_entry_point(repo_id):
            result_dict = get_repo_simple(repo_id)
            return result_dict['entry_file'], result_dict["entry_content"], result_dict["repo_path"]
        repo_id = get_repo_id(messages)
        entry_point, entry_content, repo_path = get_entry_point(repo_id)
        config = {
            "recursion_limit": 50, 
            "configurable": {
                "repo_id": repo_id,
                "entry_file": entry_point,
                "entry_content": entry_content,
                "repo_path": repo_path,
                "model": model,
                "sampling_params": sampling_params,
                "max_user_turns": rollout.multi_turn.max_user_turns,
                "max_assistant_turns": rollout.multi_turn.max_assistant_turns,
                "teacher": teacher,
            }
        }
        state = await self.graph.ainvoke(input={"messages": messages}, config=config)
        output = convert_to_agent_output(state["messages"], rollout.response_length)
        return output