from typing import Any, Literal
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, MessagesState, StateGraph
from recipe.langgraph_agent.chat_model import (
    ChatModel,
    MaxTokenExceededError,
    convert_to_agent_output,
)
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
async def call_api(text: str) -> float:
    return float(len(text) % 10) / 10.0
async def call_model(state: MessagesState, config: RunnableConfig):
    model = config["configurable"]["model"]
    sampling_params = config["configurable"]["sampling_params"]
    try:
        message = await model.ainvoke(state["messages"], sampling_params=sampling_params)
        return {"messages": [message]}
    except MaxTokenExceededError:
        return {"messages": []}
async def eval_node(state: MessagesState, config: RunnableConfig):
    last_message = state["messages"][-1]
    answer_text = getattr(last_message, "content", str(last_message))
    score = await call_api(answer_text)
    return {"messages": [{"type": "tool", "content": f"Score: {score}"}]}
def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["eval", END]:
    max_assistant_turns = config["configurable"]["max_assistant_turns"]
    num_assistant_turns = sum(1 for m in state["messages"] if m.type == "ai")
    last_message = state["messages"][-1]
    if last_message.type == "tool":
        if max_assistant_turns and num_assistant_turns >= max_assistant_turns:
            return END
        return "agent"
    return "eval"
class EvaluatedAgentLoop(AgentLoopBase):
    @classmethod
    def init_class(cls, config, tokenizer, **kwargs):
        if cls._class_initialized:
            return
        cls._class_initialized = True
        print("Performing class-level EvaluatedAgentLoop initialization")
        cls.graph = cls.build_graph()
    @classmethod
    def build_graph(cls) -> StateGraph:
        workflow = StateGraph(MessagesState)
        workflow.add_node("agent", call_model)
        workflow.add_node("eval", eval_node)
        workflow.set_entry_point("agent")
        workflow.add_conditional_edges(
            "agent",
            should_continue,
            {
                "eval": "eval",
                END: END,
            },
        )
        workflow.add_conditional_edges(
            "eval",
            should_continue,
            {
                "agent": "agent",
                END: END,
            },
        )
        return workflow.compile()
    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        messages = list(kwargs["raw_prompt"])
        model_path = self.config.actor_rollout_ref.model.path
        model_name = "/".join(model_path.split("/")[-2:])
        rollout = self.config.actor_rollout_ref.rollout
        model = ChatModel(
            model=model_name,
            client=self.server_manager,
            tokenizer=self.tokenizer,
            max_tokens=rollout.response_length,
            max_parallel_calls=rollout.multi_turn.max_parallel_calls,
            tool_parser=rollout.multi_turn.format,
        )
        model = model.bind_tools(self.tools, tool_choice="any")
        config = {
            "configurable": {
                "model": model,
                "sampling_params": sampling_params,
                "max_user_turns": rollout.multi_turn.max_user_turns,
                "max_assistant_turns": rollout.multi_turn.max_assistant_turns,
            }
        }
        state = await self.graph.ainvoke(input={"messages": messages}, config=config)
        output = convert_to_agent_output(state["messages"], rollout.response_length)
        return output