{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lSR2nwdJg962"
      },
      "source": [
        "# Fine-Tune FunctionGemma using Hugging Face TRL and OpenEnv\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb)\n",
        "\n",
        "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
        "\n",
        "This guide describes the process of fine-tuning [FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) by Google DeepMind in the [BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) environment provided by OpenEnv, using Hugging Face TRL. The steps covered include:\n",
        "\n",
        "* What is GRPO and OpenEnv\n",
        "* Setup dependencies for training\n",
        "* Initialize the OpenEnv's BrowserGym environment\n",
        "* Create rollout function with helpers\n",
        "* Define the reward functions\n",
        "* Load the custom dataset\n",
        "* Fine tune using TRL and the GRPOTrainer\n",
        "* Load the fine-tuned model and run inference\n",
        "\n",
        "> Note: The guide is designed to run on Google Colaboratory with access to an NVIDIA A100 GPU (40GB) using FunctionGemma. The workflow can be adapted to other GPU configurations, models, or environments."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "duXYuR6Cu_na"
      },
      "source": [
        "## What is GRPO and OpenEnv\n",
        "\n",
        "Group Relative Policy Optimization ([GRPO](https://huggingface.co/papers/2402.03300)) is a post-training method widely used for efficiently fine-tuning large language models. GRPO leverages reward functions to guide learning, enabling models to optimize task-specific behaviors without retraining the entire network.\n",
        "\n",
        "[OpenEnv](https://meta-pytorch.org/OpenEnv) provides a standard interface for interacting with agentic execution environments using simple Gymnasium-style APIs, such as `step()`, `reset()`, and `state()`. These APIs facilitate reinforcement learning training loops by allowing models to interact with environments in a structured manner. OpenEnv also offers tools for environment creators to build isolated, secure, and deployable environments that can be shared via common protocols like HTTP or packaged in Docker.\n",
        "\n",
        "The combination of GRPO and OpenEnv enables efficient fine-tuning of models in controlled, interactive tasks while minimizing resource requirements."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cpSAQkzKmv50"
      },
      "source": [
        "## Setup dependencies for training\n",
        "\n",
        "Install the required libraries, including Hugging Face TRL for fine-tuning and OpenEnv for reinforcement learning environments."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c-2drnj5BP56"
      },
      "outputs": [],
      "source": [
        "!pip install -Uq trl[vllm] git+https://github.com/meta-pytorch/OpenEnv.git openenv_core liger-kernel trackio"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Inxeq6ZGpRno"
      },
      "source": [
        "A valid Hugging Face token is required to save the fine-tuned model. In Google Colab, the token can be securely accessed through Colab secrets. Otherwise, it can be provided directly in the login method. Ensure the token has write permissions to allow uploading the model to the Hugging Face Hub during training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C4q5UVu3BP57"
      },
      "outputs": [],
      "source": [
        "from google.colab import userdata\n",
        "from huggingface_hub import login\n",
        "\n",
        "# Login into Hugging Face Hub\n",
        "hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab\n",
        "login(hf_token)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O3kr38TGm_hb"
      },
      "source": [
        "## Initialize the OpenEnv's BrowserGym environment\n",
        "\n",
        "External environments can guide the fine-tuning of LLMs for function calling by providing interactive feedback that enhances performance on task-specific behaviors.\n",
        "\n",
        "[BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) is a unified framework for web-based agent tasks, offering multiple benchmarks through a Gymnasium-compatible API. It enables training on simple synthetic tasks with [MiniWoB++](https://github.com/Farama-Foundation/miniwob-plusplus) and evaluation on more complex, realistic tasks with [WebArena](https://github.com/web-arena-x/webarena), [VisualWebArena](https://github.com/web-arena-x/visualwebarena), or [WorkArena](https://github.com/ServiceNow/WorkArena). This setup supports iterative training and assessment of web agents without requiring extensive infrastructure.\n",
        "\n",
        "BrowserGym supports both LLM and VLM training by providing visual information, including screenshots and DOM data, which can be utilized depending on the model type. This guide focuses on a simple web-based task called *\"click-test\"*, which is part of the MiniWoB++ benchmark of synthetic web tasks. Environments can be run locally, in Docker containers, or accessed remotely via the Hugging Face Hub. For this example, the remote environment [burtenshaw/browsergym-v2](https://huggingface.co/spaces/burtenshaw/browsergym-v2) will be used.\n",
        "\n",
        "> Note: Hosted environments on the Hub currently have limited concurrency. For higher reliability or parallel runs, duplicating the Space to your own account is strongly recommended."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "clDs-WQlBP57"
      },
      "outputs": [],
      "source": [
        "from envs.browsergym_env import BrowserGymEnv\n",
        "space_url = \"https://burtenshaw-browsergym-v2.hf.space\"\n",
        "\n",
        "client = BrowserGymEnv(base_url=space_url)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EqfDavDQnD_5"
      },
      "source": [
        "## Create rollout function with helpers\n",
        "\n",
        "The rollout function defines how the agent interacts with the environment during GRPO training. It generates model outputs, collects feedback in the form of rewards, and returns the information required for optimization.\n",
        "\n",
        "In this setup:\n",
        "- The function is invoked automatically by the GRPOTrainer (introduced later), which orchestrates the training loop and handles policy updates.\n",
        "- It uses the trainer's `generate_rollout_completions()` method for efficient output generation. This leverages vLLM, a high-performance inference engine for large language models, and is integrated within TRL to streamline rollout generation and reward collection during fine-tuning.\n",
        "- Each rollout represents a complete interaction loop, where the model acts, receives feedback from the environment, and updates based on reward signals.\n",
        "\n",
        "Rewards capture various aspects of the agent's performance. Helper functions, such as `rollout_once`, manage individual episodes, keeping the main `rollout_func` clean, modular, and reusable.\n",
        "\n",
        "This modular structure allows GRPO to efficiently sample, evaluate, and refine the model's behavior through reinforcement learning.\n",
        "\n",
        "Before executing rollouts, a `system prompt` is defined to instruct the model on how to interact with the environment. This prompt specifies the available BrowserGym actions (such as `click`, `fill`, `send_keys`, and `scroll`), describes the page structure, and enforces that the model responds with exactly one action per step. It ensures consistent and structured interactions, guiding the model to complete tasks effectively without providing extra explanations or multiple actions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ItCXS6H0BP58"
      },
      "outputs": [],
      "source": [
        "# @title System prompt (click to expand)\n",
        "SYSTEM_PROMPT = \"\"\"You control a web browser through BrowserGym actions.\n",
        "You must complete the given web task by interacting with the page.\n",
        "\n",
        "Available actions:\n",
        "- noop() - Do nothing\n",
        "- click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
        "- fill(bid, text) - Fill input field with text\n",
        "- send_keys(text) - Send keyboard input\n",
        "- scroll(direction) - Scroll up/down\n",
        "\n",
        "The page structure shows elements as: [bid] element_type 'element_text'\n",
        "For example: [13] button 'Click Me!' means bid='13'\n",
        "\n",
        "Reply with exactly ONE action on a single line, e.g.:\n",
        "click('13')\n",
        "fill('42', 'hello world')\n",
        "noop()\n",
        "\n",
        "Do not include explanations or multiple actions.\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vi1rFey39GUl"
      },
      "source": [
        "The `rollout_func` orchestrates the interaction between the model and the remote BrowserGym environment. For each prompt in the batch, it executes a complete episode using the `rollout_once` function, collecting model outputs and rewards for GRPO optimization.\n",
        "\n",
        "The parameter `max_steps` defines the maximum number of steps the model can take within a single episode. This limits the length of the interaction loop, ensuring that episodes terminate even if the task is not completed, and helps maintain efficient training.\n",
        "\n",
        "During each episode, the function tracks prompt and completion IDs, log probabilities, and both step-wise and final rewards, returning them in a structured format for the trainer to perform policy updates."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CgHd5CFBBP58"
      },
      "outputs": [],
      "source": [
        "from trl import GRPOTrainer\n",
        "\n",
        "max_steps=10\n",
        "\n",
        "def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n",
        "    episode_prompt_ids: list[list[int]] = []\n",
        "    episode_completion_ids: list[list[int]] = []\n",
        "    episode_logprobs: list[list[float]] = []\n",
        "    completion_rewards: list[float] = []\n",
        "\n",
        "    print(f\"\\n[DEBUG] rollout_func called with {len(prompts)} prompts (LLM mode, text-only)\")\n",
        "\n",
        "    for i, prompt_text in enumerate(prompts):\n",
        "        print(f\"[DEBUG] Processing prompt {i + 1}/{len(prompts)}\")\n",
        "        episode = rollout_once(\n",
        "            trainer=trainer,\n",
        "            env=client,\n",
        "            tokenizer=trainer.processing_class,\n",
        "            dataset_prompt=prompt_text,\n",
        "            max_steps=max_steps,\n",
        "        )\n",
        "        episode_prompt_ids.append(episode[\"prompt_ids\"])\n",
        "        episode_completion_ids.append(episode[\"completion_ids\"])\n",
        "        episode_logprobs.append(episode[\"logprobs\"])\n",
        "        completion_rewards.append(episode[\"completion_reward\"])\n",
        "\n",
        "    return {\n",
        "        \"prompt_ids\": episode_prompt_ids,\n",
        "        \"completion_ids\": episode_completion_ids,\n",
        "        \"logprobs\": episode_logprobs,\n",
        "        \"completion_reward\": completion_rewards,\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ioUHdIxr9ZQO"
      },
      "source": [
        "### Define `rollout_once`\n",
        "\n",
        "The `rollout_once` function runs one complete interaction loop between the model and the BrowserGym environment using the trainer's generation method.  \n",
        "It executes a single episode, from generating an action to receiving feedback and computing rewards.\n",
        "\n",
        "Here's the step-by-step breakdown:\n",
        "\n",
        "1. Environment reset: Start a new BrowserGym session and initialize the observation.\n",
        "2. Prompt construction: Combine the system prompt, environment observation (text-only via the accessibility tree), and any relevant errors or state information to form the model input.\n",
        "3. Generation: Use `trl.experimental.openenv.generate_rollout_completions()` to produce the model's action efficiently with vLLM.\n",
        "4. Action parsing and execution: Interpret the model's output and execute the corresponding BrowserGym action (e.g., `click`, `fill`, `scroll`).\n",
        "5. Reward calculation: Track step-wise rewards provided by the environment and compute completion rewards based on task success or failure.\n",
        "6. Return structured rollout data: Includes prompt/completion IDs, log probabilities, step rewards, and the final reward for the episode.\n",
        "\n",
        "This modular design allows each episode to be processed independently while providing rich feedback for the GRPO training loop, supporting both task completion and intermediate reward shaping."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y8Ml47SYBP58"
      },
      "outputs": [],
      "source": [
        "from trl.experimental.openenv import generate_rollout_completions\n",
        "from envs.browsergym_env import BrowserGymAction\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "def rollout_once(\n",
        "    trainer: GRPOTrainer,\n",
        "    env: BrowserGymEnv,\n",
        "    tokenizer: AutoTokenizer,\n",
        "    dataset_prompt: str,\n",
        "    max_steps: int,\n",
        ") -> dict[str, list]:\n",
        "    \"\"\"Run one episode and collect training data (text-only, no screenshots).\"\"\"\n",
        "    result = env.reset()\n",
        "    observation = result.observation\n",
        "\n",
        "    prompt_ids: list[int] = []\n",
        "    completion_ids: list[int] = []\n",
        "    logprobs: list[float] = []\n",
        "    step_rewards: list[float] = []\n",
        "    completion_rewards: list[float] = []\n",
        "\n",
        "    for step_num in range(max_steps):\n",
        "        if result.done:\n",
        "            break\n",
        "\n",
        "        # Create prompt from observation (text-only using accessibility tree)\n",
        "        goal = observation.goal or dataset_prompt\n",
        "        axtree = observation.axtree_txt or \"\"\n",
        "        error = observation.error if observation.last_action_error else \"\"\n",
        "\n",
        "        user_prompt = make_user_prompt(goal, step_num, axtree, error)\n",
        "        messages = [\n",
        "            {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
        "            {\"role\": \"user\", \"content\": user_prompt},\n",
        "        ]\n",
        "        prompt_text = tokenizer.apply_chat_template(\n",
        "            messages,\n",
        "            add_generation_prompt=True,\n",
        "            tokenize=False,\n",
        "        )\n",
        "\n",
        "        # Generate action with vLLM\n",
        "        rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]\n",
        "        prompt_ids.extend(rollout_outputs[\"prompt_ids\"])\n",
        "        completion_ids.extend(rollout_outputs[\"completion_ids\"])\n",
        "        logprobs.extend(rollout_outputs[\"logprobs\"])\n",
        "\n",
        "        completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n",
        "            rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n",
        "        )\n",
        "\n",
        "        # Parse and execute action\n",
        "        action_str = parse_action(completion_text)\n",
        "\n",
        "        print(f\"Step {step_num + 1}: {action_str}\")\n",
        "\n",
        "        # Take action in environment\n",
        "        result = env.step(BrowserGymAction(action_str=action_str))\n",
        "        observation = result.observation\n",
        "\n",
        "        # Track rewards\n",
        "        step_reward = float(result.reward or 0.0)\n",
        "        step_rewards.append(step_reward)\n",
        "\n",
        "        # Reward shaping: success is most important\n",
        "        if result.done and step_reward > 0:\n",
        "            completion_rewards.append(1.0)  # Task completed successfully\n",
        "        elif result.done and step_reward == 0:\n",
        "            completion_rewards.append(0.0)  # Task failed\n",
        "        else:\n",
        "            completion_rewards.append(step_reward)  # Intermediate reward\n",
        "\n",
        "    # Final reward is based on task completion\n",
        "    final_reward = completion_rewards[-1] if completion_rewards else 0.0\n",
        "\n",
        "    return {\n",
        "        \"prompt_ids\": prompt_ids,\n",
        "        \"completion_ids\": completion_ids,\n",
        "        \"logprobs\": logprobs,\n",
        "        \"step_rewards\": step_rewards,\n",
        "        \"completion_reward\": final_reward,\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MDJKMQ__8qzj"
      },
      "source": [
        "### Helper functions\n",
        "\n",
        "Supporting utilities used in `rollout_once`:\n",
        "\n",
        "- `make_user_prompt`: builds the user prompt combining the base text and previous game messages.\n",
        "- `parse_action`: parses BrowserGym action from model response"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GG4ba41PBP58"
      },
      "outputs": [],
      "source": [
        "# @title Helpers (click to expand)\n",
        "def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = \"\") -> str:\n",
        "    \"\"\"Create user prompt from observation.\"\"\"\n",
        "    prompt_parts = [f\"Step {step_num + 1}\"]\n",
        "\n",
        "    if goal:\n",
        "        prompt_parts.append(f\"Goal: {goal}\")\n",
        "\n",
        "    if error:\n",
        "        prompt_parts.append(f\"Previous action error: {error}\")\n",
        "\n",
        "    # Include accessibility tree (truncated for context)\n",
        "    if axtree:\n",
        "        max_len = 2000\n",
        "        axtree_truncated = axtree[:max_len] + \"...\" if len(axtree) > max_len else axtree\n",
        "        prompt_parts.append(f\"Page structure:\\n{axtree_truncated}\")\n",
        "\n",
        "    prompt_parts.append(\"What action do you take?\")\n",
        "\n",
        "    return \"\\n\\n\".join(prompt_parts)\n",
        "\n",
        "\n",
        "def parse_action(response_text: str) -> str:\n",
        "    \"\"\"Parse BrowserGym action from model response.\"\"\"\n",
        "    # Extract first line that looks like an action\n",
        "    for line in response_text.strip().split(\"\\n\"):\n",
        "        line = line.strip()\n",
        "        if \"(\" in line and \")\" in line:\n",
        "            return line\n",
        "\n",
        "    # Fallback to noop if no valid action found\n",
        "    return \"noop()\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oek3JhcWnKhw"
      },
      "source": [
        "## Define the reward functions\n",
        "\n",
        "Reward functions quantify the model's performance in the environment and guide the GRPO optimization process.\n",
        "\n",
        "In this setup, the `reward_completion` function assigns rewards based on task completion. It extracts the final reward for each episode, which indicates whether the agent successfully completed the task. If no reward information is available, it defaults to zero.\n",
        "\n",
        "This modular approach allows additional reward functions to be added easily, enabling more granular feedback such as intermediate progress, efficiency, or correctness of actions, depending on the task requirements."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WxkXaz5aBP59"
      },
      "outputs": [],
      "source": [
        "def reward_completion(completions: list[str], **kwargs) -> list[float]:\n",
        "    \"\"\"Reward for task completion.\"\"\"\n",
        "    rewards = kwargs.get(\"completion_reward\") if kwargs else None\n",
        "    if rewards is None:\n",
        "        return [0.0 for _ in completions]\n",
        "    return [float(r) for r in rewards]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "66ZsrLplm07U"
      },
      "source": [
        "## Load the custom dataset\n",
        "\n",
        "The dataset is constructed with repeated prompts to control the total number of training episodes.\n",
        "\n",
        "Each entry in the dataset triggers a single rollout episode during training. The `dataset_prompt` provides the initial instruction to the model at the start of each episode, ensuring consistent guidance for task execution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UX6jUjxaBP59"
      },
      "outputs": [],
      "source": [
        "from datasets import Dataset\n",
        "\n",
        "dataset_prompt = \"Complete the web task successfully.\"\n",
        "dataset_size = 1000\n",
        "\n",
        "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-mvka-96m3I7"
      },
      "source": [
        "## Fine-tune using TRL and the GRPOTrainer\n",
        "\n",
        "The next step is to define the GRPOConfig, which sets all key training parameters.\n",
        "\n",
        "This configuration determines how the model interacts with vLLM, handles memory and computation, and records training metrics and logs for monitoring the fine-tuning process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TZ34a1h-BP59"
      },
      "outputs": [],
      "source": [
        "from trl import GRPOConfig\n",
        "output_dir = \"browsergym-grpo-functiongemma-270m-it\"\n",
        "\n",
        "grpo_config = GRPOConfig(\n",
        "    # num_train_epochs=1,                                     # Number of times to iterate over the full dataset (use for full training runs)\n",
        "    max_steps=100,                                            # Number of dataset passes (for shorter runs/testing). For full trainings, use `num_train_epochs` instead\n",
        "    learning_rate=5e-6,                                       # Learning rate for the optimizer\n",
        "    warmup_steps=10,                                          # Number of steps to linearly increase learning rate at the start of training\n",
        "\n",
        "    per_device_train_batch_size=1,                            # Number of samples per device per step\n",
        "    num_generations=4,                                        # Number of completions to generate per prompt\n",
        "    generation_batch_size=4,                                  # Batch size used during generation (must be divisible by num_generations)\n",
        "    max_completion_length=32,                                 # Maximum length of generated completions\n",
        "\n",
        "    use_vllm=True,                                            # Use vLLM engine for fast inference\n",
        "    vllm_mode=\"colocate\",                                     # vLLM mode: \"colocate\" runs generation on the same GPU as training\n",
        "    vllm_gpu_memory_utilization=0.1,                          # Fraction of GPU memory allocated to vLLM\n",
        "\n",
        "    output_dir=str(output_dir),                               # Directory where checkpoints, logs, and outputs will be saved\n",
        "    logging_steps=1,                                          # Log metrics every N steps\n",
        "    report_to=\"trackio\",                                      # Logging/reporting platform (e.g., \"trackio\")\n",
        "    trackio_space_id=output_dir,                              # HF Space where the experiment tracking will be saved\n",
        "    push_to_hub=True,                                         # Optionally push trained model to Hugging Face Hub\n",
        "\n",
        "    use_liger_kernel=True,                                    # Enable Liger kernel optimizations for faster training\n",
        "    gradient_checkpointing=True,                              # Save memory by recomputing activations during backpropagation\n",
        "    gradient_checkpointing_kwargs={\"use_reentrant\": False},   # Additional args to prevent warnings during gradient checkpointing\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a1taGmD--0Y4"
      },
      "source": [
        "The next step is to initialize the GRPOTrainer, which manages the complete reinforcement learning loop.\n",
        "\n",
        "It receives the model name, reward functions, rollout function, and dataset defined earlier. From the model name, the trainer automatically initializes the model and tokenizer. It then coordinates interactions between the model and the environment, applies the defined reward signals, and updates the policy during training.\n",
        "\n",
        "Finally, calling `trainer.train()` starts the fine-tuning process, enabling the model to progressively improve its performance through iterative interaction and reinforcement learning.\n",
        "\n",
        "> Note: The training pipeline uses approximately 10.6 GB of GPU VRAM and can be adapted to different hardware configurations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "En43o4NZBP59"
      },
      "outputs": [],
      "source": [
        "model_name = \"google/functiongemma-270m-it\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "referenced_widgets": [
            "047d386e54704add95edd4beace781d7"
          ]
        },
        "id": "k8-SvqJcBP59",
        "outputId": "6a4d9276-fc91-4217-d3a2-51a18d222338"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3830121904.py:1: UserWarning: You are importing from 'rollout_func', which is an experimental feature. This API may change or be removed at any time without prior notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n",
            "  trainer = GRPOTrainer(\n",
            "The model is already on multiple devices. Skipping the move to device specified in `args`.\n",
            "`torch_dtype` is deprecated! Use `dtype` instead!\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "047d386e54704add95edd4beace781d7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 4/4 [00:00<00:00, 19.64it/s]\n"
          ]
        }
      ],
      "source": [
        "trainer = GRPOTrainer(\n",
        "    model=model_name,\n",
        "    reward_funcs=[reward_completion],\n",
        "    train_dataset=dataset,\n",
        "    args=grpo_config,\n",
        "    rollout_func=rollout_func,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e1PrBB7gBP59",
        "outputId": "61740a89-228c-4b3c-8e59-b4a3eb972c03"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 2, 'pad_token_id': 0}.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "* Trackio project initialized: huggingface\n",
            "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/browsergym-grpo-functiongemma-270m-it-dataset\n",
            "* Creating new space: https://huggingface.co/spaces/sergiopaniego/browsergym-grpo-functiongemma-270m-it\n",
            "* View dashboard by going to: https://sergiopaniego-browsergym-grpo-functiongemma-270m-it.hf.space/\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div><iframe src=\"https://sergiopaniego-browsergym-grpo-functiongemma-270m-it.hf.space/\" width=\"100%\" height=\"1000px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "* Created new run: sergiopaniego-1765969078\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: noop()\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: Click 'click(bid) - Click element with BrowserGym ID (the number in brackets\n",
            "Step 8: I will use the action `click()` to click the button.\n",
            "Step 9: noop()\n",
            "Step 10: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: noop()\n",
            "Step 2: noop()\n",
            "Step 3: Clicks ('13')\n",
            "Step 4: I will click 'Click Me!' using action 'click(bid)' on page 'Click Test Task' using a bid of '13'.\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: noop()\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: I will use the 'click(bid)' action.\n",
            "Step 2: mouse_click(bid)\n",
            "Step 3: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 4: Add action 'click(bid)' to Step 4.\n",
            "Step 5: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 9: noop()\n",
            "Step 10: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: noop()\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: Click('13')\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: noop()\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:liger_kernel.transformers.model.gemma3:It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.\n",
            "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py:282: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7095: UserWarning: \n",
            "Online softmax is disabled on the fly since Inductor decides to\n",
            "split the reduction. Cut an issue to PyTorch if this is an\n",
            "important use case and you want to speed it up with online\n",
            "softmax.\n",
            "\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='100' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [100/100 35:02, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>-0.877900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>1965.894400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>-0.830900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>10.616100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>2.320100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>1.887500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>-0.691600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>-0.764400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>31</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>32</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>33</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>34</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>36</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>37</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>38</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>39</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>40</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>41</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>42</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>43</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>44</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>45</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>46</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>47</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>48</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>49</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>51</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>52</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>53</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>54</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>55</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>56</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>57</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>58</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>59</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>60</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>61</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>62</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>63</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>64</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>65</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>66</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>67</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>68</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>69</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>70</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>71</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>72</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>73</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>74</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>75</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>76</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>77</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>78</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>79</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>80</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>81</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>82</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>83</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>84</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>85</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>86</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>87</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>88</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>89</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>90</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>91</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>92</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>93</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>94</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>95</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>96</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>97</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>98</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>99</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>100</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: Clicks ('13')\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: noop()\n",
            "Step 6: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 10: noop()\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: noop()\n",
            "Step 2: I will use action: click(bid) to click the button.\n",
            "Step 3: Yes, I can handle this. I will use the `click()` action to click the button.\n",
            "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 9: noop()\n",
            "Step 10: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 8: noop()\n",
            "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 10: Pass the button ID ('Click Me!') to the action \"click('bid')\".\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: noop()\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: I will click the button by emitting `click(bid)` and `fill(bid, text)` simultaneously.\n",
            "Step 6: noop()\n",
            "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: noop()\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: - Noop()\n",
            "Step 2: noop()\n",
            "Step 3: -noop()\n",
            "Step 4: noop()\n",
            "Step 5: Click('13')\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: noop()\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: noop()\n",
            "Step 2: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: noop()\n",
            "Step 6: Complete action: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: I will use the action 'click('bid') to click the button.\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: noop()\n",
            "Step 6: I call action Click (bid) on the page.\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: noop()\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: Oops()\n",
            "Step 2: noop()\n",
            "Step 3: fill(bid, text)\n",
            "Step 4: noop()\n",
            "Step 5: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: def click_button_on_page():\n",
            "Step 2: noop()\n",
            "Step 3: click(bid)\n",
            "Step 4: Click('13')\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: noop()\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: noop()\n",
            "Step 2: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 3: noop()\n",
            "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 5: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 6: I will click the button 'Click Me!' by using the action `click(bid)` and emitting a bid of 13.\n",
            "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: noop()\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: `click(bid)` - No action\n",
            "Step 2: - Noop()\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: noop()\n",
            "Step 10: I will click the button 'Click Me!' using the action 'click(bid)'.\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: noop()\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: Complete action: click(bid)\n",
            "Step 10: noop()\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: noop()\n",
            "Step 2: I will perform action 1: click('13') to complete the action.\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: noop()\n",
            "Step 2: noop()\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 8: noop()\n",
            "Step 9: Click ('13')\n",
            "Step 10: Add action 'fill(bid, text) - Send keyboard input' to perform the click.\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: noop()\n",
            "Step 2: Click('click(bid) - Bid')\n",
            "Step 3: noop()\n",
            "Step 4: noop()\n",
            "Step 5: noop()\n",
            "Step 6: noop()\n",
            "Step 7: noop()\n",
            "Step 8: noop()\n",
            "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
            "Step 10: noop()\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "\n",
            "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
            "[DEBUG] Processing prompt 1/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 2/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 3/4\n",
            "Step 1: click('13')\n",
            "[DEBUG] Processing prompt 4/4\n",
            "Step 1: click('13')\n",
            "* Run finished. Uploading logs to Trackio (please wait...)\n"
          ]
        }
      ],
      "source": [
        "trainer_stats = trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BZj4IG9ZBAix"
      },
      "source": [
        "In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "referenced_widgets": [
            "244ced1920694dbaae9bf98065b4f01d",
            "e3769ae107554c9ba38c1e491b15bf4e",
            "6d5b8bff73474faeb1d1b438fb4e8cec",
            "9f952f8eb63b42e4b38711737da5461e",
            "bd12780895064467b5be14e2ec3df114",
            "d1261c1083a74dca877e6eece6395d73",
            "999744cacd6a4fb08a1d4977ce2f06fd",
            "faa5e0fb4ee244689c0f9eef9902acf7",
            "6403bed2cd984ba18f74f416748c64e4",
            "38be017369524e2eb22050e7a0a18ec5",
            "b0720a4a2df948308011d4d87a288426",
            "889ca2520f4d446daf2e6ed16ce11d2e"
          ]
        },
        "id": "9oOBgEWeBP59",
        "outputId": "76bef375-fc6b-4fdd-a296-549a9b109b11"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "244ced1920694dbaae9bf98065b4f01d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Processing Files (0 / 0)      : |          |  0.00B /  0.00B            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e3769ae107554c9ba38c1e491b15bf4e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "New Data Upload               : |          |  0.00B /  0.00B            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6d5b8bff73474faeb1d1b438fb4e8cec",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...270m-it/training_args.bin: 100%|##########| 7.57kB / 7.57kB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9f952f8eb63b42e4b38711737da5461e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...a-270m-it/tokenizer.model: 100%|##########| 4.69MB / 4.69MB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bd12780895064467b5be14e2ec3df114",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...ma-270m-it/tokenizer.json: 100%|##########| 33.4MB / 33.4MB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d1261c1083a74dca877e6eece6395d73",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...270m-it/model.safetensors:   4%|3         | 41.9MB / 1.07GB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "No files have been modified since last commit. Skipping to prevent empty commit.\n",
            "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "999744cacd6a4fb08a1d4977ce2f06fd",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Processing Files (0 / 0)      : |          |  0.00B /  0.00B            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "faa5e0fb4ee244689c0f9eef9902acf7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "New Data Upload               : |          |  0.00B /  0.00B            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6403bed2cd984ba18f74f416748c64e4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...270m-it/training_args.bin: 100%|##########| 7.57kB / 7.57kB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "38be017369524e2eb22050e7a0a18ec5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...a-270m-it/tokenizer.model: 100%|##########| 4.69MB / 4.69MB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b0720a4a2df948308011d4d87a288426",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...270m-it/model.safetensors:   3%|3         | 33.5MB / 1.07GB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "889ca2520f4d446daf2e6ed16ce11d2e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  ...ma-270m-it/tokenizer.json: 100%|##########| 33.4MB / 33.4MB            "
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "No files have been modified since last commit. Skipping to prevent empty commit.\n",
            "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n"
          ]
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it/commit/a17de133c28ca7fddfcb2694c32f2791de5ddbe6', commit_message='End of training', commit_description='', oid='a17de133c28ca7fddfcb2694c32f2791de5ddbe6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/browsergym-grpo-functiongemma-270m-it'), pr_revision=None, pr_num=None)"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.save_model(output_dir)\n",
        "trainer.push_to_hub()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "talmc8b7nPXJ"
      },
      "source": [
        "## Load the Fine-Tuned Model and Run Inference\n",
        "\n",
        "The fine-tuned model is loaded to perform inference and evaluate its behavior on the target task.  \n",
        "In this case, the model is tested within the BrowserGym environment using OpenEnv, focusing on the *click* task from the MiniWoB++ benchmark, which is included among the available BrowserGym tasks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "referenced_widgets": [
            "c3879b716f37442a87d51b8414fe8c48"
          ]
        },
        "id": "iIDiaGVlBP5-",
        "outputId": "4dc0e365-e89f-40ba-b391-74c7efdc932d"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c3879b716f37442a87d51b8414fe8c48",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/1.07G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "\n",
        "model_name = \"sergiopaniego/browsergym-grpo-functiongemma-270m-it\" # Replace with your HF username or organization\n",
        "\n",
        "fine_tuned_model = AutoModelForCausalLM.from_pretrained(model_name, dtype=\"auto\", device_map=\"auto\")\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lyT-vudO5ekj"
      },
      "source": [
        "With the fine-tuned model loaded, testing can be conducted on the BrowserGym environment.\n",
        "To streamline evaluation, a reusable function is defined that executes multiple rounds of the task.\n",
        "This function follows the same interaction logic as used during training, generating model actions from observations, executing them in the environment, and printing the results step by step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "doAEIf5IBP5-"
      },
      "outputs": [],
      "source": [
        "def test_click_in_browsergym(env, model, tokenizer):\n",
        "    result = env.reset()\n",
        "    observation = result.observation\n",
        "\n",
        "    for step_num in range(max_steps):\n",
        "        if result.done:\n",
        "            break\n",
        "\n",
        "        # Create prompt from observation (text-only using accessibility tree)\n",
        "        goal = observation.goal or dataset_prompt\n",
        "        axtree = observation.axtree_txt or \"\"\n",
        "        error = observation.error if observation.last_action_error else \"\"\n",
        "\n",
        "        user_prompt = make_user_prompt(goal, step_num, axtree, error)\n",
        "        messages = [\n",
        "            {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
        "            {\"role\": \"user\", \"content\": user_prompt},\n",
        "        ]\n",
        "        prompt_text = tokenizer.apply_chat_template(\n",
        "            messages,\n",
        "            add_generation_prompt=True,\n",
        "            tokenize=False,\n",
        "        )\n",
        "\n",
        "        # Generate action\n",
        "        prompt_text = tokenizer.apply_chat_template(\n",
        "            messages,\n",
        "            add_generation_prompt=True,\n",
        "            tokenize=False,\n",
        "            enable_thinking=False,\n",
        "        )\n",
        "\n",
        "        model_inputs = tokenizer([prompt_text], return_tensors=\"pt\").to(model.device)\n",
        "\n",
        "        generated_ids = model.generate(\n",
        "            **model_inputs,\n",
        "            max_new_tokens=512\n",
        "        )\n",
        "        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
        "\n",
        "        # Decode and extract model response\n",
        "        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
        "\n",
        "        action_str = parse_action(generated_text)\n",
        "        print(f\"Step {step_num + 1}: {action_str}\")\n",
        "\n",
        "        # Take action in environment\n",
        "        result = env.step(BrowserGymAction(action_str=action_str))\n",
        "        observation = result.observation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9QvGD8f8CQx1"
      },
      "source": [
        "The `test_click_in_browsergym` function is called to run a full evaluation of the fine-tuned model on the BrowserGym *click* task.  \n",
        "\n",
        "The environment client is safely closed after testing using a `try/finally` block, ensuring that all resources are released even if an error occurs during execution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z77wlVb6BP5-",
        "outputId": "ed4ad094-1529-4cc7-8274-2782784efe2d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Step 1: click('13')\n"
          ]
        }
      ],
      "source": [
        "try:\n",
        "    test_click_in_browsergym(client, fine_tuned_model, tokenizer)\n",
        "finally:\n",
        "    client.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wHydP-ZVCcYK"
      },
      "source": [
        "## Summary and Next Steps\n",
        "\n",
        "This tutorial demonstrated how to fine-tune a FunctionGemma model using TRL, GRPO, and the BrowserGym environment from OpenEnv. Check out the following docs next:\n",
        "\n",
        "- Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
        "- Learn how to [fine-tune Gemma for vision tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora).\n",
        "- Learn how to [full model fine-tune using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune).\n",
        "- Learn how to [fine-tune Gemma using Hugging Face Transformers with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora).  \n",
        "- Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
        "- Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
        "- Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
