{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b",
   "metadata": {},
   "source": [
    "# Agent Training with GRPO using TRL\n",
    "\n",
    "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
    "\n",
    "\n",
    "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a language model to act as an **agent**. One that learns to reason, interact with external tools, and improve through reinforcement.\n",
    "\n",
    "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project!  \n",
    "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview)  \n",
    "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
    "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n",
    "\n",
    "\n",
    "TRL supports training agents that can use external tools as part of their decision process.  \n",
    "In this notebook, the agent has access to the **BioGRID database**, which it can query using **read-only SQL commands** to retrieve biological interaction data. The model learns when and how to use tools based on rewards.\n",
    "\n",
    "We'll fine-tune a model using GRPO (Group Relative Policy Optimization) via TRL. The agent will:\n",
    "\n",
    "1. Generate tool call to query the database if needed.\n",
    "2. Receive the tool response and add it it to the context.\n",
    "3. Learn to improve its tool usage and general capabilities over time through reward signals.\n",
    "\n",
    "## Install dependencies\n",
    "\n",
    "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**.  \n",
    "We'll also install **trackio** (for logging and monitoring training runs), **vLLM** (for efficient generation), and **jmespath** (needed for the tools capabilities)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4812fbf-3f61-481e-9a64-95277eada9c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -Uq \"trl[vllm]\" git+https://github.com/huggingface/transformers.git trackio jmespath "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148",
   "metadata": {},
   "source": [
    "### Log in to Hugging Face\n",
    "\n",
    "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21756ac0-78b2-495d-8137-28dfa9faae6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "KVGklspLYlmz",
   "metadata": {},
   "source": [
    "## Create the database for the tool\n",
    "\n",
    "For this example, we will use the [BioGRID database](https://thebiogrid.org/), a curated resource containing **protein, genetic, and chemical interaction data**.  We've already compiled and uploaded it to the Hub at [qgallouedec/biogrid](https://huggingface.co/datasets/qgallouedec/biogrid). The dataset is loaded and converted into an sqlite database.\n",
    "\n",
    "> 💡 We remove spaces in the column names to easen the model work. In real-world deployments, you may keep your original column names and rely on the agent to reason about them. Here, we simplify the schema to make training smoother."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "rRzPMhfXBLkF",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sqlite3\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Load dataset\n",
    "biogrid_dataset = load_dataset(\"qgallouedec/biogrid\", split=\"train\")\n",
    "df = biogrid_dataset.to_pandas()\n",
    "\n",
    "# Normalize column names: remove spaces, replace with underscores\n",
    "df.columns = [c.replace(\" \", \"_\") for c in df.columns]\n",
    "\n",
    "# Save to SQLite\n",
    "conn = sqlite3.connect(\"biogrid.db\")\n",
    "try:\n",
    "    df.to_sql(\"interactions\", conn, if_exists=\"replace\", index=False)\n",
    "    print(f\"biogrid.db created. Rows stored: {len(df)}\")\n",
    "finally:\n",
    "    conn.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "pSSGvLbmZyC2",
   "metadata": {},
   "source": [
    "## Load the QA dataset\n",
    "\n",
    "The training objective is to fine-tune a model to answer gene-related questions. The model should learn to use the database query tool to retrieve factual information when needed.\n",
    "\n",
    "We'll define a formatting function for each sample, adding instructions about the database and how to call it. The model must answer with **yes** or **no**. Let's implement the `format_example` function.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "asrv7LbaD71C",
   "metadata": {},
   "outputs": [],
   "source": [
    "import textwrap\n",
    "\n",
    "def format_example(example):\n",
    "    question = example[\"question\"]\n",
    "    preamble = textwrap.dedent(\"\"\"\\\n",
    "    You have access to the BioGRID SQLite database.\n",
    "    Use SQL queries to retrieve only the information needed to answer the question.\n",
    "\n",
    "    Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,\n",
    "    and each entry can contain multiple gene names or synonyms separated by '|', for example:\n",
    "    'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'\n",
    "    So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.\n",
    "\n",
    "    If the database schema is unclear or you are unsure about column names:\n",
    "    - First inspect the schema with `PRAGMA table_info(interactions);`\n",
    "    - Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`\n",
    "\n",
    "    Otherwise, directly query the required data.\n",
    "\n",
    "    Final answer must be enclosed in stars, e.g. *Yes* or *No*.\n",
    "    Facts:\n",
    "    - The NCBI Taxonomy identifier for humans is taxid:9606.\n",
    "    \"\"\")\n",
    "    content = f\"{preamble}\\nQuestion: {question}\"\n",
    "    prompt = [{\"role\": \"user\", \"content\": content}]\n",
    "    return {\"prompt\": prompt}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "UMnHXYZla_EO",
   "metadata": {},
   "source": [
    "Now, let's load the database and call the previous function.  \n",
    "For simplicity, we will only use questions that start with **“Does the gene…”**.  \n",
    "In a real use case, the full dataset can be used.\n",
    "\n",
    "The QA dataset is available on the [Hub](https://huggingface.co/datasets/qgallouedec/biogrid_qa)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "jEs12KqwDnVl",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"qgallouedec/biogrid_qa\", split=\"train\")\n",
    "dataset = dataset.filter(\n",
    "    lambda example: example[\"question\"].startswith(\"Does the gene \")\n",
    ")  # keep only simple questions for example\n",
    "dataset = dataset.map(format_example, remove_columns=[\"question\"])\n",
    "\n",
    "train_dataset = dataset\n",
    "eval_dataset = None  # No eval by default, can be added if needed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "m4GRjbHycM5L",
   "metadata": {},
   "source": [
    "## Create tool for the agent\n",
    "\n",
    "The `query_biogrid` function is the tool the model will use to query the database and retrieve factual information.  \n",
    "Each tool must be a standard Python function with **type-hinted arguments and return types**, and a **Google-style docstring** describing its purpose, parameters, and return value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "nLMH7hahGTyO",
   "metadata": {},
   "outputs": [],
   "source": [
    "from contextlib import contextmanager\n",
    "import signal\n",
    "\n",
    "@contextmanager\n",
    "def timeout(seconds):\n",
    "    \"\"\"Context manager that raises TimeoutError if execution exceeds time limit.\"\"\"\n",
    "\n",
    "    def timeout_handler(signum, frame):\n",
    "        raise TimeoutError(f\"Operation timed out after {seconds} seconds\")\n",
    "\n",
    "    signal.signal(signal.SIGALRM, timeout_handler)\n",
    "    signal.alarm(seconds)\n",
    "    try:\n",
    "        yield\n",
    "    finally:\n",
    "        signal.alarm(0)\n",
    "\n",
    "def query_biogrid(sql_command: str) -> list[tuple]:\n",
    "    \"\"\"\n",
    "    Execute a read-only SQL command on the BioGRID database.\n",
    "\n",
    "    BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.\n",
    "\n",
    "    Args:\n",
    "        sql_command: The SQL command to execute.\n",
    "\n",
    "    Returns:\n",
    "        A list of tuples containing the query results.\n",
    "    \"\"\"\n",
    "    with timeout(5):\n",
    "        conn = sqlite3.connect(\"file:biogrid.db?mode=ro\", uri=True)\n",
    "        cursor = conn.cursor()\n",
    "        try:\n",
    "            cursor.execute(sql_command)\n",
    "            results = cursor.fetchall()\n",
    "        finally:\n",
    "            conn.close()\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "GiHtooTwci3B",
   "metadata": {},
   "source": [
    "## Define reward functions\n",
    "\n",
    "To guide the agent during training, we define a few simple reward functions:\n",
    "\n",
    "- **`query_reward`**: evaluates the model’s query strategy — penalizes more than two queries, penalizes generic database scans, and rewards use of `WHERE` and evidence supporting the final answer.\n",
    "- **`correctness_reward`**: rewards Yes/No predictions that match the expected answer.\n",
    "- **`structure_reward`**: rewards a proper assistant structure (tool call → response → optional explanation).\n",
    "\n",
    "Each function returns a list of floats used by the **GRPOTrainer** during optimization.  \n",
    "Combined, they encourage effective tool use and factual answers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sXyqC6cJGe3L",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def query_reward(completions, answer, **kwargs):\n",
    "    \"\"\"\n",
    "    Reward query strategy:\n",
    "    - Penalize more than 2 queries\n",
    "    - Penalize generic queries (LIMIT 1 / PRAGMA)\n",
    "    - Reward usage of WHERE\n",
    "    - Reward evidence supporting the final answer\n",
    "    \"\"\"\n",
    "    rewards = []\n",
    "\n",
    "    for completion, ans in zip(completions, answer, strict=False):\n",
    "        reward = 0.0\n",
    "        sql_queries = []\n",
    "        tool_results = []\n",
    "\n",
    "        # collect all SQL queries and tool results\n",
    "        for turn in completion:\n",
    "            if turn.get(\"tool_calls\"):\n",
    "                for call in turn[\"tool_calls\"]:\n",
    "                    sql = call[\"function\"][\"arguments\"].get(\"sql_command\", \"\").lower()\n",
    "                    sql_queries.append(sql)\n",
    "            if turn.get(\"role\") == \"tool\" and turn.get(\"content\"):\n",
    "                tool_results.append(turn[\"content\"])\n",
    "\n",
    "        # --- penalize too many queries ---\n",
    "        if len(sql_queries) > 3:\n",
    "            reward -= 1.5\n",
    "\n",
    "        # --- check query quality ---\n",
    "        where_count = 0\n",
    "        for q in sql_queries:\n",
    "            if \"limit 1\" in q:\n",
    "                reward -= 1.0\n",
    "            if \" where \" not in q:\n",
    "                reward -= 0.5\n",
    "            else:\n",
    "                where_count += 1\n",
    "        reward += min(where_count, 3) * 0.4  # small bonus for WHERE usage\n",
    "\n",
    "        # --- evidence check: do queries support the answer? ---\n",
    "        combined_results = []\n",
    "        error_detected = False\n",
    "\n",
    "        for res in tool_results:\n",
    "            if isinstance(res, dict) and \"error\" in res:\n",
    "                error_detected = True\n",
    "            elif isinstance(res, list):\n",
    "                combined_results.extend(res)\n",
    "\n",
    "        # if error detected, penalize heavily\n",
    "        if error_detected:\n",
    "            reward -= 2.0\n",
    "        elif len(sql_queries) == 0:\n",
    "            reward -= 1.5\n",
    "        else:\n",
    "            has_hits = len(combined_results) > 0\n",
    "            correct_answer = ans.lower()\n",
    "            if (has_hits and correct_answer == \"yes\") or (not has_hits and correct_answer == \"no\"):\n",
    "                reward += 2.0\n",
    "            else:\n",
    "                reward -= 1.5\n",
    "\n",
    "        rewards.append(reward)\n",
    "\n",
    "    return rewards\n",
    "\n",
    "\n",
    "def correctness_reward(completions, answer, **kwargs):\n",
    "    \"\"\"\n",
    "    Reward Yes/No correctness.\n",
    "    Model must provide final answer enclosed in stars — *yes* or *no*.\n",
    "    Does not reward informal yes/no buried in text.\n",
    "    \"\"\"\n",
    "    rewards = []\n",
    "    for completion, ans in zip(completions, answer, strict=False):\n",
    "        raw = completion[-1][\"content\"].lower()\n",
    "\n",
    "        # detect form *yes* or *no*\n",
    "        match = re.search(r\"\\*(yes|no)\\*\", raw)\n",
    "        guess = match.group(1) if match else None\n",
    "\n",
    "        reward = 0.0\n",
    "\n",
    "        if guess is None:\n",
    "            reward -= 0.5  # invalid format\n",
    "        elif guess == ans.lower():\n",
    "            reward += 0.6  # correct under required format\n",
    "        else:\n",
    "            reward -= 1.0  # wrong answer\n",
    "\n",
    "        rewards.append(reward)\n",
    "\n",
    "    return rewards\n",
    "\n",
    "\n",
    "def structure_reward(completions, **kwargs):\n",
    "    \"\"\"\n",
    "    Reward proper assistant structure.\n",
    "    Encourages a logical sequence: tool call + response + optional extra content.\n",
    "    \"\"\"\n",
    "    rewards = []\n",
    "\n",
    "    for completion in completions:\n",
    "        has_call = False\n",
    "        has_response = False\n",
    "        has_other = False\n",
    "\n",
    "        for turn in completion:\n",
    "            role = turn.get(\"role\")\n",
    "            if role == \"assistant\" and turn.get(\"tool_calls\"):\n",
    "                has_call = True\n",
    "            elif role == \"tool\":\n",
    "                has_response = True\n",
    "            else:\n",
    "                content = turn.get(\"content\")\n",
    "                if content and content.strip() not in [\"\", \"<think>\"]:\n",
    "                    has_other = True\n",
    "\n",
    "        # Reward sequences\n",
    "        if has_call and has_response:\n",
    "            if has_other:\n",
    "                reward = 0.1\n",
    "            else:\n",
    "                reward = 0.05  # still positive even without extra text\n",
    "        elif has_call and not has_response:\n",
    "            reward = -0.15\n",
    "        else:\n",
    "            reward = 0.0  # neutral if no call\n",
    "\n",
    "        rewards.append(reward)\n",
    "\n",
    "    return rewards\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "zcgkrKtTb4T9",
   "metadata": {},
   "source": [
    "## Set GRPO Config\n",
    "\n",
    "Next, we define the **GRPOConfig**, which controls the main training parameters.  \n",
    "This configuration specifies how the model interacts with **vLLM**, manages memory, and logs results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "t4ifJsNLElIN",
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import GRPOConfig\n",
    "\n",
    "output_dir = \"grpo_biogrid_qwen_3g-1.7b\"\n",
    "\n",
    "grpo_config = GRPOConfig(\n",
    "    # Training schedule / optimization\n",
    "    max_steps=400,                                              # Max number of training steps\n",
    "    chat_template_kwargs = {\"enable_thinking\": False},          # Disable thinking to reduce token generation\n",
    "\n",
    "    # GRPO configuration\n",
    "    max_completion_length = 1024,                               # Maximum tokens generated per model response\n",
    "\n",
    "    # vLLM configuration\n",
    "    use_vllm = True,                                            # Enable vLLM for faster inference during rollouts\n",
    "    vllm_mode = \"colocate\",                                     # Run vLLM in colocate mode (same process as training)\n",
    "    vllm_enable_sleep_mode=False,\n",
    "\n",
    "    # Logging / reporting\n",
    "    output_dir = output_dir,                                    # Directory for checkpoints and logs\n",
    "    report_to=\"trackio\",                                        # Experiment tracking tool (integrates with HF Spaces)\n",
    "    trackio_space_id = output_dir,                              # HF Space where experiment tracking will be saved\n",
    "    save_steps = 10,                                            # Interval for saving checkpoints\n",
    "    log_completions = True,\n",
    "\n",
    "    # Memory optimization\n",
    "    gradient_checkpointing = True,                              # Enable activation recomputation to save memory\n",
    "    gradient_checkpointing_kwargs = {\"use_reentrant\": False},   # Use non-reentrant checkpointing\n",
    "\n",
    "    # Hub integration\n",
    "    push_to_hub = True,                                         # Set True to automatically push model to Hugging Face Hub\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34I-Q2MJuf42",
   "metadata": {},
   "source": [
    "## Create `GRPOTrainer` and Start Training\n",
    "\n",
    "Next, we initialize the **`GRPOTrainer`**, which handles the full reinforcement learning loop.\n",
    "\n",
    "It receives the model name, reward functions, tool(s), and dataset defined earlier.  \n",
    "\n",
    "Finally, we call `trainer.train()` to begin fine-tuning, allowing the model to learn how to query the database effectively through iterative feedback."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "IysntAUOFvRn",
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import GRPOTrainer\n",
    "\n",
    "model_name=\"Qwen/Qwen3-1.7B\"\n",
    "\n",
    "trainer = GRPOTrainer(\n",
    "    model=model_name,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    tools=[query_biogrid],\n",
    "    reward_funcs=[correctness_reward, structure_reward, query_reward],\n",
    "    args=grpo_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "r_qJ5UwLuzCG",
   "metadata": {},
   "source": [
    "Show memory stats before training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "DusT8JUaGmA6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "gpu_stats = torch.cuda.get_device_properties(0)\n",
    "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
    "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
    "\n",
    "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
    "print(f\"{start_gpu_memory} GB of memory reserved.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "OTPkiz3fu0lp",
   "metadata": {},
   "source": [
    "And train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "NwI3buPOFMFk",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_stats = trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ITnLBLcTu2-p",
   "metadata": {},
   "source": [
    "Show memory stats after training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ftek6m4-GncK",
   "metadata": {},
   "outputs": [],
   "source": [
    "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
    "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
    "used_percentage = round(used_memory / max_memory * 100, 3)\n",
    "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
    "\n",
    "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
    "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
    "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
    "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
    "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
    "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "O6LAwznKu7mc",
   "metadata": {},
   "source": [
    "Let's save the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "idVgnNS1MWPr",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(output_dir)\n",
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707318cb",
   "metadata": {},
   "source": [
    "## Load the fine-tuned model and run inference using `smolagents`\n",
    "\n",
    "After fine-tuning the model with **GRPO (TRL)** for tool calling, we can test it at inference time using **`smolagents`**, a lightweight library for running multi-step agents.\n",
    "\n",
    "`smolagents` handles the agent loop for us:\n",
    "- Detecting tool calls generated by the model\n",
    "- Executing the corresponding tools (e.g. database queries)\n",
    "- Feeding the results back to the model until a final answer is produced\n",
    "\n",
    "> **Note**  \n",
    "> Using an agent framework is optional. The fine-tuned model can also be used directly with `transformers` by manually controlling the inference loop and executing the tools outside the model.\n",
    "> Agent frameworks are especially useful when the number of steps or tool calls is not fixed.\n",
    "\n",
    "We start by installing the required package:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aab7fd5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/huggingface/smolagents.git"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24453572",
   "metadata": {},
   "source": [
    "We will use the `CodeAgent` class from `smolagents` to instantiate our agent.  \n",
    "First, we need to define the tool the agent can use. This is done using the `@tool` decorator.\n",
    "\n",
    "As shown below, the tool definition is **exactly the same** as the one used during GRPO training with TRL. This consistency is important: the model was trained to emit calls following this schema, and at inference time the agent simply executes the corresponding Python function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adcbbafa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from smolagents import tool\n",
    "\n",
    "@tool\n",
    "def query_biogrid(sql_command: str) -> list[tuple]:\n",
    "    \"\"\"\n",
    "    Execute a read-only SQL query on the BioGRID database.\n",
    "\n",
    "    BioGRID is a curated biological database that compiles protein, genetic,\n",
    "    and chemical interactions from multiple organisms.\n",
    "\n",
    "    Args:\n",
    "        sql_command: A read-only SQL query to execute.\n",
    "\n",
    "    Returns:\n",
    "        A list of tuples containing the query results.\n",
    "    \"\"\"\n",
    "    with timeout(5):\n",
    "        conn = sqlite3.connect(\n",
    "            \"file:biogrid.db?mode=ro\",\n",
    "            uri=True,\n",
    "        )\n",
    "        cursor = conn.cursor()\n",
    "        try:\n",
    "            cursor.execute(sql_command)\n",
    "            results = cursor.fetchall()\n",
    "        finally:\n",
    "            conn.close()\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59721ad2",
   "metadata": {},
   "source": [
    "Now we can instantiate the agent using our fine-tuned model and the database tool defined above.\n",
    "We wrap the model with `TransformersModel` and pass both the model and the tool when creating the `CodeAgent`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ed8d00",
   "metadata": {},
   "outputs": [],
   "source": [
    "from smolagents import TransformersModel, CodeAgent\n",
    "\n",
    "model = TransformersModel(model_id=\"sergiopaniego/grpo_biogrid_qwen_3g-1.7b\", apply_chat_template_kwargs={\"enable_thinking\": False})\n",
    "\n",
    "# Create an agent with query_biogrid as tool\n",
    "agent = CodeAgent(tools=[query_biogrid], model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57ba9462",
   "metadata": {},
   "source": [
    "Finally, we run the agent by passing the full prompt (including the instruction preamble and the question), exactly as it was used during training. This ensures the agent operates under the same context and assumptions learned with GRPO, allowing it to correctly decide when to query the database and how to format the final answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a3cdf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = agent.run(train_dataset[0]['prompt'][0]['content'])\n",
    "print(result)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
