{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "fU6KXDe-mNjD",
      "metadata": {
        "id": "fU6KXDe-mNjD"
      },
      "source": [
        "# Qwen: Direct Projection Training with Paraphrase Coherence (TriviaQA)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cva0_WZcmNjF",
      "metadata": {
        "id": "cva0_WZcmNjF"
      },
      "source": [
        "This Colab trains a **Qwen causal LM** with a **direct projection** objective:\n",
        "\n",
        "- **KL to frozen baseline** keeps the model near \\(\\omega_0\\).\n",
        "- **Pairwise JS** improves **coherence across paraphrases** of the same question.\n",
        "- Uses **REINFORCE** estimators with baselines, multi-sample averaging, and log-sum-exp numerics."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Q3RRjK20mNjF",
      "metadata": {
        "id": "Q3RRjK20mNjF"
      },
      "source": [
        "## 0) Runtime & GPU check"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "AIguI7ArakvF",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AIguI7ArakvF",
        "outputId": "c88e78e8-35a1-428c-dadb-56ed3cd22c71"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]\n",
            "OS: Linux-6.6.105+-x86_64-with-glibc2.35\n",
            "Arch: x86_64\n",
            "Torch: 2.9.0+cu126\n",
            "CUDA (toolkit from Torch): 12.6\n",
            "CUDA available: True\n",
            "GPU: NVIDIA A100-SXM4-40GB\n",
            "SM capability: (8, 0)\n",
            "VRAM (GB): 39.56\n"
          ]
        }
      ],
      "source": [
        "import platform, torch, subprocess, sys\n",
        "import os, torch\n",
        "\n",
        "print(\"Python:\", sys.version)\n",
        "print(\"OS:\", platform.platform())\n",
        "print(\"Arch:\", platform.machine())\n",
        "print(\"Torch:\", torch.__version__)\n",
        "print(\"CUDA (toolkit from Torch):\", torch.version.cuda)\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "    print(\"SM capability:\", torch.cuda.get_device_capability(0))\n",
        "    print(\"VRAM (GB):\", round(torch.cuda.get_device_properties(0).total_memory/2**30,2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "zEE6ahLV5LZX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zEE6ahLV5LZX",
        "outputId": "fef2daa8-1c9c-402d-9ad9-faf2418e651c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading: https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.15/flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n",
            "Installing: flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n",
            "Installed flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import subprocess\n",
        "\n",
        "# 1) Force flash-attn to use a prebuilt wheel or fail fast\n",
        "os.environ[\"FLASH_ATTENTION_SKIP_CUDA_BUILD\"] = \"TRUE\"\n",
        "\n",
        "# 2) Download the wheel for:\n",
        "#    - flash_attn 2.8.3\n",
        "#    - CUDA 12.6 (cu126)\n",
        "#    - torch 2.9\n",
        "#    - Python 3.12 (cp312)\n",
        "wheel_url = \"https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.15/flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\"\n",
        "wheel_name = wheel_url.split(\"/\")[-1]\n",
        "\n",
        "print(\"Downloading:\", wheel_url)\n",
        "subprocess.run([\"wget\", \"-q\", wheel_url, \"-O\", wheel_name], check=True)\n",
        "\n",
        "print(\"Installing:\", wheel_name)\n",
        "subprocess.run(\n",
        "    [\"pip\", \"install\", \"--no-dependencies\", \"--upgrade\", wheel_name],\n",
        "    check=True,\n",
        ")\n",
        "\n",
        "print(\"Installed\", wheel_name)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "rgVu_ySz5ReL",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rgVu_ySz5ReL",
        "outputId": "d05bd078-ba18-4cf9-c430-032874036a8f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "flash_attn version: 2.8.3\n"
          ]
        }
      ],
      "source": [
        "import flash_attn\n",
        "\n",
        "print(\"flash_attn version:\", getattr(flash_attn, \"__version__\", \"unknown\"))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "qQ3UlHbC5dLW",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qQ3UlHbC5dLW",
        "outputId": "32ff88a5-2fe6-4f79-ac75-fcb9bcb02c5e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Torch: 2.9.0+cu126\n",
            "CUDA available: True\n",
            "Device: NVIDIA A100-SXM4-40GB\n",
            "Compute capability: (8, 0)\n",
            "FlashAttention forward pass OK.\n",
            "Output shape: torch.Size([2, 128, 8, 64])\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from flash_attn import flash_attn_func\n",
        "\n",
        "device = \"cuda\"\n",
        "\n",
        "print(\"Torch:\", torch.__version__)\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "print(\"Device:\", torch.cuda.get_device_name(0))\n",
        "print(\"Compute capability:\", torch.cuda.get_device_capability(0))\n",
        "\n",
        "# Small test config\n",
        "B = 2      # batch size\n",
        "L = 128    # sequence length\n",
        "H = 8      # number of heads\n",
        "D = 64     # head dimension (must be <= 256 for FA2)\n",
        "\n",
        "# Q, K, V shapes: (batch, seqlen, nheads, headdim)\n",
        "q = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "k = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "v = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "\n",
        "try:\n",
        "    out = flash_attn_func(\n",
        "        q, k, v,\n",
        "        dropout_p=0.0,\n",
        "        softmax_scale=None,\n",
        "        causal=False,\n",
        "    )\n",
        "    print(\"FlashAttention forward pass OK.\")\n",
        "    print(\"Output shape:\", out.shape)\n",
        "except Exception as e:\n",
        "    print(\"FlashAttention FAILED on GPU:\")\n",
        "    print(type(e).__name__, \":\", e)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "uUeRWvsdQv1C",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uUeRWvsdQv1C",
        "outputId": "e00e90e6-01de-4434-a22b-a085f219168e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "u7T3ljRumNjG",
      "metadata": {
        "id": "u7T3ljRumNjG"
      },
      "source": [
        "## 1) Install packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d2snW-PemNjG",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "d2snW-PemNjG",
        "outputId": "1df5725d-10ad-46a6-f841-82a48eb10f2e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.8 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m77.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.1)\n"
          ]
        }
      ],
      "source": [
        "!pip -q install --upgrade pip\n",
        "!pip -q install -U \"transformers>=4.43.0\" \"accelerate>=0.33.0\" \"datasets>=2.19.0\" \"bitsandbytes>=0.43.0\" \"peft>=0.11.0\" sentence-transformers einops\n",
        "# Optional: flash-attn (if wheel exists for your GPU)\n",
        "# !pip -q install flash-attn --no-build-isolation\n",
        "!pip install tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bBwLq3tAZBfd",
      "metadata": {
        "id": "bBwLq3tAZBfd"
      },
      "outputs": [],
      "source": [
        "!pip -q install -U \"trl>=0.8.0\""
      ]
    },
    {
      "cell_type": "markdown",
      "id": "IviYIdL0mNjG",
      "metadata": {
        "id": "IviYIdL0mNjG"
      },
      "source": [
        "## 2) Imports & global config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QGO8iNsFmNjG",
      "metadata": {
        "id": "QGO8iNsFmNjG"
      },
      "outputs": [],
      "source": [
        "from dataclasses import dataclass\n",
        "from typing import List, Tuple\n",
        "import torch, math, random, numpy as np\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
        "from datasets import load_dataset\n",
        "from sentence_transformers import SentenceTransformer, util as sbert_util\n",
        "\n",
        "SEED = 42\n",
        "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "# Choose model (smaller first for Colab T4/L4)\n",
        "TRAINABLE_MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"#\"meta-llama/Llama-3.2-3B-Instruct\"#\"microsoft/Phi-3-mini-4k-instruct\"#\n",
        "BASELINE_MODEL_NAME  = TRAINABLE_MODEL_NAME\n",
        "\n",
        "USE_4BIT = True\n",
        "BNB_CONFIG = BitsAndBytesConfig(\n",
        "    load_in_4bit=USE_4BIT,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
        ")\n",
        "\n",
        "BASELINE_DEVICE_MAP = \"auto\"#\"cpu\"   # \"cpu\" to save VRAM; set to \"auto\" if you have >=24GB\n",
        "\n",
        "MAX_NEW_TOKENS = 512\n",
        "GEN_TEMPERATURE = 0.7\n",
        "TOP_P = 0.9\n",
        "TOP_K = 50\n",
        "\n",
        "PARAPHRASES_PER_QUESTION = 2     # total per class incl. original\n",
        "NUM_CLASSES_FOR_TRAIN = 10000\n",
        "SAMPLES_PER_PAIR = 16\n",
        "JS_WEIGHT = 1.0\n",
        "KL_WEIGHT = 0.1\n",
        "\n",
        "LR = 2e-5\n",
        "WEIGHT_DECAY = 0.01\n",
        "GRAD_CLIP = 1.0\n",
        "STEPS = 50\n",
        "BATCH_SIZE_CLASSES = 32\n",
        "LOG_EVERY = 20\n",
        "\n",
        "#\n",
        "WORK_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-mode-jan16/weight-1e-2\"\n",
        "CKPT_LOAD_PATH = False#f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-majority-dec29/weight-1e-2/checkpoint-1600\"#f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-majority-dec29/weight-1/checkpoint-1600\"\n",
        "COHERENCE_DATA_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct/\"\n",
        "os.makedirs(WORK_PATH, exist_ok=True)\n",
        "CACHE_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZMmHw__UNCjX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZMmHw__UNCjX",
        "outputId": "35c1a0f7-721e-4777-9c0e-766188135dec"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "ls: cannot access 'False': No such file or directory\n"
          ]
        }
      ],
      "source": [
        "!ls $CKPT_LOAD_PATH"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c7DuQugEeSaD",
      "metadata": {
        "id": "c7DuQugEeSaD"
      },
      "source": [
        "## Utils"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Ul00h3tomNjH",
      "metadata": {
        "id": "Ul00h3tomNjH"
      },
      "source": [
        "## 3) Load tokenizer, trainable QLoRA model, and frozen baseline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "T5xYHUDKds80",
      "metadata": {
        "id": "T5xYHUDKds80"
      },
      "outputs": [],
      "source": [
        "# @title\n",
        "# === utils_loaders.py ===\n",
        "import os, gc, torch, platform\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
        "\n",
        "# ---------- Helpers ----------\n",
        "\n",
        "def gpu_supports_bf16() -> bool:\n",
        "    return torch.cuda.is_available() and torch.cuda.is_bf16_supported()\n",
        "\n",
        "def pick_dtype():\n",
        "    return torch.bfloat16 if gpu_supports_bf16() else (torch.float16 if torch.cuda.is_available() else torch.float32)\n",
        "\n",
        "def pick_attn_impl() -> str:\n",
        "    \"\"\"\n",
        "    FlashAttention-2 only on Ampere+ (SM >= 80). Otherwise use SDPA.\n",
        "    \"\"\"\n",
        "    if not torch.cuda.is_available():\n",
        "        return \"sdpa\"\n",
        "    sm_major, _ = torch.cuda.get_device_capability(0)\n",
        "    print (\"using flash_attention_2\") if sm_major >= 8 else print (\"using sdpa\")\n",
        "    return \"flash_attention_2\" if sm_major >= 8 else \"sdpa\"\n",
        "\n",
        "def hard_cleanup(*names_to_del, also_empty_cache=True):\n",
        "    \"\"\"\n",
        "    Delete large globals by name (if present), run gc, and empty CUDA cache.\n",
        "    \"\"\"\n",
        "    g = globals()\n",
        "    for n in names_to_del:\n",
        "        if n in g:\n",
        "            try: del g[n]\n",
        "            except: pass\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available() and also_empty_cache:\n",
        "        torch.cuda.empty_cache()\n",
        "        torch.cuda.synchronize()\n",
        "\n",
        "def make_bnb_config(load_in_4bit=True):\n",
        "    # Lazy import to avoid dependency unless used\n",
        "    from transformers import BitsAndBytesConfig\n",
        "    return BitsAndBytesConfig(\n",
        "        load_in_4bit=load_in_4bit,\n",
        "        load_in_8bit=not load_in_4bit and True or False,   # if you ever want 8-bit, flip flags accordingly\n",
        "        bnb_4bit_compute_dtype=pick_dtype(),\n",
        "        bnb_4bit_use_double_quant=True,\n",
        "        bnb_4bit_quant_type=\"nf4\",\n",
        "    )\n",
        "\n",
        "def setup_tokenizer(model_id: str, use_fast=True):\n",
        "    tok = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, trust_remote_code=True)\n",
        "    # padding & truncation for decoder-only batched gen\n",
        "    if tok.pad_token is None:\n",
        "        # Prefer a dedicated pad token if it exists, else fall back to eos\n",
        "        tok.pad_token = tok.eos_token\n",
        "    tok.padding_side = \"left\"\n",
        "    tok.truncation_side = \"left\"\n",
        "    return tok\n",
        "\n",
        "# ---------- Trainable (QLoRA) ----------\n",
        "\n",
        "def load_trainable_model(\n",
        "    model_id: str,\n",
        "    use_4bit: bool = True,\n",
        "    lora_r: int = 8,\n",
        "    lora_alpha: int = 16,\n",
        "    lora_dropout: float = 0.05,\n",
        "    extra_target_modules: list[str] | None = None,\n",
        "    device_map: str | dict = \"auto\",\n",
        "):\n",
        "    \"\"\"\n",
        "    Load a k-bit QLoRA-ready model + tokenizer and wrap with PEFT LoRA.\n",
        "    Returns: (tokenizer, peft_model)\n",
        "\n",
        "    Special case:\n",
        "      - For microsoft/Phi-3* models, we force trust_remote_code=False\n",
        "        to avoid the DynamicCache.seen_tokens AttributeError.\n",
        "    \"\"\"\n",
        "\n",
        "    # --- tokenizer ---\n",
        "    tok = setup_tokenizer(model_id)\n",
        "\n",
        "    # --- quantization config ---\n",
        "    quant_cfg = make_bnb_config(load_in_4bit=use_4bit) if use_4bit else None\n",
        "\n",
        "    # --- pick trust_remote_code depending on model ---\n",
        "    # Phi-3 has outdated remote modeling code that breaks with recent transformers.\n",
        "    is_phi3 = \"phi-3\" in model_id.lower()\n",
        "    trust_remote = False if is_phi3 else True\n",
        "\n",
        "    # --- load base model ---\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        model_id,\n",
        "        quantization_config=quant_cfg,\n",
        "        torch_dtype=pick_dtype(),\n",
        "        device_map=device_map,\n",
        "        low_cpu_mem_usage=True,\n",
        "        trust_remote_code=trust_remote,\n",
        "        attn_implementation=pick_attn_impl(),  # HF will choose SDPA if needed\n",
        "    )\n",
        "\n",
        "    # Prepare for k-bit training\n",
        "    model = prepare_model_for_kbit_training(model)\n",
        "\n",
        "    # Auto-detect common projection module names if not provided\n",
        "    targets = set(extra_target_modules or [])\n",
        "    for name, _ in model.named_modules():\n",
        "        low = name.lower().split(\".\")[-1]\n",
        "        if low in {\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"}:\n",
        "            targets.add(low)\n",
        "    if not targets:\n",
        "        targets = {\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\"}\n",
        "\n",
        "    lora_cfg = LoraConfig(\n",
        "        r=lora_r,\n",
        "        lora_alpha=lora_alpha,\n",
        "        target_modules=sorted(targets),\n",
        "        lora_dropout=lora_dropout,\n",
        "        bias=\"none\",\n",
        "        task_type=\"CAUSAL_LM\",\n",
        "    )\n",
        "    model = get_peft_model(model, lora_cfg)\n",
        "\n",
        "    # Required for gradient checkpointing with PEFT\n",
        "    model.config.use_cache = False\n",
        "\n",
        "    model.enable_input_require_grads()\n",
        "    model.gradient_checkpointing_enable()\n",
        "\n",
        "    model.print_trainable_parameters()\n",
        "\n",
        "    return tok, model\n",
        "\n",
        "\n",
        "# ---------- Frozen Baseline (inference) ----------\n",
        "\n",
        "def load_baseline_model(\n",
        "    model_id: str,\n",
        "    force_cpu_embedder: bool = True,\n",
        "    use_4bit: bool = False,\n",
        "    device_map: str | dict = \"auto\",\n",
        "    max_memory: dict | None = None,\n",
        "    allow_offload_folder: str | None = None,\n",
        "):\n",
        "    \"\"\"\n",
        "    Load a frozen inference model with the fastest attention available for your GPU.\n",
        "    Returns: (tokenizer, baseline_model)\n",
        "    \"\"\"\n",
        "    tok = setup_tokenizer(model_id)\n",
        "\n",
        "    kwargs = dict(\n",
        "        device_map=device_map,\n",
        "        low_cpu_mem_usage=True,\n",
        "        trust_remote_code=True,\n",
        "        attn_implementation=pick_attn_impl(),\n",
        "    )\n",
        "\n",
        "    if use_4bit:\n",
        "        kwargs[\"quantization_config\"] = make_bnb_config(load_in_4bit=True)\n",
        "    else:\n",
        "        kwargs[\"torch_dtype\"] = pick_dtype()\n",
        "\n",
        "    if max_memory is not None:\n",
        "        kwargs[\"max_memory\"] = max_memory\n",
        "    if allow_offload_folder is not None:\n",
        "        os.makedirs(allow_offload_folder, exist_ok=True)\n",
        "        kwargs[\"offload_folder\"] = allow_offload_folder\n",
        "        kwargs[\"offload_state_dict\"] = True\n",
        "\n",
        "    model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval()\n",
        "\n",
        "    # Freeze to be explicit\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad_(False)\n",
        "\n",
        "    return tok, model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "FswzqYDPdwoH",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "FswzqYDPdwoH",
        "outputId": "d20e2bfe-8cfb-4137-874f-a6bcff1b8b74"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1c03c32372444f3f8fda7b2c2255b0ce",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7d0d16b05cb14031aba40fcb8cc755ee",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.json: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "de5ed31066ea41798f1b59823774a4e1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "merges.txt: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "23f936f606db40ddb040062f9daf26c1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "using flash_attention_2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b19c5fa3910e4480a0753115369ba927",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "`torch_dtype` is deprecated! Use `dtype` instead!\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3540aa7a9aac48ad818fea289ee0444c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d916a6f4a546415088cca4779619043c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "970c113c8cb74542885a2973c5b395c9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e022d15c5d5645508ec0089274defb76",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "60adbcf729aa4d44bd0f54472025984b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b3b2fa28dbfb4699a37958d1c09ef7fd",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "trainable params: 14,966,784 || all params: 3,100,905,472 || trainable%: 0.4827\n"
          ]
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\n# 0) Optional: reduce fragmentation (set *before* importing torch in a fresh runtime)\\nimport os\\nos.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True,max_split_size_mb:256\")\\n\\n# 1) Clean any old models from VRAM\\n\\n# hard_cleanup(\"baseline_model\", \"trainable_model\", \"embedder\")\\n\\n# 3) Load FROZEN baseline model for inference/paraphrasing\\nBASELINE_MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\\n# If VRAM is tight, leave some headroom and allow offload\\nmax_mem = {0: \"13GiB\", \"cpu\": \"48GiB\"} if torch.cuda.is_available() else None\\ntokenizer, baseline_model = load_baseline_model(\\n    BASELINE_MODEL_NAME,\\n    use_4bit=False,               # can set True to save VRAM\\n    device_map=BASELINE_DEVICE_MAP,\\n    max_memory=max_mem,           # optional\\n    allow_offload_folder=\"./offload\"  # optional\\n)\\n\\n# 4) Embedder (keep on CPU unless you really need GPU)\\nfrom sentence_transformers import SentenceTransformer\\nembedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\\n# embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\\n'"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# 2) Load fresh TRAINABLE (QLoRA) model\n",
        "\n",
        "\n",
        "\n",
        "# --- Assumes your helper builds a 4-bit base + attaches LoRA with the same hyperparams ---\n",
        "# (use exactly the same args you trained with)\n",
        "tokenizer, trainable_model = load_trainable_model(\n",
        "    TRAINABLE_MODEL_NAME,\n",
        "    use_4bit=True,           # QLoRA default\n",
        "    lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "\n",
        "\n",
        "'''\n",
        "import copy\n",
        "baseline_model = copy.deepcopy(trainable_model).eval()\n",
        "for p in baseline_model.parameters():\n",
        "    p.requires_grad_(False)\n",
        "'''\n",
        "# from sentence_transformers import SentenceTransformer\n",
        "# embedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\n",
        "# # embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\n",
        "\n",
        "\n",
        "\n",
        "'''\n",
        "# 0) Optional: reduce fragmentation (set *before* importing torch in a fresh runtime)\n",
        "import os\n",
        "os.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True,max_split_size_mb:256\")\n",
        "\n",
        "# 1) Clean any old models from VRAM\n",
        "\n",
        "# hard_cleanup(\"baseline_model\", \"trainable_model\", \"embedder\")\n",
        "\n",
        "# 3) Load FROZEN baseline model for inference/paraphrasing\n",
        "BASELINE_MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
        "# If VRAM is tight, leave some headroom and allow offload\n",
        "max_mem = {0: \"13GiB\", \"cpu\": \"48GiB\"} if torch.cuda.is_available() else None\n",
        "tokenizer, baseline_model = load_baseline_model(\n",
        "    BASELINE_MODEL_NAME,\n",
        "    use_4bit=False,               # can set True to save VRAM\n",
        "    device_map=BASELINE_DEVICE_MAP,\n",
        "    max_memory=max_mem,           # optional\n",
        "    allow_offload_folder=\"./offload\"  # optional\n",
        ")\n",
        "\n",
        "# 4) Embedder (keep on CPU unless you really need GPU)\n",
        "from sentence_transformers import SentenceTransformer\n",
        "embedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\n",
        "# embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "tKvuvb_gBq3k",
      "metadata": {
        "id": "tKvuvb_gBq3k"
      },
      "outputs": [],
      "source": [
        "\n",
        "if CKPT_LOAD_PATH:\n",
        "    # Now load the trained adapter weights into this PEFT model\n",
        "    ADAPTER_DIR = CKPT_LOAD_PATH  # folder containing adapter_config.json\n",
        "    print(\"loading from: \", ADAPTER_DIR)\n",
        "    try:\n",
        "        # Preferred: load adapter into the existing PeftModel\n",
        "        trainable_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\n",
        "        trainable_model.set_adapter(\"default\")\n",
        "    except AttributeError:\n",
        "        # Fallback for older PEFT versions: load the adapter state dict directly\n",
        "        from peft.utils import set_peft_model_state_dict\n",
        "        import torch, os\n",
        "        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\n",
        "                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\n",
        "        set_peft_model_state_dict(trainable_model, state)\n",
        "        # make sure only LoRA params are trainable if you plan to keep training\n",
        "        for n, p in trainable_model.named_parameters():\n",
        "            p.requires_grad = (\"lora_\" in n)\n",
        "\n",
        "    trainable_model.train()  # if you’re resuming training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "glPKBu0Uo3sX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "glPKBu0Uo3sX",
        "outputId": "86594f2f-9bd4-4791-d27c-335c8338511f"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\n_, baseline_model = load_trainable_model(\\n    TRAINABLE_MODEL_NAME,\\n    use_4bit=True,           # QLoRA default\\n    lora_r=8, lora_alpha=16, lora_dropout=0.05,\\n    device_map=\"auto\",\\n)\\n\\nfor p in baseline_model.parameters():\\n     p.requires_grad_(False)\\n'"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "'''\n",
        "_, baseline_model = load_trainable_model(\n",
        "    TRAINABLE_MODEL_NAME,\n",
        "    use_4bit=True,           # QLoRA default\n",
        "    lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "\n",
        "for p in baseline_model.parameters():\n",
        "     p.requires_grad_(False)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B6OVE7wD_awp",
      "metadata": {
        "id": "B6OVE7wD_awp"
      },
      "outputs": [],
      "source": [
        "# BASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\n",
        "\n",
        "\n",
        "# _, baseline_model = load_trainable_model(\n",
        "#     TRAINABLE_MODEL_NAME,\n",
        "#     use_4bit=True,           # QLoRA default\n",
        "#     lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "#     device_map=\"auto\",\n",
        "# )\n",
        "\n",
        "\n",
        "# baseline_model = AutoModelForCausalLM.from_pretrained(\n",
        "#         BASELINE_PATH, device_map=\"auto\", quantization_config=bnb_config\n",
        "#     )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "UNCMXFLbCJ6F",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UNCMXFLbCJ6F",
        "outputId": "156204d9-1091-4dc0-92f7-e97cfa19298b"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\nBASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\\nif BASELINE_PATH:\\n    # Now load the baseline adapter weights into this PEFT model\\n    ADAPTER_DIR = BASELINE_PATH # folder containing adapter_config.json\\n    try:\\n        # Preferred: load adapter into the existing PeftModel\\n        baseline_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\\n        baseline_model.set_adapter(\"default\")\\n    except AttributeError:\\n        # Fallback for older PEFT versions: load the adapter state dict directly\\n        from peft.utils import set_peft_model_state_dict\\n        import torch, os\\n        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\\n                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\\n        set_peft_model_state_dict(baseline_model, state)\\n        # make sure only LoRA params are trainable if you plan to keep training\\n        for n, p in baseline_model.named_parameters():\\n            p.requires_grad = (\"lora_\" in n)\\n\\n    baseline_model.eval()\\n'"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "'''\n",
        "BASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\n",
        "if BASELINE_PATH:\n",
        "    # Now load the baseline adapter weights into this PEFT model\n",
        "    ADAPTER_DIR = BASELINE_PATH # folder containing adapter_config.json\n",
        "    try:\n",
        "        # Preferred: load adapter into the existing PeftModel\n",
        "        baseline_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\n",
        "        baseline_model.set_adapter(\"default\")\n",
        "    except AttributeError:\n",
        "        # Fallback for older PEFT versions: load the adapter state dict directly\n",
        "        from peft.utils import set_peft_model_state_dict\n",
        "        import torch, os\n",
        "        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\n",
        "                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\n",
        "        set_peft_model_state_dict(baseline_model, state)\n",
        "        # make sure only LoRA params are trainable if you plan to keep training\n",
        "        for n, p in baseline_model.named_parameters():\n",
        "            p.requires_grad = (\"lora_\" in n)\n",
        "\n",
        "    baseline_model.eval()\n",
        "'''"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ChcYKgQMmNjI",
      "metadata": {
        "id": "ChcYKgQMmNjI"
      },
      "source": [
        "## 4) Load GSM8K and build paraphrase classes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "CKweB9VZLnN2",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CKweB9VZLnN2",
        "outputId": "c2a4ab60-e6b9-4e88-be2c-55f990604d4a"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c8458003eeaa47638f32403b24abc0af",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3fdd9139aa784ecb9d968909d28a0356",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7edefb25b8464012a30b18c5ffd1b095",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "58a22218c5674d2ab3a5366a8f171ac4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0761e14585c8468290afa74840dcd127",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "ds = load_dataset(\"gsm8k\", \"main\")\n",
        "train_ds = ds[\"train\"]\n",
        "# split = ds[\"train\"].train_test_split(test_size=0.05, seed=42)\n",
        "# train_ds = split[\"train\"]\n",
        "# val_ds   = split[\"test\"]   # your “validation”\n",
        "test_ds  = ds[\"test\"]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "yWg7PjfnRTx7",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yWg7PjfnRTx7",
        "outputId": "6d867e3d-af33-4366-9771-a325d55122e5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "DatasetDict({\n",
            "    train: Dataset({\n",
            "        features: ['question', 'answer'],\n",
            "        num_rows: 7473\n",
            "    })\n",
            "    test: Dataset({\n",
            "        features: ['question', 'answer'],\n",
            "        num_rows: 1319\n",
            "    })\n",
            "})\n"
          ]
        }
      ],
      "source": [
        "print (ds)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZyDraTzXfx9S",
      "metadata": {
        "id": "ZyDraTzXfx9S"
      },
      "outputs": [],
      "source": [
        "import unicodedata\n",
        "\n",
        "def to_ascii_numbers(text: str) -> str:\n",
        "    \"\"\"\n",
        "    Convert any Unicode numeric characters in `text` to ASCII digits.\n",
        "    - Handles full-width digits (e.g., １９３０ → 1930)\n",
        "    - Handles other numeric forms with integer values (e.g., ①, Ⅻ → 1, 12)\n",
        "    - Leaves everything else unchanged (no extra preprocessing)\n",
        "    \"\"\"\n",
        "    out_chars = []\n",
        "    for ch in text:\n",
        "        # 1) Decimal digits in any script (category 'Nd')\n",
        "        try:\n",
        "            d = unicodedata.digit(ch)    # 0..9\n",
        "            out_chars.append(str(d))\n",
        "            continue\n",
        "        except (TypeError, ValueError):\n",
        "            pass\n",
        "\n",
        "        # 2) Other Unicode numbers with an integer numeric value (e.g., circled digits, Roman numerals)\n",
        "        try:\n",
        "            n = unicodedata.numeric(ch)\n",
        "            if float(n).is_integer():\n",
        "                out_chars.append(str(int(n)))\n",
        "                continue\n",
        "        except (TypeError, ValueError):\n",
        "            pass\n",
        "\n",
        "        # 3) Otherwise, keep original character\n",
        "        out_chars.append(ch)\n",
        "    s = \"\".join(out_chars)\n",
        "    return s.removeprefix(\"Paraphrase: \")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "321VfgHiTxzV",
      "metadata": {
        "id": "321VfgHiTxzV"
      },
      "outputs": [],
      "source": [
        "# Speed knobs you can tune safely on a T4:\n",
        "PARA_PER_QUESTION = max(0, PARAPHRASES_PER_QUESTION - 1)  # e.g., 2\n",
        "BATCH_QUESTIONS = 64   # try 32–128 on T4; lower if you OOM\n",
        "MAX_NEW_TOKENS_PARAPHRASE = 128  # shorter is faster; you can try 16–32\n",
        "USE_GREEDY = True      # greedy is faster & cleaner for paraphrasing\n",
        "\n",
        "from tqdm.auto import tqdm\n",
        "import math, torch, numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZBryMmXepKoN",
      "metadata": {
        "id": "ZBryMmXepKoN"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import torch\n",
        "from transformers import StoppingCriteria, StoppingCriteriaList\n",
        "\n",
        "# no-op fallback if you don't have this util\n",
        "try:\n",
        "    to_ascii_numbers\n",
        "except NameError:\n",
        "    def to_ascii_numbers(s: str) -> str:\n",
        "        return s\n",
        "\n",
        "class StopOnTokenSeq(StoppingCriteria):\n",
        "    \"\"\"Stop when any stop string (tokenized) appears at the end of a row.\"\"\"\n",
        "    def __init__(self, tok, stop_strings):\n",
        "        self.tok = tok\n",
        "        self.stop_ids = [tok.encode(s, add_special_tokens=False) for s in stop_strings]\n",
        "        self.maxlen = max((len(x) for x in self.stop_ids), default=1)\n",
        "\n",
        "    def __call__(self, input_ids, scores, **kwargs):\n",
        "        for row in input_ids.tolist():\n",
        "            tail = row[-min(len(row), 6 * self.maxlen):]\n",
        "            for seq in self.stop_ids:\n",
        "                if len(tail) >= len(seq) and tail[-len(seq):] == seq:\n",
        "                    return True\n",
        "        return False\n",
        "\n",
        "\n",
        "def paraphrase_batch_chat(qs, k=1):\n",
        "    \"\"\"\n",
        "    Qwen2/2.5 chat-friendly batch paraphraser.\n",
        "    - Uses LEFT padding (required for decoder-only batched generation).\n",
        "    - Stops when the model starts the next user header.\n",
        "    - Decodes with specials skipped to avoid <|endoftext|> noise.\n",
        "    Returns: List[List[str]] of shape [len(qs)][k]\n",
        "    \"\"\"\n",
        "    assert k >= 1\n",
        "    sys = (\n",
        "      \"You paraphrase GSM8K math word problems.\\n\"\n",
        "      \"Rephrase without changing meaning.\\n\"\n",
        "      \"Preserve all names, quantities, units, and numbers exactly.\\n\"\n",
        "      \"Do NOT solve the problem.\\n\"\n",
        "      \"Output ONLY the paraphrased question, one line, no quotes, no prefix.\"\n",
        "    )\n",
        "\n",
        "    # --- Build prompts (once per question) ---\n",
        "    rendered = []\n",
        "    for q in qs:\n",
        "        msgs = [\n",
        "            {\"role\": \"system\", \"content\": sys},\n",
        "            {\"role\": \"user\", \"content\": f\"Paraphrase the following question.\\n{q}\"},\n",
        "        ]\n",
        "        rendered.append(tokenizer.apply_chat_template(\n",
        "            msgs, tokenize=False, add_generation_prompt=True\n",
        "        ))\n",
        "\n",
        "    # --- Encode with LEFT padding for decoder-only models ---\n",
        "    old_side = tokenizer.padding_side\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    try:\n",
        "        enc = tokenizer(rendered, return_tensors=\"pt\", padding=True, truncation=True).to(baseline_model.device)\n",
        "    finally:\n",
        "        tokenizer.padding_side = old_side\n",
        "\n",
        "    B = len(qs)\n",
        "    cut_idx = enc[\"input_ids\"].size(1)          # fixed padded width (same for all rows)\n",
        "    # For regrouping after num_return_sequences\n",
        "    # We don't need per-row prompt lengths when left-padding.\n",
        "\n",
        "    # --- Generation config (Qwen) ---\n",
        "    EOT_ID = tokenizer.eos_token_id\n",
        "    assert isinstance(EOT_ID, int) and EOT_ID >= 0, \"Could not resolve <|im_end|> id\"\n",
        "\n",
        "    # Respect external bad list, but make sure it doesn't block EOS\n",
        "    try:\n",
        "        bad_words_ids = bad if (isinstance(bad, list) and bad) else None\n",
        "        if bad_words_ids:\n",
        "            assert all(EOT_ID not in seq for seq in bad_words_ids), \"bad_words_ids blocks <|im_end|>\"\n",
        "    except NameError:\n",
        "        bad_words_ids = None\n",
        "\n",
        "    # # Stop if the model begins the next user message (Qwen header)\n",
        "    # stopping = StoppingCriteriaList([\n",
        "    #     StopOnTokenSeq(tokenizer, [\"<|im_start|>user\", \"\\nuser\"])\n",
        "    # ])\n",
        "\n",
        "    with torch.no_grad():\n",
        "        out = baseline_model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=MAX_NEW_TOKENS_PARAPHRASE,\n",
        "            min_new_tokens=0,                    # don't force past EOS\n",
        "            do_sample=False,\n",
        "            temperature=0.3,\n",
        "            top_p=0.3,\n",
        "            num_beams=1,\n",
        "            num_return_sequences=k,\n",
        "            bad_words_ids=bad_words_ids,\n",
        "            eos_token_id=EOT_ID,                 # Qwen chat EOT\n",
        "            pad_token_id=(tokenizer.pad_token_id or tokenizer.eos_token_id or EOT_ID),\n",
        "            # stopping_criteria=stopping,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "\n",
        "    # --- Slice off continuations (after full padded input) ---\n",
        "    seqs = out.sequences                                # [B*k, cut_idx + new_len]\n",
        "    cont_tokens = [seqs[row, cut_idx:] for row in range(B * k)]\n",
        "\n",
        "    # --- Decode: skip specials to hide <|endoftext|>/<|im_end|> ---\n",
        "    outs = tokenizer.batch_decode(cont_tokens, skip_special_tokens=True)\n",
        "\n",
        "    # --- Cleanup: drop a naked role header if it slipped through ---\n",
        "    ROLE_ONLY = re.compile(r'^(?:assistant|user)\\s*:?\\s*$', re.I)\n",
        "    def clean(s: str) -> str:\n",
        "        s = s.split(\"\\n\", 1)[0].strip()\n",
        "        if ROLE_ONLY.match(s):\n",
        "            s = \"\"\n",
        "        return to_ascii_numbers(s)\n",
        "\n",
        "    outs = [clean(o) for o in outs]                     # len = B*k\n",
        "\n",
        "    # Optional fallback: if empty, use original question text\n",
        "    per_q, ptr = [], 0\n",
        "    for q in qs:\n",
        "        group = outs[ptr:ptr + k]\n",
        "        group = [g if g else q for g in group]\n",
        "        per_q.append(group)\n",
        "        ptr += k\n",
        "\n",
        "    return per_q"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bPEoICWTndw_",
      "metadata": {
        "id": "bPEoICWTndw_"
      },
      "outputs": [],
      "source": [
        "# --- Answer generation for Qwen-style chat, batched & left-padded ---\n",
        "\n",
        "import re, math, torch\n",
        "from typing import List, Tuple\n",
        "\n",
        "# Compile once\n",
        "_ROLE_ONLY = re.compile(r'^(?:assistant|user)\\s*:?\\s*$', re.I)\n",
        "SYSTEM_PROMPT = (\n",
        "                  \"You solve grade-school math word problems.\\n\"\n",
        "                  \"Think step by step. Keep the reasoning brief.\\n\"\n",
        "                  \"Then repeat the final answer on a new line in exactly this format:\\n\"\n",
        "                  \"#### <number>\"\n",
        "                )\n",
        "\n",
        "def _clean_one_line(s: str) -> str:\n",
        "    s = s.split(\"\\n\", 1)[0].strip()\n",
        "    return \"\" if _ROLE_ONLY.match(s) else s\n",
        "\n",
        "def gsm8k_messages(question: str,\n",
        "                   *,\n",
        "                   system_prompt = SYSTEM_PROMPT):\n",
        "    return [\n",
        "        {\"role\": \"system\", \"content\": system_prompt},\n",
        "        {\n",
        "            \"role\": \"user\",\n",
        "            \"content\": f\"{question}\\n\\nSolve step by step, then give the final answer.\",\n",
        "        },\n",
        "    ]\n",
        "\n",
        "\n",
        "def format_prompt_instruct(q: str,\n",
        "                           *,\n",
        "                           system_prompt = SYSTEM_PROMPT,\n",
        "                           tokenizer=None) -> str:\n",
        "    \"\"\"\n",
        "    Render a GSM8K-style chat prompt string using gsm8k_messages().\n",
        "    \"\"\"\n",
        "    if tokenizer is None:\n",
        "        tokenizer = globals()[\"tokenizer\"]\n",
        "\n",
        "    messages = gsm8k_messages(q, system_prompt=system_prompt)\n",
        "\n",
        "    return tokenizer.apply_chat_template(\n",
        "        messages,\n",
        "        tokenize=False,\n",
        "        add_generation_prompt=True\n",
        "    )\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8jHm5ZuAHSaM",
      "metadata": {
        "id": "8jHm5ZuAHSaM"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "from typing import List, Tuple, Optional\n",
        "\n",
        "import torch\n",
        "\n",
        "\n",
        "def _extract_gsm8k_final(text: str) -> str:\n",
        "    \"\"\"\n",
        "    Find the last line that begins with ### (allow leading spaces),\n",
        "    and return the number immediately following ###.\n",
        "    If not found or no number after ###, return \"NA\".\n",
        "    \"\"\"\n",
        "    for line in reversed(text.splitlines()):\n",
        "        # number must immediately follow ### (allow spaces), then end or whitespace\n",
        "        m = re.match(r\"^\\s*####\\s*([-+]?\\d+(?:\\.\\d+)?)(?:\\s|$)\", line)\n",
        "        if m:\n",
        "            return m.group(1)\n",
        "        # if we see a ### line with no number, treat as missing and stop\n",
        "        if re.match(r\"^\\s*####(?:\\s|$)\", line):\n",
        "            return \"NA\"\n",
        "    return \"NA\"\n",
        "\n",
        "\n",
        "def answer_batch_chat(\n",
        "    qs: List[str],\n",
        "    *,\n",
        "    model=None,\n",
        "    tokenizer=None,\n",
        "    max_new_tokens: int = MAX_NEW_TOKENS,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        "    use_greedy: bool = False,\n",
        ") -> List[str]:\n",
        "    \"\"\"\n",
        "    Generate one answer per input question, in a single batched call.\n",
        "    Uses format_prompt_instruct() to render the chat prompt string.\n",
        "    \"\"\"\n",
        "    # Fallback to globals if not provided\n",
        "    if model is None:\n",
        "        model = baseline_model\n",
        "    if tokenizer is None:\n",
        "        tokenizer = globals()[\"tokenizer\"]\n",
        "\n",
        "    # --- Build rendered prompts using format_prompt_instruct ---\n",
        "    rendered = [\n",
        "        format_prompt_instruct(q, tokenizer=tokenizer)\n",
        "        for q in qs\n",
        "    ]\n",
        "\n",
        "    # --- Encode with LEFT padding ---\n",
        "    old_side = tokenizer.padding_side\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    try:\n",
        "        enc = tokenizer(\n",
        "            rendered,\n",
        "            return_tensors=\"pt\",\n",
        "            padding=True,\n",
        "            truncation=True,\n",
        "        ).to(next(model.parameters()).device)\n",
        "    finally:\n",
        "        tokenizer.padding_side = old_side\n",
        "\n",
        "    cut_idx = enc[\"input_ids\"].size(1)\n",
        "\n",
        "    # --- EOS / PAD handling ---\n",
        "    eos_id = tokenizer.eos_token_id\n",
        "    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id\n",
        "\n",
        "    # --- Generation ---\n",
        "    with torch.inference_mode():\n",
        "        out = model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            min_new_tokens=0,\n",
        "            do_sample=not use_greedy,\n",
        "            temperature=(temperature if not use_greedy else None),\n",
        "            top_p=(top_p if not use_greedy else None),\n",
        "            num_beams=1,\n",
        "            num_return_sequences=1,\n",
        "            eos_token_id=eos_id,\n",
        "            pad_token_id=pad_id,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "\n",
        "    # --- Slice off the prompt and decode just the continuation ---\n",
        "    seqs = out.sequences\n",
        "    cont = [seqs[i, cut_idx:] for i in range(len(qs))]\n",
        "    outs = tokenizer.batch_decode(cont, skip_special_tokens=True)\n",
        "    # --- Cleanup: prefer GSM8K final answer if present, else one-line fallback ---\n",
        "    cleaned = []\n",
        "    for o in outs:\n",
        "        s = o.strip()\n",
        "\n",
        "        # If you're using GSM8K prompting, this will pull the final numeric answer.\n",
        "        gsm = _extract_gsm8k_final(s)\n",
        "        if gsm is not None:\n",
        "            cleaned.append(gsm)\n",
        "            continue\n",
        "    return cleaned\n",
        "\n",
        "\n",
        "def answer_originals_and_paraphrases(\n",
        "    qs: List[str],\n",
        "    paras_per_q: List[List[str]],\n",
        "    *,\n",
        "    micro_bs: int = 64,\n",
        "    model=None,\n",
        "    tokenizer=None,\n",
        "    max_new_tokens: int = MAX_NEW_TOKENS,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        "    use_greedy: bool = False,\n",
        ") -> Tuple[List[str], List[List[str]]]:\n",
        "    \"\"\"\n",
        "    Answers each original question and each of its paraphrases.\n",
        "    Returns (answers_for_originals, answers_for_paraphrases)\n",
        "    where answers_for_paraphrases[i] aligns with paras_per_q[i].\n",
        "    Uses micro-batching to avoid OOM.\n",
        "    \"\"\"\n",
        "\n",
        "    def _batched_answer(batch: List[str]) -> List[str]:\n",
        "        return answer_batch_chat(\n",
        "            batch,\n",
        "            model=model,\n",
        "            tokenizer=tokenizer,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            temperature=temperature,\n",
        "            top_p=top_p,\n",
        "            use_greedy=use_greedy,\n",
        "        )\n",
        "\n",
        "    # 1) Answers for originals\n",
        "    ans_orig: List[str] = []\n",
        "    for i in range(0, len(qs), micro_bs):\n",
        "        ans_orig.extend(_batched_answer(qs[i : i + micro_bs]))\n",
        "\n",
        "    # 2) Answers for all paraphrases (flatten, answer, then regroup)\n",
        "    flat_paras = [p for group in paras_per_q for p in group]\n",
        "    ans_paras_flat: List[str] = []\n",
        "    for i in range(0, len(flat_paras), micro_bs):\n",
        "        ans_paras_flat.extend(_batched_answer(flat_paras[i : i + micro_bs]))\n",
        "\n",
        "    # Regroup by question\n",
        "    ans_paras: List[List[str]] = []\n",
        "    ptr = 0\n",
        "    for group in paras_per_q:\n",
        "        n = len(group)\n",
        "        ans_paras.append(ans_paras_flat[ptr : ptr + n])\n",
        "        ptr += n\n",
        "\n",
        "    return ans_orig, ans_paras\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6hWe785vn5Uw",
      "metadata": {
        "id": "6hWe785vn5Uw"
      },
      "outputs": [],
      "source": [
        "def _to_text(x):\n",
        "    # Accept str, list[str], dict-like {\"text\": \"...\"} / {\"answer\": \"...\"} / [{\"text\": ...}]\n",
        "    if x is None:\n",
        "        return \"\"\n",
        "    if isinstance(x, str):\n",
        "        return x\n",
        "    if isinstance(x, (int, float)):\n",
        "        return str(x)\n",
        "    if isinstance(x, list):\n",
        "        # join multiple references; tweak as you like\n",
        "        return \"; \".join(_to_text(xx) for xx in x if xx is not None)\n",
        "    if isinstance(x, dict):\n",
        "        for k in (\"text\", \"answer\", \"final\", \"final_answer\", \"label\", \"gold\", \"target\"):\n",
        "            if k in x:\n",
        "                return _to_text(x[k])\n",
        "        # fallback: stringify dict\n",
        "        return str(x)\n",
        "    return str(x)\n",
        "\n",
        "def fetch_ground_truths(train_ds, idxs):\n",
        "    gts = []\n",
        "    for i in idxs:\n",
        "        rec = train_ds[i]\n",
        "        # Try common fields in order\n",
        "        val = None\n",
        "        for key in (\"answer\", \"answers\", \"final_answer\", \"final\", \"label\", \"target\"):\n",
        "            if key in rec:\n",
        "                val = rec[key]; break\n",
        "        gts.append(_to_text(val))\n",
        "    return gts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "t_pao7CWT7Kg",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "t_pao7CWT7Kg",
        "outputId": "aa264796-122d-4d67-8353-49dfbb2ab707"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\nPARAPHRASES_PER_QUESTION_tmp = 2\\nk_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\\n\\n# Keep the indices so we can fetch ground truths\\nqs = []\\nidxs = [0]\\nfor idx in idxs:\\n  qs.append(train_ds[idx][\"question\"])\\n\\nparas_per_q = paraphrase_batch_chat(qs, k=k_tmp)\\n\\n# --- NEW: get model answers ---\\nanswers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\\n\\n# --- NEW: fetch ground-truths for the originals ---\\ngt_for_qs = fetch_ground_truths(train_ds, idxs)\\n\\n# ---- Print nicely ----\\nprint(\"=== Test inputs ===\")\\nfor i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\\n    print(f\"[Q{i} @ idx={idx}] {q}\")\\n    print(f\"    Ground truth: {gt}\")\\n\\nprint(\"\\n=== Paraphrases + Answers ===\")\\nfor i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\\n    print(f\"[Q{i}]\")\\n    print(f\"  Original answer: {a_q}\")\\n    print(f\"  Ground truth   : {gt}\")\\n    if not paras:\\n        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\\n    else:\\n        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\\n            print(f\"  {j}. {p}\")\\n            print(f\"     ↳ Answer: {ap}\")\\n'"
            ]
          },
          "execution_count": 24,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# --- your snippet (kept) ---\n",
        "'''\n",
        "PARAPHRASES_PER_QUESTION_tmp = 2\n",
        "k_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\n",
        "\n",
        "# Keep the indices so we can fetch ground truths\n",
        "qs = []\n",
        "idxs = [0]\n",
        "for idx in idxs:\n",
        "  qs.append(train_ds[idx][\"question\"])\n",
        "\n",
        "paras_per_q = paraphrase_batch_chat(qs, k=k_tmp)\n",
        "\n",
        "# --- NEW: get model answers ---\n",
        "answers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\n",
        "\n",
        "# --- NEW: fetch ground-truths for the originals ---\n",
        "gt_for_qs = fetch_ground_truths(train_ds, idxs)\n",
        "\n",
        "# ---- Print nicely ----\n",
        "print(\"=== Test inputs ===\")\n",
        "for i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\n",
        "    print(f\"[Q{i} @ idx={idx}] {q}\")\n",
        "    print(f\"    Ground truth: {gt}\")\n",
        "\n",
        "print(\"\\n=== Paraphrases + Answers ===\")\n",
        "for i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\n",
        "    print(f\"[Q{i}]\")\n",
        "    print(f\"  Original answer: {a_q}\")\n",
        "    print(f\"  Ground truth   : {gt}\")\n",
        "    if not paras:\n",
        "        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\n",
        "    else:\n",
        "        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\n",
        "            print(f\"  {j}. {p}\")\n",
        "            print(f\"     ↳ Answer: {ap}\")\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "lejUKL25Sf3U",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lejUKL25Sf3U",
        "outputId": "e442d735-f609-44ca-fe78-cb5a30da3c7f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CACHE_PATH:  /content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl\n"
          ]
        }
      ],
      "source": [
        "import os, json, time, hashlib, pathlib\n",
        "\n",
        "\n",
        "print(\"CACHE_PATH: \", CACHE_PATH)\n",
        "\n",
        "import os, time\n",
        "\n",
        "from pathlib import Path\n",
        "def _ensure_parent_dir(path: str | Path) -> Path:\n",
        "    p = Path(path)\n",
        "    p.parent.mkdir(parents=True, exist_ok=True)\n",
        "    print (\"Making directory: \", path)\n",
        "    return p\n",
        "\n",
        "\n",
        "def save_paraphrase_cache(path, classes_all, classes_usable, meta: dict):\n",
        "    path = _ensure_parent_dir(path)                    # NEW: ensure folder exists\n",
        "    meta = dict(meta)\n",
        "    meta[\"created_utc\"] = int(time.time())\n",
        "    meta[\"hash\"] = _meta_hash(**meta)\n",
        "\n",
        "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta}) + \"\\n\")\n",
        "        for i, cls in enumerate(classes_all):\n",
        "            rec = {\"idx\": i, \"size\": len(cls), \"prompts\": cls, \"usable\": int(len(cls) >= 2)}\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    print(f\"Saved cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "\n",
        "\n",
        "import uuid, os\n",
        "\n",
        "def save_paraphrase_cache_atomic(path, classes_all, classes_usable, meta):\n",
        "    p = _ensure_parent_dir(path)                       # NEW\n",
        "    tmp = p.with_name(p.name + f\".tmp-{uuid.uuid4().hex}\")  # same dir → safe rename\n",
        "    save_paraphrase_cache(tmp, classes_all, classes_usable, meta)\n",
        "    os.replace(tmp, p)                                 # atomic on POSIX\n",
        "\n",
        "\n",
        "def meta_compatible(current_meta: dict, cached_meta: dict | None) -> bool:\n",
        "    if not cached_meta:\n",
        "        return False\n",
        "    # Be strict about the knobs that affect outputs\n",
        "    keys = [\n",
        "        \"train_len\", \"per_class\", \"paraphraser_model\",\n",
        "        \"top_p\", \"top_k\", \"temperature\", \"max_new_tokens\",\n",
        "        \"dedup_thresh\", \"embedder\",\n",
        "    ]\n",
        "    return all(current_meta.get(k) == cached_meta.get(k) for k in keys)\n",
        "\n",
        "\n",
        "def _meta_hash(**kw):\n",
        "    # tie cache to settings that affect paraphrases\n",
        "    key = json.dumps(kw, sort_keys=True).encode()\n",
        "    return hashlib.sha1(key).hexdigest()[:10]\n",
        "\n",
        "def save_paraphrase_cache(path, classes_all, classes_usable, meta: dict):\n",
        "    path = pathlib.Path(path)\n",
        "    meta = dict(meta)\n",
        "    meta[\"created_utc\"] = int(time.time())\n",
        "    meta[\"hash\"] = _meta_hash(**meta)\n",
        "\n",
        "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta}) + \"\\n\")\n",
        "        # classes_all is a list[list[str]]\n",
        "        for i, cls in enumerate(classes_all):\n",
        "            rec = {\n",
        "                \"idx\": i,\n",
        "                \"size\": len(cls),\n",
        "                \"prompts\": cls,\n",
        "                \"usable\": int(len(cls) >= 2)\n",
        "            }\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    print(f\"Saved cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "\n",
        "def load_paraphrase_cache(path, require_hash: str | None = None):\n",
        "    path = pathlib.Path(path)\n",
        "    if not path.exists():\n",
        "        return None, None, None\n",
        "    classes_all, classes_usable = [], []\n",
        "    meta = None\n",
        "    with path.open(\"r\", encoding=\"utf-8\") as f:\n",
        "        for line in f:\n",
        "            obj = json.loads(line)\n",
        "            if \"__meta__\" in obj:\n",
        "                meta = obj[\"__meta__\"]\n",
        "                continue\n",
        "            cls = obj[\"prompts\"]\n",
        "            classes_all.append(cls)\n",
        "            if obj.get(\"usable\") and len(cls) >= 2:\n",
        "                classes_usable.append(cls)\n",
        "    if require_hash and meta and meta.get(\"hash\") != require_hash:\n",
        "        print(\"⚠️ Cache meta hash mismatch; rebuilding may be safer.\")\n",
        "    print(f\"Loaded cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "    return classes_all, classes_usable, meta"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B07v0ZOKuI3F",
      "metadata": {
        "id": "B07v0ZOKuI3F"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "def build_classes_full_batched(\n",
        "    per_class=PARAPHRASES_PER_QUESTION,\n",
        "    dedup_thresh=0.98,\n",
        "    do_dedup=False,   # <- new flag\n",
        "    cache_path: str | None = None,\n",
        "    flush_every: int = 1000,\n",
        "    generate_fn=paraphrase_batch_chat,\n",
        "):\n",
        "    \"\"\"\n",
        "    Process the ENTIRE train_ds using big, batched generation calls.\n",
        "    Periodically checkpoints to `cache_path` and resumes if compatible.\n",
        "    Returns (classes_all, classes_usable).\n",
        "    \"\"\"\n",
        "    k = max(0, per_class - 1)\n",
        "    n_total = len(train_ds)\n",
        "\n",
        "    # --- try to resume from a compatible cache ---\n",
        "    start_at = 0\n",
        "    classes_all, classes_usable = [], []\n",
        "    cached = (None, None, None)\n",
        "\n",
        "    if cache_path:\n",
        "        try:\n",
        "            cached = load_paraphrase_cache(cache_path)\n",
        "        except Exception:\n",
        "            cached = (None, None, None)\n",
        "\n",
        "    cached_all, cached_usable, cached_meta = cached\n",
        "    if (\n",
        "        cache_path\n",
        "        and cached_all is not None\n",
        "        and 0 < len(cached_all) <= n_total\n",
        "        and meta_compatible(META, cached_meta)\n",
        "    ):\n",
        "        classes_all = list(cached_all)\n",
        "        classes_usable = list(cached_usable or [])\n",
        "        start_at = len(classes_all)\n",
        "        print(f\"[Resume] Compatible cache found: {start_at}/{n_total} rows; resuming…\")\n",
        "    else:\n",
        "        if cache_path:\n",
        "            print(\"[Resume] No compatible partial cache; starting fresh.\")\n",
        "\n",
        "    # --- Put embedder on GPU if possible ---\n",
        "    try:\n",
        "        embedder.to(\"cuda\")\n",
        "    except Exception:\n",
        "        pass\n",
        "\n",
        "    pbar = tqdm(total=n_total, desc=\"Building (batched) classes\", leave=True)\n",
        "    pbar.update(start_at)\n",
        "    since_last_flush = 0\n",
        "\n",
        "    try:\n",
        "        for start in range(start_at, n_total, BATCH_QUESTIONS):\n",
        "            end = min(start + BATCH_QUESTIONS, n_total)\n",
        "            batch = train_ds[start:end]\n",
        "            qs = batch[\"question\"] # <-- list of questions\n",
        "            # Fast column slice (HF datasets return list for a column slice)\n",
        "            # qs = train_ds[\"question\"][start:end]\n",
        "\n",
        "            # ---- fast batched paraphrase generation ----\n",
        "            if k > 0:\n",
        "                paras_per_q = generate_fn(qs, k=k)\n",
        "            else:\n",
        "                paras_per_q = [[] for _ in qs]\n",
        "\n",
        "            # ---- dedup per question (batched embeddings for speed) ----\n",
        "            for i, q in enumerate(qs):\n",
        "                paras = [q] + paras_per_q[i]\n",
        "\n",
        "                if do_dedup:\n",
        "                    # evaluate paraphrases by cosine value in embedding space. Discard those that are not novel enough.\n",
        "                    embs = embedder.encode(\n",
        "                        paras, convert_to_tensor=True,\n",
        "                        normalize_embeddings=True, batch_size=64\n",
        "                    )\n",
        "\n",
        "\n",
        "                    # evaluate paraphrases by cosine value in embedding space. Discard those that are not novel enough.\n",
        "                    keep = [0]\n",
        "                    for j in range(1, len(paras)):\n",
        "                        sims = sbert_util.cos_sim(embs[j], embs[keep]).cpu().numpy().flatten()\n",
        "                        if float(np.max(sims)) < dedup_thresh:\n",
        "                            keep.append(j)\n",
        "\n",
        "                    cls = [paras[j] for j in keep]\n",
        "                else:\n",
        "                    cls = paras[:]  # original + all paraphrases, in order, no discard\n",
        "\n",
        "                classes_all.append(cls)\n",
        "                if len(cls) >= 2:\n",
        "                    classes_usable.append(cls)\n",
        "\n",
        "            processed = end - start\n",
        "            since_last_flush += processed\n",
        "            pbar.update(processed)\n",
        "            pbar.set_postfix({\"usable\": len(classes_usable)})\n",
        "\n",
        "            # ---- periodic checkpoint ----\n",
        "            if cache_path and since_last_flush >= flush_every:\n",
        "                save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "                since_last_flush = 0\n",
        "\n",
        "        # final save\n",
        "        if cache_path:\n",
        "            save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "\n",
        "    except KeyboardInterrupt:\n",
        "        if cache_path:\n",
        "            print(\"\\n[Interrupt] Saving partial progress…\")\n",
        "            save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "        raise\n",
        "    finally:\n",
        "        pbar.close()\n",
        "\n",
        "    return classes_all, classes_usable"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QM6pwSm3TAyW",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QM6pwSm3TAyW",
        "outputId": "8d8b653c-7343-4dad-8603-f78b205ae907"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loaded cache: /content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl  (7473 classes, usable=7473)\n",
            "Total classes: 7473; usable: 7473\n"
          ]
        }
      ],
      "source": [
        "# Describe the paraphrase-generation settings so caches are consistent\n",
        "META = {\n",
        "    \"train_len\": len(train_ds),\n",
        "    \"per_class\": PARAPHRASES_PER_QUESTION,\n",
        "    \"paraphraser_model\": BASELINE_MODEL_NAME,\n",
        "    \"top_p\": TOP_P, \"top_k\": TOP_K, \"temperature\": GEN_TEMPERATURE,\n",
        "    \"max_new_tokens\": 48,\n",
        "    \"dedup_thresh\": 0.98,\n",
        "    \"embedder\": \"sentence-transformers/all-MiniLM-L6-v2\",\n",
        "}\n",
        "\n",
        "# Try to load cache first\n",
        "paraphrase_classes_all, paraphrase_classes, meta = load_paraphrase_cache(CACHE_PATH)\n",
        "need_build = paraphrase_classes_all is None or len(paraphrase_classes_all) != len(train_ds)\n",
        "\n",
        "if need_build:\n",
        "    print(\"No valid cache found — building now…\")\n",
        "    _ensure_parent_dir(CACHE_PATH)\n",
        "    paraphrase_classes_all, paraphrase_classes = build_classes_full_batched(\n",
        "        per_class=PARAPHRASES_PER_QUESTION,\n",
        "        dedup_thresh=0.98,\n",
        "        cache_path=CACHE_PATH,      # enables resume & periodic saves\n",
        "        flush_every=1000,           # save every 1000 rows (tweak as you like)\n",
        "        generate_fn=paraphrase_batch_chat,\n",
        "    )\n",
        "else:\n",
        "    # Optional: check compatibility (hash) if you want strict matching\n",
        "    pass\n",
        "\n",
        "print(f\"Total classes: {len(paraphrase_classes_all)}; usable: {len(paraphrase_classes)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GSdAyKG47uP8",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GSdAyKG47uP8",
        "outputId": "b13c1716-abe8-4104-87da-44912597d7b2"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\nPARAPHRASES_PER_QUESTION_tmp = 2\\nk_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\\n\\n# Keep the indices so we can fetch ground truths\\nqs = []\\nidxs = [1000, 1300]\\nparas_per_q = []\\nfor idx in idxs:\\n  qs.append(train_ds[idx][\"question\"])\\n  paras_per_q.append([paraphrase_classes[idx][1]])\\nprint(qs)\\nprint(paras_per_q)\\n'"
            ]
          },
          "execution_count": 28,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# --- your snippet (kept) ---\n",
        "'''\n",
        "PARAPHRASES_PER_QUESTION_tmp = 2\n",
        "k_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\n",
        "\n",
        "# Keep the indices so we can fetch ground truths\n",
        "qs = []\n",
        "idxs = [1000, 1300]\n",
        "paras_per_q = []\n",
        "for idx in idxs:\n",
        "  qs.append(train_ds[idx][\"question\"])\n",
        "  paras_per_q.append([paraphrase_classes[idx][1]])\n",
        "print(qs)\n",
        "print(paras_per_q)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "gdXynjcC9BGw",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gdXynjcC9BGw",
        "outputId": "77868b80-e6da-44b7-b8a7-0df4317c670b"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\n# --- NEW: get model answers ---\\nanswers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\\n\\n# --- NEW: fetch ground-truths for the originals ---\\ngt_for_qs = fetch_ground_truths(train_ds, idxs)\\n\\n# ---- Print nicely ----\\nprint(\"=== Test inputs ===\")\\nfor i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\\n    print(f\"[Q{i} @ idx={idx}] {q}\")\\n    print(f\"    Ground truth: {gt}\")\\n\\nprint(\"\\n=== Paraphrases + Answers ===\")\\nfor i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\\n    print(f\"[Q{i}]\")\\n    print(f\"  Original answer: {a_q}\")\\n    print(f\"  Ground truth   : {gt}\")\\n    if not paras:\\n        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\\n    else:\\n        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\\n            print(f\"  {j}. {p}\")\\n            print(f\"     ↳ Answer: {ap}\")\\n'"
            ]
          },
          "execution_count": 29,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "'''\n",
        "# --- NEW: get model answers ---\n",
        "answers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\n",
        "\n",
        "# --- NEW: fetch ground-truths for the originals ---\n",
        "gt_for_qs = fetch_ground_truths(train_ds, idxs)\n",
        "\n",
        "# ---- Print nicely ----\n",
        "print(\"=== Test inputs ===\")\n",
        "for i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\n",
        "    print(f\"[Q{i} @ idx={idx}] {q}\")\n",
        "    print(f\"    Ground truth: {gt}\")\n",
        "\n",
        "print(\"\\n=== Paraphrases + Answers ===\")\n",
        "for i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\n",
        "    print(f\"[Q{i}]\")\n",
        "    print(f\"  Original answer: {a_q}\")\n",
        "    print(f\"  Ground truth   : {gt}\")\n",
        "    if not paras:\n",
        "        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\n",
        "    else:\n",
        "        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\n",
        "            print(f\"  {j}. {p}\")\n",
        "            print(f\"     ↳ Answer: {ap}\")\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "SfSfLvyWQ9aB",
      "metadata": {
        "id": "SfSfLvyWQ9aB"
      },
      "outputs": [],
      "source": [
        "FLAG_DEBUG = True"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ywbhUshonDfJ",
      "metadata": {
        "id": "ywbhUshonDfJ"
      },
      "source": [
        "## Evaluations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B5VVKuJrDVKo",
      "metadata": {
        "id": "B5VVKuJrDVKo"
      },
      "outputs": [],
      "source": [
        "import re, math\n",
        "import torch\n",
        "from typing import Dict, Any\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "\n",
        "def _wilson_ci(k: int, n: int, level: float = 0.95):\n",
        "    if n == 0:\n",
        "        return (0.0, 0.0)\n",
        "    z = 1.959963984540054  # 95%\n",
        "    phat = k / n\n",
        "    denom = 1 + (z*z)/n\n",
        "    center = (phat + (z*z)/(2*n)) / denom\n",
        "    half = (z / denom) * math.sqrt((phat*(1-phat) + (z*z)/(4*n)) / n)\n",
        "    return max(0.0, center - half), min(1.0, center + half)\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_gsm8k(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"eval\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,   # number of sample Q/pred/gold to print total\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Evaluate GSM8K with a tqdm progress bar.\n",
        "\n",
        "    Prints:\n",
        "      - tqdm progress bar over batches\n",
        "      - optional few sample outputs (total across run)\n",
        "      - final ACC + 95% CI\n",
        "\n",
        "    Returns result dict with preds + sample + ACC/CI.\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    N_total = len(dataset)\n",
        "    N = min(int(limit), N_total) if limit is not None else N_total\n",
        "\n",
        "    sampled_indices = random.sample(range(N_total), k=N)\n",
        "    qs, gold_list, preds = [], [], []\n",
        "\n",
        "    # Collect questions + golds\n",
        "    for i in sampled_indices:\n",
        "        ex = dataset[i]\n",
        "        qs.append(ex[\"question\"])\n",
        "        gold_list.append(_extract_gsm8k_final(ex.get(\"answer\", \"\")))\n",
        "\n",
        "    num_batches = (N + batch_size - 1) // batch_size\n",
        "    pbar = tqdm(total=num_batches, desc=f\"Evaluating {name}\", dynamic_ncols=True)\n",
        "\n",
        "    printed = 0\n",
        "    for b, start in enumerate(range(0, N, batch_size), start=1):\n",
        "        end = min(start + batch_size, N)\n",
        "        batch_qs = qs[start:end]\n",
        "\n",
        "        batch_preds = answer_batch_chat(\n",
        "            batch_qs,\n",
        "            model=model,\n",
        "            tokenizer=tokenizer,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            use_greedy=use_greedy,\n",
        "        )\n",
        "        preds.extend(batch_preds)\n",
        "\n",
        "        # Update running accuracy for the bar\n",
        "        # (Compute correctness only for this batch)\n",
        "        batch_correct = 0\n",
        "        for j, p in enumerate(batch_preds):\n",
        "            g = gold_list[start + j]\n",
        "            p = (p or \"\").strip()\n",
        "            g = (g or \"\").strip()\n",
        "            if p == g and p != \"NA\" and g != \"NA\":\n",
        "                batch_correct += 1\n",
        "\n",
        "        # Running totals (avoid O(N^2) by tracking separately)\n",
        "        # We'll compute final ACC at the end accurately anyway.\n",
        "        # Use displayed running accuracy as a convenience.\n",
        "        # Derive running correct by re-counting cheaply (still fine at GSM8K scale),\n",
        "        # but better track it incrementally:\n",
        "        if b == 1:\n",
        "            running_correct = batch_correct\n",
        "        else:\n",
        "            running_correct += batch_correct\n",
        "        running_n = len(preds)\n",
        "        running_acc = running_correct / running_n if running_n else 0.0\n",
        "\n",
        "        pbar.set_postfix({\n",
        "            \"N\": running_n,\n",
        "            \"ACC%\": f\"{running_acc*100:.2f}\",\n",
        "            \"greedy\": use_greedy,\n",
        "        })\n",
        "        pbar.update(1)\n",
        "\n",
        "        # Optional: print a few samples total across run\n",
        "        if show_samples > 0 and printed < show_samples:\n",
        "            k = min(show_samples - printed, len(batch_qs))\n",
        "            for j in range(k):\n",
        "                idx = start + j\n",
        "                q_preview = batch_qs[j].replace(\"\\n\", \" \")\n",
        "                if len(q_preview) > 160:\n",
        "                    q_preview = q_preview[:160] + \"...\"\n",
        "                pbar.write(f\"[{name}] idx={idx} Q: {q_preview}\")\n",
        "                pbar.write(f\"[{name}]        pred: {batch_preds[j]}\")\n",
        "                pbar.write(f\"[{name}]        gold: {gold_list[idx]}\")\n",
        "                printed += 1\n",
        "\n",
        "    pbar.close()\n",
        "\n",
        "    # Final exact-match accuracy\n",
        "    correct = 0\n",
        "    pred_na = 0\n",
        "\n",
        "    for p, g in zip(preds, gold_list):\n",
        "        p = (p or \"\").strip()\n",
        "        g = (g or \"\").strip()\n",
        "\n",
        "        if p == \"NA\":\n",
        "            pred_na += 1\n",
        "\n",
        "        if p == g and p != \"NA\" and g != \"NA\":\n",
        "            correct += 1\n",
        "\n",
        "    n = len(preds)\n",
        "    acc_mean = correct / n if n else 0.0\n",
        "    acc_lo, acc_hi = _wilson_ci(correct, n, level=0.95)\n",
        "\n",
        "    pred_na_frac = (pred_na / n) if n else 0.0\n",
        "\n",
        "    print(\n",
        "        f\"[{name}] N={n}  \"\n",
        "        f\"ACC: {acc_mean*100:.2f} and ACC confidence \"\n",
        "        f\"[{acc_lo*100:.2f}, {acc_hi*100:.2f}]  |  \"\n",
        "        f\"pred NA: {pred_na}/{n} ({pred_na_frac*100:.2f}%)  \"\n",
        "    )\n",
        "\n",
        "    model.train(was_training)\n",
        "    return {\n",
        "        \"preds\": preds,\n",
        "        \"sample\": {\n",
        "            \"indices\": sampled_indices,\n",
        "            \"questions\": qs,\n",
        "            \"golds\": gold_list,\n",
        "        },\n",
        "        \"N\": len(preds),\n",
        "        \"ACC\": acc_mean,\n",
        "        \"ACC_CI\": (acc_lo, acc_hi),\n",
        "        \"ACC_CI_Level\": 0.95,\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "t0RWnAfBnCCv",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 512,
          "referenced_widgets": [
            "3e4094c0ed3041afa63d26fc05b6f17d",
            "aa4ba36a2be8423bb9c23cd1834b3c15",
            "2c400043327a4fa7b10b5a8f527df0cd",
            "e4a308877b7d4fb283f15b6613814519",
            "acf8d4bc4e4b4cf285ea39855fd6c75f",
            "9b871c411c2b4809bb1d4ba9b8b650ab",
            "bfc80c383c86441b9aa68e7246922580",
            "50c82dcbe2f14801b3cbdf412f91b094",
            "eab965c99d6d445db3a182b128a0fce5",
            "0224a2b6fc59439e9771f36f7f30176a",
            "19bb218724e741248d79f0d5cbeee526"
          ]
        },
        "id": "t0RWnAfBnCCv",
        "outputId": "4b43dad1-2499-4e01-b8db-4088f1f131c1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Evaluation limit: 1319\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3e4094c0ed3041afa63d26fc05b6f17d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating qwen3b greedy trained:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n",
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-2493726714.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Evaluation limit:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_limit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m res = evaluate_gsm8k(\n\u001b[0m\u001b[1;32m      6\u001b[0m     \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrainable_model\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mtokenizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/tmp/ipython-input-4002072956.py\u001b[0m in \u001b[0;36mevaluate_gsm8k\u001b[0;34m(model, tokenizer, dataset, limit, batch_size, max_new_tokens, name, use_greedy, show_samples)\u001b[0m\n\u001b[1;32m     62\u001b[0m         \u001b[0mbatch_qs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mqs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m         batch_preds = answer_batch_chat(\n\u001b[0m\u001b[1;32m     65\u001b[0m             \u001b[0mbatch_qs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m             \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/tmp/ipython-input-3707647594.py\u001b[0m in \u001b[0;36manswer_batch_chat\u001b[0;34m(qs, model, tokenizer, max_new_tokens, temperature, top_p, use_greedy)\u001b[0m\n\u001b[1;32m     69\u001b[0m     \u001b[0;31m# --- Generation ---\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minference_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m         out = model.generate(\n\u001b[0m\u001b[1;32m     72\u001b[0m             \u001b[0;34m**\u001b[0m\u001b[0menc\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m             \u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/peft/peft_model.py\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   2046\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_enable_peft_forward_hooks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2047\u001b[0m                     \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspecial_peft_forward_args\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2048\u001b[0;31m                     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2049\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2050\u001b[0m                 \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    118\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mctx_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    121\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    122\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)\u001b[0m\n\u001b[1;32m   2564\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2565\u001b[0m         \u001b[0;31m# 9. Call generation mode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2566\u001b[0;31m         result = decoding_method(\n\u001b[0m\u001b[1;32m   2567\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2568\u001b[0m             \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py\u001b[0m in \u001b[0;36m_sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[1;32m   2787\u001b[0m                 \u001b[0mis_prefill\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2788\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2789\u001b[0;31m                 \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2790\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2791\u001b[0m             \u001b[0;31m# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1773\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1774\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1777\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1784\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1785\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1788\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    916\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict_passed\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    917\u001b[0m             \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict_passed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 918\u001b[0;31m         \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    919\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    920\u001b[0m             \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_tuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m    447\u001b[0m         \u001b[0;34m\"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    448\u001b[0m         ```\"\"\"\n\u001b[0;32m--> 449\u001b[0;31m         outputs: BaseModelOutputWithPast = self.model(\n\u001b[0m\u001b[1;32m    450\u001b[0m             \u001b[0minput_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    451\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1773\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1774\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1777\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1784\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1785\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1788\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1070\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1071\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1072\u001b[0;31m                 \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1073\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0moriginal_exception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1074\u001b[0m                 \u001b[0;31m# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m    382\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    383\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mdecoder_layer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_hidden_layers\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 384\u001b[0;31m             hidden_states = decoder_layer(\n\u001b[0m\u001b[1;32m    385\u001b[0m                 \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    386\u001b[0m                 \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcausal_mask_mapping\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdecoder_layer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention_type\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/modeling_layers.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     93\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gradient_checkpointing_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1773\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1774\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1777\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1784\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1785\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1788\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/utils/deprecation.py\u001b[0m in \u001b[0;36mwrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    170\u001b[0m                 \u001b[0mwarnings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFutureWarning\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstacklevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    171\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    174\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped_func\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m    232\u001b[0m         \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_layernorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    233\u001b[0m         \u001b[0;31m# Self Attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m         hidden_states, _ = self.self_attn(\n\u001b[0m\u001b[1;32m    235\u001b[0m             \u001b[0mhidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    236\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1773\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1774\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1777\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1784\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1785\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1788\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/utils/deprecation.py\u001b[0m in \u001b[0;36mwrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    170\u001b[0m                 \u001b[0mwarnings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFutureWarning\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstacklevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    171\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    174\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped_func\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m    167\u001b[0m             \u001b[0mattention_interface\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mALL_ATTENTION_FUNCTIONS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_attn_implementation\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m         attn_output, attn_weights = attention_interface(\n\u001b[0m\u001b[1;32m    170\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    171\u001b[0m             \u001b[0mquery_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/integrations/flash_attention.py\u001b[0m in \u001b[0;36mflash_attention_forward\u001b[0;34m(module, query, key, value, attention_mask, dropout, scaling, sliding_window, softcap, **kwargs)\u001b[0m\n\u001b[1;32m     64\u001b[0m         \u001b[0mis_causal\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_causal\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 66\u001b[0;31m     attn_output = _flash_attention_forward(\n\u001b[0m\u001b[1;32m     67\u001b[0m         \u001b[0mquery\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     68\u001b[0m         \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/modeling_flash_attention_utils.py\u001b[0m in \u001b[0;36m_flash_attention_forward\u001b[0;34m(query_states, key_states, value_states, attention_mask, query_length, is_causal, dropout, position_ids, softmax_scale, sliding_window, use_top_left_mask, softcap, deterministic, cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, target_dtype, implementation, **kwargs)\u001b[0m\n\u001b[1;32m    614\u001b[0m             \u001b[0mcu_seq_lens_k\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcu_seq_lens_k\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    615\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 616\u001b[0;31m         out_unpad = flash_varlen_fn(\n\u001b[0m\u001b[1;32m    617\u001b[0m             \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    618\u001b[0m             \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/flash_attn/flash_attn_interface.py\u001b[0m in \u001b[0;36mflash_attn_varlen_func\u001b[0;34m(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, block_table)\u001b[0m\n\u001b[1;32m   1441\u001b[0m             \u001b[0mpattern\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnegative\u001b[0m \u001b[0mmeans\u001b[0m \u001b[0mthat\u001b[0m \u001b[0mlocation\u001b[0m \u001b[0mwas\u001b[0m \u001b[0mdropped\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnonnegative\u001b[0m \u001b[0mmeans\u001b[0m \u001b[0mit\u001b[0m \u001b[0mwas\u001b[0m \u001b[0mkept\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1442\u001b[0m     \"\"\"\n\u001b[0;32m-> 1443\u001b[0;31m     return FlashAttnVarlenFunc.apply(\n\u001b[0m\u001b[1;32m   1444\u001b[0m         \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1445\u001b[0m         \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    579\u001b[0m             \u001b[0;31m# See NOTE: [functorch vjp and autograd interaction]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    580\u001b[0m             \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_functorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munwrap_dead_wrappers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 581\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    582\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    583\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_setup_ctx_defined\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/flash_attn/flash_attn_interface.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, block_table, is_grad_enabled)\u001b[0m\n\u001b[1;32m    923\u001b[0m             \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m8\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mhead_size_og\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    924\u001b[0m             \u001b[0mv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m8\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mhead_size_og\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 925\u001b[0;31m         out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(\n\u001b[0m\u001b[1;32m    926\u001b[0m             \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    927\u001b[0m             \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_ops.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1253\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_has_torchbind_op_overload\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0m_must_dispatch_in_python\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1254\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0m_call_overload_packet_from_python\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1255\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1257\u001b[0m     \u001b[0;31m# TODO: use this to make a __dir__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py\u001b[0m in \u001b[0;36mbackend_impl\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    342\u001b[0m                     \u001b[0;32mdef\u001b[0m \u001b[0mbackend_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 343\u001b[0;31m                         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend_fns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdevice_type\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    344\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    345\u001b[0m                         \u001b[0;32mdef\u001b[0m \u001b[0mget_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_compile.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     51\u001b[0m                 \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dynamo_disable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdisable_fn\u001b[0m  \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdisable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     55\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py\u001b[0m in \u001b[0;36m_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1042\u001b[0m                 \u001b[0m_maybe_set_eval_frame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_callback_from_stance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1043\u001b[0m                 \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m                     \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1045\u001b[0m                 \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1046\u001b[0m                     \u001b[0mset_eval_frame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    374\u001b[0m                         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_init_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    375\u001b[0m                     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 376\u001b[0;31m                         \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    377\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    378\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend_fns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdevice_type\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwrapped_fn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/flash_attn/flash_attn_interface.py\u001b[0m in \u001b[0;36m_flash_attn_varlen_forward\u001b[0;34m(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size_left, window_size_right, softcap, alibi_slopes, return_softmax, block_table, leftpad_k, seqused_k, zero_tensors)\u001b[0m\n\u001b[1;32m    163\u001b[0m ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n\u001b[1;32m    164\u001b[0m     \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mmaybe_contiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m     out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(\n\u001b[0m\u001b[1;32m    166\u001b[0m         \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    167\u001b[0m         \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ],
      "source": [
        "# Without brief reasoning instruction\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"qwen3b greedy trained\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "K0A9sh_uklZ5",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 138,
          "referenced_widgets": [
            "1eac4a3feaaf4a7eb20cae82b5a78ba4",
            "c2ce751a1ad44d4d9e26cc7b5dd5ddfc",
            "3acfe564bc134c2d9ae871bb568b203f",
            "8e00c17408e44c95939b06e5e9954441",
            "27b7a1b99aa142999397e2437123f27e",
            "26932873ef3e4e439700e62540a54a98",
            "9cc921e5ec544e07832501b57d9f2f42",
            "6fd2c60356224d9c841a1fff4feaf6a4",
            "fc4d8b814f454a7c8515dfc3f4f957d0",
            "9383fc235dee4ee19bd47c2e0df567a6",
            "9dc8c60675634ebda305d807cc36e97e"
          ]
        },
        "id": "K0A9sh_uklZ5",
        "outputId": "350fd8b4-e11f-4032-ad41-75178400297e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Evaluation limit: 1319\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1eac4a3feaaf4a7eb20cae82b5a78ba4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating qwen3b greedy trained:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n",
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[qwen3b greedy trained] N=1319  ACC: 75.13 and ACC confidence [72.73, 77.39]  |  pred NA: 56/1319 (4.25%)  \n"
          ]
        }
      ],
      "source": [
        "# Without brief reasoning instruction\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=1024,\n",
        "    name=\"qwen3b greedy trained\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "jBUdpc4iYQXY",
      "metadata": {
        "id": "jBUdpc4iYQXY"
      },
      "outputs": [],
      "source": [
        "# Without brief reasoning instruction\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"qwen3b greedy trained step 2400\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "iXoXLquSv2Kf",
      "metadata": {
        "id": "iXoXLquSv2Kf"
      },
      "outputs": [],
      "source": [
        "# Without brief reasoning instruction\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"qwen3b greedy trained weight 0.02\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "JcRhFYe5JO68",
      "metadata": {
        "id": "JcRhFYe5JO68"
      },
      "outputs": [],
      "source": [
        "# Without brief reasoning instruction\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"qwen3b greedy trained weight 0.01\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "PEp1QblOHIbN",
      "metadata": {
        "id": "PEp1QblOHIbN"
      },
      "outputs": [],
      "source": [
        "\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"qwen3b greedy\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "O1w08WfY85lE",
      "metadata": {
        "id": "O1w08WfY85lE"
      },
      "outputs": [],
      "source": [
        "\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"phi3 greedy\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "SbXDCT4_gUhc",
      "metadata": {
        "id": "SbXDCT4_gUhc"
      },
      "outputs": [],
      "source": [
        "\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"phi3 greedy\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sWQC9mJ-msED",
      "metadata": {
        "id": "sWQC9mJ-msED"
      },
      "outputs": [],
      "source": [
        "MAX_NEW_TOKENS = 1024\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"phi3 greedy 1024\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "h9Juy1GI6_9U",
      "metadata": {
        "id": "h9Juy1GI6_9U"
      },
      "outputs": [],
      "source": [
        "\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"llama greedy baseline\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ypXVQeE_zA1V",
      "metadata": {
        "id": "ypXVQeE_zA1V"
      },
      "outputs": [],
      "source": [
        "# Source - https://stackoverflow.com/a\n",
        "# Posted by Ayo Reis\n",
        "# Retrieved 2025-12-30, License - CC BY-SA 4.0\n",
        "\n",
        "from google.colab import runtime\n",
        "runtime.unassign()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "BbHwE2vv1Xu5",
      "metadata": {
        "id": "BbHwE2vv1Xu5"
      },
      "source": [
        "### Best of N"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "C-BdL7ax1WXq",
      "metadata": {
        "id": "C-BdL7ax1WXq"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "from typing import Tuple\n",
        "\n",
        "def _mean_logprob_of_generated(\n",
        "    model, input_ids: torch.Tensor, attention_mask: torch.Tensor, input_len: int\n",
        ") -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Compute mean log-prob over generated tokens for each sequence in the batch.\n",
        "    input_ids: [B, T] full prompt+generation\n",
        "    attention_mask: [B, T]\n",
        "    input_len: scalar length of the prompt (assumed same across the batch)\n",
        "    Returns: [B] mean log-probability per sequence (higher is better).\n",
        "    \"\"\"\n",
        "    # Forward once to score tokens\n",
        "    out = model(input_ids=input_ids, attention_mask=attention_mask)\n",
        "    logits = out.logits  # [B, T, V]\n",
        "    logprobs = torch.log_softmax(logits, dim=-1)\n",
        "\n",
        "    # We want log p(y_t | x, y_<t>) for generated tokens.\n",
        "    # Token at position t (0-index) is predicted by logits at position t-1.\n",
        "    # So we gather from logprobs[:, :-1] using input_ids[:, 1:].\n",
        "    pred_logprobs = logprobs[:, :-1].gather(dim=-1, index=input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)  # [B, T-1]\n",
        "\n",
        "    B, T = input_ids.shape\n",
        "    gen_start = input_len  # first generated token index in input_ids\n",
        "    # Sequence lengths may differ due to early EOS; rely on attention_mask to know actual used tokens.\n",
        "    # We'll mask everything before gen_start and after the last real token.\n",
        "    valid_mask = attention_mask[:, 1:]  # [B, T-1] tokens that exist (shifted)\n",
        "    gen_mask = torch.zeros_like(valid_mask)\n",
        "    gen_mask[:, gen_start:] = 1  # mark generated region (shifted space)\n",
        "\n",
        "    mask = (valid_mask * gen_mask).to(pred_logprobs.dtype)  # [B, T-1]\n",
        "    # Sum and average over generated tokens actually present\n",
        "    token_counts = mask.sum(dim=-1).clamp_min(1.0)  # avoid /0\n",
        "    seq_logprob_sum = (pred_logprobs * mask).sum(dim=-1)  # [B]\n",
        "    mean_logprob = seq_logprob_sum / token_counts  # [B]\n",
        "    return mean_logprob  # higher is better (less negative)\n",
        "\n",
        "\n",
        "def _pick_first_line(text: str) -> str:\n",
        "    return text.strip().split(\"\\n\")[0].strip()\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def batched_generate_bestof(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    questions: List[str],\n",
        "    best_of: int = 8,\n",
        "    batch_size: int = 4,\n",
        "    max_new_tokens: int = 32,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        ") -> List[str]:\n",
        "    \"\"\"\n",
        "    Best-of-N sampling: for each question, draw `best_of` samples, score each by mean token log-prob,\n",
        "    and return the single highest-prob answer (first line) per question.\n",
        "    \"\"\"\n",
        "    assert best_of >= 1\n",
        "    model_device = next(model.parameters()).device\n",
        "    eos_id = tokenizer.eos_token_id\n",
        "    results = []\n",
        "\n",
        "    for i in range(0, len(questions), batch_size):\n",
        "        batch_q = questions[i:i+batch_size]\n",
        "        prompts = [format_prompt(q) for q in batch_q]\n",
        "        enc = tokenizer(prompts, padding=True, return_tensors=\"pt\").to(model_device)\n",
        "\n",
        "        input_len = enc[\"input_ids\"].shape[1]\n",
        "        # Sample best_of sequences per prompt in a single call\n",
        "        gen_out = model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            do_sample=True,\n",
        "            temperature=temperature,\n",
        "            top_p=top_p,\n",
        "            num_return_sequences=best_of,\n",
        "            eos_token_id=eos_id,\n",
        "            pad_token_id=eos_id,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "        # sequences: [B*best_of, input_len + gen_len<=]\n",
        "        seqs = gen_out.sequences  # token ids\n",
        "        # Build attention mask for the combined sequences\n",
        "        attn = (seqs != eos_id).long()\n",
        "        # But ensure prompt portion is always attended even if eos == pad\n",
        "        attn[:, :input_len] = 1\n",
        "\n",
        "        # Score each sample by mean log-prob over the generated continuation\n",
        "        mean_lp = _mean_logprob_of_generated(model, seqs, attn, input_len)  # [B*best_of]\n",
        "\n",
        "        # Group back per original example\n",
        "        B = len(batch_q)\n",
        "        seqs_per_q = seqs.view(B, best_of, -1)\n",
        "        mean_lp_per_q = mean_lp.view(B, best_of)\n",
        "\n",
        "        # Choose argmax per question\n",
        "        best_idx = torch.argmax(mean_lp_per_q, dim=1)  # [B]\n",
        "        # Decode just the generated part (then take first line)\n",
        "        for b in range(B):\n",
        "            chosen = seqs_per_q[b, best_idx[b].item(), :]\n",
        "            gen_only = chosen[input_len:]  # tokens after the prompt\n",
        "            text = tokenizer.decode(gen_only, skip_special_tokens=True)\n",
        "            results.append(_pick_first_line(text))\n",
        "\n",
        "    return results\n",
        "\n",
        "\n",
        "def evaluate_model_bestof(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 1000,\n",
        "    batch_size: int = 4,\n",
        "    max_new_tokens: int = 32,\n",
        "    best_of: int = 8,\n",
        "    name: str = \"model-bestofN\",\n",
        ") -> Dict:\n",
        "    \"\"\"\n",
        "    Evaluate with Best-of-N sampling (highest mean token log-prob per question).\n",
        "    Returns the same dict schema as `evaluate_model`.\n",
        "    \"\"\"\n",
        "    # Build eval slice\n",
        "    qs, gold_list = [], []\n",
        "    for ex in dataset:\n",
        "        qs.append(ex[\"question\"])\n",
        "        gold_list.append(canonical_answers(ex[\"answer\"]))\n",
        "        if len(qs) >= limit:\n",
        "            break\n",
        "\n",
        "    preds = batched_generate_bestof(\n",
        "        model,\n",
        "        tokenizer,\n",
        "        qs,\n",
        "        best_of=best_of,\n",
        "        batch_size=batch_size,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "    )\n",
        "    assert len(preds) == len(gold_list)\n",
        "\n",
        "    ems, f1s = [], []\n",
        "    for p, g in zip(preds, gold_list):\n",
        "        em, f1 = em_and_f1(p, g)\n",
        "        ems.append(em); f1s.append(f1)\n",
        "\n",
        "    em_mean = float(np.mean(ems)) * 100.0\n",
        "    f1_mean_raw = float(np.mean(f1s))\n",
        "    f1_mean = f1_mean_raw * 100.0\n",
        "\n",
        "    # 95% normal-approx CI for F1\n",
        "    n = len(f1s)\n",
        "    if n > 1:\n",
        "        z = 1.96\n",
        "        std = float(np.std(f1s, ddof=1))\n",
        "        se = std / np.sqrt(n)\n",
        "        f1_lo = max(0.0, (f1_mean_raw - z * se) * 100.0)\n",
        "        f1_hi = min(100.0, (f1_mean_raw + z * se) * 100.0)\n",
        "    else:\n",
        "        f1_lo = f1_mean\n",
        "        f1_hi = f1_mean\n",
        "\n",
        "    print(f\"[{name}] N={len(preds)}  Best-of={best_of}  EM: {em_mean:.2f}  F1: {f1_mean:.2f}  \"\n",
        "          f\"F1 95% CI: [{f1_lo:.2f}, {f1_hi:.2f}]\")\n",
        "\n",
        "    return {\n",
        "        \"N\": len(preds),\n",
        "        \"EM\": em_mean,\n",
        "        \"F1\": f1_mean,\n",
        "        \"F1_CI\": (f1_lo, f1_hi),\n",
        "        \"F1_CI_Level\": 0.95,\n",
        "        \"preds\": preds,\n",
        "        \"golds\": gold_list,\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GHtbvCx3rK87",
      "metadata": {
        "id": "GHtbvCx3rK87"
      },
      "outputs": [],
      "source": [
        "# Greedy baseline\n",
        "USE_GREEDY = True\n",
        "BATCH_SIZE = 64\n",
        "MAX_NEW_TOKENS = 24\n",
        "EVAL_LIMIT = len(val_ds)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "el9lZDDl1hSH",
      "metadata": {
        "id": "el9lZDDl1hSH"
      },
      "outputs": [],
      "source": [
        "\n",
        "res_base_greedy = evaluate_model(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_greedy\")\n",
        "\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline Greedy:\", res_base_greedy[\"preds\"][i])\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "RQTjmepQ-jiM",
      "metadata": {
        "id": "RQTjmepQ-jiM"
      },
      "outputs": [],
      "source": [
        "# Best-of-N (e.g., N=32)\n",
        "BATCH_SIZE = 2\n",
        "res_base_bestof = evaluate_model_bestof(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, best_of=24, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_bestof24\")\n",
        "\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline Best-of-N:\", res_base_bestof[\"preds\"][i])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "rWY0OHvH1xLP",
      "metadata": {
        "id": "rWY0OHvH1xLP"
      },
      "outputs": [],
      "source": [
        "# Greedy trainable\n",
        "# _ = evaluate_model(trainable_model, tokenizer, val_ds, limit=500, batch_size=8, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_greedy\")\n",
        "# trainable_model.eval()  # keep in eval for decoding\n",
        "\n",
        "# Choose evaluation budget (increase if you have time/compute)\n",
        "BATCH_SIZE = 64\n",
        "MAX_NEW_TOKENS = 24\n",
        "\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "trainable_model.eval()\n",
        "res_trained_greedy = evaluate_model(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_greedy\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained Greedy:\", res_trained_greedy[\"preds\"][i])\n",
        "\n",
        "\n",
        "# Best-of-N (e.g., N=8)\n",
        "# _ = evaluate_model_bestof(trainable_model, tokenizer, val_ds, limit=500, batch_size=4, best_of=8, name=\"trainable_bestof8\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VnX-WgxQ-BXZ",
      "metadata": {
        "id": "VnX-WgxQ-BXZ"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 2\n",
        "res_trained_bestof = evaluate_model_bestof(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, best_of=24, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_bestof24\")\n",
        "for i in range(5):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained Best-of-N:\", res_trained_bestof[\"preds\"][i])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Pb94eMijnT2q",
      "metadata": {
        "id": "Pb94eMijnT2q"
      },
      "outputs": [],
      "source": [
        "# baseline\n",
        "USE_GREEDY = False\n",
        "BATCH_SIZE = 64\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "MAX_NEW_TOKENS = 24\n",
        "\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "trainable_model.eval()\n",
        "res_trained_greedy = evaluate_model(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_sampling\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained sampling:\", res_trained_greedy[\"preds\"][i])\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Kxghp_lE_1sp",
      "metadata": {
        "id": "Kxghp_lE_1sp"
      },
      "outputs": [],
      "source": [
        "res_trained_greedy = evaluate_model(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_sampling\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline sampling:\", res_trained_greedy[\"preds\"][i])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "uO0cn0IUP6UH",
      "metadata": {
        "id": "uO0cn0IUP6UH"
      },
      "source": [
        "## Check Coherence"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5DGKZyvoRTeJ",
      "metadata": {
        "id": "5DGKZyvoRTeJ"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "import torch\n",
        "from typing import Dict, Any, List\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "\n",
        "\n",
        "def _class_agrees(preds_for_class):\n",
        "    # If every prediction is NA, count as agreement (consistent failure)\n",
        "    ps = [(p or \"\").strip() for p in preds_for_class]\n",
        "    if all(p == \"NA\" for p in ps):\n",
        "        return True\n",
        "\n",
        "    vals = [p for p in ps if p != \"NA\"]\n",
        "    if len(vals) < 2:\n",
        "        return False\n",
        "    return len(set(vals)) == 1\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_gsm8k_paraphrase_agreement(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"agree\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,   # number of class samples to print total\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Evaluate agreement across paraphrases.\n",
        "\n",
        "    Args:\n",
        "      dataset: list-like of paraphrase classes, each is List[str]\n",
        "               e.g., classes_usable where each element is [orig, para1, para2, ...]\n",
        "      limit: number of classes to evaluate\n",
        "      batch_size: generation batch size (over paraphrase questions)\n",
        "      max_new_tokens: passed to answer_batch_chat\n",
        "      use_greedy: greedy vs sampling\n",
        "\n",
        "    Prints:\n",
        "      [{name}] N={num_classes}  AGREE: {mean:.2f} and AGREE confidence [lo, hi]\n",
        "\n",
        "    Returns:\n",
        "      {\n",
        "        \"preds\": per_class_preds,   # List[List[str]] answers for each class item\n",
        "        \"sample\": {\"indices\": sampled_indices, \"questions\": qs, \"golds\": None},\n",
        "        \"N\": num_classes,\n",
        "        \"AGREE\": agree_mean,\n",
        "        \"AGREE_CI\": (lo, hi),\n",
        "        \"AGREE_CI_Level\": 0.95,\n",
        "      }\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    try:\n",
        "        n_total = len(dataset)\n",
        "        n_classes = min(int(limit), n_total) if limit is not None else n_total\n",
        "        sampled_indices = list(range(n_classes))\n",
        "\n",
        "        per_class_preds: List[List[str]] = []\n",
        "        qs_sample: List[List[str]] = []  # store the paraphrases per class for debugging\n",
        "\n",
        "        num_batches = (n_classes + batch_size - 1) // batch_size\n",
        "        pbar = tqdm(total=n_classes, desc=f\"Paraphrase agreement {name}\", dynamic_ncols=True)\n",
        "\n",
        "        agree_yes = 0\n",
        "        printed = 0\n",
        "\n",
        "        for idx in sampled_indices:\n",
        "            cls = dataset[idx]  # List[str] paraphrases (includes original)\n",
        "            qs_sample.append(cls)\n",
        "\n",
        "            # Generate in chunks if class is large (so we respect batch_size)\n",
        "            cls_preds: List[str] = []\n",
        "            for start in range(0, len(cls), batch_size):\n",
        "                sub_qs = cls[start:start + batch_size]\n",
        "                sub_preds = answer_batch_chat(\n",
        "                    sub_qs,\n",
        "                    model=model,\n",
        "                    tokenizer=tokenizer,\n",
        "                    max_new_tokens=max_new_tokens,\n",
        "                    use_greedy=use_greedy,\n",
        "                )\n",
        "                cls_preds.extend(sub_preds)\n",
        "\n",
        "            per_class_preds.append(cls_preds)\n",
        "\n",
        "            agrees = _class_agrees(cls_preds)\n",
        "            if agrees:\n",
        "                agree_yes += 1\n",
        "\n",
        "            done = len(per_class_preds)\n",
        "            agree_mean_running = agree_yes / done if done else 0.0\n",
        "\n",
        "            pbar.set_postfix({\n",
        "                \"classes\": done,\n",
        "                \"AGREE%\": f\"{agree_mean_running*100:.2f}\",\n",
        "                \"greedy\": use_greedy,\n",
        "            })\n",
        "            pbar.update(1)\n",
        "\n",
        "            # Optional: print a few class samples total\n",
        "            if show_samples > 0 and printed < show_samples:\n",
        "                pbar.write(f\"[{name}] class_idx={idx} size={len(cls)} agrees={agrees}\")\n",
        "                # show first 2 questions and all preds\n",
        "                for j in range(min(2, len(cls))):\n",
        "                    q_preview = cls[j].replace(\"\\n\", \" \")\n",
        "                    if len(q_preview) > 160:\n",
        "                        q_preview = q_preview[:160] + \"...\"\n",
        "                    pbar.write(f\"  q{j}: {q_preview}\")\n",
        "                pbar.write(f\"  preds: {cls_preds}\")\n",
        "                printed += 1\n",
        "\n",
        "        pbar.close()\n",
        "\n",
        "        agree_mean = agree_yes / n_classes if n_classes else 0.0\n",
        "        lo, hi = _wilson_ci(agree_yes, n_classes, level=0.95)\n",
        "\n",
        "        print(\n",
        "            f\"[{name}] N={n_classes}  \"\n",
        "            f\"AGREE: {agree_mean*100:.2f} and AGREE confidence \"\n",
        "            f\"[{lo*100:.2f}, {hi*100:.2f}]\"\n",
        "        )\n",
        "\n",
        "        return {\n",
        "            \"preds\": per_class_preds,\n",
        "            \"sample\": {\n",
        "                \"indices\": sampled_indices,\n",
        "                \"questions\": qs_sample,\n",
        "                \"golds\": None,\n",
        "            },\n",
        "            \"N\": n_classes,\n",
        "            \"AGREE\": agree_mean,\n",
        "            \"AGREE_CI\": (lo, hi),\n",
        "            \"AGREE_CI_Level\": 0.95,\n",
        "        }\n",
        "\n",
        "    finally:\n",
        "        model.train(was_training)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sa-QOGDcEJhd",
      "metadata": {
        "id": "sa-QOGDcEJhd"
      },
      "outputs": [],
      "source": [
        "from typing import Dict, Any, List\n",
        "from tqdm.auto import tqdm\n",
        "import torch\n",
        "\n",
        "def evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,          # number of CLASSES per batch (not prompts)\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"agree\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Batched agreement eval when each class has ~2 items (orig + 1 paraphrase).\n",
        "\n",
        "    batch_size here means \"number of classes per batch\".\n",
        "    Each batch does 1 generation call over ~2*batch_size prompts.\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    try:\n",
        "        n_total = len(dataset)\n",
        "        n_classes = min(int(limit), n_total) if limit is not None else n_total\n",
        "        sampled_indices = list(range(n_classes))\n",
        "\n",
        "        per_class_preds: List[List[str]] = []\n",
        "        qs_sample: List[List[str]] = []\n",
        "\n",
        "        pbar = tqdm(total=n_classes, desc=f\"Paraphrase agreement {name}\", dynamic_ncols=True)\n",
        "        agree_yes = 0\n",
        "        printed = 0\n",
        "\n",
        "        for start in range(0, n_classes, batch_size):\n",
        "            end = min(start + batch_size, n_classes)\n",
        "            cls_batch = [dataset[i] for i in range(start, end)]  # list of classes\n",
        "\n",
        "            # Record questions for the sample output\n",
        "            qs_sample.extend(cls_batch)\n",
        "\n",
        "            # Flatten prompts: [c0_q0, c0_q1, c1_q0, c1_q1, ...]\n",
        "            flat_prompts: List[str] = []\n",
        "            lengths: List[int] = []\n",
        "            for cls in cls_batch:\n",
        "                # cls is [orig, para] (or sometimes only [orig])\n",
        "                lengths.append(len(cls))\n",
        "                flat_prompts.extend(cls)\n",
        "\n",
        "            # One big batched generation call\n",
        "            flat_preds = answer_batch_chat(\n",
        "                flat_prompts,\n",
        "                model=model,\n",
        "                tokenizer=tokenizer,\n",
        "                max_new_tokens=max_new_tokens,\n",
        "                use_greedy=use_greedy,\n",
        "            )\n",
        "\n",
        "            # Unflatten back into per-class preds\n",
        "            offset = 0\n",
        "            for bi, L in enumerate(lengths):\n",
        "                cls_preds = flat_preds[offset:offset + L]\n",
        "                offset += L\n",
        "\n",
        "                per_class_preds.append(cls_preds)\n",
        "\n",
        "                agrees = _class_agrees(cls_preds)\n",
        "                if agrees:\n",
        "                    agree_yes += 1\n",
        "\n",
        "                # Optional sample prints (total across run)\n",
        "                if show_samples > 0 and printed < show_samples:\n",
        "                    class_idx = start + bi\n",
        "                    pbar.write(f\"[{name}] class_idx={class_idx} size={L} agrees={agrees}\")\n",
        "                    for j in range(min(2, L)):\n",
        "                        q_preview = cls_batch[bi][j].replace(\"\\n\", \" \")\n",
        "                        if len(q_preview) > 160:\n",
        "                            q_preview = q_preview[:160] + \"...\"\n",
        "                        pbar.write(f\"  q{j}: {q_preview}\")\n",
        "                    pbar.write(f\"  preds: {cls_preds}\")\n",
        "                    printed += 1\n",
        "\n",
        "            done = len(per_class_preds)\n",
        "            agree_mean_running = agree_yes / done if done else 0.0\n",
        "            pbar.set_postfix({\n",
        "                \"classes\": done,\n",
        "                \"AGREE%\": f\"{agree_mean_running*100:.2f}\",\n",
        "                \"greedy\": use_greedy,\n",
        "            })\n",
        "            pbar.update(end - start)\n",
        "\n",
        "        pbar.close()\n",
        "\n",
        "        agree_mean = agree_yes / n_classes if n_classes else 0.0\n",
        "        lo, hi = _wilson_ci(agree_yes, n_classes, level=0.95)\n",
        "\n",
        "        print(\n",
        "            f\"[{name}] N={n_classes}  \"\n",
        "            f\"AGREE: {agree_mean*100:.2f} and AGREE confidence \"\n",
        "            f\"[{lo*100:.2f}, {hi*100:.2f}]\"\n",
        "        )\n",
        "\n",
        "        return {\n",
        "            \"preds\": per_class_preds,\n",
        "            \"sample\": {\n",
        "                \"indices\": sampled_indices,\n",
        "                \"questions\": qs_sample,\n",
        "                \"golds\": None,\n",
        "            },\n",
        "            \"N\": n_classes,\n",
        "            \"AGREE\": agree_mean,\n",
        "            \"AGREE_CI\": (lo, hi),\n",
        "            \"AGREE_CI_Level\": 0.95,\n",
        "        }\n",
        "\n",
        "    finally:\n",
        "        model.train(was_training)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "pnyq4rJrRjAe",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 243,
          "referenced_widgets": [
            "b6c0617396d44837ac2bc00893b812d1",
            "39b9ffa02eb84f85a76156cbc1c8a88b",
            "7f8b51a3fe4343f39fa6ddc435410a3a",
            "3a3884b6b2c045c2ae33bba4c0c462b4",
            "51096091a9594b3cbba9e1f06c439f65",
            "346936dae8d0460b88f574168fea2209",
            "a7531cf6b513470da02333b1420e5c9c",
            "da8588237dec4ae9877dc1ebb1236e48",
            "59f2130e9c564172b17f10efa43f72d6",
            "3345258ab7fd4dc085c81ca3eacb35dd",
            "441473c91daf4aeea3a91dc460512b2a"
          ]
        },
        "id": "pnyq4rJrRjAe",
        "outputId": "bb932ce5-2731-471c-d6c8-3f2fd568ee2e"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b6c0617396d44837ac2bc00893b812d1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Paraphrase agreement qwen3b:   0%|          | 0/7473 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[qwen3b] class_idx=0 size=2 agrees=True\n",
            "  q0: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n",
            "  q1: In April, Natalia sold clips to 48 of her friends, and in May she sold half as many. How many clips did Natalia sell in total between April and May?\n",
            "  preds: ['72', '72']\n",
            "[qwen3b] class_idx=1 size=2 agrees=True\n",
            "  q0: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\n",
            "  q1: Weng earns $12 per hour from babysitting. She worked for 50 minutes yesterday. How much did she earn?\n",
            "  preds: ['10', '10']\n",
            "[qwen3b] N=7473  AGREE: 67.14 and AGREE confidence [66.06, 68.19]\n"
          ]
        }
      ],
      "source": [
        "coherence_eval_limit = len(paraphrase_classes)\n",
        "\n",
        "# classes_usable is output of build_classes_full_batched(...)\n",
        "agree_res = evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=paraphrase_classes,\n",
        "    limit=coherence_eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=256,\n",
        "    name=\"qwen3b\",\n",
        "    use_greedy=True,\n",
        "    show_samples=2,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "XUChbdyqVsLQ",
      "metadata": {
        "id": "XUChbdyqVsLQ"
      },
      "source": [
        "## GRPO Implementation"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3HnZFHkuHZVx",
      "metadata": {
        "id": "3HnZFHkuHZVx"
      },
      "source": [
        "### GRPO Main"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dx7osq9fVvm1",
      "metadata": {
        "id": "dx7osq9fVvm1"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "from torch.utils.data import IterableDataset\n",
        "\n",
        "class ParaphrasePairIterableFormatted(IterableDataset):\n",
        "    \"\"\"\n",
        "    Infinite stream. Each \"pair\" yields exactly two consecutive items:\n",
        "      (q0, is_original=1) then (q1, is_original=0)\n",
        "    \"\"\"\n",
        "    def __init__(self, classes_usable, tokenizer, format_prompt_instruct, system_prompt, seed=0):\n",
        "        self.classes = classes_usable          # list[list[str]]; cls[0] is original\n",
        "        self.tok = tokenizer\n",
        "        self.format_prompt_instruct = format_prompt_instruct\n",
        "        self.system_prompt = system_prompt\n",
        "        self.rng = random.Random(seed)\n",
        "\n",
        "    def __iter__(self):\n",
        "        while True:\n",
        "            cls = self.rng.choice(self.classes)\n",
        "            q0 = cls[0]\n",
        "            q1 = self.rng.choice(cls[1:]) if len(cls) > 1 else q0\n",
        "            pair_id = self.rng.randrange(2**63)\n",
        "\n",
        "            p0 = self.format_prompt_instruct(q0, tokenizer=self.tok, system_prompt=self.system_prompt)\n",
        "            p1 = self.format_prompt_instruct(q1, tokenizer=self.tok, system_prompt=self.system_prompt)\n",
        "\n",
        "            yield {\"prompt\": p0, \"is_original\": 1, \"pair_id\": pair_id}\n",
        "            yield {\"prompt\": p1, \"is_original\": 0, \"pair_id\": pair_id}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1snZI84RW7ml",
      "metadata": {
        "id": "1snZI84RW7ml"
      },
      "outputs": [],
      "source": [
        "from collections import Counter\n",
        "\n",
        "'''\n",
        "def agreement_reward_fn(prompts, completions, **kwargs):\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    # majority vote among all answers (including \"NA\", per your current behavior)\n",
        "    if not answers:\n",
        "        return []\n",
        "\n",
        "    counts = Counter(answers)\n",
        "    max_count = max(counts.values())\n",
        "    majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "    # reward 1 to (tied) majority answer(s), 0 otherwise\n",
        "    return [1.0 if a in majority_set else 0.0 for a in answers]\n",
        "\n",
        "\n",
        "\n",
        "from collections import Counter\n",
        "\n",
        "def agreement_reward_(prompts, completions, **kwargs):\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    # majority vote among all answers (including \"NA\", per your current behavior)\n",
        "    if not answers:\n",
        "        return []\n",
        "\n",
        "    counts = Counter(answers)\n",
        "    max_count = max(counts.values())\n",
        "    majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "    # reward 1 to (tied) majority answer(s), 0 otherwise\n",
        "    return [1.0 if a in majority_set else 0.0 for a in answers]\n",
        "\n",
        "'''\n",
        "\n",
        "from collections import Counter\n",
        "\n",
        "def pair_agreement_reward_fn(prompts, completions, pair_id=None, is_original=None, **kwargs):\n",
        "    # pair_id and is_original should be aligned with (prompts, completions)\n",
        "    # TRL typically repeats prompt-level columns across generations.\n",
        "\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    rewards = [0.0] * len(answers)\n",
        "    if pair_id is None or is_original is None:\n",
        "        raise ValueError(\"Need pair_id and is_original passed into reward fn (from dataset columns).\")\n",
        "\n",
        "    # group indices by pair_id\n",
        "    by_pid = {}\n",
        "    for i, pid in enumerate(pair_id):\n",
        "        by_pid.setdefault(pid, []).append(i)\n",
        "\n",
        "    for pid, idxs in by_pid.items():\n",
        "        idx0 = [i for i in idxs if int(is_original[i]) == 1]  # original prompt (q0)\n",
        "        idx1 = [i for i in idxs if int(is_original[i]) == 0]  # paraphrase (q1)\n",
        "\n",
        "        a0 = [answers[i] for i in idx0 if answers[i] != \"NA\"]\n",
        "        a1 = [answers[i] for i in idx1 if answers[i] != \"NA\"]\n",
        "\n",
        "        # Only reward if BOTH sides have at least one non-NA answer\n",
        "        if len(a0) == 0 or len(a1) == 0:\n",
        "            continue\n",
        "\n",
        "        c0, c1 = Counter(a0), Counter(a1)\n",
        "        d0, d1 = len(a0), len(a1)\n",
        "\n",
        "        # Reward each completion by \"how often the other paraphrase produced the same answer\"\n",
        "        for i in idx0:\n",
        "            a = answers[i]\n",
        "            if a != \"NA\":\n",
        "                rewards[i] = c1.get(a, 0) / d1\n",
        "        for i in idx1:\n",
        "            a = answers[i]\n",
        "            if a != \"NA\":\n",
        "                rewards[i] = c0.get(a, 0) / d0\n",
        "\n",
        "    return rewards"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sj52WrVt5JSe",
      "metadata": {
        "id": "sj52WrVt5JSe"
      },
      "outputs": [],
      "source": [
        "from collections import Counter\n",
        "\n",
        "def pair_majority_reward_fn(prompts, completions, pair_id=None, is_original=None, **kwargs):\n",
        "    \"\"\"\n",
        "    Reward = 1.0 iff the completion's extracted answer is in the (tied) majority vote\n",
        "    within its side (original vs paraphrase) for the same pair_id. Otherwise 0.0.\n",
        "\n",
        "    Separation between pairs: majority is computed independently per pair_id,\n",
        "    and also independently per side (is_original==1 vs ==0) within each pair.\n",
        "    \"\"\"\n",
        "    if pair_id is None or is_original is None:\n",
        "        raise ValueError(\"Need pair_id and is_original passed into reward fn (from dataset columns).\")\n",
        "\n",
        "    # Extract answers (same as before)\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    rewards = [0.0] * len(answers)\n",
        "\n",
        "    # Group indices by pair_id\n",
        "    by_pid = {}\n",
        "    for i, pid in enumerate(pair_id):\n",
        "        by_pid.setdefault(pid, []).append(i)\n",
        "\n",
        "    for pid, idxs in by_pid.items():\n",
        "        # Split indices by side\n",
        "        idx0 = [i for i in idxs if int(is_original[i]) == 1]  # original prompt (q0)\n",
        "        idx1 = [i for i in idxs if int(is_original[i]) == 0]  # paraphrase (q1)\n",
        "\n",
        "        # Majority vote within each side (including \"NA\", matching your current behavior)\n",
        "        for side_idxs in (idx0, idx1):\n",
        "            if not side_idxs:\n",
        "                continue\n",
        "\n",
        "            side_answers = [answers[i] for i in side_idxs]\n",
        "            counts = Counter(side_answers)\n",
        "            max_count = max(counts.values())\n",
        "            majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "            for i in side_idxs:\n",
        "                rewards[i] = 1.0 if answers[i] in majority_set else 0.0\n",
        "\n",
        "            # ensure NA stays 0\n",
        "            for i in side_idxs:\n",
        "                if answers[i] == \"NA\":\n",
        "                    rewards[i] = 0.0\n",
        "\n",
        "    return rewards\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "xkJgG7HfW9Vn",
      "metadata": {
        "id": "xkJgG7HfW9Vn"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from trl import GRPOTrainer\n",
        "from collections.abc import Mapping\n",
        "import torch\n",
        "from trl import GRPOTrainer\n",
        "\n",
        "class OrigOnlyKLGRPOTrainer(GRPOTrainer):\n",
        "    def _prepare_inputs(self, inputs):\n",
        "        # inputs is list[dict]\n",
        "        is_orig_list = [ex[\"is_original\"] for ex in inputs]\n",
        "        pair_id_list = [ex[\"pair_id\"] for ex in inputs]\n",
        "\n",
        "        prepared = super()._prepare_inputs(inputs)\n",
        "\n",
        "        device = prepared[\"prompt_ids\"].device\n",
        "        B_prompts = prepared[\"prompt_ids\"].shape[0]   # prompt-level count (usually 2*PAIRS_PER_STEP)\n",
        "\n",
        "        def compress_expanded_to_prompt_level(t, B_prompts):\n",
        "            # t has length B_prompts*K (expanded); compress to length B_prompts\n",
        "            K = t.numel() // B_prompts\n",
        "            first_block = t[:B_prompts]\n",
        "\n",
        "            # If the first block already contains mixed values, it's interleaved by prompt:\n",
        "            # [q0,q1,q0,q1,...] per generation => first_block is exactly prompt-level.\n",
        "            if torch.unique(first_block).numel() > 1:\n",
        "                return first_block\n",
        "\n",
        "            # Otherwise grouped layout: [q0...q0, q1...q1]\n",
        "            idx = torch.arange(0, t.numel(), K, device=t.device)  # 0, K, 2K, ...\n",
        "            return t[idx]\n",
        "\n",
        "        def align_to_B(vec):\n",
        "            \"\"\"\n",
        "            Align vec to length B_prompts. vec may be:\n",
        "              - prompt-level (len == B_prompts)\n",
        "              - expanded (len == B_prompts*K)\n",
        "            \"\"\"\n",
        "            t = torch.as_tensor(vec, device=device).view(-1).long()\n",
        "\n",
        "            if t.numel() == B_prompts:\n",
        "                return t\n",
        "\n",
        "            if t.numel() > B_prompts:\n",
        "                # ✅ THIS is where \"already-expanded → compress\" goes\n",
        "                if t.numel() % B_prompts != 0:\n",
        "                    raise ValueError(f\"Cannot compress metadata: got {t.numel()}, need multiple of {B_prompts}\")\n",
        "                return compress_expanded_to_prompt_level(t, B_prompts)\n",
        "\n",
        "            # (rare) scalar -> broadcast\n",
        "            if t.numel() == 1:\n",
        "                return t.repeat(B_prompts)\n",
        "\n",
        "            raise ValueError(f\"Cannot align metadata: got {t.numel()}, need {B_prompts}\")\n",
        "\n",
        "        prepared[\"is_original_prompt\"] = align_to_B(is_orig_list)\n",
        "        prepared[\"pair_id_prompt\"] = align_to_B(pair_id_list)\n",
        "\n",
        "        # keep prompt-level values here; expand later in compute_loss to match rollouts\n",
        "        return prepared\n",
        "\n",
        "    def _get_logps(self, model, input_ids, attention_mask, logits_to_keep):\n",
        "        # TRL versions differ: some have _get_per_token_logps, newer use _get_per_token_logps_and_entropies\n",
        "        if hasattr(self, \"_get_per_token_logps\"):\n",
        "            return self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)\n",
        "\n",
        "        out = self._get_per_token_logps_and_entropies(\n",
        "            model,\n",
        "            input_ids,\n",
        "            attention_mask,\n",
        "            logits_to_keep,\n",
        "            compute_entropy=False,\n",
        "            pixel_values=None,\n",
        "            image_grid_thw=None,\n",
        "            num_images=None,\n",
        "        )\n",
        "        # out can be (logps, entropies) or {\"logps\":..., \"entropies\":...}\n",
        "        if isinstance(out, tuple):\n",
        "            return out[0]\n",
        "        return out[\"logps\"]\n",
        "\n",
        "    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
        "        import torch\n",
        "\n",
        "        if return_outputs:\n",
        "            raise ValueError(\"GRPOTrainer does not support return_outputs=True\")\n",
        "\n",
        "        prompt_ids, prompt_mask = inputs[\"prompt_ids\"], inputs[\"prompt_mask\"]\n",
        "        completion_ids, completion_mask = inputs[\"completion_ids\"], inputs[\"completion_mask\"]\n",
        "\n",
        "        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n",
        "        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)\n",
        "\n",
        "        logits_to_keep = completion_ids.size(1)\n",
        "        per_token_logps = self._get_logps(model, input_ids, attention_mask, logits_to_keep)  # [N, T]\n",
        "\n",
        "        # ---- dimensions ----\n",
        "        N = per_token_logps.shape[0]  # rollout-level rows\n",
        "        denom = completion_mask.sum(dim=1).clamp_min(1)  # [N]\n",
        "\n",
        "        # ---- advantages: align to N ----\n",
        "        advantages = inputs[\"advantages\"].view(-1)  # could be [N] or [B]\n",
        "        if advantages.numel() != N:\n",
        "            if N % advantages.numel() == 0:\n",
        "                advantages = advantages.repeat_interleave(N // advantages.numel())\n",
        "            else:\n",
        "                raise ValueError(f\"advantages has {advantages.numel()} elems but need {N}\")\n",
        "        advantages = advantages.to(per_token_logps.device)\n",
        "\n",
        "        # ---- PPO/GRPO ratio term ----\n",
        "        if \"old_per_token_logps\" in inputs:\n",
        "            ratio = torch.exp(per_token_logps - inputs[\"old_per_token_logps\"])\n",
        "        else:\n",
        "            ratio = torch.exp(per_token_logps - per_token_logps.detach())\n",
        "\n",
        "        per_token_pg = ratio * advantages.unsqueeze(1)  # [N, T]\n",
        "\n",
        "        beta = float(getattr(self, \"beta\", getattr(self.args, \"beta\", 0.0)))\n",
        "\n",
        "        # ---- KL term (masked to originals only) ----\n",
        "        is_orig = None\n",
        "        per_token_kl = 0.0\n",
        "\n",
        "        if beta != 0.0:\n",
        "            ref_per_token_logps = inputs[\"ref_per_token_logps\"]\n",
        "            per_token_kl_raw = (\n",
        "                torch.exp(ref_per_token_logps - per_token_logps)\n",
        "                - (ref_per_token_logps - per_token_logps)\n",
        "                - 1\n",
        "            )  # [N, T]\n",
        "\n",
        "            # prompt-level mask produced in _prepare_inputs\n",
        "            is_orig_prompt = inputs.get(\"is_original_prompt\", None)\n",
        "            if is_orig_prompt is None:\n",
        "                raise ValueError(\"Missing inputs['is_original_prompt'] from _prepare_inputs().\")\n",
        "\n",
        "            is_orig_prompt = torch.as_tensor(is_orig_prompt, device=per_token_logps.device).float().view(-1)  # [B]\n",
        "            B = is_orig_prompt.numel()\n",
        "\n",
        "            if N % B != 0:\n",
        "                raise ValueError(f\"Cannot expand is_original_prompt: N={N}, B={B}\")\n",
        "            K = N // B\n",
        "\n",
        "            # Detect whether rollout rows are interleaved-by-prompt (first B rows are all different prompts)\n",
        "            def _is_interleaved_by_prompt(prompt_ids, B):\n",
        "                reps = []\n",
        "                for i in range(min(B, prompt_ids.size(0))):\n",
        "                    row = prompt_ids[i]\n",
        "                    if not any(torch.equal(row, r) for r in reps):\n",
        "                        reps.append(row)\n",
        "                return len(reps) == B\n",
        "\n",
        "            interleaved = _is_interleaved_by_prompt(prompt_ids, B)\n",
        "\n",
        "            if interleaved:\n",
        "                # [p0,p1,...,pB-1, p0,p1,...] repeated K times\n",
        "                is_orig = is_orig_prompt.repeat(K)  # [N]\n",
        "            else:\n",
        "                # [p0...p0, p1...p1, ...] each repeated K times\n",
        "                is_orig = is_orig_prompt.repeat_interleave(K)  # [N]\n",
        "\n",
        "            # Mask KL to originals only\n",
        "            per_token_kl = per_token_kl_raw * is_orig.unsqueeze(1)  # [N, T]\n",
        "\n",
        "            # ---- logging: masked KL only (as requested) ----\n",
        "            with torch.no_grad():\n",
        "                # prompt-level counts\n",
        "                is_p = is_orig_prompt.float().view(-1)\n",
        "                p_n1 = int((is_p == 1).sum().item())\n",
        "                p_n0 = int((is_p == 0).sum().item())\n",
        "\n",
        "                # rollout-level counts (after expansion)\n",
        "                r_n1 = int((is_orig == 1).sum().item())\n",
        "                r_n0 = int((is_orig == 0).sum().item())\n",
        "\n",
        "            self.log({\n",
        "                \"dbg/prompt_n1\": p_n1,\n",
        "                \"dbg/prompt_n0\": p_n0,\n",
        "                \"dbg/rollout_n1\": r_n1,\n",
        "                \"dbg/rollout_n0\": r_n0,\n",
        "            })\n",
        "            with torch.no_grad():\n",
        "                # mean KL over completion tokens for each sample\n",
        "                kl_per_sample = (per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)\n",
        "\n",
        "                # raw mean KL (all samples)\n",
        "                kl_all = kl_per_sample.mean()\n",
        "\n",
        "                # orig-only mean KL (only samples with is_orig==1)\n",
        "                orig_mask = is_orig.bool()\n",
        "                kl_orig = kl_per_sample[orig_mask].mean() if orig_mask.any() else torch.tensor(0.0, device=kl_all.device)\n",
        "\n",
        "            self.log({\n",
        "                \"kl_all\": kl_all.detach().float().cpu().item(),\n",
        "                \"kl_orig_only\": kl_orig.detach().float().cpu().item(),\n",
        "            })\n",
        "\n",
        "        # ---- final loss ----\n",
        "        per_token_loss = -(per_token_pg - beta * per_token_kl)  # [N, T]\n",
        "        loss = ((per_token_loss * completion_mask).sum(dim=1) / denom).mean()\n",
        "\n",
        "        return loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "i5QbPheK3EXX",
      "metadata": {
        "id": "i5QbPheK3EXX"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainerCallback\n",
        "import torch\n",
        "\n",
        "class GSM8KEvalCallback(TrainerCallback):\n",
        "    def __init__(self, trainer, *, eval_fn, log_every: int, eval_kwargs: dict,\n",
        "                 reward_key=\"reward\", kl_key=\"kl_all_raw\"):\n",
        "        self.trainer = trainer\n",
        "        self.eval_fn = eval_fn\n",
        "        self.log_every = log_every\n",
        "        self.eval_kwargs = eval_kwargs\n",
        "        self.reward_key = reward_key\n",
        "        self.kl_key = kl_key\n",
        "        self._last_reward = None\n",
        "        self._last_kl = None\n",
        "\n",
        "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
        "        logs = logs or {}\n",
        "        # cache latest values when they appear\n",
        "        if self.reward_key in logs:\n",
        "            self._last_reward = logs[self.reward_key]\n",
        "        if self.kl_key in logs:\n",
        "            self._last_kl = logs[self.kl_key]\n",
        "\n",
        "        step = int(state.global_step)\n",
        "        if step == 0 or (step % self.log_every) != 0:\n",
        "            return control\n",
        "\n",
        "        # main process only\n",
        "        if hasattr(self.trainer, \"is_world_process_zero\") and not self.trainer.is_world_process_zero():\n",
        "            return control\n",
        "\n",
        "        def fmt(x):\n",
        "            return \"NA\" if x is None else f\"{x:.6g}\"\n",
        "\n",
        "        print(f\"[step {step}] {self.reward_key}={fmt(self._last_reward)}  {self.kl_key}={fmt(self._last_kl)}\")\n",
        "\n",
        "        # run eval\n",
        "        model = self.trainer.model\n",
        "        was_training = model.training\n",
        "        model.eval()\n",
        "\n",
        "\n",
        "        with torch.no_grad():\n",
        "            metrics = self.eval_fn(model=model, **self.eval_kwargs)\n",
        "        if was_training:\n",
        "            model.train()\n",
        "\n",
        "        self.trainer.log({f\"gsm8k/{k}\": v for k, v in metrics.items()})\n",
        "\n",
        "        if torch.cuda.is_available():\n",
        "            torch.cuda.empty_cache()\n",
        "        return control"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "51XzkxBxBJNa",
      "metadata": {
        "id": "51XzkxBxBJNa"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainerCallback\n",
        "import torch\n",
        "\n",
        "class PrintAndEvalOnSave(TrainerCallback):\n",
        "    def __init__(self, trainer, *, eval_fn, eval_kwargs: dict,\n",
        "                 reward_key=\"reward\", kl_key=\"kl_orig_only\",\n",
        "                 n1_key=\"dbg/rollout_n1\", n0_key=\"dbg/rollout_n0\"):\n",
        "        self.trainer = trainer\n",
        "        self.eval_fn = eval_fn\n",
        "        self.eval_kwargs = eval_kwargs\n",
        "        self.reward_key = reward_key\n",
        "        self.kl_key = kl_key\n",
        "        self.n1_key = n1_key\n",
        "        self.n0_key = n0_key\n",
        "\n",
        "        self._last_reward = None\n",
        "        self._last_kl = None\n",
        "        self._last_n1 = None\n",
        "        self._last_n0 = None\n",
        "\n",
        "        self._last_eval_step = -1\n",
        "        self._logging_eval = False\n",
        "\n",
        "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
        "        # Cache train stats whenever they are logged\n",
        "        if self._logging_eval:\n",
        "            return control\n",
        "        logs = logs or {}\n",
        "\n",
        "        if self.reward_key in logs:\n",
        "            self._last_reward = logs[self.reward_key]\n",
        "        if self.kl_key in logs:\n",
        "            self._last_kl = logs[self.kl_key]\n",
        "        if self.n1_key in logs:\n",
        "            self._last_n1 = logs[self.n1_key]\n",
        "        if self.n0_key in logs:\n",
        "            self._last_n0 = logs[self.n0_key]\n",
        "\n",
        "        return control\n",
        "\n",
        "    def on_save(self, args, state, control, **kwargs):\n",
        "        step = int(state.global_step)\n",
        "        if step == 0 or step == self._last_eval_step:\n",
        "            return control\n",
        "        self._last_eval_step = step\n",
        "\n",
        "        if hasattr(self.trainer, \"is_world_process_zero\") and not self.trainer.is_world_process_zero():\n",
        "            return control\n",
        "\n",
        "        def fmt(x):\n",
        "            return \"NA\" if x is None else f\"{x:.6g}\" if isinstance(x, (float, int)) else str(x)\n",
        "\n",
        "        print(\n",
        "            f\"[save step {step}] \"\n",
        "            f\"{self.reward_key}={fmt(self._last_reward)}  \"\n",
        "            f\"{self.kl_key}={fmt(self._last_kl)}  \"\n",
        "            f\"n1={fmt(self._last_n1)}  n0={fmt(self._last_n0)}\"\n",
        "        )\n",
        "\n",
        "        # Run eval once per checkpoint\n",
        "        model = self.trainer.model\n",
        "        was_training = model.training\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            metrics = self.eval_fn(model=model, **self.eval_kwargs)\n",
        "        if was_training:\n",
        "            model.train()\n",
        "\n",
        "        # Log eval metrics, but avoid triggering our own caching logic\n",
        "        self._logging_eval = True\n",
        "        try:\n",
        "            self.trainer.log({f\"gsm8k/{k}\": v for k, v in metrics.items()})\n",
        "        finally:\n",
        "            self._logging_eval = False\n",
        "\n",
        "        if torch.cuda.is_available():\n",
        "            torch.cuda.empty_cache()\n",
        "\n",
        "        return control"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "P8qP661M7LAo",
      "metadata": {
        "id": "P8qP661M7LAo"
      },
      "outputs": [],
      "source": [
        "### Llama and qwen params\n",
        "\n",
        "from trl import GRPOConfig\n",
        "\n",
        "K = 32               # rollouts per prompt\n",
        "PAIRS_PER_STEP = 4  # per device\n",
        "KL_WEIGHT = 0.05\n",
        "\n",
        "LOG_EVERY = 8\n",
        "SAVE_EVERY = 20 * LOG_EVERY   #\n",
        "EVAL_LIMIT = len(test_ds)\n",
        "\n",
        "TOTAL_STEPS = 200 * LOG_EVERY  # pick whatever\n",
        "\n",
        "MAX_NEW_TOKENS = 512\n",
        "\n",
        "train_dataset = ParaphrasePairIterableFormatted(\n",
        "    classes_usable=paraphrase_classes,\n",
        "    tokenizer=tokenizer,\n",
        "    format_prompt_instruct=format_prompt_instruct,\n",
        "    system_prompt=SYSTEM_PROMPT,\n",
        "    seed=0,\n",
        ")\n",
        "\n",
        "grpo_args = GRPOConfig(\n",
        "    output_dir=WORK_PATH,\n",
        "    per_device_train_batch_size=2 * PAIRS_PER_STEP,   # 2 prompts = (q0,q1)\n",
        "    num_generations=K,\n",
        "    generation_batch_size = (2 * PAIRS_PER_STEP) * K,     # MUST be divisible by K\n",
        "\n",
        "    max_steps=TOTAL_STEPS,           # required for IterableDataset (no __len__)\n",
        "    accelerator_config={\n",
        "        \"dispatch_batches\": False,   # <-- fixes your str concatenation crash\n",
        "        \"split_batches\": False,\n",
        "    },\n",
        "\n",
        "    # ✅ checkpointing\n",
        "    save_strategy=\"steps\",\n",
        "    save_steps=SAVE_EVERY,\n",
        "    save_total_limit=30,        # keep last 30 checkpoints (optional)\n",
        "    save_safetensors=True,     # optional\n",
        "\n",
        "    beta=KL_WEIGHT,\n",
        "    logging_steps=SAVE_EVERY, #LOG_EVERY,\n",
        "    max_completion_length=MAX_NEW_TOKENS,\n",
        "    learning_rate=1e-5,\n",
        "    dataloader_num_workers=0,\n",
        "    report_to=[]\n",
        ")\n",
        "\n",
        "\n",
        "trainer = OrigOnlyKLGRPOTrainer(\n",
        "    model=trainable_model,\n",
        "    args=grpo_args,\n",
        "    train_dataset=train_dataset,\n",
        "    processing_class=tokenizer,\n",
        "    reward_funcs=pair_majority_reward_fn,\n",
        ")\n",
        "\n",
        "eval_kwargs = dict(\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,          # or test split\n",
        "    limit=EVAL_LIMIT,              # whatever you use\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"trained_grpo\",\n",
        ")\n",
        "\n",
        "\n",
        "trainer.add_callback(\n",
        "    PrintAndEvalOnSave(\n",
        "        trainer,\n",
        "        eval_fn=evaluate_gsm8k,\n",
        "        eval_kwargs=eval_kwargs,\n",
        "        reward_key=\"reward\",\n",
        "        kl_key=\"kl_orig_only\",\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "H0sUgLjkaahX",
      "metadata": {
        "colab": {
          "background_save": true,
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "ca928a7ea4c842c1a9213fa94729fec0",
            "4b6686f1b535452fa6b2aa22a4e6b701",
            "9d6bb83db6ce499eaa2ca07bb07eca6c",
            "7c7af70ab0b048428942d2042b6b619e",
            "9d723354a4344bdf8fd331a1446a2b8f",
            "a4549b715203429f85d3e152e1f6dfde",
            "a6bad0fbc40d41199de894615150b714",
            "67849f803a08422fa357bd27a11bf964",
            "75daa99c284d4012bf0df1f87999fbfb",
            "c1fba98d6fc148398dcdf8379e934d18",
            "2c22e8e41b804702821dac54e8572221",
            "c2bd8dccff074a0ca8a9844319aba0ce",
            "b8b781185a874943adab7f09ae19411c",
            "df6bebf829a54f86a8ffdb1c1358ca82",
            "7b9a9cfb02a74fb087e2eb3e128db7eb",
            "6d1c8f83e798474da7ff43fa69045b6b",
            "0f45cb9c7061417bafe13fe40f54b663",
            "815fce21c79d43aaa335dd620804ec16",
            "e98817ec7f25489698654cbf0c5e050a",
            "dd6c0cd3aec4480fb0f1f3e742575c1b",
            "86b165dbf93541adb4a12f655d849c17",
            "2d055629ae454fde9581e1a14e25c117",
            "bf8a24d686b24da584fd5443076f81af",
            "d885506b579b4e3cbf49f6f94e798929",
            "b40d1de4d655496b9da43aafbe598e78",
            "c70061a9ed754c2da95b745fb6506983",
            "b43925fcc9d94f9ab8d2082e7d9dc4f3",
            "239a3be7a04444cd8ff61cc41ce6063e",
            "4512f7a4b9134f4ab867da0c37642100",
            "e6643458ec20491da4545e733e7b5206",
            "c61a70f21b7842e394e0245c3339bb20",
            "dc8a5f641d2d4c358838598fd89cfc9b",
            "18ec27986e1d497695c66dd1cd4142a7",
            "6a2b3e4accb3479bbf46c199927ad393",
            "a9c3bba89d8b436e85dcf2c81bb35a84",
            "aa45e9b80f3a46c3bc4e549322614387",
            "5c7202bf9cd04c0a97a42bec080d448b",
            "59169c85f9774adbb8325c9d11ad7b00",
            "78a5366783a940afba4940c56349b7fb",
            "22a1c8806adf4710a02d5694485f42c1",
            "2ce513bcfcd7429d9fe6872d177bf38e",
            "1596780773c64a6097256a02988375b7",
            "08634ce86bd240e78a75864098528d88",
            "322f7a07b1de4532b31980509f132743",
            "8089993e890249a885f48b20d127fd0d",
            "fffac9209fa24b4ea9728ec45b900bbe",
            "facb9d35f00a493880724c3caef39e0f",
            "3a65e1adf5a44fd0986e3011e33bf64c",
            "44590ee9685940dabf80cef9de6c36db",
            "53847574e82447c1ad184ce452e58b5b",
            "bd8a0dcb544e4e96b75c06fce618e747",
            "4d9917046a984adf821475a945ce25f9",
            "157ce59e61c241ca97f75fcf0da58a6c",
            "ea78979948144c6e800e16d0ac934bf5",
            "ceb55c622676492d81007c4de64ddb20",
            "9c48059cb2054bbc875de066654e191a",
            "b5893de63b934a4186596a42fcea8030",
            "004da1ed97a44382bde90d6c5c244a2f",
            "6ee861348c174dd4af158c10e200ddaa",
            "871a90e79fc04112bbb72d54d8ae8250",
            "5010496fa3ec42b3b4f3e2d088fb38d3",
            "f2dd806f220443e38f2217ed88454632",
            "17528517e5d54f2cb86c90206b1c9f57",
            "ddabfacb1370443b9e38726aad49bbb3",
            "2a617d3a24c74fa1a092123d7425ba4c",
            "af7f958bdf324b8b874212b929f1b30e",
            "7f64771024e6484d864e2eb0678f4991",
            "2f42c28f32794f19b3134e6e5dd79a25",
            "9f4d36fc4f5c43d69b7d1c2cabbdced3",
            "719a9e80b5be4b59b282aa4f189aea5a",
            "d3e0ef7490d84c4b9c9f9b8aa2d0bacc",
            "d1e977c54c0f48bf9bc787436da92f52",
            "b0ac144780df465ab6db88d2d57e5174",
            "f6bb7b85de724c7499493fa088886a28",
            "3e446973a3dc4785ae37c485753cd964",
            "517c75a17f03426caaa7ba64f106abde",
            "3058e0053b9d4278b0aaff3d42a15ef9",
            "fa2d6c5f05064aaaadb361b62c978702",
            "ac1e15245d2d4f3f8662a9dae4e1da2c",
            "74969bf267b248918164a67ddd94551b",
            "d2e36fee086b42a59bac40117f0ba4bc",
            "4343230de097491ebd98919615ca7c1f",
            "5feae3252dcc45e58c65c5a114980ed6",
            "018572d4a0684e3e8a338a0e9ef942de",
            "f4c3f1cea2ab422bbe6df48e34f87950",
            "5c0761880e5f4438be927b9ab3c272ad",
            "8214948990d845e286f1dad4739a0b5b",
            "4af918c43f6543a5bbfd43d2fa231160",
            "26be7c6dbeca4509b1652e34a12386a8",
            "13fe92f630ff49ef80a57765f3b9c36e"
          ]
        },
        "id": "H0sUgLjkaahX",
        "outputId": "60160c26-f95d-43a2-b9d2-13265690fe53"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1601' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1600/1600 7:06:22, Epoch 1/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>160</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>480</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>0.000200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1120</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1440</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[save step 160] reward=0.980469  kl_orig_only=0.00121649  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ca928a7ea4c842c1a9213fa94729fec0",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 78.70 and ACC confidence [76.40, 80.82]  |  pred NA: 5/1319 (0.38%)  \n",
            "[save step 320] reward=0.984375  kl_orig_only=0.00294133  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c2bd8dccff074a0ca8a9844319aba0ce",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 80.14 and ACC confidence [77.90, 82.20]  |  pred NA: 13/1319 (0.99%)  \n",
            "[save step 480] reward=0.992188  kl_orig_only=0.00464991  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bf8a24d686b24da584fd5443076f81af",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.91 and ACC confidence [77.66, 81.98]  |  pred NA: 8/1319 (0.61%)  \n",
            "[save step 640] reward=0.996094  kl_orig_only=0.00456677  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6a2b3e4accb3479bbf46c199927ad393",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.15 and ACC confidence [76.88, 81.26]  |  pred NA: 8/1319 (0.61%)  \n",
            "[save step 800] reward=0.984375  kl_orig_only=0.00353526  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8089993e890249a885f48b20d127fd0d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.83 and ACC confidence [77.58, 81.91]  |  pred NA: 10/1319 (0.76%)  \n",
            "[save step 960] reward=0.996094  kl_orig_only=0.00338572  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9c48059cb2054bbc875de066654e191a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.61 and ACC confidence [77.35, 81.69]  |  pred NA: 7/1319 (0.53%)  \n",
            "[save step 1120] reward=0.988281  kl_orig_only=0.00300359  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7f64771024e6484d864e2eb0678f4991",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 80.82 and ACC confidence [78.61, 82.85]  |  pred NA: 8/1319 (0.61%)  \n",
            "[save step 1280] reward=0.992188  kl_orig_only=0.00200457  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "fa2d6c5f05064aaaadb361b62c978702",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 80.59 and ACC confidence [78.37, 82.64]  |  pred NA: 5/1319 (0.38%)  \n",
            "[save step 1440] reward=0.996094  kl_orig_only=0.00600658  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "26be7c6dbeca4509b1652e34a12386a8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 80.36 and ACC confidence [78.13, 82.42]  |  pred NA: 10/1319 (0.76%)  \n",
            "[save step 1600] reward=0.996094  kl_orig_only=0.00569447  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "13fe92f630ff49ef80a57765f3b9c36e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "                                     # majority vote weight 0.05\n",
        "if CKPT_LOAD_PATH:\n",
        "    trainer.train(resume_from_checkpoint=CKPT_LOAD_PATH)\n",
        "else:\n",
        "    trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "LXTC3ufn7PZZ",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "9ac05b0723ed4e13bfe6016993cf798a",
            "990fa215bf844e3b8ea8bc9bda7e1a3c",
            "b9dfc499a7904ad0bcb6626f93487019",
            "6c7383eab16b4c2385ab1c300b8423ee",
            "3bbaeb287b5f4086b7c3e1474a34972c",
            "9f1cadc5d79649b19bc6a6941c3245e2",
            "b0161588dac440cda9382e2482b32219",
            "68f6dd3d216d4a3d8ed8f0a915fccc27",
            "157c4560897f4de3959667ae8483508e",
            "a88a028b745843a7aa4e495d70cfd270",
            "0a863dbca2044fbaad16705cd419fdb9",
            "43eb6d953311464391ee454aa2c3e279",
            "fb3ae36336974394913c44c99e5de061",
            "35c453987de54f8491bec86a0fe062dc",
            "b8afce26ae8b4305bff9ed6688f3ddfe",
            "8ada8281b3ca4a2bbfc9bab988ec3df4",
            "ea74d8f5c7bb42d794872a048875ce93",
            "2bbbedc99ea347e59b70a59a30186297",
            "1f6bfba559be428480eb83c73e78315a",
            "4783394409ba45e28e6872fe74c7b7da",
            "18717620d990478b9e58ce52fe77dcdc",
            "84531fd148ad4bcfac95d43da9913159",
            "4b6c966e6bcb462f9817cc37eadafc85",
            "25ff08bb4bbb4c31a9d9c5b5f08bb508",
            "725029a332d2467eb966e0b88fe16d2c",
            "d25f806fab61436f8f5d3d5a16f5e0f1",
            "bdc935dbdb1f4036b46e58fa2d80ddbf",
            "4e7998fb8f6345a681a868445e9aff4c",
            "2db9147dcbd340cabcb514a7a023cdc4",
            "0d163a21883f448dbad9464fb66467e4",
            "5bc1e955e9c34d7fba8c9de399ae1ac0",
            "f06449b580ac4e849538a00e2df37a81",
            "01806afad25045508bd4eb9cc25fd085",
            "7d11610cf9ee4892b4056495f794b08c",
            "07a3b12d03cd4371a9698cbfe8173cfe",
            "b4ea0408ecdd4d8ca466755833c4c5c8",
            "b4eb097c85964ca9bfc73c6bf91cbaf9",
            "3fafd0243abd4a528e686e5c5d8418bd",
            "172069fe12fd4550ba24ee98b0f86343",
            "272ada2d7569455ba07407c4471146cc",
            "d801cfe586a84f2894c185c9f39aba38",
            "95b79511c0ee42e7b2ca30b63ea484b0",
            "3044700dd4774e2aba7aed8f70a39888",
            "1a4508b1ac9c47919bb1565d636fa3a8",
            "d567c31ca55149979e15a11e2d996581",
            "32a472a8ac6c4c598f2b6426144a960a",
            "13962aacdf40481e82271020cfee3ea4",
            "390bc2be57a343198a673543d41539d0",
            "fb02ac5aa1cd4b8b8ee68a4592911456",
            "6681c8573ffe47479b45516da9d97175",
            "8618259c164645a8b84cd465b6bfd3f8",
            "d2bdcf3ef69b4d078d0f7b55bbe14422",
            "f64dddc0ccdc49f6b94e191c0107f510",
            "f00948e6be484b6b84e2496c32227eef",
            "1b329a73adf94f2ea7675ae99cd01fef",
            "d867cd98dbf14053961a64c6efeb26f9",
            "dfe054ec515940549cfcbd9b7581b67d",
            "7028269d58964755babbc4e93c1d482b",
            "5bfa9a9eb18c4c8abcce1f1c9d2d0208",
            "a3449a2e296043d0abeba3982d24e839",
            "0559847792ee4f358dff48a64861ceff",
            "daa3a060f2cc43aab105c38476c26b57",
            "e3989b46ac8340e588fdaaaabf9b87e4",
            "680210bb7bc64f0fb96b3cd11f955463",
            "639b552ed05248f9b7c72d73c14da9cd",
            "695c75b1751944bf8feb55d1026e3e46",
            "afb5e52a41534152a3db972a9bc94c54",
            "7f312f978adb474db9ce49c2f3f6153b",
            "447f11db2a344b6abcf549f941628737",
            "d340cb25e52b43caa6c98b8733cc4852",
            "6753f3afa9284919825e60c1f71fd853",
            "e3712025a45545dcb040599a739cedbc",
            "53246caf2dc6425498f634664d914162",
            "aafbaa511943448cad086cccccf763d3",
            "3749db91fd034b37b2f44330a017ae54",
            "fc5ba29bb98f466e999b28549b750c14",
            "7d41fb48052e4afbb6ba066568edc4e0",
            "79abb262b1e442a99b23b327f9547c21",
            "5e59bd31893d465e860cfb6b6cc0b3ba",
            "e93f4d72fdab45ce9b21e42e430702b6",
            "e39ab8735bc14a11a8becb6d8b199a03",
            "876b7a1f4fd142f7966996549c5e55eb",
            "2b35f4a7911e425ba562c2435913f72d",
            "7976edc1125d4076a10b58cb1978e1db",
            "973d042d851042da996b747d3da804b1",
            "550be0f6d12e4c60886d0d5035ba6b3d",
            "d2127f38f39d426e9a4728ecae2d2738",
            "7583761f254c4a0fa19ef9c42de93cbd",
            "2888c9cb30564fa893a7b043ae4f90db",
            "282d3372db4f40589438b08fd38c8b58",
            "30f0fcc7ae7d493cb2b69232e6456fba",
            "8ec215d4d667480996ed8da053e945cc",
            "2718473403834fee877163873d336309",
            "8b1dd16d783d4c31a9e2a41386c72a1b",
            "b3c7ade6ac214c12b7be4427a4119c72",
            "f2bae048b4034efea7db8b2f6fb4f526",
            "2b9d943d33e445d7aa16e31bbd148f7c",
            "c67785ae392f4507acf38816eab3e5a9",
            "6499bd4e0b3e499981c5e2cf182f1ba0",
            "80261814dc5846e88c72e0e36390c7a8",
            "16faf06fe337402cb1d11e7ae06df0bd",
            "b96e91dc71a54ec7b641598f13349382",
            "10823f8c532b4bb69c49bb50a752b4dc",
            "571d5fe9dd29472eabca54537d244354",
            "bc684009a1ac42d3bc55815b41f56d88",
            "fbe0e0a348ad49c5be519e189b2e7b5c",
            "ce5b87c5dce045a88a4bf7885b2fb750",
            "a11f4f065a784736aa0cc9f5fdcf2afd",
            "641de6da851248b6b01b712f3020e5bd",
            "10930da0ec404970a5a815f0fdeb3a92"
          ]
        },
        "id": "LXTC3ufn7PZZ",
        "outputId": "cb38a1f4-6932-4ac8-d814-3081808b27d5"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.\n",
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1473' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1473/1600 7:02:00 < 36:26, 0.06 it/s, Epoch 0.92/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>160</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>0.000300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>480</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1120</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1440</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[save step 160] reward=0.648438  kl_orig_only=0.00597381  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9ac05b0723ed4e13bfe6016993cf798a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 77.48 and ACC confidence [75.15, 79.66]  |  pred NA: 10/1319 (0.76%)  \n",
            "[save step 320] reward=0.609375  kl_orig_only=0.00480385  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "43eb6d953311464391ee454aa2c3e279",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.91 and ACC confidence [77.66, 81.98]  |  pred NA: 8/1319 (0.61%)  \n",
            "[save step 480] reward=0.679688  kl_orig_only=0.0499871  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4b6c966e6bcb462f9817cc37eadafc85",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 77.71 and ACC confidence [75.39, 79.87]  |  pred NA: 23/1319 (1.74%)  \n",
            "[save step 640] reward=0.570312  kl_orig_only=0.00436745  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7d11610cf9ee4892b4056495f794b08c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 77.79 and ACC confidence [75.46, 79.95]  |  pred NA: 12/1319 (0.91%)  \n",
            "[save step 800] reward=0.585938  kl_orig_only=0.00528257  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d567c31ca55149979e15a11e2d996581",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.45 and ACC confidence [77.19, 81.55]  |  pred NA: 8/1319 (0.61%)  \n",
            "[save step 960] reward=0.65625  kl_orig_only=0.0106137  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d867cd98dbf14053961a64c6efeb26f9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.61 and ACC confidence [77.35, 81.69]  |  pred NA: 10/1319 (0.76%)  \n",
            "[save step 1120] reward=0.601562  kl_orig_only=0.00532593  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "afb5e52a41534152a3db972a9bc94c54",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.15 and ACC confidence [76.88, 81.26]  |  pred NA: 5/1319 (0.38%)  \n",
            "[save step 1280] reward=0.703125  kl_orig_only=0.00470573  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "79abb262b1e442a99b23b327f9547c21",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.68 and ACC confidence [77.43, 81.77]  |  pred NA: 7/1319 (0.53%)  \n",
            "[save step 1440] reward=0.679688  kl_orig_only=0.010951  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2888c9cb30564fa893a7b043ae4f90db",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.00 and ACC confidence [76.72, 81.11]  |  pred NA: 6/1319 (0.45%)  \n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1600' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1600/1600 7:45:56, Epoch 1/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>160</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>0.000300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>480</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1120</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1440</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1600</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[save step 1600] reward=0.59375  kl_orig_only=0.00600542  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "80261814dc5846e88c72e0e36390c7a8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.61 and ACC confidence [77.35, 81.69]  |  pred NA: 7/1319 (0.53%)  \n"
          ]
        }
      ],
      "source": [
        "if CKPT_LOAD_PATH:  # majority vote\n",
        "    trainer.train(resume_from_checkpoint=CKPT_LOAD_PATH)\n",
        "else:\n",
        "    trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "k5M1jnQdXDxC",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "k5M1jnQdXDxC",
        "outputId": "8f0fe89a-4491-40c2-a7da-d5b6344bf349"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using masked KL penalty.\n"
          ]
        }
      ],
      "source": [
        "# Option B: Train in chunks with TRL GRPOTrainer, run your eval() between chunks.\n",
        "# Assumes you already have:\n",
        "#   - model, tokenizer\n",
        "#   - classes_usable (paraphrase classes)\n",
        "#   - val_ds (gsm8k validation split you made)\n",
        "#   - your evaluate_gsm8k(...) and (optional) evaluate_gsm8k_paraphrase_agreement_batched(...)\n",
        "#   - your agreement_reward_fn(...) (reward over completions)\n",
        "#\n",
        "# Notes:\n",
        "# - GRPOTrainer supports resuming training; we use max_steps to control total steps.\n",
        "# - For IterableDataset, max_steps is the right way to bound training.\n",
        "import os\n",
        "os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
        "\n",
        "import math\n",
        "from trl import GRPOConfig\n",
        "try:\n",
        "    # If you use the masked KL subclass we discussed earlier:\n",
        "    TrainerCls = MaskedKLGRPOTrainer\n",
        "    print(\"Using masked KL penalty.\")\n",
        "except NameError:\n",
        "    from trl import GRPOTrainer\n",
        "    TrainerCls = GRPOTrainer\n",
        "\n",
        "# --- dataset for TRL training (paired stream) ---\n",
        "train_dataset = ParaphrasePairIterableFormatted(paraphrase_classes, tokenizer, seed=42)\n",
        "\n",
        "# --- GRPO config ---\n",
        "K = SAMPLES_PER_PAIR\n",
        "\n",
        "# ----------------------------\n",
        "# Chunked training + evaluation\n",
        "# ----------------------------\n",
        "TOTAL_STEPS = LOG_EVERY * 200\n",
        "CHUNK_STEPS = LOG_EVERY * 20\n",
        "\n",
        "EVAL_LIMIT = 1000\n",
        "EVAL_BS = 64\n",
        "\n",
        "DO_PARAPHRASE_EVAL = False\n",
        "PARA_EVAL_LIMIT = 200\n",
        "PARA_EVAL_CLASS_BS = 16   # classes per batch (prompts ~ 2*bs if each class has 2 items)\n",
        "\n",
        "\n",
        "grpo_args = GRPOConfig(\n",
        "    output_dir=WORK_PATH,\n",
        "    per_device_train_batch_size=2,       # (q0, q1)\n",
        "    gradient_accumulation_steps=1,\n",
        "    generation_batch_size=K * 2,   # <-- must be divisible by num_generations\n",
        "    num_generations=K,                   # K rollouts per prompt\n",
        "    max_completion_length=MAX_NEW_TOKENS,\n",
        "    learning_rate=1e-5,\n",
        "    beta=KL_WEIGHT,                        # adjust if your version uses beta instead\n",
        "    logging_steps=LOG_EVERY,\n",
        "    save_steps= CHUNK_STEPS,\n",
        "    max_steps= TOTAL_STEPS,                         # we'll set this per chunk below\n",
        ")\n",
        "\n",
        "trainer = TrainerCls(\n",
        "    model=trainable_model,\n",
        "    args=grpo_args,\n",
        "    train_dataset=train_dataset,\n",
        "    processing_class=tokenizer,\n",
        "    reward_funcs=agreement_reward_fn,\n",
        "    )\n",
        "\n",
        "# Set collator AFTER init (works with your GRPOTrainer signature)\n",
        "trainer.data_collator = lambda feats: grpo_tokenizing_collator(\n",
        "    feats, tokenizer, max_prompt_length=2048\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "m2ssljPjv5gT",
      "metadata": {
        "id": "m2ssljPjv5gT"
      },
      "source": [
        "### Debug"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "oBWStm0RxM1n",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oBWStm0RxM1n",
        "outputId": "c628f178-7254-48d3-b019-6a5ff5bc3540"
      },
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "Can only concatenate tensors but got <class 'str'>",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-761637024.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# After building trainer, grab one batch from its dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_train_dataloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    864\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stop_iteration\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    865\u001b[0m         \u001b[0mfirst_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 866\u001b[0;31m         \u001b[0mnext_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_batch_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    867\u001b[0m         \u001b[0mbatch_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    868\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstop_iteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m_fetch_batches\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    820\u001b[0m                             \u001b[0mbatches\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    821\u001b[0m                     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 822\u001b[0;31m                         \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    823\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mRuntimeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    824\u001b[0m                         raise RuntimeError(\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    615\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    619\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    620\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mTypeError\u001b[0m: Can only concatenate tensors but got <class 'str'>"
          ]
        }
      ],
      "source": [
        "# After building trainer, grab one batch from its dataloader\n",
        "dl = trainer.get_train_dataloader()\n",
        "batch = next(iter(dl))\n",
        "\n",
        "loss, out = trainer.compute_loss(trainer.model, batch, return_outputs=True)\n",
        "\n",
        "print(\"beta:\", float(out[\"beta\"]))\n",
        "print(\"kl_mean_all:\", float(out[\"kl_mean_all\"]))\n",
        "print(\"kl_orig_mean:\", float(out[\"kl_orig_mean\"]))\n",
        "print(\"kl_para_mean:\", float(out[\"kl_para_mean\"]))\n",
        "print(\"masked_kl:\", float(out[\"masked_kl\"]))\n",
        "print(\"loss:\", float(loss.detach().cpu()))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "q-p8g9fkv7gq",
      "metadata": {
        "id": "q-p8g9fkv7gq"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "def debug_kl_mask_once(trainer, train_dataset, tokenizer, K: int):\n",
        "    \"\"\"\n",
        "    Runs one forward pass and prints KL stats split by is_original.\n",
        "    \"\"\"\n",
        "    # Build a single batch: exactly 2 prompts (q0,q1)\n",
        "    it = iter(train_dataset)\n",
        "    batch = [next(it), next(it)]  # two records from ParaphrasePairIterable\n",
        "    # TRL will usually handle collation; easiest is to use trainer's dataloader.\n",
        "    # But we can mimic minimal inputs expected by compute_loss:\n",
        "    prompts = [b[\"prompt\"] for b in batch]\n",
        "    is_orig = torch.tensor([b[\"is_original\"] for b in batch])\n",
        "\n",
        "    # Let TRL tokenize the prompts the same way it does internally:\n",
        "    # Many TRL versions expect \"prompt\" in inputs and will process it.\n",
        "    inputs = {\"prompt\": prompts, \"is_original\": is_orig}\n",
        "\n",
        "    # Call compute_loss with outputs\n",
        "    loss, outputs = trainer.compute_loss(trainer.model, inputs, return_outputs=True)\n",
        "\n",
        "    # Find KL vector key\n",
        "    kl_key = None\n",
        "    for k in [\"kl\", \"kl_div\", \"per_sample_kl\", \"kl_per_sample\"]:\n",
        "        if isinstance(outputs, dict) and k in outputs:\n",
        "            kl_key = k\n",
        "            break\n",
        "    if kl_key is None:\n",
        "        raise KeyError(f\"Couldn't find per-sample KL in outputs keys: {list(outputs.keys())}\")\n",
        "\n",
        "    kl = outputs[kl_key].detach().cpu()         # expected shape [B*G] = [2*K]\n",
        "    mask = is_orig.float().repeat_interleave(K).cpu()\n",
        "\n",
        "    # Split KL by prompt type\n",
        "    kl_orig = kl[mask == 1]\n",
        "    kl_para = kl[mask == 0]\n",
        "\n",
        "    print(\"KL key:\", kl_key, \"shape:\", tuple(kl.shape))\n",
        "    print(\"KL mean (orig):\", float(kl_orig.mean()) if len(kl_orig) else None)\n",
        "    print(\"KL mean (para):\", float(kl_para.mean()) if len(kl_para) else None)\n",
        "\n",
        "    if \"masked_kl\" in outputs:\n",
        "        print(\"masked_kl (trainer):\", float(outputs[\"masked_kl\"].detach().cpu()))\n",
        "    else:\n",
        "        # recompute masked KL ourselves\n",
        "        masked_kl = (kl * mask).sum() / mask.sum().clamp_min(1.0)\n",
        "        print(\"masked_kl (recomputed):\", float(masked_kl))\n",
        "\n",
        "    print(\"loss:\", float(loss.detach().cpu()))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "fr6HpoQjv9QY",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fr6HpoQjv9QY",
        "outputId": "573da719-a9b0-48d7-cdc3-c38e51ca5446"
      },
      "outputs": [
        {
          "ename": "ValueError",
          "evalue": "The GRPOTrainer does not support returning outputs",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-1024889168.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdebug_kl_mask_once\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mSAMPLES_PER_PAIR\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/tmp/ipython-input-1970023020.py\u001b[0m in \u001b[0;36mdebug_kl_mask_once\u001b[0;34m(trainer, train_dataset, tokenizer, K)\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0;31m# Call compute_loss with outputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m     \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m     \u001b[0;31m# Find KL vector key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/tmp/ipython-input-276940471.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, model, inputs, return_outputs, **kwargs)\u001b[0m\n\u001b[1;32m     10\u001b[0m         \u001b[0;31m# Run the normal GRPO forward to get outputs that include per-sample stats.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0;31m# We rely on GRPOTrainer producing a dict-like outputs containing KL.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0;31m# inputs[\"is_original\"]: shape [batch_prompts] (here batch_prompts should be 2)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/trl/extras/profiling.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     96\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mprofiling_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/trl/trainer/grpo_trainer.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, model, inputs, return_outputs, num_items_in_batch)\u001b[0m\n\u001b[1;32m   2104\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2105\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2106\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The GRPOTrainer does not support returning outputs\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2107\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_liger_kernel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2108\u001b[0m             \u001b[0;31m# Compute the loss using the liger grpo loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mValueError\u001b[0m: The GRPOTrainer does not support returning outputs"
          ]
        }
      ],
      "source": [
        "debug_kl_mask_once(trainer, train_dataset, tokenizer, K=SAMPLES_PER_PAIR)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sx5zM-iWwAJR",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "sx5zM-iWwAJR",
        "outputId": "34036967-7614-447c-9d61-5bc4162034de"
      },
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "Can only concatenate tensors but got <class 'str'>",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-3887887193.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0;31m# Train up to cur_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m     \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresume\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m     \u001b[0mresume\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   2323\u001b[0m                 \u001b[0mhf_hub_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_progress_bars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2324\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2325\u001b[0;31m             return inner_training_loop(\n\u001b[0m\u001b[1;32m   2326\u001b[0m                 \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2327\u001b[0m                 \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2616\u001b[0m                 \u001b[0mupdate_step\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2617\u001b[0m                 \u001b[0mnum_batches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgradient_accumulation_steps\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mupdate_step\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtotal_updates\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mremainder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2618\u001b[0;31m                 \u001b[0mbatch_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_batch_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_batches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2619\u001b[0m                 \u001b[0;31m# Store the number of batches for current gradient accumulation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2620\u001b[0m                 \u001b[0;31m# This is used to correctly scale the loss when the last accumulation step has fewer batches\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mget_batch_samples\u001b[0;34m(self, epoch_iterator, num_batches, device)\u001b[0m\n\u001b[1;32m   5652\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5653\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5654\u001b[0;31m                 \u001b[0mbatch_samples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   5655\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5656\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    864\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stop_iteration\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    865\u001b[0m         \u001b[0mfirst_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 866\u001b[0;31m         \u001b[0mnext_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_batch_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    867\u001b[0m         \u001b[0mbatch_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    868\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstop_iteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m_fetch_batches\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    820\u001b[0m                             \u001b[0mbatches\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    821\u001b[0m                     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 822\u001b[0;31m                         \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    823\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mRuntimeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    824\u001b[0m                         raise RuntimeError(\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    615\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    619\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    620\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mTypeError\u001b[0m: Can only concatenate tensors but got <class 'str'>"
          ]
        }
      ],
      "source": [
        "\n",
        "resume = None\n",
        "cur_target = 0\n",
        "while cur_target < TOTAL_STEPS:\n",
        "    cur_target = min(cur_target + CHUNK_STEPS, TOTAL_STEPS)\n",
        "\n",
        "    # Update trainer's total step target.\n",
        "    # HF Trainer uses args.max_steps as the global target; calling train() again will continue.\n",
        "    trainer.args.max_steps = cur_target\n",
        "\n",
        "    # Train up to cur_target\n",
        "    trainer.train(resume_from_checkpoint = resume)\n",
        "    resume = True\n",
        "\n",
        "    # --- Evaluation (your functions) ---\n",
        "    # 1) GSM8K ACC\n",
        "    _ = evaluate_gsm8k(\n",
        "        model=trainer.model,\n",
        "        tokenizer=tokenizer,\n",
        "        dataset=test_ds,\n",
        "        limit=EVAL_LIMIT,\n",
        "        batch_size=EVAL_BS,\n",
        "        max_new_tokens=MAX_NEW_TOKENS,\n",
        "        name=f\"grpo_step{cur_target}\",\n",
        "        use_greedy=True,\n",
        "    )\n",
        "\n",
        "    # 2) Optional paraphrase agreement (fast batched version)\n",
        "    if DO_PARAPHRASE_EVAL:\n",
        "        _ = evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "            model=trainer.model,\n",
        "            tokenizer=tokenizer,\n",
        "            dataset=paraphrase_classes,\n",
        "            limit=PARA_EVAL_LIMIT,\n",
        "            batch_size=PARA_EVAL_CLASS_BS,\n",
        "            max_new_tokens=MAX_NEW_TOKENS,\n",
        "            name=f\"agree_step{cur_target}\",\n",
        "            use_greedy=True,\n",
        "            show_samples=0,\n",
        "        )\n",
        "\n",
        "print(\"Done.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "SD17xtwHtI6d",
      "metadata": {
        "id": "SD17xtwHtI6d"
      },
      "source": [
        "## Load/Generated Answer-Level Partitions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "V5ZfIEeatNHc",
      "metadata": {
        "id": "V5ZfIEeatNHc"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from typing import List\n",
        "from collections import defaultdict\n",
        "\n",
        "# Assume something like:\n",
        "# from sentence_transformers import util as sbert_util\n",
        "# and 'embedder' is a global SentenceTransformer\n",
        "def cluster_embeddings_online(\n",
        "    embs: torch.Tensor,           # [K, D]\n",
        "    sim_threshold: float = 0.95,\n",
        ") -> List[int]:\n",
        "    \"\"\"\n",
        "    Simple online clustering:\n",
        "      - Start first cluster with embs[0]\n",
        "      - For each next embedding:\n",
        "            if cos_sim_to_any_center >= threshold -> join that cluster\n",
        "            else create a new cluster\n",
        "    Returns a list cluster_ids[k] in {0, 1, ...}\n",
        "    \"\"\"\n",
        "    K, D = embs.shape\n",
        "    centers = []          # list of [D] tensors\n",
        "    counts = []           # num points in each cluster\n",
        "    cluster_ids = []\n",
        "\n",
        "    for k in range(K):\n",
        "        e = embs[k : k + 1]          # [1, D] for util.cos_sim\n",
        "\n",
        "        best_cid = None\n",
        "        best_sim = -1.0\n",
        "\n",
        "        for cid, center in enumerate(centers):\n",
        "            sim = torch.nn.functional.cosine_similarity(\n",
        "                e, center.unsqueeze(0), dim=1\n",
        "            ).item()\n",
        "            if sim > best_sim:\n",
        "                best_sim = sim\n",
        "                best_cid = cid\n",
        "\n",
        "        if best_cid is not None and best_sim >= sim_threshold:\n",
        "            # join existing cluster\n",
        "            cluster_ids.append(best_cid)\n",
        "            # update running mean center\n",
        "            n = counts[best_cid]\n",
        "            centers[best_cid] = (centers[best_cid] * n + e.squeeze(0)) / (n + 1)\n",
        "            counts[best_cid] = n + 1\n",
        "        else:\n",
        "            # start new cluster\n",
        "            cid = len(centers)\n",
        "            centers.append(e.squeeze(0))\n",
        "            counts.append(1)\n",
        "            cluster_ids.append(cid)\n",
        "\n",
        "    return cluster_ids\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def cluster_answers_for_question(\n",
        "    answers: List[str],\n",
        "    *,\n",
        "    question: str,\n",
        "    embedder,\n",
        "    sim_threshold: float = 0.95,\n",
        "):\n",
        "    \"\"\"\n",
        "    Returns:\n",
        "      cluster_ids: [K] int\n",
        "      centers:     [C, D] tensor of cluster centers\n",
        "    \"\"\"\n",
        "    texts = [f\"Q: {question}\\nA: {a}\" for a in answers]\n",
        "    embs = embedder.encode(texts, convert_to_tensor=True, show_progress_bar=False)  # [K, D]\n",
        "\n",
        "    K, D = embs.shape\n",
        "    centers = []\n",
        "    counts = []\n",
        "    cluster_ids = []\n",
        "\n",
        "    for k in range(K):\n",
        "        e = embs[k]  # [D]\n",
        "        best_cid = None\n",
        "        best_sim = -1.0\n",
        "\n",
        "        for cid, c in enumerate(centers):\n",
        "            sim = torch.nn.functional.cosine_similarity(\n",
        "                e.unsqueeze(0), c.unsqueeze(0), dim=1\n",
        "            ).item()\n",
        "            if sim > best_sim:\n",
        "                best_sim = sim\n",
        "                best_cid = cid\n",
        "\n",
        "        if best_cid is not None and best_sim >= sim_threshold:\n",
        "            # join existing cluster\n",
        "            cluster_ids.append(best_cid)\n",
        "            n = counts[best_cid]\n",
        "            centers[best_cid] = (centers[best_cid] * n + e) / (n + 1)\n",
        "            counts[best_cid] = n + 1\n",
        "        else:\n",
        "            # create new cluster\n",
        "            cid = len(centers)\n",
        "            centers.append(e)\n",
        "            counts.append(1)\n",
        "            cluster_ids.append(cid)\n",
        "\n",
        "    centers_tensor = torch.stack(centers, dim=0)  # [C, D]\n",
        "    return cluster_ids, centers_tensor\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2izcDe6yyIxP",
      "metadata": {
        "id": "2izcDe6yyIxP"
      },
      "outputs": [],
      "source": [
        "import json, time, hashlib, random, pathlib\n",
        "from collections import Counter\n",
        "from typing import List, Dict, Any\n",
        "\n",
        "\n",
        "\n",
        "COH_CACHE_PATH = f\"cache/triviaqa_cluster.jsonl\"\n",
        "COH_CACHE_PATH = os.path.join(COHERENCE_DATA_PATH, COH_CACHE_PATH)\n",
        "COH_CACHE_PATH\n",
        "\n",
        "def _coh_meta_hash(meta: dict) -> str:\n",
        "    key = json.dumps(meta, sort_keys=True).encode(\"utf-8\")\n",
        "    return hashlib.sha1(key).hexdigest()[:10]\n",
        "\n",
        "\n",
        "def _coh_meta_compatible(current: dict, cached: dict) -> bool:\n",
        "    keys = [\n",
        "        \"train_size\",\n",
        "        \"n_questions\",\n",
        "        \"samples_per_x\",\n",
        "        \"baseline_model_id\",\n",
        "        \"clustering_version\",\n",
        "        \"seed\",\n",
        "    ]\n",
        "    return all(current.get(k) == cached.get(k) for k in keys)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def _sample_baseline_answers_for_x(\n",
        "    x: str,\n",
        "    *,\n",
        "    model,\n",
        "    tokenizer,\n",
        "    K: int,\n",
        "    max_new_tokens: int = 32,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        ") -> List[str]:\n",
        "    ys = answer_batch_chat(\n",
        "        [x] * K,\n",
        "        model=model,\n",
        "        tokenizer=tokenizer,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        temperature=temperature,\n",
        "        top_p=top_p,\n",
        "    )\n",
        "    return ys\n",
        "\n",
        "\n",
        "def build_or_load_coherence_set_cluster(\n",
        "    train_ds,\n",
        "    *,\n",
        "    baseline_model,\n",
        "    tokenizer,\n",
        "    embedder,\n",
        "    paraphrase_classes_all = paraphrase_classes_all,              # <--- NEW: list per train_idx\n",
        "    cache_path: str = COH_CACHE_PATH,\n",
        "    n_questions: int = 600,\n",
        "    samples_per_x: int = 32,\n",
        "    sim_threshold: float = 0.95,\n",
        "    seed: int = 0,\n",
        ") -> List[Dict[str, Any]]:\n",
        "    \"\"\"\n",
        "    Coherence via paraphrases:\n",
        "      For each chosen train_idx:\n",
        "        - Get original question q0 and paraphrases q1..qM-1 from paraphrase_classes_all[train_idx]\n",
        "        - Generate 'samples_per_x' answers for EACH of these questions.\n",
        "          Total rollouts = samples_per_x * M.\n",
        "        - Cluster ALL answers together into semantic classes z.\n",
        "        - nu0(z | x): empirical distribution over clusters using only answers to q0.\n",
        "        - nu_coh(z | x): empirical distribution over clusters using all answers (q0 + paraphrases).\n",
        "        - ratio(z) = nu_coh(z) / nu0(z) for z with nu0(z) > 0.\n",
        "    \"\"\"\n",
        "    path = pathlib.Path(cache_path)\n",
        "    meta_current = {\n",
        "        \"train_size\": len(train_ds),\n",
        "        \"n_questions\": n_questions,\n",
        "        \"samples_per_x\": samples_per_x,\n",
        "        \"baseline_model_id\": getattr(baseline_model, \"name_or_path\", \"unknown\"),\n",
        "        \"clustering_version\": f\"cluster_sim_{sim_threshold}_with_paraphrases\",\n",
        "        \"seed\": seed,\n",
        "    }\n",
        "\n",
        "    # ---------- Try loading existing cache ----------\n",
        "    if path.exists():\n",
        "        records: List[Dict[str, Any]] = []\n",
        "        cached_meta = None\n",
        "        with path.open(\"r\", encoding=\"utf-8\") as f:\n",
        "            for line in f:\n",
        "                obj = json.loads(line)\n",
        "                if \"__meta__\" in obj:\n",
        "                    cached_meta = obj[\"__meta__\"]\n",
        "                else:\n",
        "                    records.append(obj)\n",
        "\n",
        "        if cached_meta is not None and _coh_meta_compatible(meta_current, cached_meta):\n",
        "            print(f\"[coh] Loaded clustered coherence set from {path}\")\n",
        "            print(f\"[coh] {len(records)} questions, \"\n",
        "                  f\"{cached_meta['samples_per_x']} samples each (per question variant)\")\n",
        "            return records\n",
        "        else:\n",
        "            print(\"[coh] Cache exists but metadata mismatch; rebuilding.\")\n",
        "\n",
        "    # ---------- Need to build a new coherence set ----------\n",
        "    print(\"[coh] Building new clustered coherence set (with paraphrases)...\")\n",
        "    path.parent.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "    rng = random.Random(seed)\n",
        "    all_idxs = list(range(len(train_ds)))\n",
        "    chosen_idxs = rng.sample(all_idxs, n_questions)\n",
        "\n",
        "    records: List[Dict[str, Any]] = []\n",
        "\n",
        "    for i, train_idx in enumerate(chosen_idxs):\n",
        "        ex = train_ds[train_idx]\n",
        "\n",
        "        # 1) Original question + paraphrases\n",
        "        q_variants = paraphrase_classes_all[train_idx]\n",
        "        if not q_variants:\n",
        "            # Fallback: if somehow no paraphrases exist, just use the original question\n",
        "            q0 = ex[\"question\"]\n",
        "            q_variants = [q0]\n",
        "        else:\n",
        "            # Ensure index 0 is the original as you described\n",
        "            q0 = q_variants[0]\n",
        "\n",
        "        # 2) Generate answers for each question variant\n",
        "        all_answers: List[str] = []\n",
        "        answer_src: List[int] = []  # which variant index produced each answer (0 = original)\n",
        "\n",
        "        for j, q_text in enumerate(q_variants):\n",
        "            ys = _sample_baseline_answers_for_x(\n",
        "                q_text,\n",
        "                model=baseline_model,\n",
        "                tokenizer=tokenizer,\n",
        "                K=samples_per_x,\n",
        "            )\n",
        "            all_answers.extend(ys)\n",
        "            answer_src.extend([j] * samples_per_x)\n",
        "\n",
        "        # 3) CLUSTERING-BASED SEMANTICS over ALL answers\n",
        "        #    (cluster IDs shared between original+paraphrases)\n",
        "        cluster_ids, centers = cluster_answers_for_question(\n",
        "            all_answers,\n",
        "            question=q0,             # cluster function can ignore this if it's answer-only\n",
        "            embedder=embedder,\n",
        "            sim_threshold=sim_threshold,\n",
        "        )\n",
        "\n",
        "        z_all = [f\"c{cid}\" for cid in cluster_ids]  # semantics E(y, x) for all rollouts\n",
        "\n",
        "        # 3a) nu_coh(z | x): empirical distribution over ALL rollouts (original + paraphrases)\n",
        "        counts_all = Counter(z_all)\n",
        "        total_all = float(len(z_all))\n",
        "        nu_coh = {z: c / total_all for z, c in counts_all.items()}\n",
        "\n",
        "        # 3b) nu0(z | x): distribution using ONLY answers to the original question (variant index 0)\n",
        "        z_orig = [z for z, src in zip(z_all, answer_src) if src == 0]\n",
        "        counts_orig = Counter(z_orig)\n",
        "        total_orig = float(len(z_orig))  # should be samples_per_x\n",
        "        nu0 = {z: c / total_orig for z, c in counts_orig.items()}\n",
        "\n",
        "        # 3c) ratio(z) = nu_coh(z) / nu0(z) where nu0(z) > 0\n",
        "        ratio = {}\n",
        "        for z, p0 in nu0.items():\n",
        "            if p0 > 0.0:\n",
        "                p_coh = nu_coh.get(z, 0.0)\n",
        "                ratio[z] = p_coh / p0\n",
        "\n",
        "        rec = {\n",
        "            \"train_idx\": train_idx,\n",
        "            \"question\": q0,                # original question\n",
        "            \"questions\": q_variants,       # original + paraphrases\n",
        "            \"answers\": all_answers,        # flattened, length = samples_per_x * len(q_variants)\n",
        "            \"answer_src\": answer_src,      # variant index for each answer\n",
        "            \"z_list\": z_all,               # cluster ID per answer\n",
        "            \"nu0\": nu0,                    # distribution from original-only rollouts\n",
        "            \"nu_coh\": nu_coh,              # distribution from all (original + paraphrases)\n",
        "            \"ratio\": ratio,                # nu_coh(z) / nu0(z) for z with nu0(z) > 0\n",
        "            \"centers\": centers.cpu().tolist(),\n",
        "        }\n",
        "        records.append(rec)\n",
        "\n",
        "        if (i + 1) % 50 == 0:\n",
        "            print(f\"[coh] processed {i+1}/{n_questions} base questions\")\n",
        "\n",
        "    # ---------- Save cache ----------\n",
        "    meta_to_save = dict(meta_current)\n",
        "    meta_to_save[\"created_utc\"] = int(time.time())\n",
        "    meta_to_save[\"hash\"] = _coh_meta_hash(meta_to_save)\n",
        "\n",
        "    # Uncomment this block when you're ready to persist to disk\n",
        "\n",
        "    tmp = path.with_suffix(path.suffix + \".tmp\")\n",
        "    with tmp.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta_to_save}) + \"\\n\")\n",
        "        for rec in records:\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    tmp.replace(path)\n",
        "\n",
        "    print(f\"[coh] Saved clustered coherence set to {path}\")\n",
        "\n",
        "\n",
        "    return records\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "DrkEk66cDZ4i",
      "metadata": {
        "id": "DrkEk66cDZ4i"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from typing import List, Union, Iterable, Optional\n",
        "\n",
        "class LMEmbedder:\n",
        "    \"\"\"\n",
        "    Generic wrapper to use any HF LM (Qwen, LLaMA, Phi-3, etc.)\n",
        "    as a sentence embedder with a SentenceTransformer-like interface.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        model,\n",
        "        tokenizer,\n",
        "        max_length: int = 128,\n",
        "        normalize: bool = True,\n",
        "        device: Optional[Union[str, torch.device]] = None,\n",
        "        batch_size: int = 32,\n",
        "    ):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model: HF model (e.g. AutoModelForCausalLM / LlamaForCausalLM / Phi3ForCausalLM, possibly PEFT-wrapped)\n",
        "            tokenizer: matching tokenizer\n",
        "            max_length: truncation length for inputs\n",
        "            normalize: whether to L2-normalize embeddings\n",
        "            device: override device; if None, use model.device\n",
        "            batch_size: batching size for .encode()\n",
        "        \"\"\"\n",
        "        self.model = model\n",
        "        self.tokenizer = tokenizer\n",
        "        self.max_length = max_length\n",
        "        self.normalize = normalize\n",
        "        self.batch_size = batch_size\n",
        "\n",
        "        # Ensure we can access hidden states\n",
        "        if hasattr(self.model.config, \"output_hidden_states\"):\n",
        "            self.model.config.output_hidden_states = True\n",
        "\n",
        "        self.model.eval()\n",
        "\n",
        "        # Figure out device\n",
        "        if device is None:\n",
        "            # works for most HF models, including PEFT-wrapped\n",
        "            try:\n",
        "                self.device = next(self.model.parameters()).device\n",
        "            except StopIteration:\n",
        "                self.device = torch.device(\"cpu\")\n",
        "        else:\n",
        "            self.device = torch.device(device)\n",
        "\n",
        "        # Make sure we have a pad_token_id\n",
        "        if self.tokenizer.pad_token_id is None:\n",
        "            # common choice: use EOS as pad\n",
        "            if self.tokenizer.eos_token_id is not None:\n",
        "                self.tokenizer.pad_token = self.tokenizer.eos_token\n",
        "                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def encode(\n",
        "        self,\n",
        "        texts: Union[str, List[str], Iterable[str]],\n",
        "        convert_to_tensor: bool = True,\n",
        "        show_progress_bar: bool = False,\n",
        "    ):\n",
        "        \"\"\"\n",
        "        SentenceTransformer-style API.\n",
        "\n",
        "        Args:\n",
        "            texts: single string or list/iterable of strings\n",
        "            convert_to_tensor: if True, returns a torch.Tensor [N, D]\n",
        "                               else returns a list of numpy arrays\n",
        "            show_progress_bar: ignored here, but kept for API compatibility\n",
        "\n",
        "        Returns:\n",
        "            embeddings: [N, D] tensor or list of np.array\n",
        "        \"\"\"\n",
        "        if isinstance(texts, str):\n",
        "            texts = [texts]\n",
        "        else:\n",
        "            texts = list(texts)\n",
        "\n",
        "        all_embs = []\n",
        "        bs = self.batch_size\n",
        "\n",
        "        for i in range(0, len(texts), bs):\n",
        "            batch_texts = texts[i : i + bs]\n",
        "\n",
        "            enc = self.tokenizer(\n",
        "                batch_texts,\n",
        "                padding=True,\n",
        "                truncation=True,\n",
        "                max_length=self.max_length,\n",
        "                return_tensors=\"pt\",\n",
        "            )\n",
        "            input_ids = enc[\"input_ids\"].to(self.device)\n",
        "            attention_mask = enc[\"attention_mask\"].to(self.device)\n",
        "\n",
        "            out = self.model(\n",
        "                input_ids=input_ids,\n",
        "                attention_mask=attention_mask,\n",
        "                output_hidden_states=True,\n",
        "            )\n",
        "\n",
        "            # last hidden state: [B, T, D]\n",
        "            # some models return .last_hidden_state, some in hidden_states[-1]\n",
        "            if hasattr(out, \"hidden_states\") and out.hidden_states is not None:\n",
        "                hidden = out.hidden_states[-1]\n",
        "            else:\n",
        "                # fallback to last_hidden_state if hidden_states is not present\n",
        "                hidden = out.last_hidden_state\n",
        "\n",
        "            # mean-pool over non-pad tokens\n",
        "            mask = attention_mask.unsqueeze(-1)  # [B, T, 1]\n",
        "            masked_hidden = hidden * mask\n",
        "            lengths = mask.sum(dim=1).clamp(min=1)  # [B, 1]\n",
        "            emb = masked_hidden.sum(dim=1) / lengths  # [B, D]\n",
        "\n",
        "            if self.normalize:\n",
        "                emb = F.normalize(emb, p=2, dim=1)\n",
        "\n",
        "            all_embs.append(emb.cpu())\n",
        "\n",
        "        embs = torch.cat(all_embs, dim=0)  # [N, D]\n",
        "\n",
        "        if convert_to_tensor:\n",
        "            return embs\n",
        "        else:\n",
        "            return [e.numpy() for e in embs]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Zp-xhGcSyXRq",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 404
        },
        "id": "Zp-xhGcSyXRq",
        "outputId": "503eaed4-b348-42d1-ce84-873811d4939b"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "28501ff67bde42a6a9410c55c8977bb9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "994a0db0c63545118bce1b33958c1614",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "cc3f6cd9667a49edafbab2449e02b4e5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0d5d447849144930962682f67ed4b1cb",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4288dfb5897345768b2378e8d159f539",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f1e6da22975748af9609388a2c5d5798",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ce24fe0613244bc6887926753798d897",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a593af02a7cf489f98067e067d0c5a38",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.txt: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e3507cdb84464004b6add70103ccfe09",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c05e536a07e44073895cc90f77f7655d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a6837fefcfe9491787b137373a41995e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[coh] Loaded clustered coherence set from /content/drive/MyDrive/coherence-answer/Qwen/Qwen2.5-3B-Instruct/instruct/cache/triviaqa_cluster.jsonl\n",
            "[coh] 6400 questions, 16 samples each (per question variant)\n"
          ]
        }
      ],
      "source": [
        "from sentence_transformers import SentenceTransformer\n",
        "\n",
        "embedder = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\")\n",
        "\n",
        "# embedder = LMEmbedder(\n",
        "#     model=baseline_model,\n",
        "#     tokenizer=tokenizer,\n",
        "#     max_length=128,\n",
        "#     normalize=True,\n",
        "#     batch_size=32,\n",
        "# )\n",
        "\n",
        "coh_records = build_or_load_coherence_set_cluster(\n",
        "    train_ds,\n",
        "    baseline_model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    embedder = embedder,\n",
        "    n_questions=32 * 200,\n",
        "    samples_per_x=16,\n",
        "    cache_path=COH_CACHE_PATH,\n",
        "    seed = 42\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ikGAzn8nFtkS",
      "metadata": {
        "id": "ikGAzn8nFtkS"
      },
      "outputs": [],
      "source": [
        "\n",
        "import json\n",
        "\n",
        "for rec in coh_records[:1]:\n",
        "    temp_rec = {k: v for k, v in rec.items() if k != 'centers'}\n",
        "    print(json.dumps(temp_rec, indent=2))\n",
        "\n",
        "    # Fetch and print ground truth\n",
        "    question_idx = rec[\"train_idx\"]\n",
        "    ground_truth_obj = train_ds[question_idx][\"answer\"]\n",
        "    print(f\"  Ground Truth Answers (train_idx={question_idx}): {ground_truth_obj}\")\n",
        "    print(\"\\n\" + \"=\"*80 + \"\\n\") # Separator for readability\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "aJqwV-W114QE",
      "metadata": {
        "id": "aJqwV-W114QE"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def E_cluster_for_known_x(\n",
        "    y: str,\n",
        "    x: str,\n",
        "    rec: dict,        # one coherence record for this x\n",
        "    embedder,\n",
        "    min_sim: float = 0.0,  # or e.g. 0.7 if you want a floor\n",
        ") -> str:\n",
        "    \"\"\"\n",
        "    Map a new completion y for a *known* question x\n",
        "    onto one of the existing clusters in rec.\n",
        "    Returns a cluster label like 'c0', 'c1', ...\n",
        "    \"\"\"\n",
        "    # 1) embed the (x, y) pair the same way as during clustering\n",
        "    text = f\"Q: {x}\\nA: {y}\"\n",
        "    emb = embedder.encode([text], convert_to_tensor=True, show_progress_bar=False)[0]  # [D]\n",
        "\n",
        "    # 2) load centers\n",
        "    centers = torch.tensor(rec[\"centers\"], dtype=emb.dtype, device=emb.device)  # [C, D]\n",
        "\n",
        "    # 3) cosine similarity to each center\n",
        "    sims = torch.nn.functional.cosine_similarity(\n",
        "        emb.unsqueeze(0), centers, dim=1\n",
        "    )  # [C]\n",
        "\n",
        "    best_cid = int(torch.argmax(sims).item())\n",
        "    best_sim = float(sims[best_cid].item())\n",
        "\n",
        "    if best_sim < min_sim:\n",
        "        # Optional: treat as 'other' cluster; depends on your design.\n",
        "        # For many WSFT setups it’s simpler to just force assignment.\n",
        "        pass\n",
        "\n",
        "    return f\"c{best_cid}\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0jXMFGTV1607",
      "metadata": {
        "id": "0jXMFGTV1607"
      },
      "outputs": [],
      "source": [
        "# Build a lookup once after loading coh_records\n",
        "coh_by_idx = {rec[\"train_idx\"]: rec for rec in coh_records}\n",
        "\n",
        "def E_for_training_example(train_idx: int, x: str, y: str) -> str:\n",
        "    rec = coh_by_idx[train_idx]\n",
        "    return E_cluster_for_known_x(y, x, rec, embedder)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5_3wH5NkftWo",
      "metadata": {
        "id": "5_3wH5NkftWo"
      },
      "outputs": [],
      "source": [
        "from collections import Counter, defaultdict\n",
        "from typing import List, Dict, Any, Iterable, Tuple, Optional\n",
        "\n",
        "\n",
        "from dataclasses import dataclass\n",
        "\n",
        "@dataclass\n",
        "class WSFTExample:\n",
        "    x: str\n",
        "    y: str\n",
        "    w: float\n",
        "    train_idx: int\n",
        "    sample_idx: int   # 0..K-1 within that question\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def make_wsft_examples_from_coh_records(\n",
        "    coh_records: List[Dict[str, Any]],\n",
        "    alpha: float,\n",
        "    *,\n",
        "    # By default, \"coh\" means use x0 and x1 only (original + 1 paraphrase).\n",
        "    # If you want to include more paraphrases, pass e.g. use_variants=None to use all.\n",
        "    use_variants: Optional[Tuple[int, ...]] = (0, 1),\n",
        "    eps: float = 1e-8,\n",
        "    weight_clip: Optional[float] = None,\n",
        ") -> List[WSFTExample]:\n",
        "    \"\"\"\n",
        "    Build WSFT training tuples (x, y, w) from coherence cluster records.\n",
        "\n",
        "    E(y) is implemented by using the stored cluster id z in rec[\"z_list\"].\n",
        "\n",
        "    v0(z|x): empirical distribution of cluster labels among answers generated for *that* x (variant-specific).\n",
        "    v_coh(z|x): empirical distribution of cluster labels among answers generated for x0 and x1 (by default).\n",
        "\n",
        "    w = (1-alpha) + alpha * v_coh(z|x) / v0(z|x)\n",
        "    \"\"\"\n",
        "    assert 0.0 <= alpha <= 1.0, \"alpha must be in [0, 1].\"\n",
        "\n",
        "    wsft_examples: List[WSFTExample] = []\n",
        "\n",
        "    for rec in coh_records:\n",
        "        train_idx = rec[\"train_idx\"]\n",
        "        questions: List[str] = rec[\"questions\"]          # [x0, x1, ...]\n",
        "        answers: List[str] = rec[\"answers\"]              # flattened\n",
        "        answer_src: List[int] = rec[\"answer_src\"]        # which x_j produced each answer\n",
        "        z_list: List[str] = rec[\"z_list\"]                # E(y): cluster id per answer\n",
        "\n",
        "        if not questions or not answers:\n",
        "            continue\n",
        "\n",
        "        # Decide which variants define \"coherence\" set (default: only x0 and x1).\n",
        "        if use_variants is None:\n",
        "            coh_vars: Iterable[int] = range(len(questions))\n",
        "        else:\n",
        "            coh_vars = tuple(v for v in use_variants if v < len(questions))\n",
        "            if len(coh_vars) == 0:\n",
        "                coh_vars = (0,)  # fallback\n",
        "\n",
        "        coh_var_set = set(coh_vars)\n",
        "\n",
        "        # ---- v_coh over answers from selected variants (default: x0+x1) ----\n",
        "        coh_z = [z for z, src in zip(z_list, answer_src) if src in coh_var_set]\n",
        "        if len(coh_z) == 0:\n",
        "            continue\n",
        "        coh_counts = Counter(coh_z)\n",
        "        coh_total = float(len(coh_z))\n",
        "        v_coh = {z: c / coh_total for z, c in coh_counts.items()}\n",
        "\n",
        "        # ---- v0 per variant x_j (only for variants we include) ----\n",
        "        v0_by_var: Dict[int, Dict[str, float]] = {}\n",
        "        for j in coh_var_set:\n",
        "            z_j = [z for z, src in zip(z_list, answer_src) if src == j]\n",
        "            if len(z_j) == 0:\n",
        "                continue\n",
        "            cnt = Counter(z_j)\n",
        "            tot = float(len(z_j))\n",
        "            v0_by_var[j] = {z: c / tot for z, c in cnt.items()}\n",
        "\n",
        "        # ---- build WSFT examples ----\n",
        "        per_var_k = defaultdict(int)  # sample_idx within that x_j\n",
        "\n",
        "        for y, z, src in zip(answers, z_list, answer_src):\n",
        "            if src not in coh_var_set:\n",
        "                continue\n",
        "            if src not in v0_by_var:\n",
        "                continue  # no stats for this variant\n",
        "\n",
        "            v0_z = max(v0_by_var[src].get(z, 0.0), eps)\n",
        "            vcoh_z = v_coh.get(z, 0.0)\n",
        "\n",
        "            w = (1.0 - alpha) + alpha * (vcoh_z / v0_z)\n",
        "            if weight_clip is not None:\n",
        "                w = min(w, weight_clip)\n",
        "\n",
        "            wsft_examples.append(\n",
        "                WSFTExample(\n",
        "                    x=questions[src],           # x is either x0 or x1 (or more if you allow)\n",
        "                    y=y,\n",
        "                    w=float(w),\n",
        "                    train_idx=train_idx,\n",
        "                    sample_idx=int(per_var_k[src]),\n",
        "                )\n",
        "            )\n",
        "            per_var_k[src] += 1\n",
        "\n",
        "    print(f\"[wsft] Built {len(wsft_examples)} WSFT examples (alpha={alpha}, use_variants={use_variants})\")\n",
        "    return wsft_examples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "_hU9FbRClOjZ",
      "metadata": {
        "id": "_hU9FbRClOjZ"
      },
      "outputs": [],
      "source": [
        "from collections import Counter, defaultdict\n",
        "from typing import List, Dict, Any, Optional, Tuple, Iterable\n",
        "\n",
        "def make_wsft_examples_from_coh_records_majority(\n",
        "    coh_records: List[Dict[str, Any]],\n",
        "    alpha: float,\n",
        "    *,\n",
        "    use_variants: Optional[Tuple[int, ...]] = (0, 1),\n",
        "    eps: float = 1e-8,\n",
        "    weight_clip: Optional[float] = None,\n",
        ") -> List[WSFTExample]:\n",
        "    \"\"\"\n",
        "    Build WSFT training tuples (x, y, w) from coherence cluster records.\n",
        "\n",
        "    E(y) is implemented by using the stored cluster id z in rec[\"z_list\"].\n",
        "\n",
        "    v0(z|x): empirical distribution of cluster labels among answers generated for *that* x (variant-specific).\n",
        "    v_coh(z|x): MAJORITY VOTE over cluster labels among answers from the selected variants\n",
        "               (default: variants 0 and 1). Implemented as a degenerate distribution:\n",
        "                 v_coh(z|x) = 1 if z == z*, else 0,\n",
        "               where z* is the most frequent cluster among those answers.\n",
        "\n",
        "    w = (1-alpha) + alpha * v_coh(z|x) / v0(z|x)\n",
        "      = (1-alpha) + alpha * (1 / v0(z|x)) if z==z*, else (1-alpha)\n",
        "    \"\"\"\n",
        "    assert 0.0 <= alpha <= 1.0, \"alpha must be in [0, 1].\"\n",
        "\n",
        "    def _z_numeric(z: str) -> int:\n",
        "        # deterministic tie-break (prefer smaller cluster id)\n",
        "        try:\n",
        "            if isinstance(z, str) and z.startswith(\"c\"):\n",
        "                return int(z[1:])\n",
        "            return int(z)\n",
        "        except Exception:\n",
        "            return 0\n",
        "\n",
        "    wsft_examples: List[WSFTExample] = []\n",
        "\n",
        "    for rec in coh_records:\n",
        "        train_idx = rec[\"train_idx\"]\n",
        "        questions: List[str] = rec[\"questions\"]\n",
        "        answers: List[str] = rec[\"answers\"]\n",
        "        answer_src: List[int] = rec[\"answer_src\"]\n",
        "        z_list: List[str] = rec[\"z_list\"]\n",
        "\n",
        "        if not questions or not answers:\n",
        "            continue\n",
        "\n",
        "        # Decide which variants define the \"coherence\" majority vote set.\n",
        "        if use_variants is None:\n",
        "            coh_vars: Iterable[int] = range(len(questions))\n",
        "        else:\n",
        "            coh_vars = tuple(v for v in use_variants if v < len(questions))\n",
        "            if len(coh_vars) == 0:\n",
        "                coh_vars = (0,)\n",
        "\n",
        "        coh_var_set = set(coh_vars)\n",
        "\n",
        "        # ---- majority vote z* over answers from selected variants ----\n",
        "        coh_z = [z for z, src in zip(z_list, answer_src) if src in coh_var_set]\n",
        "        if len(coh_z) == 0:\n",
        "            continue\n",
        "\n",
        "        coh_counts = Counter(coh_z)\n",
        "        # pick max count; tie-break by smaller numeric cluster id for determinism\n",
        "        z_star = max(coh_counts.items(), key=lambda kv: (kv[1], -_z_numeric(kv[0])))[0]\n",
        "\n",
        "        # ---- v0 per variant x_j ----\n",
        "        v0_by_var: Dict[int, Dict[str, float]] = {}\n",
        "        for j in coh_var_set:\n",
        "            z_j = [z for z, src in zip(z_list, answer_src) if src == j]\n",
        "            if len(z_j) == 0:\n",
        "                continue\n",
        "            cnt = Counter(z_j)\n",
        "            tot = float(len(z_j))\n",
        "            v0_by_var[j] = {z: c / tot for z, c in cnt.items()}\n",
        "\n",
        "        # ---- build WSFT examples ----\n",
        "        per_var_k = defaultdict(int)\n",
        "\n",
        "        for y, z, src in zip(answers, z_list, answer_src):\n",
        "            if src not in coh_var_set:\n",
        "                continue\n",
        "            if src not in v0_by_var:\n",
        "                continue\n",
        "\n",
        "            v0_z = max(v0_by_var[src].get(z, 0.0), eps)\n",
        "            vcoh_z = 1.0 if z == z_star else 0.0  # majority vote distribution\n",
        "\n",
        "            w = (1.0 - alpha) + alpha * (vcoh_z / v0_z)\n",
        "            if weight_clip is not None:\n",
        "                w = min(w, weight_clip)\n",
        "\n",
        "            wsft_examples.append(\n",
        "                WSFTExample(\n",
        "                    x=questions[src],\n",
        "                    y=y,\n",
        "                    w=float(w),\n",
        "                    train_idx=train_idx,\n",
        "                    sample_idx=int(per_var_k[src]),\n",
        "                )\n",
        "            )\n",
        "            per_var_k[src] += 1\n",
        "\n",
        "    print(f\"[wsft] Built {len(wsft_examples)} WSFT examples (alpha={alpha}, use_variants={use_variants})\")\n",
        "    return wsft_examples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4P4XzkJNf6oG",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4P4XzkJNf6oG",
        "outputId": "59e001c4-6903-49ff-80dd-9c05dc3c958b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[wsft] Built 204800 WSFT examples (alpha=0.5, use_variants=(0, 1))\n"
          ]
        }
      ],
      "source": [
        "alpha = 0.5\n",
        "wsft_examples = make_wsft_examples_from_coh_records_majority(coh_records, alpha=alpha)  # uses x0+x1 only\n",
        "# or include all paraphrases in v_coh:\n",
        "# wsft_examples_all = make_wsft_examples_from_coh_records(coh_records, alpha=alpha, use_variants=None)\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Q3RRjK20mNjF",
        "IviYIdL0mNjG",
        "Ul00h3tomNjH",
        "ChcYKgQMmNjI",
        "SD17xtwHtI6d"
      ],
      "gpuType": "A100",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "004da1ed97a44382bde90d6c5c244a2f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_17528517e5d54f2cb86c90206b1c9f57",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_ddabfacb1370443b9e38726aad49bbb3",
            "value": 21
          }
        },
        "01806afad25045508bd4eb9cc25fd085": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "018572d4a0684e3e8a338a0e9ef942de": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0224a2b6fc59439e9771f36f7f30176a": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0559847792ee4f358dff48a64861ceff": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "07a3b12d03cd4371a9698cbfe8173cfe": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_172069fe12fd4550ba24ee98b0f86343",
            "placeholder": "​",
            "style": "IPY_MODEL_272ada2d7569455ba07407c4471146cc",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "08634ce86bd240e78a75864098528d88": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0a863dbca2044fbaad16705cd419fdb9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0d163a21883f448dbad9464fb66467e4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0f45cb9c7061417bafe13fe40f54b663": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "10823f8c532b4bb69c49bb50a752b4dc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_641de6da851248b6b01b712f3020e5bd",
            "placeholder": "​",
            "style": "IPY_MODEL_10930da0ec404970a5a815f0fdeb3a92",
            "value": " 21/21 [30:08&lt;00:00, 86.36s/it, N=1319, ACC%=79.61, greedy=1]"
          }
        },
        "10930da0ec404970a5a815f0fdeb3a92": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "13962aacdf40481e82271020cfee3ea4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d2bdcf3ef69b4d078d0f7b55bbe14422",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f64dddc0ccdc49f6b94e191c0107f510",
            "value": 21
          }
        },
        "157c4560897f4de3959667ae8483508e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "157ce59e61c241ca97f75fcf0da58a6c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "1596780773c64a6097256a02988375b7": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "16faf06fe337402cb1d11e7ae06df0bd": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bc684009a1ac42d3bc55815b41f56d88",
            "placeholder": "​",
            "style": "IPY_MODEL_fbe0e0a348ad49c5be519e189b2e7b5c",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "172069fe12fd4550ba24ee98b0f86343": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "17528517e5d54f2cb86c90206b1c9f57": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "18717620d990478b9e58ce52fe77dcdc": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "18ec27986e1d497695c66dd1cd4142a7": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "19bb218724e741248d79f0d5cbeee526": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1a4508b1ac9c47919bb1565d636fa3a8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1b329a73adf94f2ea7675ae99cd01fef": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1eac4a3feaaf4a7eb20cae82b5a78ba4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_c2ce751a1ad44d4d9e26cc7b5dd5ddfc",
              "IPY_MODEL_3acfe564bc134c2d9ae871bb568b203f",
              "IPY_MODEL_8e00c17408e44c95939b06e5e9954441"
            ],
            "layout": "IPY_MODEL_27b7a1b99aa142999397e2437123f27e"
          }
        },
        "1f6bfba559be428480eb83c73e78315a": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "22a1c8806adf4710a02d5694485f42c1": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "239a3be7a04444cd8ff61cc41ce6063e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "25ff08bb4bbb4c31a9d9c5b5f08bb508": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4e7998fb8f6345a681a868445e9aff4c",
            "placeholder": "​",
            "style": "IPY_MODEL_2db9147dcbd340cabcb514a7a023cdc4",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "26932873ef3e4e439700e62540a54a98": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2718473403834fee877163873d336309": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "272ada2d7569455ba07407c4471146cc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "27b7a1b99aa142999397e2437123f27e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "282d3372db4f40589438b08fd38c8b58": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8b1dd16d783d4c31a9e2a41386c72a1b",
            "placeholder": "​",
            "style": "IPY_MODEL_b3c7ade6ac214c12b7be4427a4119c72",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "2888c9cb30564fa893a7b043ae4f90db": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_282d3372db4f40589438b08fd38c8b58",
              "IPY_MODEL_30f0fcc7ae7d493cb2b69232e6456fba",
              "IPY_MODEL_8ec215d4d667480996ed8da053e945cc"
            ],
            "layout": "IPY_MODEL_2718473403834fee877163873d336309"
          }
        },
        "2a617d3a24c74fa1a092123d7425ba4c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2b35f4a7911e425ba562c2435913f72d": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2b9d943d33e445d7aa16e31bbd148f7c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "2bbbedc99ea347e59b70a59a30186297": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2c22e8e41b804702821dac54e8572221": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2c400043327a4fa7b10b5a8f527df0cd": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_50c82dcbe2f14801b3cbdf412f91b094",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_eab965c99d6d445db3a182b128a0fce5",
            "value": 0
          }
        },
        "2ce513bcfcd7429d9fe6872d177bf38e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2d055629ae454fde9581e1a14e25c117": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2db9147dcbd340cabcb514a7a023cdc4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2f42c28f32794f19b3134e6e5dd79a25": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d1e977c54c0f48bf9bc787436da92f52",
            "placeholder": "​",
            "style": "IPY_MODEL_b0ac144780df465ab6db88d2d57e5174",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "3044700dd4774e2aba7aed8f70a39888": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3058e0053b9d4278b0aaff3d42a15ef9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "30f0fcc7ae7d493cb2b69232e6456fba": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f2bae048b4034efea7db8b2f6fb4f526",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2b9d943d33e445d7aa16e31bbd148f7c",
            "value": 21
          }
        },
        "322f7a07b1de4532b31980509f132743": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "32a472a8ac6c4c598f2b6426144a960a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6681c8573ffe47479b45516da9d97175",
            "placeholder": "​",
            "style": "IPY_MODEL_8618259c164645a8b84cd465b6bfd3f8",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "3345258ab7fd4dc085c81ca3eacb35dd": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "346936dae8d0460b88f574168fea2209": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "35c453987de54f8491bec86a0fe062dc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1f6bfba559be428480eb83c73e78315a",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_4783394409ba45e28e6872fe74c7b7da",
            "value": 21
          }
        },
        "3749db91fd034b37b2f44330a017ae54": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "390bc2be57a343198a673543d41539d0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f00948e6be484b6b84e2496c32227eef",
            "placeholder": "​",
            "style": "IPY_MODEL_1b329a73adf94f2ea7675ae99cd01fef",
            "value": " 21/21 [28:12&lt;00:00, 79.46s/it, N=1319, ACC%=79.45, greedy=1]"
          }
        },
        "39b9ffa02eb84f85a76156cbc1c8a88b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_346936dae8d0460b88f574168fea2209",
            "placeholder": "​",
            "style": "IPY_MODEL_a7531cf6b513470da02333b1420e5c9c",
            "value": "Paraphrase agreement qwen3b: 100%"
          }
        },
        "3a3884b6b2c045c2ae33bba4c0c462b4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3345258ab7fd4dc085c81ca3eacb35dd",
            "placeholder": "​",
            "style": "IPY_MODEL_441473c91daf4aeea3a91dc460512b2a",
            "value": " 7473/7473 [2:12:43&lt;00:00,  1.12s/it, classes=7473, AGREE%=67.14, greedy=1]"
          }
        },
        "3a65e1adf5a44fd0986e3011e33bf64c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ea78979948144c6e800e16d0ac934bf5",
            "placeholder": "​",
            "style": "IPY_MODEL_ceb55c622676492d81007c4de64ddb20",
            "value": " 21/21 [28:34&lt;00:00, 75.93s/it, N=1319, ACC%=79.83, greedy=1]"
          }
        },
        "3acfe564bc134c2d9ae871bb568b203f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6fd2c60356224d9c841a1fff4feaf6a4",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_fc4d8b814f454a7c8515dfc3f4f957d0",
            "value": 21
          }
        },
        "3bbaeb287b5f4086b7c3e1474a34972c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "3e4094c0ed3041afa63d26fc05b6f17d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_aa4ba36a2be8423bb9c23cd1834b3c15",
              "IPY_MODEL_2c400043327a4fa7b10b5a8f527df0cd",
              "IPY_MODEL_e4a308877b7d4fb283f15b6613814519"
            ],
            "layout": "IPY_MODEL_acf8d4bc4e4b4cf285ea39855fd6c75f"
          }
        },
        "3e446973a3dc4785ae37c485753cd964": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "3fafd0243abd4a528e686e5c5d8418bd": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "4343230de097491ebd98919615ca7c1f": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "43eb6d953311464391ee454aa2c3e279": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_fb3ae36336974394913c44c99e5de061",
              "IPY_MODEL_35c453987de54f8491bec86a0fe062dc",
              "IPY_MODEL_b8afce26ae8b4305bff9ed6688f3ddfe"
            ],
            "layout": "IPY_MODEL_8ada8281b3ca4a2bbfc9bab988ec3df4"
          }
        },
        "441473c91daf4aeea3a91dc460512b2a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "44590ee9685940dabf80cef9de6c36db": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "447f11db2a344b6abcf549f941628737": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_aafbaa511943448cad086cccccf763d3",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_3749db91fd034b37b2f44330a017ae54",
            "value": 21
          }
        },
        "4512f7a4b9134f4ab867da0c37642100": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4783394409ba45e28e6872fe74c7b7da": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "4af918c43f6543a5bbfd43d2fa231160": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4b6686f1b535452fa6b2aa22a4e6b701": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a4549b715203429f85d3e152e1f6dfde",
            "placeholder": "​",
            "style": "IPY_MODEL_a6bad0fbc40d41199de894615150b714",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "4b6c966e6bcb462f9817cc37eadafc85": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_25ff08bb4bbb4c31a9d9c5b5f08bb508",
              "IPY_MODEL_725029a332d2467eb966e0b88fe16d2c",
              "IPY_MODEL_d25f806fab61436f8f5d3d5a16f5e0f1"
            ],
            "layout": "IPY_MODEL_bdc935dbdb1f4036b46e58fa2d80ddbf"
          }
        },
        "4d9917046a984adf821475a945ce25f9": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4e7998fb8f6345a681a868445e9aff4c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5010496fa3ec42b3b4f3e2d088fb38d3": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "50c82dcbe2f14801b3cbdf412f91b094": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "51096091a9594b3cbba9e1f06c439f65": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "517c75a17f03426caaa7ba64f106abde": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "53246caf2dc6425498f634664d914162": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "53847574e82447c1ad184ce452e58b5b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "550be0f6d12e4c60886d0d5035ba6b3d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "571d5fe9dd29472eabca54537d244354": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "59169c85f9774adbb8325c9d11ad7b00": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "59f2130e9c564172b17f10efa43f72d6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5bc1e955e9c34d7fba8c9de399ae1ac0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5bfa9a9eb18c4c8abcce1f1c9d2d0208": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_639b552ed05248f9b7c72d73c14da9cd",
            "placeholder": "​",
            "style": "IPY_MODEL_695c75b1751944bf8feb55d1026e3e46",
            "value": " 21/21 [29:44&lt;00:00, 81.87s/it, N=1319, ACC%=79.61, greedy=1]"
          }
        },
        "5c0761880e5f4438be927b9ab3c272ad": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5c7202bf9cd04c0a97a42bec080d448b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_08634ce86bd240e78a75864098528d88",
            "placeholder": "​",
            "style": "IPY_MODEL_322f7a07b1de4532b31980509f132743",
            "value": " 21/21 [28:30&lt;00:00, 76.55s/it, N=1319, ACC%=79.15, greedy=1]"
          }
        },
        "5e59bd31893d465e860cfb6b6cc0b3ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2b35f4a7911e425ba562c2435913f72d",
            "placeholder": "​",
            "style": "IPY_MODEL_7976edc1125d4076a10b58cb1978e1db",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "5feae3252dcc45e58c65c5a114980ed6": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "639b552ed05248f9b7c72d73c14da9cd": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "641de6da851248b6b01b712f3020e5bd": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6499bd4e0b3e499981c5e2cf182f1ba0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6681c8573ffe47479b45516da9d97175": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6753f3afa9284919825e60c1f71fd853": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "67849f803a08422fa357bd27a11bf964": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "680210bb7bc64f0fb96b3cd11f955463": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "68f6dd3d216d4a3d8ed8f0a915fccc27": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "695c75b1751944bf8feb55d1026e3e46": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6a2b3e4accb3479bbf46c199927ad393": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a9c3bba89d8b436e85dcf2c81bb35a84",
              "IPY_MODEL_aa45e9b80f3a46c3bc4e549322614387",
              "IPY_MODEL_5c7202bf9cd04c0a97a42bec080d448b"
            ],
            "layout": "IPY_MODEL_59169c85f9774adbb8325c9d11ad7b00"
          }
        },
        "6c7383eab16b4c2385ab1c300b8423ee": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a88a028b745843a7aa4e495d70cfd270",
            "placeholder": "​",
            "style": "IPY_MODEL_0a863dbca2044fbaad16705cd419fdb9",
            "value": " 21/21 [29:45&lt;00:00, 87.47s/it, N=1319, ACC%=77.48, greedy=1]"
          }
        },
        "6d1c8f83e798474da7ff43fa69045b6b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "6ee861348c174dd4af158c10e200ddaa": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2a617d3a24c74fa1a092123d7425ba4c",
            "placeholder": "​",
            "style": "IPY_MODEL_af7f958bdf324b8b874212b929f1b30e",
            "value": " 21/21 [28:39&lt;00:00, 90.68s/it, N=1319, ACC%=79.61, greedy=1]"
          }
        },
        "6fd2c60356224d9c841a1fff4feaf6a4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7028269d58964755babbc4e93c1d482b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e3989b46ac8340e588fdaaaabf9b87e4",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_680210bb7bc64f0fb96b3cd11f955463",
            "value": 21
          }
        },
        "719a9e80b5be4b59b282aa4f189aea5a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_517c75a17f03426caaa7ba64f106abde",
            "placeholder": "​",
            "style": "IPY_MODEL_3058e0053b9d4278b0aaff3d42a15ef9",
            "value": " 21/21 [27:48&lt;00:00, 75.09s/it, N=1319, ACC%=80.82, greedy=1]"
          }
        },
        "725029a332d2467eb966e0b88fe16d2c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0d163a21883f448dbad9464fb66467e4",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5bc1e955e9c34d7fba8c9de399ae1ac0",
            "value": 21
          }
        },
        "74969bf267b248918164a67ddd94551b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f4c3f1cea2ab422bbe6df48e34f87950",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5c0761880e5f4438be927b9ab3c272ad",
            "value": 0
          }
        },
        "7583761f254c4a0fa19ef9c42de93cbd": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "75daa99c284d4012bf0df1f87999fbfb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "78a5366783a940afba4940c56349b7fb": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7976edc1125d4076a10b58cb1978e1db": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "79abb262b1e442a99b23b327f9547c21": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_5e59bd31893d465e860cfb6b6cc0b3ba",
              "IPY_MODEL_e93f4d72fdab45ce9b21e42e430702b6",
              "IPY_MODEL_e39ab8735bc14a11a8becb6d8b199a03"
            ],
            "layout": "IPY_MODEL_876b7a1f4fd142f7966996549c5e55eb"
          }
        },
        "7b9a9cfb02a74fb087e2eb3e128db7eb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_86b165dbf93541adb4a12f655d849c17",
            "placeholder": "​",
            "style": "IPY_MODEL_2d055629ae454fde9581e1a14e25c117",
            "value": " 21/21 [30:05&lt;00:00, 78.32s/it, N=1319, ACC%=80.14, greedy=1]"
          }
        },
        "7c7af70ab0b048428942d2042b6b619e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c1fba98d6fc148398dcdf8379e934d18",
            "placeholder": "​",
            "style": "IPY_MODEL_2c22e8e41b804702821dac54e8572221",
            "value": " 21/21 [30:20&lt;00:00, 87.48s/it, N=1319, ACC%=78.70, greedy=1]"
          }
        },
        "7d11610cf9ee4892b4056495f794b08c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_07a3b12d03cd4371a9698cbfe8173cfe",
              "IPY_MODEL_b4ea0408ecdd4d8ca466755833c4c5c8",
              "IPY_MODEL_b4eb097c85964ca9bfc73c6bf91cbaf9"
            ],
            "layout": "IPY_MODEL_3fafd0243abd4a528e686e5c5d8418bd"
          }
        },
        "7d41fb48052e4afbb6ba066568edc4e0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7f312f978adb474db9ce49c2f3f6153b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e3712025a45545dcb040599a739cedbc",
            "placeholder": "​",
            "style": "IPY_MODEL_53246caf2dc6425498f634664d914162",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "7f64771024e6484d864e2eb0678f4991": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_2f42c28f32794f19b3134e6e5dd79a25",
              "IPY_MODEL_9f4d36fc4f5c43d69b7d1c2cabbdced3",
              "IPY_MODEL_719a9e80b5be4b59b282aa4f189aea5a"
            ],
            "layout": "IPY_MODEL_d3e0ef7490d84c4b9c9f9b8aa2d0bacc"
          }
        },
        "7f8b51a3fe4343f39fa6ddc435410a3a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_da8588237dec4ae9877dc1ebb1236e48",
            "max": 7473,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_59f2130e9c564172b17f10efa43f72d6",
            "value": 7473
          }
        },
        "80261814dc5846e88c72e0e36390c7a8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_16faf06fe337402cb1d11e7ae06df0bd",
              "IPY_MODEL_b96e91dc71a54ec7b641598f13349382",
              "IPY_MODEL_10823f8c532b4bb69c49bb50a752b4dc"
            ],
            "layout": "IPY_MODEL_571d5fe9dd29472eabca54537d244354"
          }
        },
        "8089993e890249a885f48b20d127fd0d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_fffac9209fa24b4ea9728ec45b900bbe",
              "IPY_MODEL_facb9d35f00a493880724c3caef39e0f",
              "IPY_MODEL_3a65e1adf5a44fd0986e3011e33bf64c"
            ],
            "layout": "IPY_MODEL_44590ee9685940dabf80cef9de6c36db"
          }
        },
        "815fce21c79d43aaa335dd620804ec16": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8214948990d845e286f1dad4739a0b5b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "84531fd148ad4bcfac95d43da9913159": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8618259c164645a8b84cd465b6bfd3f8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "86b165dbf93541adb4a12f655d849c17": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "871a90e79fc04112bbb72d54d8ae8250": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "876b7a1f4fd142f7966996549c5e55eb": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "8ada8281b3ca4a2bbfc9bab988ec3df4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "8b1dd16d783d4c31a9e2a41386c72a1b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8e00c17408e44c95939b06e5e9954441": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9383fc235dee4ee19bd47c2e0df567a6",
            "placeholder": "​",
            "style": "IPY_MODEL_9dc8c60675634ebda305d807cc36e97e",
            "value": " 21/21 [32:26&lt;00:00, 82.95s/it, N=1319, ACC%=75.13, greedy=1]"
          }
        },
        "8ec215d4d667480996ed8da053e945cc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c67785ae392f4507acf38816eab3e5a9",
            "placeholder": "​",
            "style": "IPY_MODEL_6499bd4e0b3e499981c5e2cf182f1ba0",
            "value": " 21/21 [29:58&lt;00:00, 76.60s/it, N=1319, ACC%=79.00, greedy=1]"
          }
        },
        "9383fc235dee4ee19bd47c2e0df567a6": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "95b79511c0ee42e7b2ca30b63ea484b0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "973d042d851042da996b747d3da804b1": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "990fa215bf844e3b8ea8bc9bda7e1a3c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9f1cadc5d79649b19bc6a6941c3245e2",
            "placeholder": "​",
            "style": "IPY_MODEL_b0161588dac440cda9382e2482b32219",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "9ac05b0723ed4e13bfe6016993cf798a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_990fa215bf844e3b8ea8bc9bda7e1a3c",
              "IPY_MODEL_b9dfc499a7904ad0bcb6626f93487019",
              "IPY_MODEL_6c7383eab16b4c2385ab1c300b8423ee"
            ],
            "layout": "IPY_MODEL_3bbaeb287b5f4086b7c3e1474a34972c"
          }
        },
        "9b871c411c2b4809bb1d4ba9b8b650ab": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9c48059cb2054bbc875de066654e191a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b5893de63b934a4186596a42fcea8030",
              "IPY_MODEL_004da1ed97a44382bde90d6c5c244a2f",
              "IPY_MODEL_6ee861348c174dd4af158c10e200ddaa"
            ],
            "layout": "IPY_MODEL_871a90e79fc04112bbb72d54d8ae8250"
          }
        },
        "9cc921e5ec544e07832501b57d9f2f42": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9d6bb83db6ce499eaa2ca07bb07eca6c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_67849f803a08422fa357bd27a11bf964",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_75daa99c284d4012bf0df1f87999fbfb",
            "value": 21
          }
        },
        "9d723354a4344bdf8fd331a1446a2b8f": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "9dc8c60675634ebda305d807cc36e97e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9f1cadc5d79649b19bc6a6941c3245e2": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9f4d36fc4f5c43d69b7d1c2cabbdced3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f6bb7b85de724c7499493fa088886a28",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_3e446973a3dc4785ae37c485753cd964",
            "value": 21
          }
        },
        "a11f4f065a784736aa0cc9f5fdcf2afd": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "a3449a2e296043d0abeba3982d24e839": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "a4549b715203429f85d3e152e1f6dfde": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a6bad0fbc40d41199de894615150b714": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a7531cf6b513470da02333b1420e5c9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a88a028b745843a7aa4e495d70cfd270": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a9c3bba89d8b436e85dcf2c81bb35a84": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_78a5366783a940afba4940c56349b7fb",
            "placeholder": "​",
            "style": "IPY_MODEL_22a1c8806adf4710a02d5694485f42c1",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "aa45e9b80f3a46c3bc4e549322614387": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2ce513bcfcd7429d9fe6872d177bf38e",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_1596780773c64a6097256a02988375b7",
            "value": 21
          }
        },
        "aa4ba36a2be8423bb9c23cd1834b3c15": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9b871c411c2b4809bb1d4ba9b8b650ab",
            "placeholder": "​",
            "style": "IPY_MODEL_bfc80c383c86441b9aa68e7246922580",
            "value": "Evaluating qwen3b greedy trained:   0%"
          }
        },
        "aafbaa511943448cad086cccccf763d3": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ac1e15245d2d4f3f8662a9dae4e1da2c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5feae3252dcc45e58c65c5a114980ed6",
            "placeholder": "​",
            "style": "IPY_MODEL_018572d4a0684e3e8a338a0e9ef942de",
            "value": "Evaluating trained_grpo:   0%"
          }
        },
        "acf8d4bc4e4b4cf285ea39855fd6c75f": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "af7f958bdf324b8b874212b929f1b30e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "afb5e52a41534152a3db972a9bc94c54": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7f312f978adb474db9ce49c2f3f6153b",
              "IPY_MODEL_447f11db2a344b6abcf549f941628737",
              "IPY_MODEL_d340cb25e52b43caa6c98b8733cc4852"
            ],
            "layout": "IPY_MODEL_6753f3afa9284919825e60c1f71fd853"
          }
        },
        "b0161588dac440cda9382e2482b32219": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b0ac144780df465ab6db88d2d57e5174": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b3c7ade6ac214c12b7be4427a4119c72": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b40d1de4d655496b9da43aafbe598e78": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e6643458ec20491da4545e733e7b5206",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c61a70f21b7842e394e0245c3339bb20",
            "value": 21
          }
        },
        "b43925fcc9d94f9ab8d2082e7d9dc4f3": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "b4ea0408ecdd4d8ca466755833c4c5c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d801cfe586a84f2894c185c9f39aba38",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_95b79511c0ee42e7b2ca30b63ea484b0",
            "value": 21
          }
        },
        "b4eb097c85964ca9bfc73c6bf91cbaf9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3044700dd4774e2aba7aed8f70a39888",
            "placeholder": "​",
            "style": "IPY_MODEL_1a4508b1ac9c47919bb1565d636fa3a8",
            "value": " 21/21 [29:18&lt;00:00, 79.74s/it, N=1319, ACC%=77.79, greedy=1]"
          }
        },
        "b5893de63b934a4186596a42fcea8030": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5010496fa3ec42b3b4f3e2d088fb38d3",
            "placeholder": "​",
            "style": "IPY_MODEL_f2dd806f220443e38f2217ed88454632",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "b6c0617396d44837ac2bc00893b812d1": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_39b9ffa02eb84f85a76156cbc1c8a88b",
              "IPY_MODEL_7f8b51a3fe4343f39fa6ddc435410a3a",
              "IPY_MODEL_3a3884b6b2c045c2ae33bba4c0c462b4"
            ],
            "layout": "IPY_MODEL_51096091a9594b3cbba9e1f06c439f65"
          }
        },
        "b8afce26ae8b4305bff9ed6688f3ddfe": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_18717620d990478b9e58ce52fe77dcdc",
            "placeholder": "​",
            "style": "IPY_MODEL_84531fd148ad4bcfac95d43da9913159",
            "value": " 21/21 [29:43&lt;00:00, 84.98s/it, N=1319, ACC%=79.91, greedy=1]"
          }
        },
        "b8b781185a874943adab7f09ae19411c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0f45cb9c7061417bafe13fe40f54b663",
            "placeholder": "​",
            "style": "IPY_MODEL_815fce21c79d43aaa335dd620804ec16",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "b96e91dc71a54ec7b641598f13349382": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ce5b87c5dce045a88a4bf7885b2fb750",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_a11f4f065a784736aa0cc9f5fdcf2afd",
            "value": 21
          }
        },
        "b9dfc499a7904ad0bcb6626f93487019": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_68f6dd3d216d4a3d8ed8f0a915fccc27",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_157c4560897f4de3959667ae8483508e",
            "value": 21
          }
        },
        "bc684009a1ac42d3bc55815b41f56d88": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bd8a0dcb544e4e96b75c06fce618e747": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "bdc935dbdb1f4036b46e58fa2d80ddbf": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "bf8a24d686b24da584fd5443076f81af": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d885506b579b4e3cbf49f6f94e798929",
              "IPY_MODEL_b40d1de4d655496b9da43aafbe598e78",
              "IPY_MODEL_c70061a9ed754c2da95b745fb6506983"
            ],
            "layout": "IPY_MODEL_b43925fcc9d94f9ab8d2082e7d9dc4f3"
          }
        },
        "bfc80c383c86441b9aa68e7246922580": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c1fba98d6fc148398dcdf8379e934d18": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c2bd8dccff074a0ca8a9844319aba0ce": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b8b781185a874943adab7f09ae19411c",
              "IPY_MODEL_df6bebf829a54f86a8ffdb1c1358ca82",
              "IPY_MODEL_7b9a9cfb02a74fb087e2eb3e128db7eb"
            ],
            "layout": "IPY_MODEL_6d1c8f83e798474da7ff43fa69045b6b"
          }
        },
        "c2ce751a1ad44d4d9e26cc7b5dd5ddfc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_26932873ef3e4e439700e62540a54a98",
            "placeholder": "​",
            "style": "IPY_MODEL_9cc921e5ec544e07832501b57d9f2f42",
            "value": "Evaluating qwen3b greedy trained: 100%"
          }
        },
        "c61a70f21b7842e394e0245c3339bb20": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c67785ae392f4507acf38816eab3e5a9": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c70061a9ed754c2da95b745fb6506983": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_dc8a5f641d2d4c358838598fd89cfc9b",
            "placeholder": "​",
            "style": "IPY_MODEL_18ec27986e1d497695c66dd1cd4142a7",
            "value": " 21/21 [28:58&lt;00:00, 74.91s/it, N=1319, ACC%=79.91, greedy=1]"
          }
        },
        "ca928a7ea4c842c1a9213fa94729fec0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_4b6686f1b535452fa6b2aa22a4e6b701",
              "IPY_MODEL_9d6bb83db6ce499eaa2ca07bb07eca6c",
              "IPY_MODEL_7c7af70ab0b048428942d2042b6b619e"
            ],
            "layout": "IPY_MODEL_9d723354a4344bdf8fd331a1446a2b8f"
          }
        },
        "ce5b87c5dce045a88a4bf7885b2fb750": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ceb55c622676492d81007c4de64ddb20": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d1e977c54c0f48bf9bc787436da92f52": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d2127f38f39d426e9a4728ecae2d2738": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d25f806fab61436f8f5d3d5a16f5e0f1": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f06449b580ac4e849538a00e2df37a81",
            "placeholder": "​",
            "style": "IPY_MODEL_01806afad25045508bd4eb9cc25fd085",
            "value": " 21/21 [31:20&lt;00:00, 81.14s/it, N=1319, ACC%=77.71, greedy=1]"
          }
        },
        "d2bdcf3ef69b4d078d0f7b55bbe14422": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d2e36fee086b42a59bac40117f0ba4bc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8214948990d845e286f1dad4739a0b5b",
            "placeholder": "​",
            "style": "IPY_MODEL_4af918c43f6543a5bbfd43d2fa231160",
            "value": " 0/21 [00:00&lt;?, ?it/s]"
          }
        },
        "d340cb25e52b43caa6c98b8733cc4852": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fc5ba29bb98f466e999b28549b750c14",
            "placeholder": "​",
            "style": "IPY_MODEL_7d41fb48052e4afbb6ba066568edc4e0",
            "value": " 21/21 [30:08&lt;00:00, 84.39s/it, N=1319, ACC%=79.15, greedy=1]"
          }
        },
        "d3e0ef7490d84c4b9c9f9b8aa2d0bacc": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "d567c31ca55149979e15a11e2d996581": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_32a472a8ac6c4c598f2b6426144a960a",
              "IPY_MODEL_13962aacdf40481e82271020cfee3ea4",
              "IPY_MODEL_390bc2be57a343198a673543d41539d0"
            ],
            "layout": "IPY_MODEL_fb02ac5aa1cd4b8b8ee68a4592911456"
          }
        },
        "d801cfe586a84f2894c185c9f39aba38": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d867cd98dbf14053961a64c6efeb26f9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_dfe054ec515940549cfcbd9b7581b67d",
              "IPY_MODEL_7028269d58964755babbc4e93c1d482b",
              "IPY_MODEL_5bfa9a9eb18c4c8abcce1f1c9d2d0208"
            ],
            "layout": "IPY_MODEL_a3449a2e296043d0abeba3982d24e839"
          }
        },
        "d885506b579b4e3cbf49f6f94e798929": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_239a3be7a04444cd8ff61cc41ce6063e",
            "placeholder": "​",
            "style": "IPY_MODEL_4512f7a4b9134f4ab867da0c37642100",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "da8588237dec4ae9877dc1ebb1236e48": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "daa3a060f2cc43aab105c38476c26b57": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "dc8a5f641d2d4c358838598fd89cfc9b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "dd6c0cd3aec4480fb0f1f3e742575c1b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ddabfacb1370443b9e38726aad49bbb3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "df6bebf829a54f86a8ffdb1c1358ca82": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e98817ec7f25489698654cbf0c5e050a",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_dd6c0cd3aec4480fb0f1f3e742575c1b",
            "value": 21
          }
        },
        "dfe054ec515940549cfcbd9b7581b67d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0559847792ee4f358dff48a64861ceff",
            "placeholder": "​",
            "style": "IPY_MODEL_daa3a060f2cc43aab105c38476c26b57",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "e3712025a45545dcb040599a739cedbc": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e3989b46ac8340e588fdaaaabf9b87e4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e39ab8735bc14a11a8becb6d8b199a03": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d2127f38f39d426e9a4728ecae2d2738",
            "placeholder": "​",
            "style": "IPY_MODEL_7583761f254c4a0fa19ef9c42de93cbd",
            "value": " 21/21 [30:25&lt;00:00, 80.55s/it, N=1319, ACC%=79.68, greedy=1]"
          }
        },
        "e4a308877b7d4fb283f15b6613814519": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0224a2b6fc59439e9771f36f7f30176a",
            "placeholder": "​",
            "style": "IPY_MODEL_19bb218724e741248d79f0d5cbeee526",
            "value": " 0/21 [00:00&lt;?, ?it/s]"
          }
        },
        "e6643458ec20491da4545e733e7b5206": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e93f4d72fdab45ce9b21e42e430702b6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_973d042d851042da996b747d3da804b1",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_550be0f6d12e4c60886d0d5035ba6b3d",
            "value": 21
          }
        },
        "e98817ec7f25489698654cbf0c5e050a": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ea74d8f5c7bb42d794872a048875ce93": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ea78979948144c6e800e16d0ac934bf5": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "eab965c99d6d445db3a182b128a0fce5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "f00948e6be484b6b84e2496c32227eef": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f06449b580ac4e849538a00e2df37a81": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f2bae048b4034efea7db8b2f6fb4f526": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f2dd806f220443e38f2217ed88454632": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f4c3f1cea2ab422bbe6df48e34f87950": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f64dddc0ccdc49f6b94e191c0107f510": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "f6bb7b85de724c7499493fa088886a28": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fa2d6c5f05064aaaadb361b62c978702": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_ac1e15245d2d4f3f8662a9dae4e1da2c",
              "IPY_MODEL_74969bf267b248918164a67ddd94551b",
              "IPY_MODEL_d2e36fee086b42a59bac40117f0ba4bc"
            ],
            "layout": "IPY_MODEL_4343230de097491ebd98919615ca7c1f"
          }
        },
        "facb9d35f00a493880724c3caef39e0f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4d9917046a984adf821475a945ce25f9",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_157ce59e61c241ca97f75fcf0da58a6c",
            "value": 21
          }
        },
        "fb02ac5aa1cd4b8b8ee68a4592911456": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "fb3ae36336974394913c44c99e5de061": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ea74d8f5c7bb42d794872a048875ce93",
            "placeholder": "​",
            "style": "IPY_MODEL_2bbbedc99ea347e59b70a59a30186297",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "fbe0e0a348ad49c5be519e189b2e7b5c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "fc4d8b814f454a7c8515dfc3f4f957d0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "fc5ba29bb98f466e999b28549b750c14": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fffac9209fa24b4ea9728ec45b900bbe": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_53847574e82447c1ad184ce452e58b5b",
            "placeholder": "​",
            "style": "IPY_MODEL_bd8a0dcb544e4e96b75c06fce618e747",
            "value": "Evaluating trained_grpo: 100%"
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}