{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b",
   "metadata": {
    "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b"
   },
   "source": [
    "# OpenEnv Wordle with GRPO using TRL\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb)\n",
    "\n",
    "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
    "\n",
    "\n",
    "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a model that learns to **play Wordle**, a word-guessing game, through interaction and reinforcement.\n",
    "\n",
    "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project!  \n",
    "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview)  \n",
    "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
    "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n",
    "\n",
    "\n",
    "An **agentic environment** is a setting where a model can take actions, observe outcomes, and adjust its behavior based on feedback, similar to how humans learn from trial and error.\n",
    "In this case, the agent interacts with the **Wordle** environment through the [**OpenEnv**](https://github.com/meta-pytorch/OpenEnv) framework, which standardizes multi-agent and RL-style text environments.\n",
    "\n",
    "[Wordle](https://en.wikipedia.org/wiki/Wordle) is a popular word puzzle where the player must guess a secret five-letter word within six tries.  \n",
    "After each guess, feedback indicates whether each letter is:\n",
    "- 🟩 **Correct and in the right position**\n",
    "- 🟨 **Present but in the wrong position**\n",
    "- ⬛ **Not in the word**\n",
    "\n",
    "This feedback loop makes Wordle a perfect environment for **RL with LLMs**, where the goal is to maximize the probability of guessing the correct word efficiently.\n",
    "\n",
    "\n",
    "We'll fine-tune a model using **GRPO** (Group Relative Policy Optimization) via TRL.  \n",
    "The agent will:\n",
    "1. Generate guesses based on the game state and feedback.\n",
    "2. Receive structured feedback from the environment after each guess.\n",
    "3. Learn to improve its guessing strategy over time through reward signals.\n",
    "\n",
    "\n",
    "## Install dependencies\n",
    "\n",
    "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**.  \n",
    "We'll also install the **OpenEnv** framework (for the environment), **trackio** (for logging and monitoring training runs), and **vLLM** (for efficient generation)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4812fbf-3f61-481e-9a64-95277eada9c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -Uq trl git+https://github.com/meta-pytorch/OpenEnv.git trackio vllm==0.11.2 bitsandbytes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148",
   "metadata": {
    "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148"
   },
   "source": [
    "### Log in to Hugging Face\n",
    "\n",
    "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21756ac0-78b2-495d-8137-28dfa9faae6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "rpFT3PxHT5Uc",
   "metadata": {
    "id": "rpFT3PxHT5Uc"
   },
   "source": [
    "## Initialize the Environment\n",
    "\n",
    "Let's begin by setting up the environment that will be used during training.  \n",
    "For this task, we'll rely on the **TextArena** environment from **OpenEnv**, which exposes a familiar Gymnasium-style API (`reset()`, `step()`, etc.) to simplify interaction.\n",
    "\n",
    "In this example, we'll connect to the hosted environment at [burtenshaw/textarena](https://huggingface.co/spaces/burtenshaw/textarena).  \n",
    "For production use or custom configurations, we **strongly recommend** running the environment locally via Docker. The hosted versions on the Hub currently have limited concurrency support, so duplicating the Space to your own account is the preferred approach in those cases.\n",
    "\n",
    "For more information, refer to the [TRL-OpenEnv documentation](https://huggingface.co/docs/trl/main/en/openenv).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "rZimqp1UTIV_",
   "metadata": {},
   "outputs": [],
   "source": [
    "from envs.textarena_env import TextArenaEnv\n",
    "\n",
    "textarena_url = \"https://burtenshaw-textarena.hf.space\" # Duplicate the Space and update this!\n",
    "env = TextArenaEnv(base_url=textarena_url)\n",
    "# textarena_url = \"burtenshaw/textarena\"\n",
    "# env = TextArenaEnv.from_hub(repo_id=textarena_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "hARwiQm8ehw3",
   "metadata": {
    "id": "hARwiQm8ehw3"
   },
   "source": [
    "## Init model and tokenizer\n",
    "\n",
    "We'll use [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B), a lightweight instruction-tuned model that works well for quick experiments.  \n",
    "Despite its small size, it can still learn interesting strategies during fine-tuning.  \n",
    "If you have stronger hardware, you can easily scale up to larger models.\n",
    "\n",
    "We'll load the **tokenizer** (needed for text processing) here.  \n",
    "The **model** itself will be handled internally by TRL during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "lR7usp2Dd-JK",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_name = \"Qwen/Qwen3-1.7B\" #\"Qwen/Qwen2.5-0.5B-Instruct\" # \"Qwen/Qwen3-0.6B\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0oojh2i0ey88",
   "metadata": {
    "id": "0oojh2i0ey88"
   },
   "source": [
    "## Rollout function with helpers\n",
    "\n",
    "The **rollout function** defines how the agent interacts with the environment during GRPO training.\n",
    "It's responsible for generating model completions, collecting feedback (rewards), and returning all necessary information for optimization.\n",
    "\n",
    "In this setup:\n",
    "- The function is called automatically by the **GRPOTrainer** during each training step.  \n",
    "- It uses the trainer's built-in `generate_rollout_completions()` method for efficient generation with vLLM in colocate mode.\n",
    "- Each rollout represents a full interaction loop. The model guesses, receives feedback from Wordle, and updates based on reward signals.\n",
    "\n",
    "The rewards track different aspects of the agent's performance. Helper functions (like `rollout_once`) handle one episode of interaction, keeping the main `rollout_func` clean and modular.\n",
    "\n",
    "This modular approach allows GRPO to efficiently sample, evaluate, and improve the model's guessing strategy through reinforcement learning.\n",
    "\n",
    "First, we define the `system_prompt` that guides the model's behavior as an expert Wordle solver with strategic reasoning and structured responses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "QlUHqvZV6ytz",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title System prompt (click to expand)\n",
    "system_prompt = \"\"\"\n",
    "You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies.\n",
    "\n",
    "## GAME RULES\n",
    "\n",
    "1. The target is a 5-letter English word\n",
    "2. You have 6 attempts to guess the correct word\n",
    "3. After each guess, you receive color-coded feedback:\n",
    "   - GREEN: Letter is correct and in the correct position\n",
    "   - YELLOW: Letter is in the word but in the wrong position\n",
    "   - GRAY: Letter is not in the word at all\n",
    "4. All guesses must be valid 5-letter English words\n",
    "5. You cannot reuse a word you've already guessed\n",
    "\n",
    "## RESPONSE FORMAT\n",
    "\n",
    "Only respond with your next guess in square brackets, e.g., [crane].\n",
    "\n",
    "Format:\n",
    "```\n",
    "[guess]\n",
    "```\n",
    "\n",
    "## STRATEGIC APPROACH\n",
    "\n",
    "Do not repeat the same guess twice.\n",
    "\n",
    "### Opening Strategy\n",
    "- Start with words rich in common vowels (A, E, I, O, U) and consonants (R, S, T, L, N)\n",
    "- Optimal starters: CRANE, SLATE, STARE, AROSE, IRATE\n",
    "- Prioritize words that test the most common letters in different positions\n",
    "\n",
    "### Mid-Game Strategy\n",
    "- Use confirmed GREEN letters in their correct positions\n",
    "- Place YELLOW letters in different positions than where they appeared\n",
    "- Eliminate GRAY letters from consideration\n",
    "- If multiple letters are unknown, prioritize common letter combinations (TH, CH, ST, ER, etc.)\n",
    "- Consider letter frequency: E is most common, followed by A, R, I, O, T, N, S\n",
    "\n",
    "### Vowel Placement\n",
    "- Most 5-letter words have 2 vowels\n",
    "- Common patterns: vowel-consonant-vowel (like CRANE) or consonant-vowel-vowel-consonant-vowel (like QUEUE)\n",
    "- If you have 1-2 vowels confirmed, consider where the others might be\n",
    "\n",
    "### Advanced Tactics\n",
    "- Use \"sacrificial\" guesses to test multiple new letters if you have attempts to spare\n",
    "- Avoid repeating letter patterns unless you're certain (e.g., SPEED has two E's)\n",
    "- Think about word endings: -ER, -LY, -ED, -ING are common but may not fit the 5-letter constraint\n",
    "- Consider less common letters (Q, X, Z, J) only when you've eliminated the most common options\n",
    "\n",
    "### Common Pitfalls to Avoid\n",
    "- Don't reuse letters marked GRAY (eliminated letters)\n",
    "- Don't place YELLOW letters in the same position they appeared\n",
    "- Don't ignore confirmed GREEN letters in future guesses\n",
    "- Don't guess words that contradict known information\n",
    "\n",
    "## EXAMPLES\n",
    "\n",
    "### Example 1: Opening Guess\n",
    "\"Starting with a word that tests common vowels and consonants in varied positions.\"\n",
    "[crane]\n",
    "\n",
    "### Example 2: After Receiving Feedback\n",
    "Previous guess: CRANE\n",
    "Feedback: C=gray, R=yellow, A=green, N=gray, E=yellow\n",
    "\n",
    "\"A is confirmed in position 2. R and E are in the word but need different positions. C and N are eliminated. I'll try a word with A in position 2, and test R and E in new positions along with common letters like S and T.\"\n",
    "[spare]\n",
    "\n",
    "### Example 3: Narrowing Down\n",
    "Previous guesses: CRANE (C=gray, R=yellow, A=green, N=gray, E=yellow), SPARE (S=gray, P=gray, A=green, R=green, E=green)\n",
    "Feedback summary: _ARE_ with R in position 4, A in position 2, E in position 5\n",
    "\n",
    "\"I have _AR E_ confirmed. Positions 1 and 3 are unknown. Common letters to try: T, L, D, B, F, G. Testing with TARED.\"\n",
    "[tared]\n",
    "\n",
    "### Example 4: Final Deduction\n",
    "Previous feedback shows: _ARED with position 1 unknown and all common consonants tested\n",
    "\n",
    "\"Only position 1 remains. I've eliminated S, P, C, N. Common starting consonants left are B, F, G, H. BARED is a common word.\"\n",
    "[bared]\n",
    "\n",
    "## LETTER FREQUENCY REFERENCE\n",
    "\n",
    "Most common letters in 5-letter words (in order):\n",
    "S, E, A, O, R, I, L, T, N, U, D, Y, C, P, M, H, G, B, K, F\n",
    "\n",
    "Most common starting letters:\n",
    "S, C, B, T, P, A, F, G, D, M\n",
    "\n",
    "Most common ending letters:\n",
    "E, Y, T, S, R, L, N, D\n",
    "\n",
    "## IMPORTANT CONSTRAINTS\n",
    "\n",
    "- Use lowercase only\n",
    "- One guess per response\n",
    "- Must be exactly 5 letters\n",
    "- Must be a real English word from standard dictionaries\n",
    "- Never repeat a previous guess\n",
    "- Always include brief reasoning before your guess\n",
    "\n",
    "## YOUR GOAL\n",
    "\n",
    "Solve the Wordle in as few guesses as possible by strategically using feedback to eliminate impossible words and narrow down the solution space efficiently.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "rUOAm7o-kJ5U",
   "metadata": {
    "id": "rUOAm7o-kJ5U"
   },
   "source": [
    "Now, let's define the `rollout_func`:\n",
    "\n",
    "This function orchestrates the interaction between the model and the Wordle environment. For each prompt in the batch, it runs the episode interaction, collecting rewards and model outputs for GRPO optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8a9e7a62-fff9-4caa-9500-dd278f49ec0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout_func(prompts, trainer=None):\n",
    "    \"\"\"\n",
    "    Rollout function for GRPO training with environment interaction.\n",
    "\n",
    "    This function is called by GRPOTrainer to generate completions and compute rewards.\n",
    "    In colocate mode, it uses trainer.generate_rollout_completions() for inference.\n",
    "\n",
    "    Args:\n",
    "        prompts: List of prompts to generate from\n",
    "        trainer: GRPOTrainer instance containing context and configuration\n",
    "\n",
    "    Returns:\n",
    "        Dictionary with prompt_ids, completion_ids, logprobs, and reward signals\n",
    "    \"\"\"\n",
    "    episode_prompt_ids = []\n",
    "    episode_completion_ids = []\n",
    "    episode_logprobs = []\n",
    "    correctness_rewards = []\n",
    "    green_rewards = []\n",
    "    yellow_rewards = []\n",
    "    repetition_rewards = []\n",
    "\n",
    "    for prompt_text in prompts:\n",
    "        episode = rollout_once(\n",
    "            trainer=trainer,\n",
    "            env=env,\n",
    "            tokenizer=tokenizer,\n",
    "            dataset_prompt=prompt_text,\n",
    "            system_prompt=system_prompt,\n",
    "            max_turns=6,\n",
    "        )\n",
    "        episode_prompt_ids.append(episode[\"prompt_ids\"])\n",
    "        episode_completion_ids.append(episode[\"completion_ids\"])\n",
    "        episode_logprobs.append(episode[\"logprobs\"])\n",
    "        correctness_rewards.append(episode[\"correct_reward\"])\n",
    "        green_rewards.append(episode[\"green_reward\"])\n",
    "        yellow_rewards.append(episode[\"yellow_reward\"])\n",
    "        repetition_rewards.append(episode[\"repetition_reward\"])\n",
    "\n",
    "    return {\n",
    "        \"prompt_ids\": episode_prompt_ids,\n",
    "        \"completion_ids\": episode_completion_ids,\n",
    "        \"logprobs\": episode_logprobs,\n",
    "        \"correct_reward\": correctness_rewards,\n",
    "        \"green_reward\": green_rewards,\n",
    "        \"yellow_reward\": yellow_rewards,\n",
    "        \"repetition_reward\": repetition_rewards,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "mJ4D8zvAkQLh",
   "metadata": {
    "id": "mJ4D8zvAkQLh"
   },
   "source": [
    "### Define `rollout_once`\n",
    "\n",
    "The `rollout_once` function runs **one full interaction loop** between the model and the Wordle environment using the trainer's generation method.  \n",
    "It executes a mini episode of gameplay, from generating a guess to receiving and processing feedback.\n",
    "\n",
    "Here's the step-by-step breakdown:\n",
    "\n",
    "1. **Environment reset:** Start a new game session and initialize the observation.  \n",
    "2. **Prompt construction:** Combine the system prompt, current state, and user messages to form the model input.  \n",
    "3. **Generation:** Use `trl.experimental.openenv.generate_rollout_completions()` to produce the model's guess efficiently.  \n",
    "4. **Feedback extraction:** Parse the environment's response using helpers like `extract_guess()` and `extract_wordle_feedback()`.  \n",
    "5. **Reward calculation:** Compute rewards based on correctness, green/yellow feedback, and repetition penalty.\n",
    "6. **Return structured rollout data:** Includes prompt/completion IDs, logprobs, and all computed reward components.\n",
    "\n",
    "This modular design ensures that each episode can be processed independently while still providing rich feedback for the **GRPO training loop**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipython-input-1881463685.py:4: UserWarning: You are importing from 'trl.experimental'. APIs here are unstable and may change or be removed without notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n",
      "  from trl.experimental.openenv import generate_rollout_completions\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "from envs.textarena_env import TextArenaAction\n",
    "from envs.textarena_env.rewards import extract_feedback_counts, extract_guess, extract_wordle_feedback\n",
    "from trl.experimental.openenv import generate_rollout_completions\n",
    "\n",
    "\n",
    "def rollout_once(trainer, env, tokenizer, dataset_prompt, system_prompt, max_turns):\n",
    "    \"\"\"\n",
    "    Execute one full Wordle episode with the model.\n",
    "\n",
    "    This function uses generate_rollout_completions() instead of manual vLLM handling,\n",
    "    making the code cleaner and more maintainable.\n",
    "    \"\"\"\n",
    "    result = env.reset()\n",
    "    observation = result.observation\n",
    "\n",
    "    prompt_ids = []\n",
    "    completion_ids = []\n",
    "    logprobs = []\n",
    "    raw_rewards = []\n",
    "    green_scores = []\n",
    "    yellow_scores = []\n",
    "    repetition_scores = []\n",
    "    correct_scores = []\n",
    "    guess_counts = defaultdict(int)\n",
    "\n",
    "    for _turn in range(max_turns):\n",
    "        # when the game is over the environment will return a done=True\n",
    "        if result.done:\n",
    "            break\n",
    "\n",
    "        # set up the prompt for the model\n",
    "        base_prompt = observation.prompt or dataset_prompt\n",
    "        user_prompt = make_user_prompt(base_prompt, observation.messages)\n",
    "        messages = [\n",
    "            {\"role\": \"system\", \"content\": system_prompt},\n",
    "            {\"role\": \"user\", \"content\": user_prompt},\n",
    "        ]\n",
    "        prompt_text = tokenizer.apply_chat_template(\n",
    "            messages,\n",
    "            add_generation_prompt=True,\n",
    "            tokenize=False,\n",
    "            enable_thinking=False,\n",
    "        )\n",
    "\n",
    "        # Generate using trainer's built-in method (much cleaner!)\n",
    "        rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]\n",
    "        prompt_ids.extend(rollout_outputs[\"prompt_ids\"])\n",
    "        completion_ids.extend(rollout_outputs[\"completion_ids\"])\n",
    "        logprobs.extend(rollout_outputs[\"logprobs\"])\n",
    "        completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n",
    "            rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n",
    "        )\n",
    "\n",
    "        # extract the guess from the completion\n",
    "        guess = extract_guess(completion_text)\n",
    "\n",
    "        # step the environment with the guess\n",
    "        result = env.step(TextArenaAction(message=guess))\n",
    "        raw_rewards.append(float(result.reward or 0.0))\n",
    "        observation = result.observation\n",
    "        correct_score = float(result.reward or 0.0)\n",
    "        feedback = extract_wordle_feedback(observation)\n",
    "\n",
    "        # Update guess counts\n",
    "        previous_occurrences = guess_counts[guess]\n",
    "        repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts))\n",
    "        guess_counts[guess] += 1\n",
    "\n",
    "        # calculate custom reward signals from the feedback\n",
    "        if not feedback:\n",
    "            green_score = 0.0\n",
    "            yellow_score = 0.0\n",
    "        else:\n",
    "            green_count, yellow_count = extract_feedback_counts(feedback)\n",
    "            green_score = green_count / 5.0\n",
    "            yellow_score = yellow_count / 5.0\n",
    "\n",
    "        repetition_scores.append(repetition_score)\n",
    "        green_scores.append(green_score)\n",
    "        yellow_scores.append(yellow_score)\n",
    "        correct_scores.append(correct_score)\n",
    "\n",
    "    correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)\n",
    "\n",
    "    return {\n",
    "        \"prompt_ids\": prompt_ids,\n",
    "        \"completion_ids\": completion_ids,\n",
    "        \"logprobs\": logprobs,\n",
    "        \"raw_rewards\": raw_rewards,\n",
    "        \"correct_reward\": correct_reward_value,\n",
    "        \"green_reward\": green_scores[-1] if green_scores else 0.0,\n",
    "        \"yellow_reward\": yellow_scores[-1] if yellow_scores else 0.0,\n",
    "        \"repetition_reward\": repetition_scores[-1] if repetition_scores else 0.0,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cipvIDzcoF3C",
   "metadata": {
    "id": "cipvIDzcoF3C"
   },
   "source": [
    "### Helper functions\n",
    "\n",
    "Supporting utilities used in `rollout_once`:\n",
    "\n",
    "- **`make_user_prompt`**: builds the user prompt combining the base text and previous game messages.\n",
    "- **`format_history`**: formats the conversation log for consistent context.\n",
    "- **`scale_repetition_score`**: applies a penalty when guesses are repeated to encourage exploration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bVeKfbaK7C4z",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Helpers definition (click to expand)\n",
    "def make_user_prompt(prompt_text, messages):\n",
    "    \"\"\"Builds a structured user prompt combining the task description and message history\"\"\"\n",
    "    history = format_history(messages)\n",
    "    prompt_section = prompt_text.strip() if prompt_text.strip() else \"Wordle-v0\"\n",
    "    history_section = history if history else \"[PROMPT] Awaiting first feedback.\"\n",
    "    return (\n",
    "        f\"Game prompt:\\n{prompt_section}\\n\\n\"\n",
    "        f\"Conversation so far:\\n{history_section}\\n\\n\"\n",
    "        \"Reply with your next guess enclosed in square brackets.\"\n",
    "    )\n",
    "\n",
    "def format_history(messages):\n",
    "    \"\"\"Formats the message history with tags for clear conversational context\"\"\"\n",
    "    lines = []\n",
    "    for message in messages:\n",
    "        tag = message.category or \"MESSAGE\"\n",
    "        content = message.content.strip()\n",
    "        if not content:\n",
    "            continue\n",
    "        lines.append(f\"[{tag}] {content}\")\n",
    "    return \"\\n\".join(lines)\n",
    "\n",
    "def scale_repetition_score(previous_occurrences, max_occurrences):\n",
    "    \"\"\"Scale the repetition score based on the number of previous occurrences from 0 to 1\"\"\"\n",
    "    if max_occurrences == 0:\n",
    "        return 0.0\n",
    "    return (max_occurrences - previous_occurrences) / max_occurrences"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "i3G0x0RheYkL",
   "metadata": {
    "id": "i3G0x0RheYkL"
   },
   "source": [
    "## Define reward functions\n",
    "\n",
    "To guide the agent's learning process, we define simple reward functions that map the feedback from the environment into numeric signals.  \n",
    "Each function corresponds to a specific aspect of the **Wordle** game:\n",
    "\n",
    "- ✅ **`reward_correct`**: rewards the model when it guesses the correct word.  \n",
    "- 🟩 **`reward_greens`**: rewards letters correctly placed (green feedback).  \n",
    "- 🟨 **`reward_yellows`**: rewards letters that are present but misplaced (yellow feedback).  \n",
    "- 🔁 **`reward_repetition`**: rewards diverse guessing by scoring based on guess uniqueness.\n",
    "\n",
    "These functions return lists of float values that the **GRPOTrainer** uses during optimization.  \n",
    "By combining them, the model learns to balance correctness, information gathering, and exploration in its guessing strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "61e454d1-9abc-42a6-868c-a24e9801ac44",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reward_correct(completions, **kwargs):\n",
    "    rewards = kwargs.get(\"correct_reward\") if kwargs else None\n",
    "    if rewards is None:\n",
    "        return [0.0 for _ in completions]\n",
    "    return [float(r) for r in rewards]\n",
    "\n",
    "\n",
    "def reward_greens(completions, **kwargs):\n",
    "    rewards = kwargs.get(\"green_reward\") if kwargs else None\n",
    "    if rewards is None:\n",
    "        return [0.0 for _ in completions]\n",
    "    return [float(r) for r in rewards]\n",
    "\n",
    "\n",
    "def reward_yellows(completions, **kwargs):\n",
    "    rewards = kwargs.get(\"yellow_reward\") if kwargs else None\n",
    "    if rewards is None:\n",
    "        return [0.0 for _ in completions]\n",
    "    return [float(r) for r in rewards]\n",
    "\n",
    "\n",
    "def reward_repetition(completions, **kwargs):\n",
    "    rewards = kwargs.get(\"repetition_reward\") if kwargs else None\n",
    "    if rewards is None:\n",
    "        return [0.0 for _ in completions]\n",
    "    return [float(r) for r in rewards]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "RN5VkehojyOJ",
   "metadata": {
    "id": "RN5VkehojyOJ"
   },
   "source": [
    "## Create dataset\n",
    "\n",
    "We create a dataset with repeated prompts to control the number of training episodes.  \n",
    "Each entry in the dataset triggers one rollout episode during training. The `dataset_prompt` provides the initial instruction to the model before each game starts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "deab8040-9b51-4c52-befe-e48578cdbb53",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "dataset_size = 1000\n",
    "dataset_prompt = \"Play Wordle like an expert.\"\n",
    "\n",
    "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "DnR90-D66Fm_",
   "metadata": {
    "id": "DnR90-D66Fm_"
   },
   "source": [
    "## Set GRPO Config\n",
    "\n",
    "Next, we define the **GRPOConfig**, which controls all key training parameters.  \n",
    "This configuration specifies how the model interacts with **vLLM**, manages memory, and logs results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "20ac9371-af1a-4b9e-b678-33d6a3bf07cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import GRPOConfig\n",
    "\n",
    "output_dir = \"wordle-grpo-Qwen3-1.7B\"\n",
    "\n",
    "grpo_config = GRPOConfig(\n",
    "    # Training schedule / optimization\n",
    "    num_train_epochs = 1,                 # Number of full dataset passes\n",
    "    learning_rate = 5e-6,                 # Learning rate for the optimizer\n",
    "    gradient_accumulation_steps = 64,     # Accumulate gradients over multiple steps\n",
    "    per_device_train_batch_size = 1,      # Batch size per GPU (number of prompts processed together)\n",
    "    warmup_steps = 20,                    # Steps for learning rate warmup\n",
    "\n",
    "    # GRPO configuration\n",
    "    num_generations = 2,                  # Number of rollout episodes per prompt (for variance reduction)\n",
    "    max_completion_length = 8,            # Maximum tokens generated per model response\n",
    "    max_prompt_length = 1400,             # Maximum prompt length in tokens (Wordle prompts ~1296 tokens)\n",
    "\n",
    "    # vLLM configuration\n",
    "    use_vllm = True,                      # Enable vLLM for faster inference during rollouts\n",
    "    vllm_mode = \"colocate\",               # Run vLLM in colocate mode (same process as training)\n",
    "    vllm_gpu_memory_utilization = 0.1,    # Fraction of GPU memory reserved for vLLM inference\n",
    "\n",
    "    # Logging / reporting\n",
    "    output_dir = output_dir,              # Directory for checkpoints and logs\n",
    "    report_to=\"trackio\",                  # Experiment tracking tool (integrates with HF Spaces)\n",
    "    trackio_space_id = output_dir,        # HF Space where experiment tracking will be saved\n",
    "    logging_steps = 1,                    # Log metrics every N steps\n",
    "    save_steps = 10,                      # Interval for saving checkpoints\n",
    "\n",
    "    # Memory optimization\n",
    "    gradient_checkpointing = True,        # Enable activation recomputation to save memory\n",
    "    gradient_checkpointing_kwargs = {\"use_reentrant\": False},  # Use non-reentrant checkpointing\n",
    "\n",
    "    # Hub integration\n",
    "    push_to_hub = True,                  # Set True to automatically push model to Hugging Face Hub\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "Mrs9bAr06H2G",
   "metadata": {
    "id": "Mrs9bAr06H2G"
   },
   "source": [
    "## Create `GRPOTrainer` and start training\n",
    "\n",
    "Now we initialize the `GRPOTrainer`, which manages the entire reinforcement learning loop.\n",
    "\n",
    "It takes the model, tokenizer, reward functions, rollout function, and dataset defined earlier.  \n",
    "The trainer coordinates the interaction between the model and the environment, applies the reward signals, and updates the policy.\n",
    "\n",
    "Finally, we call `trainer.train()` to start the fine-tuning process and let the model learn to play Wordle through feedback and iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import GRPOTrainer\n",
    "\n",
    "trainer = GRPOTrainer(\n",
    "    model=model_name,\n",
    "    processing_class=tokenizer,\n",
    "    reward_funcs=[\n",
    "        reward_correct,\n",
    "        reward_greens,\n",
    "        reward_yellows,\n",
    "        reward_repetition,\n",
    "    ],\n",
    "    train_dataset=dataset,\n",
    "    args=grpo_config,\n",
    "    rollout_func=rollout_func,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "HkDpR4dH4VxK",
   "metadata": {
    "id": "HkDpR4dH4VxK"
   },
   "source": [
    "Show memory stats before training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "hxr5Rv0wVu_P",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.\n",
      "10.516 GB of memory reserved.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "gpu_stats = torch.cuda.get_device_properties(0)\n",
    "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
    "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
    "\n",
    "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
    "print(f\"{start_gpu_memory} GB of memory reserved.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "U1Oyh63J4UPV",
   "metadata": {
    "id": "U1Oyh63J4UPV"
   },
   "source": [
    "And train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "55c52596-4082-405b-b626-4b0401c2ce9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151645}.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Trackio project initialized: huggingface\n",
      "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/wordle-grpo-Qwen3-1.7B-dataset\n",
      "* Creating new space: https://huggingface.co/spaces/sergiopaniego/wordle-grpo-Qwen3-1.7B\n",
      "* View dashboard by going to: https://sergiopaniego-wordle-grpo-Qwen3-1.7B.hf.space/\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://sergiopaniego-wordle-grpo-Qwen3-1.7B.hf.space/\" width=\"100%\" height=\"1000px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Created new run: sergiopaniego-1763727287\n",
      "INFO 11-21 12:14:47 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='31' max='31' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [31/31 1:25:09, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.008300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.001900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.015100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.008700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.009800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.006700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.006100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.004400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>-0.002100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.007500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.008400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.008000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.007800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>-0.002400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>-0.003200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>-0.006000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>-0.008300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>-0.011000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>-0.004200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>-0.001700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>-0.004100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>-0.011600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>-0.006400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>-0.009100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.003200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.005100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>-0.002800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.011500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>-0.010500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>-0.006400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 11-21 12:16:45 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:19:33 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:22:23 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:25:11 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:27:59 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:30:47 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:33:36 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:36:24 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:39:12 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:42:38 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:45:41 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:48:28 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:51:17 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:54:05 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:56:52 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 12:59:08 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:01:36 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:04:24 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:06:43 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:10:09 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:12:22 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:14:22 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:17:12 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:19:13 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:22:01 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:24:52 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:27:41 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:30:32 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:33:22 [block_pool.py:292] Successfully reset prefix cache\n",
      "INFO 11-21 13:37:30 [block_pool.py:292] Successfully reset prefix cache\n",
      "* Run finished. Uploading logs to Trackio (please wait...)\n"
     ]
    }
   ],
   "source": [
    "trainer_stats = trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "o-hEO4oK4ZXr",
   "metadata": {
    "id": "o-hEO4oK4ZXr"
   },
   "source": [
    "Show memory stats after training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "zuHTwuxAVp8p",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5231.7046 seconds used for training.\n",
      "87.2 minutes used for training.\n",
      "Peak reserved memory = 36.68 GB.\n",
      "Peak reserved memory for training = 26.164 GB.\n",
      "Peak reserved memory % of max memory = 92.727 %.\n",
      "Peak reserved memory for training % of max memory = 66.143 %.\n"
     ]
    }
   ],
   "source": [
    "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
    "used_memory_for_training = round(used_memory - start_gpu_memory, 3)\n",
    "used_percentage = round(used_memory / max_memory * 100, 3)\n",
    "training_memory_percentage = round(used_memory_for_training / max_memory * 100, 3)\n",
    "\n",
    "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
    "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
    "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
    "print(f\"Peak reserved memory for training = {used_memory_for_training} GB.\")\n",
    "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
    "print(f\"Peak reserved memory for training % of max memory = {training_memory_percentage} %.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "13e9fd4e-e7a5-468d-a25a-3f7d2794201f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e8053fa4a524b03842a23f987f0b09b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing Files (0 / 0)      : |          |  0.00B /  0.00B            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b42c51a14384ea0b222a665ae0f35dd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "New Data Upload               : |          |  0.00B /  0.00B            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df766eead3e94399918561f47e4c94c2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...n3-1.7B/training_args.bin: 100%|##########| 7.31kB / 7.31kB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "685bd447dd6d4e4b883e22311bca1982",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...Qwen3-1.7B/tokenizer.json: 100%|##########| 11.4MB / 11.4MB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "af0d3d7a321c493c9a67eb7e7f9167d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...adapter_model.safetensors: 100%|##########| 25.7MB / 25.7MB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ce0cdfe83624a458a2205925f88f4f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...0002-of-00002.safetensors:   2%|2         | 41.9MB / 1.91GB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a3170d4bcbb4e40b97921cf57fb9d3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...0001-of-00002.safetensors:   1%|          | 33.5MB / 4.97GB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No files have been modified since last commit. Skipping to prevent empty commit.\n",
      "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c9234157def4dddb7c08a49d9c83d4d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing Files (0 / 0)      : |          |  0.00B /  0.00B            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e49dbc0f9a741a28b53930ece8de736",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "New Data Upload               : |          |  0.00B /  0.00B            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff14a1ce9dcf4250add61cfb9ae262f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...n3-1.7B/training_args.bin: 100%|##########| 7.31kB / 7.31kB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7b733c562d7432a8a18fee70f6a0248",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...Qwen3-1.7B/tokenizer.json: 100%|##########| 11.4MB / 11.4MB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0539d5ee34234fb08cb93996fa7a26ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...0001-of-00002.safetensors:   1%|          | 41.9MB / 4.97GB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed426ec7315e405e976465fdf34f0eb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...0002-of-00002.safetensors:   2%|1         | 33.5MB / 1.91GB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a5073ad35954e4a96c80f3fedf91bc9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  ...adapter_model.safetensors: 100%|##########| 25.7MB / 25.7MB            "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No files have been modified since last commit. Skipping to prevent empty commit.\n",
      "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n"
     ]
    },
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B/commit/b81b548867ab35601d3bda845ed5e18147550e30', commit_message='End of training', commit_description='', oid='b81b548867ab35601d3bda845ed5e18147550e30', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/wordle-grpo-Qwen3-1.7B'), pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.close()\n",
    "trainer.save_model(output_dir)\n",
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "wQyVb1nAxWld",
   "metadata": {
    "id": "wQyVb1nAxWld"
   },
   "source": [
    "## Load the Fine-Tuned Model and Run Inference\n",
    "\n",
    "Now let's test our fine-tuned model by loading the **adapter** and running **inference**.  \n",
    "We begin by loading the **base model**, attaching the adapter, and obtaining the final fine-tuned model ready for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "JcTeeSBXxWWF",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n",
      "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n",
      "You are not authenticated with the Hugging Face Hub in this notebook.\n",
      "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "281b1cf074fd4d60bb754906a0764865",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e129fb465f1a41c1bdf2495d14143458",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model_name = \"sergiopaniego/wordle-grpo-Qwen3-1.7B\" # Replace with your HF username or organization\n",
    "\n",
    "fine_tuned_model = AutoModelForCausalLM.from_pretrained(model_name, dtype=\"auto\", device_map=\"auto\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ZZ-8K433gKK",
   "metadata": {
    "id": "5ZZ-8K433gKK"
   },
   "source": [
    "Now that we have the fine-tuned model loaded, we can start playing Wordle.  \n",
    "To make this easier, we'll define a reusable function so we can play multiple rounds.  \n",
    "The function implements the same logic we explored earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "hUFkr5aEYaKf",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_TURNS=6\n",
    "\n",
    "def play_wordle(env, model, tokenizer):\n",
    "    result = env.reset()\n",
    "    observation = result.observation\n",
    "\n",
    "    print(\"📜 Initial Prompt:\\n\" + observation.prompt)\n",
    "\n",
    "    for turn in range(MAX_TURNS):\n",
    "        if result.done:\n",
    "            break\n",
    "\n",
    "        user_prompt = make_user_prompt(observation.prompt, observation.messages)\n",
    "        messages = [\n",
    "            {\"role\": \"system\", \"content\": system_prompt},\n",
    "            {\"role\": \"user\", \"content\": user_prompt},\n",
    "        ]\n",
    "        prompt_text = tokenizer.apply_chat_template(\n",
    "            messages,\n",
    "            add_generation_prompt=True,\n",
    "            tokenize=False,\n",
    "            enable_thinking=False,\n",
    "        )\n",
    "\n",
    "        model_inputs = tokenizer([prompt_text], return_tensors=\"pt\").to(model.device)\n",
    "\n",
    "        generated_ids = model.generate(\n",
    "            **model_inputs,\n",
    "            max_new_tokens=512\n",
    "        )\n",
    "        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
    "\n",
    "        # Decode and extract model response\n",
    "        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
    "        guess = extract_guess(generated_text)\n",
    "\n",
    "        print(f\"\\n🎯 Turn {turn}: model replied with -> {generated_text}\")\n",
    "        print(f\"   Parsed guess: {guess}\")\n",
    "\n",
    "        result = env.step(TextArenaAction(message=guess))\n",
    "        observation = result.observation\n",
    "\n",
    "        print(\"   Feedback messages:\")\n",
    "        for message in observation.messages:\n",
    "            print(f\"     [{message.category}] {message.content}\")\n",
    "\n",
    "    print(\"\\n✅ Game finished\")\n",
    "    print(f\"   Reward: {result.reward}\")\n",
    "    print(f\"   Done: {result.done}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "MjIxHOHK4PVe",
   "metadata": {
    "id": "MjIxHOHK4PVe"
   },
   "source": [
    "Let's play the game!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "JjOzWexUXmfW",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "📜 Initial Prompt:\n",
      "You are Player 0 in Wordle.\n",
      "A secret 5-letter word has been chosen. You have 6 attempts to guess it.\n",
      "For each guess, wrap your word in square brackets (e.g., [apple]).\n",
      "Feedback for each letter will be given as follows:\n",
      "  - G (green): correct letter in the correct position\n",
      "  - Y (yellow): letter exists in the word but in the wrong position\n",
      "  - X (wrong): letter is not in the word\n",
      "Enter your guess to begin.\n",
      "\n",
      "🎯 Turn 0: model replied with -> [crane]\n",
      "   Parsed guess: [crane]\n",
      "   Feedback messages:\n",
      "     [MESSAGE] [crane]\n",
      "     [MESSAGE] Player 0 submitted [crane].\n",
      "Feedback:\n",
      "C R A N E\n",
      "X Y X X X\n",
      "\n",
      "You have 5 guesses left.\n",
      "\n",
      "🎯 Turn 1: model replied with -> [spare]\n",
      "   Parsed guess: [spare]\n",
      "   Feedback messages:\n",
      "     [MESSAGE] [spare]\n",
      "     [MESSAGE] Player 0 submitted [spare].\n",
      "Feedback:\n",
      "C R A N E\n",
      "X Y X X X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "You have 4 guesses left.\n",
      "\n",
      "🎯 Turn 2: model replied with -> [spare]\n",
      "   Parsed guess: [spare]\n",
      "   Feedback messages:\n",
      "     [MESSAGE] [spare]\n",
      "     [MESSAGE] Player 0 submitted [spare].\n",
      "Feedback:\n",
      "C R A N E\n",
      "X Y X X X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "You have 3 guesses left.\n",
      "\n",
      "🎯 Turn 3: model replied with -> [spare]\n",
      "   Parsed guess: [spare]\n",
      "   Feedback messages:\n",
      "     [MESSAGE] [spare]\n",
      "     [MESSAGE] Player 0 submitted [spare].\n",
      "Feedback:\n",
      "C R A N E\n",
      "X Y X X X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "You have 2 guesses left.\n",
      "\n",
      "🎯 Turn 4: model replied with -> [spare]\n",
      "   Parsed guess: [spare]\n",
      "   Feedback messages:\n",
      "     [MESSAGE] [spare]\n",
      "     [MESSAGE] Player 0 submitted [spare].\n",
      "Feedback:\n",
      "C R A N E\n",
      "X Y X X X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "You have 1 guesses left.\n",
      "\n",
      "🎯 Turn 5: model replied with -> [spare]\n",
      "   Parsed guess: [spare]\n",
      "   Feedback messages:\n",
      "     [MESSAGE] [spare]\n",
      "     [MESSAGE] Player 0 submitted [spare].\n",
      "Feedback:\n",
      "C R A N E\n",
      "X Y X X X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "S P A R E\n",
      "G X X G X\n",
      "\n",
      "You have 0 guesses left.\n",
      "     [MESSAGE] The game ended in a draw. Reason: Turn limit reached.\n",
      "\n",
      "✅ Game finished\n",
      "   Reward: 0.0\n",
      "   Done: True\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    play_wordle(env, fine_tuned_model, tokenizer)\n",
    "finally:\n",
    "    env.close()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
