{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train ReAct agent with code sandbox\n",
    "\n",
    "In this tutorial, we will demonstrate how to train a [ReAct](https://arxiv.org/abs/2210.03629) agent to solve math problem with code sandbox.\n",
    "\n",
    "The agent works as follows:\n",
    "1. Given a math problem, the agent first query LLM to generate response and tool calls, which are python code to be executed in sandbox.\n",
    "2. If there is a tool call, the agent execute the python code in code sandbox.\n",
    "3. After code execution, the agent get the result from sandbox and append to chat history.\n",
    "4. The agent query LLM again until no tool call or max context length reached.\n",
    "\n",
    "\n",
    "<figure>\n",
    "  <img src=\"https://langchain-ai.github.io/langgraph/agents/assets/agent.png\" alt=\"ReAct\" width=\"400\">\n",
    "  <figcaption style=\"font-style: italic; color: #666;\">\n",
    "    source: <a href=\"https://langchain-ai.github.io/langgraph/agents/overview/\" target=\"_blank\">LangGraph</a>\n",
    "  </figcaption>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Prerequisite\n",
    "\n",
    "To run the examples in this notebook, you need to install the verl package first.\n",
    "```bash\n",
    "git clone https://github.com/volcengine/verl\n",
    "cd verl\n",
    "pip install -e .\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-10-16 23:20:11,956\tINFO worker.py:2004 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n",
      "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py:2052: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import asyncio\n",
    "import sys\n",
    "import tempfile\n",
    "import os\n",
    "import socket\n",
    "import json\n",
    "\n",
    "import requests\n",
    "import ray\n",
    "import fastapi\n",
    "import uvicorn\n",
    "from starlette.requests import Request\n",
    "from starlette.responses import JSONResponse\n",
    "from pprint import pprint\n",
    "\n",
    "import verl\n",
    "\n",
    "ray.init()\n",
    "verl_config_dir = os.path.join(os.path.dirname(verl.__file__), \"trainer/config\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For demo purpose, we will use Qwen/Qwen3-1.7B as the LLM. First, let's download required model and dataset used in this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyarrow.parquet as pq\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "snapshot_download(\n",
    "    repo_id=\"verl-team/lighteval-MATH-preprocessed\",\n",
    "    repo_type=\"dataset\",\n",
    "    local_dir=os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed\"),\n",
    ")\n",
    "snapshot_download(\n",
    "    repo_id=\"Qwen/Qwen3-1.7B\",\n",
    "    repo_type=\"model\",\n",
    "    local_dir=os.path.expanduser(\"~/Qwen/Qwen3-1.7B\"),\n",
    ")\n",
    "\n",
    "model_path = os.path.expanduser(\"~/Qwen/Qwen3-1.7B\")\n",
    "train_file = os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed/train.parquet\")\n",
    "test_file = os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed/test.parquet\")\n",
    "\n",
    "test = pq.read_table(test_file)\n",
    "test_file = os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed/test_100.parquet\")\n",
    "pq.write_table(test[:100], test_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "verl support both vllm and sglang rollout server for high performance inference. This tutorial has been tested on both vllm and sglang, you can choose either of them to run the tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "rollout_name = \"???\"  # vllm or sglang"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Basic tool call\n",
    "For beginning, let's see how we can do basic tool call in verl with example from [Transformer tool use](https://huggingface.co/docs/transformers/main/chat_extras#tool-use). To use tool in verl, we need to define a tool class that inherits from `BaseTool`, and implement the following methods:\n",
    "- `get_openai_tool_schema`: return the schema of the tool in `OpenAIFunctionToolSchema` format.\n",
    "- `execute`: execute the tool with the given parameters, and return the result in `ToolResponse` format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"type\": \"function\",\n",
      "  \"function\": {\n",
      "    \"name\": \"get_current_temperature\",\n",
      "    \"description\": \"Get current temperature at a location.\",\n",
      "    \"parameters\": {\n",
      "      \"type\": \"object\",\n",
      "      \"properties\": {\n",
      "        \"location\": {\n",
      "          \"type\": \"string\",\n",
      "          \"description\": \"The location to get the temperature for, in the format \\\"City, State, Country\\\".\"\n",
      "        },\n",
      "        \"unit\": {\n",
      "          \"type\": \"string\",\n",
      "          \"description\": \"The unit to return the temperature in. Defaults to \\\"celsius\\\".\",\n",
      "          \"enum\": [\n",
      "            \"celsius\",\n",
      "            \"fahrenheit\"\n",
      "          ]\n",
      "        }\n",
      "      },\n",
      "      \"required\": [\n",
      "        \"location\"\n",
      "      ]\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "from transformers.utils import get_json_schema\n",
    "from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema, ToolResponse\n",
    "\n",
    "\n",
    "class WeatherTool(BaseTool):\n",
    "    def get_current_temperature(self, location: str, unit: str = \"celsius\"):\n",
    "        \"\"\"Get current temperature at a location.\n",
    "\n",
    "        Args:\n",
    "            location: The location to get the temperature for, in the format \"City, State, Country\".\n",
    "            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n",
    "\n",
    "        Returns:\n",
    "            the temperature, the location, and the unit in a dict\n",
    "        \"\"\"\n",
    "        return {\n",
    "            \"temperature\": 26.1,\n",
    "            \"location\": location,\n",
    "            \"unit\": unit,\n",
    "        }\n",
    "\n",
    "    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n",
    "        schema = get_json_schema(self.get_current_temperature)\n",
    "        return OpenAIFunctionToolSchema(**schema)\n",
    "\n",
    "    async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[ToolResponse, float, dict]:\n",
    "        try:\n",
    "            result = self.get_current_temperature(**parameters)\n",
    "            return ToolResponse(text=json.dumps(result)), 0, {}\n",
    "        except Exception as e:\n",
    "            return ToolResponse(text=str(e)), 0, {}\n",
    "\n",
    "\n",
    "weather_tool = WeatherTool(config={}, tool_schema=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's launch a standalone rollout server without hybrid engine (which is more heavy to start) to test the basic tool call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import compose, initialize_config_dir\n",
    "from verl.workers.rollout.replica import get_rollout_replica_class\n",
    "\n",
    "with initialize_config_dir(config_dir=verl_config_dir):\n",
    "    config = compose(\n",
    "        config_name=\"ppo_trainer\",\n",
    "        overrides=[\n",
    "            \"actor_rollout_ref.rollout.name=\" + rollout_name,\n",
    "            \"actor_rollout_ref.rollout.mode=async\",\n",
    "            \"actor_rollout_ref.rollout.tensor_model_parallel_size=1\",\n",
    "            \"actor_rollout_ref.model.path=\" + model_path,\n",
    "            \"actor_rollout_ref.rollout.response_length=4096\",\n",
    "            \"actor_rollout_ref.rollout.skip_tokenizer_init=False\",\n",
    "            \"+actor_rollout_ref.rollout.engine_kwargs.vllm.enable_auto_tool_choice=True\",\n",
    "            \"+actor_rollout_ref.rollout.engine_kwargs.vllm.tool_call_parser=hermes\",\n",
    "            \"+actor_rollout_ref.rollout.engine_kwargs.sglang.tool_call_parser=qwen25\",\n",
    "        ],\n",
    "    )\n",
    "\n",
    "rollout_server_class = get_rollout_replica_class(config.actor_rollout_ref.rollout.name)\n",
    "rollout_server = rollout_server_class(\n",
    "    replica_rank=0,\n",
    "    config=config.actor_rollout_ref.rollout,\n",
    "    model_config=config.actor_rollout_ref.model,\n",
    ")\n",
    "\n",
    "await rollout_server.init_standalone()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we can query LLM with openai client. Note that we need to pass the tool schema to server to guide LLM generating tool calls. We can see that the LLM correctly generates a tool call to get the temperature in Paris."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'content': \"Hey, what's the temperature in Paris right now?\", 'role': 'user'},\n",
      " {'role': 'assistant',\n",
      "  'tool_calls': [{'function': {'arguments': '{\"location\": \"Paris, France\"}',\n",
      "                               'name': 'get_current_temperature'},\n",
      "                  'id': 'call_b10bdde504a0411690e96b55',\n",
      "                  'index': -1,\n",
      "                  'type': 'function'}]}]\n"
     ]
    }
   ],
   "source": [
    "from openai import AsyncOpenAI\n",
    "\n",
    "client = AsyncOpenAI(\n",
    "    api_key=\"dummy\",\n",
    "    base_url=f\"http://{rollout_server._server_address}/v1\",\n",
    ")\n",
    "\n",
    "messages = [{\"role\": \"user\", \"content\": \"Hey, what's the temperature in Paris right now?\"}]\n",
    "completion = await client.chat.completions.create(\n",
    "    model=config.actor_rollout_ref.model.path,\n",
    "    messages=messages,\n",
    "    tools=[weather_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\n",
    "    extra_body={\n",
    "        \"chat_template_kwargs\": {\"enable_thinking\": False},\n",
    "    },\n",
    ")\n",
    "\n",
    "message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\n",
    "messages.append(message)\n",
    "pprint(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can execute the tool call with arguments generated by LLM and get the temperature in Paris."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "text='{\"temperature\": 26.1, \"location\": \"Paris, France\", \"unit\": \"celsius\"}' image=None video=None\n"
     ]
    }
   ],
   "source": [
    "args = json.loads(message[\"tool_calls\"][0][\"function\"][\"arguments\"])\n",
    "tool_response, _, _ = await weather_tool.execute(\"\", args)\n",
    "print(tool_response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we can add the tool response to chat history and query LLM again. With the tool response, LLM can generate a final response to the user."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'content': \"Hey, what's the temperature in Paris right now?\", 'role': 'user'},\n",
      " {'role': 'assistant',\n",
      "  'tool_calls': [{'function': {'arguments': '{\"location\": \"Paris, France\"}',\n",
      "                               'name': 'get_current_temperature'},\n",
      "                  'id': 'call_b10bdde504a0411690e96b55',\n",
      "                  'index': -1,\n",
      "                  'type': 'function'}]},\n",
      " {'content': '{\"temperature\": 26.1, \"location\": \"Paris, France\", \"unit\": '\n",
      "             '\"celsius\"}',\n",
      "  'role': 'tool'},\n",
      " {'content': 'The current temperature in Paris is 26.1°C.',\n",
      "  'role': 'assistant'}]\n"
     ]
    }
   ],
   "source": [
    "messages.append({\"role\": \"tool\", \"content\": tool_response.text})\n",
    "completion = await client.chat.completions.create(\n",
    "    model=config.actor_rollout_ref.model.path,\n",
    "    messages=messages,\n",
    "    tools=[weather_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\n",
    "    extra_body={\n",
    "        \"chat_template_kwargs\": {\"enable_thinking\": False},\n",
    "    },\n",
    ")\n",
    "\n",
    "message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\n",
    "messages.append(message)\n",
    "pprint(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Advanced tool call with code sandbox\n",
    "\n",
    "Now, let's see a more realistic example of tool call with code sandbox, which is widely used in real-world applications.\n",
    "\n",
    "### 2.1 Implement a naive code sandbox\n",
    "\n",
    "To execute python code snippet generated by LLM, we need a code sandbox environment. In this tutorial, we will implement a very naive code sandbox, which is\n",
    "a FastAPI http server with `/run_code` endpoint. The server works as follows:\n",
    "1. Receive a http request, write the python code snippet to a temp file.\n",
    "2. Spawn a subprocess to execute the code, and get stdout and stderr of the subprocess.\n",
    "3. Return the stdout and stderr of the subprocess as http response.\n",
    "\n",
    "> 🚨 **WARNING:** This naive code sandbox is for demonstration purpose only, do not use it in production. Please use docker/kata container for stronger isolation and security restriction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote(num_cpus=1)\n",
    "class Sandbox:\n",
    "    \"\"\"Sandbox to execute python code.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self.address = ray._private.services.get_node_ip_address()\n",
    "        self.port = self._get_free_port()\n",
    "        asyncio.create_task(self._start_fastapi_server())\n",
    "\n",
    "    async def code_execution(self, request: Request):\n",
    "        request_json = await request.json()\n",
    "        code = request_json[\"code\"]\n",
    "        # print(f\"execute code:\\n{code}\")\n",
    "\n",
    "        _, temp_file = tempfile.mkstemp(suffix=\".py\", prefix=\"temp_code\", dir=None, text=True)\n",
    "        with open(temp_file, \"w\") as f:\n",
    "            f.write(code)\n",
    "\n",
    "        try:\n",
    "            process = await asyncio.create_subprocess_exec(\n",
    "                sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE\n",
    "            )\n",
    "\n",
    "            stdout, stderr = await process.communicate()\n",
    "\n",
    "            response = {\n",
    "                \"status\": \"Success\" if process.returncode == 0 else \"Failed\",\n",
    "                \"run_result\": {\n",
    "                    \"status\": \"Finished\",\n",
    "                    \"stdout\": stdout.decode(),\n",
    "                    \"stderr\": stderr.decode(),\n",
    "                    \"return_code\": process.returncode,\n",
    "                },\n",
    "            }\n",
    "            return JSONResponse(content=response)\n",
    "        finally:\n",
    "            try:\n",
    "                os.unlink(temp_file)\n",
    "            except:  # noqa: E722\n",
    "                pass\n",
    "\n",
    "    def _get_free_port(self):\n",
    "        with socket.socket() as sock:\n",
    "            sock.bind((\"\", 0))\n",
    "            return sock.getsockname()[1]\n",
    "\n",
    "    async def _start_fastapi_server(self):\n",
    "        app = fastapi.FastAPI()\n",
    "        app.router.add_api_route(\"/run_code\", self.code_execution, methods=[\"POST\"])\n",
    "\n",
    "        config = uvicorn.Config(app, host=[\"::\", \"0.0.0.0\"], port=self.port, log_level=\"warning\")\n",
    "        server = uvicorn.Server(config)\n",
    "        await server.serve()\n",
    "\n",
    "    async def get_server_address(self) -> str:\n",
    "        \"\"\"Get FastAPI server address.\"\"\"\n",
    "        return f\"{self.address}:{self.port}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sandbox = Sandbox.remote()\n",
    "sandbox_address = ray.get(sandbox.get_server_address.remote())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Define sandbox tool\n",
    "\n",
    "As shown in the previous section, we also defined a tool for the code sandbox. In the `execute` method, we send the code snippet to code sandbox by http request and get the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"type\": \"function\",\n",
      "  \"function\": {\n",
      "    \"name\": \"code_interpreter\",\n",
      "    \"description\": \"Execute the code in the sandbox.\",\n",
      "    \"parameters\": {\n",
      "      \"type\": \"object\",\n",
      "      \"properties\": {\n",
      "        \"code\": {\n",
      "          \"type\": \"string\",\n",
      "          \"description\": \"The code to be executed.\"\n",
      "        }\n",
      "      },\n",
      "      \"required\": [\n",
      "        \"code\"\n",
      "      ]\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "import aiohttp\n",
    "\n",
    "\n",
    "class SandboxTool(BaseTool):\n",
    "    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n",
    "        super().__init__(config, tool_schema)\n",
    "        # Different model may use different code pattern, e.g. python, py, etc.\n",
    "        self.code_pattern = re.compile(r\"```py(.*?)```\", re.DOTALL)\n",
    "\n",
    "    async def code_interpreter(self, code: str) -> str:\n",
    "        \"\"\"Execute the code in the sandbox.\n",
    "\n",
    "        Args:\n",
    "            code: The code to be executed.\n",
    "\n",
    "        Returns:\n",
    "            str: The output of the code execution.\n",
    "        \"\"\"\n",
    "        async with aiohttp.ClientSession() as session:\n",
    "            async with session.post(\n",
    "                self.config.get(\"sandbox_fusion_url\"),\n",
    "                json={\"code\": code},\n",
    "            ) as resp:\n",
    "                resp.raise_for_status()\n",
    "                result = await resp.json()\n",
    "                stdout, stderr = result[\"run_result\"][\"stdout\"], result[\"run_result\"][\"stderr\"]\n",
    "                return stdout + stderr\n",
    "\n",
    "    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n",
    "        schema = get_json_schema(self.code_interpreter)\n",
    "        return OpenAIFunctionToolSchema(**schema)\n",
    "\n",
    "    async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[str, float, dict]:\n",
    "        code = parameters[\"code\"]\n",
    "        matches = self.code_pattern.findall(code)\n",
    "        if matches:\n",
    "            code = matches[0].strip()\n",
    "\n",
    "        # NOTE: Some script may not explicitly print result, we need to add a print statement to the end of the script.\n",
    "        # More better way is to SFT the model to make it print result by default, we skip SFT stage in this tutorial.\n",
    "        lines = code.split(\"\\n\")\n",
    "        for i, line in reversed(list(enumerate(lines))):\n",
    "            if line == \"\":\n",
    "                continue\n",
    "            if not lines[i].startswith(\"print\"):\n",
    "                lines[i] = f\"print({line})\"\n",
    "            break\n",
    "        code = \"\\n\".join(lines)\n",
    "\n",
    "        result = await self.code_interpreter(code)\n",
    "        return ToolResponse(text=result), 0.0, {}\n",
    "\n",
    "\n",
    "sandbox_tool = SandboxTool(config={\"sandbox_fusion_url\": f\"http://{sandbox_address}/run_code\"}, tool_schema=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's try to execute a valid code and check the response with stdout."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(ToolResponse(text='sqrt(3)\\n', image=None, video=None), 0.0, {})\n"
     ]
    }
   ],
   "source": [
    "code = \"\"\"```py\n",
    "import sympy\n",
    "\n",
    "print(sympy.sqrt(3))\n",
    "```\"\"\"\n",
    "\n",
    "print(await sandbox_tool.execute(instance_id=\"\", parameters={\"code\": code}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, let's try to execute an invalid code and check the response with stderr. The error message is important to inform LLM to fix code in next generation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(ToolResponse(text='Traceback (most recent call last):\\n  File \"/tmp/temp_code3e2f638_.py\", line 2, in <module>\\n    print(sympy.sqrt(3))\\n          ^^^^^\\nNameError: name \\'sympy\\' is not defined\\n', image=None, video=None), 0.0, {})\n"
     ]
    }
   ],
   "source": [
    "code_invalid = \"\"\"\n",
    "print(sympy.sqrt(3))\n",
    "\"\"\"\n",
    "\n",
    "print(await sandbox_tool.execute(instance_id=\"\", parameters={\"code\": code_invalid}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Test sandbox tool\n",
    "\n",
    "Now, we can test sandbox tool with real math problem. In this tutorial, we will use the [DigitalLearningGmbH/MATH-lighteval](https://huggingface.co/datasets/DigitalLearningGmbH/MATH-lighteval) dataset, which consists of problems from mathematics competitions, including the AMC 10, AMC 12, AIME, and more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ebd09c8816b140a59a879e5a5e218950",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"parquet\", data_files=test_file)[\"train\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For debug purpose, we can implement ReAct agent as a simple loop. For RL training, there are more subtle issue and corner case to deal with, we provide a built-in ReAct agent loop which will be discussed in next section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No tool calls, finish_reason: stop\n"
     ]
    }
   ],
   "source": [
    "messages = dataset[\"prompt\"][0]\n",
    "\n",
    "while True:\n",
    "    # 1. Chat with the model\n",
    "    completion = await client.chat.completions.create(\n",
    "        model=config.actor_rollout_ref.model.path,\n",
    "        messages=messages,\n",
    "        tools=[sandbox_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\n",
    "        extra_body={\n",
    "            \"chat_template_kwargs\": {\"enable_thinking\": False},\n",
    "        },\n",
    "    )\n",
    "\n",
    "    message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\n",
    "    messages.append(message)\n",
    "\n",
    "    # 2. Call tools\n",
    "    finish_reason = completion.choices[0].finish_reason\n",
    "    if finish_reason != \"tool_calls\":\n",
    "        print(f\"No tool calls, finish_reason: {finish_reason}\")\n",
    "        break\n",
    "\n",
    "    try:\n",
    "        tool_calls = completion.choices[0].message.tool_calls[0]\n",
    "        args = json.loads(tool_calls.function.arguments)\n",
    "        result, _, _ = await sandbox_tool.execute(\"\", args)\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "\n",
    "    # 3. Add tool response to messages\n",
    "    messages.append(\n",
    "        {\n",
    "            \"role\": \"tool\",\n",
    "            \"content\": result.text,\n",
    "        }\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'content': \"How many vertical asymptotes does the graph of $y=\\\\frac{2}{x^2+x-6}$ have? Let's think step by step and output the final answer within \\\\boxed{}.\",\n",
       "  'role': 'user'},\n",
       " {'content': \"To determine the number of vertical asymptotes for the function $ y = \\\\frac{2}{x^2 + x - 6} $, we need to find the values of $ x $ where the denominator equals zero, as these points are where the function is undefined and potentially where it has vertical asymptotes.\\n\\nThe denominator is $ x^2 + x - 6 $. To find the vertical asymptotes, we need to solve the equation:\\n\\n$$ x^2 + x - 6 = 0 $$\\n\\nThis is a quadratic equation, and we can solve it using the quadratic formula:\\n\\n$$ x = \\\\frac{-b \\\\pm \\\\sqrt{b^2 - 4ac}}{2a} $$\\n\\nwhere $ a = 1 $, $ b = 1 $, and $ c = -6 $. Let's solve this equation to find the values of $ x $ where the denominator is zero, which will give us the vertical asymptotes.\",\n",
       "  'role': 'assistant',\n",
       "  'tool_calls': [{'id': 'call_4d873672ff8445159e4e5e45',\n",
       "    'function': {'arguments': '{\"code\": \"from sympy import symbols, solve\\\\nx = symbols(\\'x\\')\\\\nroots = solve(x**2 + x - 6, x)\\\\nroots\"}',\n",
       "     'name': 'code_interpreter'},\n",
       "    'type': 'function',\n",
       "    'index': -1}]},\n",
       " {'role': 'tool', 'content': '[-3, 2]\\n'},\n",
       " {'content': 'The roots of the equation $ x^2 + x - 6 = 0 $ are $ x = -3 $ and $ x = 2 $. These are the values of $ x $ where the denominator is zero, which means the function $ y = \\\\frac{2}{x^2 + x - 6} $ is undefined at these points. \\n\\nSince the denominator is zero at these values, the function has vertical asymptotes at $ x = -3 $ and $ x = 2 $. Therefore, the graph of the function has two vertical asymptotes.\\n\\nThe final answer is $\\\\boxed{2}$.',\n",
       "  'role': 'assistant'}]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "messages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the ReAct agent properly query LLM, execute sandbox tool call, finally generate the answer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. End-to-end training with tool agent loop\n",
    "\n",
    "After tool has been implemented and tested, we can do end-to-end RL training to tune the model to properly use the tool. To simplify agentic RL training, verl provide [Agent Loop](https://verl.readthedocs.io/en/latest/advance/agent_loop.html) abstraction, which allow user to define custom agent loop:\n",
    "- Search agent\n",
    "- Math agent\n",
    "- SWE agent\n",
    "- GUI agent\n",
    "- ...\n",
    "\n",
    "For ease of use, verl provide two pre-defined agent loop:\n",
    "- SingleTurnAgentLoop: single-turn conversation without tool calling\n",
    "- ToolAgentLoop: multi-turn conversation with tool calling, interaction\n",
    "\n",
    "To use ToolAgentLoop, user only need to provide tools configuration in json/yaml file. In the configuration file, user should specify following fields for each tool:\n",
    "- class_name: fully qualified class name of the tool used to dynamically load the custom tool class\n",
    "- config: key-word arguments used to initialize the tool instance\n",
    "\n",
    "Let's dump our sandbox tool configuration to a json file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-10-16 23:07:16,868\tINFO worker.py:2004 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n"
     ]
    }
   ],
   "source": [
    "ray.shutdown()\n",
    "\n",
    "sandbox = Sandbox.remote()\n",
    "sandbox_address = ray.get(sandbox.get_server_address.remote())\n",
    "\n",
    "tool_config = {\n",
    "    \"tools\": [\n",
    "        {\n",
    "            \"class_name\": \"sandbox.SandboxTool\",\n",
    "            \"config\": {\n",
    "                \"type\": \"native\",\n",
    "                \"sandbox_fusion_url\": f\"http://{sandbox_address}/run_code\",\n",
    "            },\n",
    "        },\n",
    "    ],\n",
    "}\n",
    "\n",
    "tool_config_path = \"tool_config.json\"\n",
    "with open(tool_config_path, \"w\") as f:\n",
    "    json.dump(tool_config, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_174199/3963810189.py:3: UserWarning: \n",
      "The version_base parameter is not specified.\n",
      "Please specify a compatability version level, or None.\n",
      "Will assume defaults for version 1.1\n",
      "  with initialize_config_dir(config_dir=verl_config_dir):\n"
     ]
    }
   ],
   "source": [
    "from hydra import compose, initialize_config_dir\n",
    "\n",
    "with initialize_config_dir(config_dir=verl_config_dir):\n",
    "    config = compose(\n",
    "        config_name=\"ppo_trainer\",\n",
    "        overrides=[\n",
    "            \"algorithm.adv_estimator=grpo\",\n",
    "            \"data.train_files=\" + train_file,\n",
    "            \"data.val_files=\" + test_file,\n",
    "            \"data.return_raw_chat=True\",\n",
    "            \"data.train_batch_size=32\",\n",
    "            \"data.max_prompt_length=1024\",\n",
    "            \"data.max_response_length=1024\",\n",
    "            \"+data.apply_chat_template_kwargs.enable_thinking=False\",\n",
    "            # actor related\n",
    "            \"actor_rollout_ref.model.path=\" + model_path,\n",
    "            \"actor_rollout_ref.actor.ppo_mini_batch_size=8\",\n",
    "            \"actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8\",\n",
    "            \"actor_rollout_ref.actor.fsdp_config.param_offload=True\",\n",
    "            \"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\",\n",
    "            # rollout related\n",
    "            \"actor_rollout_ref.rollout.name=\" + rollout_name,\n",
    "            \"actor_rollout_ref.rollout.mode=async\",\n",
    "            \"actor_rollout_ref.rollout.tensor_model_parallel_size=1\",\n",
    "            \"actor_rollout_ref.rollout.n=8\",\n",
    "            \"actor_rollout_ref.rollout.multi_turn.tool_config_path=\" + tool_config_path,\n",
    "            \"actor_rollout_ref.rollout.agent.default_agent_loop=tool_agent\",\n",
    "            \"actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8\",\n",
    "            # trainer related\n",
    "            \"trainer.val_before_train=True\",\n",
    "            \"trainer.log_val_generations=10\",\n",
    "            \"trainer.n_gpus_per_node=8\",\n",
    "            \"trainer.test_freq=-1\",\n",
    "            \"trainer.total_training_steps=5\",\n",
    "            \"trainer.logger=['console','tensorboard', 'wandb']\",\n",
    "            \"trainer.project_name=verl\",\n",
    "            \"trainer.experiment_name=\" + os.path.basename(model_path),\n",
    "        ],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from verl.trainer.main_ppo import main\n",
    "\n",
    "main(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For demo purpose, we only train 5 steps, you can verify the training process by checking wandb metrics:\n",
    "- num_turns: min/max/mean chat conversation turns in each step.\n",
    "- critic rewards: min/max/mean critic rewards in each step.\n",
    "\n",
    "For more realistic agentic RL training, please refer to our recipe:\n",
    "- [retool](https://github.com/volcengine/verl/tree/main/recipe/retool): implementation of paper [ReTool: Reinforcement Learning for Strategic Tool Use in LLMs](https://arxiv.org/abs/2504.11536)\n",
    "- [collabllm](https://github.com/volcengine/verl/tree/main/recipe/collabllm): implementation of paper [CollabLLM: From Passive Responders to Active Collaborators](https://arxiv.org/pdf/2502.00640)\n",
    "- [deepeyes](https://github.com/volcengine/verl/tree/main/recipe/deepeyes): implementation of paper [DeepEyes: Incentivizing \"Thinking with Images\" via Reinforcement Learning](https://arxiv.org/abs/2505.14362)"
   ]
  }
 ],
 "metadata": {
  "fileId": "398ea641-8a51-4a0b-b64e-6b7cd6b72164",
  "filePath": "/opt/tiger/open_verl/examples/agent_loop_tutorial.ipynb",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
