{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "fU6KXDe-mNjD",
      "metadata": {
        "id": "fU6KXDe-mNjD"
      },
      "source": [
        "# Qwen: Direct Projection Training with Paraphrase Coherence (TriviaQA)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cva0_WZcmNjF",
      "metadata": {
        "id": "cva0_WZcmNjF"
      },
      "source": [
        "This Colab trains a **Qwen causal LM** with a **direct projection** objective:\n",
        "\n",
        "- **KL to frozen baseline** keeps the model near \\(\\omega_0\\).\n",
        "- **Pairwise JS** improves **coherence across paraphrases** of the same question.\n",
        "- Uses **REINFORCE** estimators with baselines, multi-sample averaging, and log-sum-exp numerics."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Q3RRjK20mNjF",
      "metadata": {
        "id": "Q3RRjK20mNjF"
      },
      "source": [
        "## 0) Runtime & GPU check"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "AIguI7ArakvF",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AIguI7ArakvF",
        "outputId": "0261e064-5287-4203-dee3-993aef4d6872"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]\n",
            "OS: Linux-6.6.105+-x86_64-with-glibc2.35\n",
            "Arch: x86_64\n",
            "Torch: 2.9.0+cu126\n",
            "CUDA (toolkit from Torch): 12.6\n",
            "CUDA available: True\n",
            "GPU: NVIDIA A100-SXM4-80GB\n",
            "SM capability: (8, 0)\n",
            "VRAM (GB): 79.32\n"
          ]
        }
      ],
      "source": [
        "import platform, torch, subprocess, sys\n",
        "import os, torch\n",
        "\n",
        "print(\"Python:\", sys.version)\n",
        "print(\"OS:\", platform.platform())\n",
        "print(\"Arch:\", platform.machine())\n",
        "print(\"Torch:\", torch.__version__)\n",
        "print(\"CUDA (toolkit from Torch):\", torch.version.cuda)\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "    print(\"SM capability:\", torch.cuda.get_device_capability(0))\n",
        "    print(\"VRAM (GB):\", round(torch.cuda.get_device_properties(0).total_memory/2**30,2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "zEE6ahLV5LZX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zEE6ahLV5LZX",
        "outputId": "10478c16-705e-4c0d-fab3-c8c527aad62b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading: https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.15/flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n",
            "Installing: flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n",
            "Installed flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import subprocess\n",
        "\n",
        "# 1) Force flash-attn to use a prebuilt wheel or fail fast\n",
        "os.environ[\"FLASH_ATTENTION_SKIP_CUDA_BUILD\"] = \"TRUE\"\n",
        "\n",
        "# 2) Download the wheel for:\n",
        "#    - flash_attn 2.8.3\n",
        "#    - CUDA 12.6 (cu126)\n",
        "#    - torch 2.9\n",
        "#    - Python 3.12 (cp312)\n",
        "wheel_url = \"https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.15/flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\"\n",
        "wheel_name = wheel_url.split(\"/\")[-1]\n",
        "\n",
        "print(\"Downloading:\", wheel_url)\n",
        "subprocess.run([\"wget\", \"-q\", wheel_url, \"-O\", wheel_name], check=True)\n",
        "\n",
        "print(\"Installing:\", wheel_name)\n",
        "subprocess.run(\n",
        "    [\"pip\", \"install\", \"--no-dependencies\", \"--upgrade\", wheel_name],\n",
        "    check=True,\n",
        ")\n",
        "\n",
        "print(\"Installed\", wheel_name)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "rgVu_ySz5ReL",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rgVu_ySz5ReL",
        "outputId": "49845b60-f33b-45bb-f60f-fdfbd786117a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "flash_attn version: 2.8.3\n"
          ]
        }
      ],
      "source": [
        "import flash_attn\n",
        "\n",
        "print(\"flash_attn version:\", getattr(flash_attn, \"__version__\", \"unknown\"))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "qQ3UlHbC5dLW",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qQ3UlHbC5dLW",
        "outputId": "1f7ce443-09e1-4950-9c1f-d520c65a9bb8"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Torch: 2.9.0+cu126\n",
            "CUDA available: True\n",
            "Device: NVIDIA A100-SXM4-80GB\n",
            "Compute capability: (8, 0)\n",
            "FlashAttention forward pass OK.\n",
            "Output shape: torch.Size([2, 128, 8, 64])\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from flash_attn import flash_attn_func\n",
        "\n",
        "device = \"cuda\"\n",
        "\n",
        "print(\"Torch:\", torch.__version__)\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "print(\"Device:\", torch.cuda.get_device_name(0))\n",
        "print(\"Compute capability:\", torch.cuda.get_device_capability(0))\n",
        "\n",
        "# Small test config\n",
        "B = 2      # batch size\n",
        "L = 128    # sequence length\n",
        "H = 8      # number of heads\n",
        "D = 64     # head dimension (must be <= 256 for FA2)\n",
        "\n",
        "# Q, K, V shapes: (batch, seqlen, nheads, headdim)\n",
        "q = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "k = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "v = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "\n",
        "try:\n",
        "    out = flash_attn_func(\n",
        "        q, k, v,\n",
        "        dropout_p=0.0,\n",
        "        softmax_scale=None,\n",
        "        causal=False,\n",
        "    )\n",
        "    print(\"FlashAttention forward pass OK.\")\n",
        "    print(\"Output shape:\", out.shape)\n",
        "except Exception as e:\n",
        "    print(\"FlashAttention FAILED on GPU:\")\n",
        "    print(type(e).__name__, \":\", e)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "uUeRWvsdQv1C",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uUeRWvsdQv1C",
        "outputId": "1b6741de-f509-4131-ab56-795cbae5e265"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "u7T3ljRumNjG",
      "metadata": {
        "id": "u7T3ljRumNjG"
      },
      "source": [
        "## 1) Install packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d2snW-PemNjG",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "d2snW-PemNjG",
        "outputId": "c18a511c-0bfa-4eec-881b-56635863f008"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.8 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m73.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.1)\n"
          ]
        }
      ],
      "source": [
        "!pip -q install --upgrade pip\n",
        "!pip -q install -U \"transformers>=4.43.0\" \"accelerate>=0.33.0\" \"datasets>=2.19.0\" \"bitsandbytes>=0.43.0\" \"peft>=0.11.0\" sentence-transformers einops\n",
        "# Optional: flash-attn (if wheel exists for your GPU)\n",
        "# !pip -q install flash-attn --no-build-isolation\n",
        "!pip install tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bBwLq3tAZBfd",
      "metadata": {
        "id": "bBwLq3tAZBfd"
      },
      "outputs": [],
      "source": [
        "!pip -q install -U \"trl>=0.8.0\""
      ]
    },
    {
      "cell_type": "markdown",
      "id": "IviYIdL0mNjG",
      "metadata": {
        "id": "IviYIdL0mNjG"
      },
      "source": [
        "## 2) Imports & global config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QGO8iNsFmNjG",
      "metadata": {
        "id": "QGO8iNsFmNjG"
      },
      "outputs": [],
      "source": [
        "from dataclasses import dataclass\n",
        "from typing import List, Tuple\n",
        "import torch, math, random, numpy as np\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
        "from datasets import load_dataset\n",
        "from sentence_transformers import SentenceTransformer, util as sbert_util\n",
        "\n",
        "SEED = 42\n",
        "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "# Choose model (smaller first for Colab T4/L4)\n",
        "TRAINABLE_MODEL_NAME = \"meta-llama/Llama-3.2-3B-Instruct\"#\"microsoft/Phi-3-mini-4k-instruct\"#\"Qwen/Qwen2.5-3B-Instruct\"#\n",
        "BASELINE_MODEL_NAME  = TRAINABLE_MODEL_NAME\n",
        "\n",
        "USE_4BIT = True\n",
        "BNB_CONFIG = BitsAndBytesConfig(\n",
        "    load_in_4bit=USE_4BIT,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
        ")\n",
        "\n",
        "BASELINE_DEVICE_MAP = \"auto\"#\"cpu\"   # \"cpu\" to save VRAM; set to \"auto\" if you have >=24GB\n",
        "\n",
        "MAX_NEW_TOKENS = 512\n",
        "GEN_TEMPERATURE = 0.7\n",
        "TOP_P = 0.9\n",
        "TOP_K = 50\n",
        "\n",
        "PARAPHRASES_PER_QUESTION = 2     # total per class incl. original\n",
        "NUM_CLASSES_FOR_TRAIN = 10000\n",
        "SAMPLES_PER_PAIR = 16\n",
        "JS_WEIGHT = 1.0\n",
        "KL_WEIGHT = 0.02\n",
        "\n",
        "LR = 2e-5\n",
        "WEIGHT_DECAY = 0.01\n",
        "GRAD_CLIP = 1.0\n",
        "STEPS = 50\n",
        "BATCH_SIZE_CLASSES = 32\n",
        "LOG_EVERY = 20\n",
        "\n",
        "#\n",
        "WORK_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-mode-jan17/weight-5e-2\"\n",
        "CKPT_LOAD_PATH = False#f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-majority-dec29/weight-1e-2/checkpoint-1600\"#f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-majority-dec29/weight-1/checkpoint-1600\"\n",
        "COHERENCE_DATA_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct/\"\n",
        "os.makedirs(WORK_PATH, exist_ok=True)\n",
        "CACHE_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZMmHw__UNCjX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZMmHw__UNCjX",
        "outputId": "c13c8556-cb8b-4b4c-bbd2-03e7f3d57e77"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ls: cannot access 'False': No such file or directory\n"
          ]
        }
      ],
      "source": [
        "!ls $CKPT_LOAD_PATH"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c7DuQugEeSaD",
      "metadata": {
        "id": "c7DuQugEeSaD"
      },
      "source": [
        "## Utils"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Ul00h3tomNjH",
      "metadata": {
        "id": "Ul00h3tomNjH"
      },
      "source": [
        "## 3) Load tokenizer, trainable QLoRA model, and frozen baseline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "T5xYHUDKds80",
      "metadata": {
        "id": "T5xYHUDKds80"
      },
      "outputs": [],
      "source": [
        "# @title\n",
        "# === utils_loaders.py ===\n",
        "import os, gc, torch, platform\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
        "\n",
        "# ---------- Helpers ----------\n",
        "\n",
        "def gpu_supports_bf16() -> bool:\n",
        "    return torch.cuda.is_available() and torch.cuda.is_bf16_supported()\n",
        "\n",
        "def pick_dtype():\n",
        "    return torch.bfloat16 if gpu_supports_bf16() else (torch.float16 if torch.cuda.is_available() else torch.float32)\n",
        "\n",
        "def pick_attn_impl() -> str:\n",
        "    \"\"\"\n",
        "    FlashAttention-2 only on Ampere+ (SM >= 80). Otherwise use SDPA.\n",
        "    \"\"\"\n",
        "    if not torch.cuda.is_available():\n",
        "        return \"sdpa\"\n",
        "    sm_major, _ = torch.cuda.get_device_capability(0)\n",
        "    print (\"using flash_attention_2\") if sm_major >= 8 else print (\"using sdpa\")\n",
        "    return \"flash_attention_2\" if sm_major >= 8 else \"sdpa\"\n",
        "\n",
        "def hard_cleanup(*names_to_del, also_empty_cache=True):\n",
        "    \"\"\"\n",
        "    Delete large globals by name (if present), run gc, and empty CUDA cache.\n",
        "    \"\"\"\n",
        "    g = globals()\n",
        "    for n in names_to_del:\n",
        "        if n in g:\n",
        "            try: del g[n]\n",
        "            except: pass\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available() and also_empty_cache:\n",
        "        torch.cuda.empty_cache()\n",
        "        torch.cuda.synchronize()\n",
        "\n",
        "def make_bnb_config(load_in_4bit=True):\n",
        "    # Lazy import to avoid dependency unless used\n",
        "    from transformers import BitsAndBytesConfig\n",
        "    return BitsAndBytesConfig(\n",
        "        load_in_4bit=load_in_4bit,\n",
        "        load_in_8bit=not load_in_4bit and True or False,   # if you ever want 8-bit, flip flags accordingly\n",
        "        bnb_4bit_compute_dtype=pick_dtype(),\n",
        "        bnb_4bit_use_double_quant=True,\n",
        "        bnb_4bit_quant_type=\"nf4\",\n",
        "    )\n",
        "\n",
        "def setup_tokenizer(model_id: str, use_fast=True):\n",
        "    tok = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, trust_remote_code=True)\n",
        "    # padding & truncation for decoder-only batched gen\n",
        "    if tok.pad_token is None:\n",
        "        # Prefer a dedicated pad token if it exists, else fall back to eos\n",
        "        tok.pad_token = tok.eos_token\n",
        "    tok.padding_side = \"left\"\n",
        "    tok.truncation_side = \"left\"\n",
        "    return tok\n",
        "\n",
        "# ---------- Trainable (QLoRA) ----------\n",
        "\n",
        "def load_trainable_model(\n",
        "    model_id: str,\n",
        "    use_4bit: bool = True,\n",
        "    lora_r: int = 8,\n",
        "    lora_alpha: int = 16,\n",
        "    lora_dropout: float = 0.05,\n",
        "    extra_target_modules: list[str] | None = None,\n",
        "    device_map: str | dict = \"auto\",\n",
        "):\n",
        "    \"\"\"\n",
        "    Load a k-bit QLoRA-ready model + tokenizer and wrap with PEFT LoRA.\n",
        "    Returns: (tokenizer, peft_model)\n",
        "\n",
        "    Special case:\n",
        "      - For microsoft/Phi-3* models, we force trust_remote_code=False\n",
        "        to avoid the DynamicCache.seen_tokens AttributeError.\n",
        "    \"\"\"\n",
        "\n",
        "    # --- tokenizer ---\n",
        "    tok = setup_tokenizer(model_id)\n",
        "\n",
        "    # --- quantization config ---\n",
        "    quant_cfg = make_bnb_config(load_in_4bit=use_4bit) if use_4bit else None\n",
        "\n",
        "    # --- pick trust_remote_code depending on model ---\n",
        "    # Phi-3 has outdated remote modeling code that breaks with recent transformers.\n",
        "    is_phi3 = \"phi-3\" in model_id.lower()\n",
        "    trust_remote = False if is_phi3 else True\n",
        "\n",
        "    # --- load base model ---\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        model_id,\n",
        "        quantization_config=quant_cfg,\n",
        "        torch_dtype=pick_dtype(),\n",
        "        device_map=device_map,\n",
        "        low_cpu_mem_usage=True,\n",
        "        trust_remote_code=trust_remote,\n",
        "        attn_implementation=pick_attn_impl(),  # HF will choose SDPA if needed\n",
        "    )\n",
        "\n",
        "    # Prepare for k-bit training\n",
        "    model = prepare_model_for_kbit_training(model)\n",
        "\n",
        "    # Auto-detect common projection module names if not provided\n",
        "    targets = set(extra_target_modules or [])\n",
        "    for name, _ in model.named_modules():\n",
        "        low = name.lower().split(\".\")[-1]\n",
        "        if low in {\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"}:\n",
        "            targets.add(low)\n",
        "    if not targets:\n",
        "        targets = {\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\"}\n",
        "\n",
        "    lora_cfg = LoraConfig(\n",
        "        r=lora_r,\n",
        "        lora_alpha=lora_alpha,\n",
        "        target_modules=sorted(targets),\n",
        "        lora_dropout=lora_dropout,\n",
        "        bias=\"none\",\n",
        "        task_type=\"CAUSAL_LM\",\n",
        "    )\n",
        "    model = get_peft_model(model, lora_cfg)\n",
        "\n",
        "    # Required for gradient checkpointing with PEFT\n",
        "    model.config.use_cache = False\n",
        "\n",
        "    model.enable_input_require_grads()\n",
        "    model.gradient_checkpointing_enable()\n",
        "\n",
        "    model.print_trainable_parameters()\n",
        "\n",
        "    return tok, model\n",
        "\n",
        "\n",
        "# ---------- Frozen Baseline (inference) ----------\n",
        "\n",
        "def load_baseline_model(\n",
        "    model_id: str,\n",
        "    force_cpu_embedder: bool = True,\n",
        "    use_4bit: bool = False,\n",
        "    device_map: str | dict = \"auto\",\n",
        "    max_memory: dict | None = None,\n",
        "    allow_offload_folder: str | None = None,\n",
        "):\n",
        "    \"\"\"\n",
        "    Load a frozen inference model with the fastest attention available for your GPU.\n",
        "    Returns: (tokenizer, baseline_model)\n",
        "    \"\"\"\n",
        "    tok = setup_tokenizer(model_id)\n",
        "\n",
        "    kwargs = dict(\n",
        "        device_map=device_map,\n",
        "        low_cpu_mem_usage=True,\n",
        "        trust_remote_code=True,\n",
        "        attn_implementation=pick_attn_impl(),\n",
        "    )\n",
        "\n",
        "    if use_4bit:\n",
        "        kwargs[\"quantization_config\"] = make_bnb_config(load_in_4bit=True)\n",
        "    else:\n",
        "        kwargs[\"torch_dtype\"] = pick_dtype()\n",
        "\n",
        "    if max_memory is not None:\n",
        "        kwargs[\"max_memory\"] = max_memory\n",
        "    if allow_offload_folder is not None:\n",
        "        os.makedirs(allow_offload_folder, exist_ok=True)\n",
        "        kwargs[\"offload_folder\"] = allow_offload_folder\n",
        "        kwargs[\"offload_state_dict\"] = True\n",
        "\n",
        "    model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval()\n",
        "\n",
        "    # Freeze to be explicit\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad_(False)\n",
        "\n",
        "    return tok, model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "FswzqYDPdwoH",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "referenced_widgets": [
            "5c2a8e737abb41cd861753909ecc4f1b",
            "2a7970cb9ff44c3b8f9e370223327d49",
            "feca0d3b91d241c0a81824c57a28d6df",
            "ff754a2a323741b585a73b55b63cd219",
            "9e1fae8fcf3545fcb8dfcf9fd14d5890",
            "e2f37aea6afd4ede81a240588f1d88dd",
            "4068211b5c1b48f59253923e3237ba25",
            "b118a18d02b24044853c086278044d88",
            "4410990ea8f346c39703ff02ca7996af",
            "65597fbb2ccc43d282a1f8366e94f72f",
            "d02a4ead5c7b434f913ec8e70550187d",
            "66263c8eadfd474c8a67a11b30055ddb",
            "22eed0d1eef643c4b9f5d5f02a3a5bd9",
            "d14709117ec147ea9547b72fd3bc3dbb",
            "feacc7ad63334712a7f7e08b8d45624c",
            "9a38e5b3eb314071bed4f955632605ec",
            "e879a2550d514432bccd7f8f76169a56",
            "b226dc2cdc844e37852fb2c98950812e",
            "f97c79e9765847a49b5b5c987ed7e835",
            "848ec43f0e5148bbb452f89d29870b30",
            "af63a7c7c8dc4a68a6436e9b02c7b428",
            "5e2d2cc4bcc6407ea39a4fb70d10155c",
            "2f32d607fbcf4eb39ea3962a8d9bc520",
            "92f7ecd8d8fa4f429a78e1499cc7f400",
            "ac2c082b53ac4e1fbb5d78b9ef80b81b",
            "52a3413726ce4ae79ec9bad9b972f01d",
            "2d01798b5c3d49a5b056f2ef6bfb6c25",
            "9ca9bdc58f3a42e395905139385b9e9b",
            "7134e35138e64efea8abff67100ae40e",
            "8ffee91d96b34a13bf448178a109ce79",
            "ea7672d0767f480696aa09f075d4772d",
            "164133cae7374a258b3d217b115bc379",
            "6b0436f2c9ac4e78b8b456bc4c73afae",
            "345807be67e54d2ebd9bdadebf8992b2",
            "a79a2eb11be6486a9e42993f1e8ba2d4",
            "279630ffc6a544e196f8d0fd4e5ead06",
            "b1aed9b47c5842b6afa5d91461728259",
            "9f014d810d5e4be892ed0fff1cd9183c",
            "d58e83c0ede843a6bfca39fb18753489",
            "e01eb59e7a9f4c059cd08ba8d2c8957a",
            "c3e13b43249e49b9b270a1f9aecb15d8",
            "f94d24a0eee24f8bb99afd296a6fb884",
            "248768b4966d45318f0feaf8da35411b",
            "95d87cd4cde34ba4a449a98eb27cf169",
            "253871439f644e97815714dcf14a55ae",
            "d4a34bbaf6264706a6a79953649fd31e",
            "e48f6ef7549d4cc9bc3ea6acff8ba375",
            "bca11efb3a754b39b055445f470ae89f",
            "686d32adbffc4600b4ea5e9ed755f620",
            "d2dfb8df5c27478eaa1d2c674a62faa5",
            "1d5500bb6f31454293e8d922a1ca1101",
            "d18c1c5e9c884514bb794373d78caa2e",
            "b828683ae02c44aaab389538a41e67ea",
            "3cb48f1eabd543afb0cefd44b004c348",
            "5419d0335507421a9a7a8ae28697a679",
            "6885f9a52efd4f9b9e65d43b062f3103",
            "aceb980633344130a3086cbd0b45fa9b",
            "ae4d8432117d4928bdc3f16392749120",
            "4ef3316dded5415589b49e2f2cf5988a",
            "5649585455c74d7a88664b6c6dfbe575",
            "04f3bcff5fd84c988a1c67154e3d3460",
            "53f36048a2764289a49ea26692000911",
            "637ca7d0597942a0ab719e2920334174",
            "8c19165f390e44f284e696ad9c5eecca",
            "eb5c01168ef64d85873e5726e45be03b",
            "34a4977199cb4d05ab2fbc8c3076100d",
            "d1fd251b122f4b7eb9d0887adb94d8c4",
            "7a1088b57ed4412283e917b2312f3799",
            "f10ef23bdd034cbea4c5151f16977351",
            "2405965d92fd41dc96360500b2aa72dd",
            "20e493de23184f508bfed4832b3ae345",
            "4e98e8d97cf34aaa9142e96c70bf3bfa",
            "87a4daada34f4cc493b83fa9cba200f0",
            "9a0b62601c9b4658a68d67f365fc21b5",
            "6bddf26b59eb4a948457b1cc7d5a4044",
            "ae6c402cf7c1407bad8c8d0e31572bf5",
            "bb359cf0ce3a4981a53cff4535022f69",
            "1d2ca94bc22e499cba0ef83b22541e6d",
            "e691a50ec860483f964bd98319327d1b",
            "5abe002f92ed43b38fccb3c02eeb1d3b",
            "b1b744c870a54f26ba7161beda026379",
            "4e0a7ad1a6454b57b1aaffa91bd7fe50",
            "8b9c5f0a2b1d4a539de9c7bafac96146",
            "f79777359f684bddbdf5172dd200d737",
            "6dbf455b1bb441b1bf4eede9bec60111",
            "cbb8f7741cec45f9aad53b04913313fa",
            "5c159ca87657449aa0a49d5fcd056b2c",
            "b5f4573a6b314e879b2e315f64a543dc",
            "7750c11e1e194ff0890a57e56e27e263",
            "5a8b0227548d43bcaec232d96adf734e",
            "1e1871367f7d4094af05db9d37f04607",
            "195a11df03dd4d70ba54464f8976b073",
            "dc410ad99a5a44cdb6b9762bd713afff",
            "d3e708079ef64dc6b7b3f466f834b7be",
            "786cdc0611a34b8c854f68b8b2e09e9c",
            "be406211e41b44cf83e862e30e81ab75",
            "1c7ce4adb8da4386b26a63396da0a371",
            "0aee1ef79f444de2b9f339171e632de0",
            "0039028878bb4068950bc57dbefbdf58",
            "8fa409bfc803409b8fd72b6139524acf",
            "a05dfcf3316640c59b68fab1d32f44b7",
            "288f07b9303a4673a3f47abf3a61c216",
            "5c358655cd494ee8bce41d6cbe5b0634",
            "d3b8ff18c54c43fd976997a87a90dd57",
            "6ad1432aadfb42329a77a17954efeda5",
            "2eed566299d74471a737b836dff92cf0",
            "53c53f1199644f7f9cd06ac35f96d83e",
            "8c1ced63f1bc49ada2525d473fc9a29f",
            "7829a146b0ff44f9ac880cd6af96107c",
            "b62990369a7847daa3d23cf2ee3b6f53"
          ]
        },
        "collapsed": true,
        "id": "FswzqYDPdwoH",
        "outputId": "1860c0f2-1557-4762-ccab-b79021510023"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "5c2a8e737abb41cd861753909ecc4f1b"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "66263c8eadfd474c8a67a11b30055ddb"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "2f32d607fbcf4eb39ea3962a8d9bc520"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "using flash_attention_2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "345807be67e54d2ebd9bdadebf8992b2"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "`torch_dtype` is deprecated! Use `dtype` instead!\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "253871439f644e97815714dcf14a55ae"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "6885f9a52efd4f9b9e65d43b062f3103"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d1fd251b122f4b7eb9d0887adb94d8c4"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "1d2ca94bc22e499cba0ef83b22541e6d"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7750c11e1e194ff0890a57e56e27e263"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "8fa409bfc803409b8fd72b6139524acf"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "trainable params: 12,156,928 || all params: 3,224,906,752 || trainable%: 0.3770\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\n# 0) Optional: reduce fragmentation (set *before* importing torch in a fresh runtime)\\nimport os\\nos.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True,max_split_size_mb:256\")\\n\\n# 1) Clean any old models from VRAM\\n\\n# hard_cleanup(\"baseline_model\", \"trainable_model\", \"embedder\")\\n\\n# 3) Load FROZEN baseline model for inference/paraphrasing\\nBASELINE_MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\\n# If VRAM is tight, leave some headroom and allow offload\\nmax_mem = {0: \"13GiB\", \"cpu\": \"48GiB\"} if torch.cuda.is_available() else None\\ntokenizer, baseline_model = load_baseline_model(\\n    BASELINE_MODEL_NAME,\\n    use_4bit=False,               # can set True to save VRAM\\n    device_map=BASELINE_DEVICE_MAP,\\n    max_memory=max_mem,           # optional\\n    allow_offload_folder=\"./offload\"  # optional\\n)\\n\\n# 4) Embedder (keep on CPU unless you really need GPU)\\nfrom sentence_transformers import SentenceTransformer\\nembedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\\n# embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 11
        }
      ],
      "source": [
        "# 2) Load fresh TRAINABLE (QLoRA) model\n",
        "\n",
        "\n",
        "\n",
        "# --- Assumes your helper builds a 4-bit base + attaches LoRA with the same hyperparams ---\n",
        "# (use exactly the same args you trained with)\n",
        "tokenizer, trainable_model = load_trainable_model(\n",
        "    TRAINABLE_MODEL_NAME,\n",
        "    use_4bit=True,           # QLoRA default\n",
        "    lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "\n",
        "\n",
        "'''\n",
        "import copy\n",
        "baseline_model = copy.deepcopy(trainable_model).eval()\n",
        "for p in baseline_model.parameters():\n",
        "    p.requires_grad_(False)\n",
        "'''\n",
        "# from sentence_transformers import SentenceTransformer\n",
        "# embedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\n",
        "# # embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\n",
        "\n",
        "\n",
        "\n",
        "'''\n",
        "# 0) Optional: reduce fragmentation (set *before* importing torch in a fresh runtime)\n",
        "import os\n",
        "os.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True,max_split_size_mb:256\")\n",
        "\n",
        "# 1) Clean any old models from VRAM\n",
        "\n",
        "# hard_cleanup(\"baseline_model\", \"trainable_model\", \"embedder\")\n",
        "\n",
        "# 3) Load FROZEN baseline model for inference/paraphrasing\n",
        "BASELINE_MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
        "# If VRAM is tight, leave some headroom and allow offload\n",
        "max_mem = {0: \"13GiB\", \"cpu\": \"48GiB\"} if torch.cuda.is_available() else None\n",
        "tokenizer, baseline_model = load_baseline_model(\n",
        "    BASELINE_MODEL_NAME,\n",
        "    use_4bit=False,               # can set True to save VRAM\n",
        "    device_map=BASELINE_DEVICE_MAP,\n",
        "    max_memory=max_mem,           # optional\n",
        "    allow_offload_folder=\"./offload\"  # optional\n",
        ")\n",
        "\n",
        "# 4) Embedder (keep on CPU unless you really need GPU)\n",
        "from sentence_transformers import SentenceTransformer\n",
        "embedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\n",
        "# embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "tKvuvb_gBq3k",
      "metadata": {
        "id": "tKvuvb_gBq3k"
      },
      "outputs": [],
      "source": [
        "\n",
        "if CKPT_LOAD_PATH:\n",
        "    # Now load the trained adapter weights into this PEFT model\n",
        "    ADAPTER_DIR = CKPT_LOAD_PATH  # folder containing adapter_config.json\n",
        "    print(\"loading from: \", ADAPTER_DIR)\n",
        "    try:\n",
        "        # Preferred: load adapter into the existing PeftModel\n",
        "        trainable_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\n",
        "        trainable_model.set_adapter(\"default\")\n",
        "    except AttributeError:\n",
        "        # Fallback for older PEFT versions: load the adapter state dict directly\n",
        "        from peft.utils import set_peft_model_state_dict\n",
        "        import torch, os\n",
        "        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\n",
        "                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\n",
        "        set_peft_model_state_dict(trainable_model, state)\n",
        "        # make sure only LoRA params are trainable if you plan to keep training\n",
        "        for n, p in trainable_model.named_parameters():\n",
        "            p.requires_grad = (\"lora_\" in n)\n",
        "\n",
        "    trainable_model.train()  # if you’re resuming training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "glPKBu0Uo3sX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "glPKBu0Uo3sX",
        "outputId": "e13c7505-4e97-477d-9ce3-0bfc1c94143e"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\n_, baseline_model = load_trainable_model(\\n    TRAINABLE_MODEL_NAME,\\n    use_4bit=True,           # QLoRA default\\n    lora_r=8, lora_alpha=16, lora_dropout=0.05,\\n    device_map=\"auto\",\\n)\\n\\nfor p in baseline_model.parameters():\\n     p.requires_grad_(False)\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 13
        }
      ],
      "source": [
        "'''\n",
        "_, baseline_model = load_trainable_model(\n",
        "    TRAINABLE_MODEL_NAME,\n",
        "    use_4bit=True,           # QLoRA default\n",
        "    lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "\n",
        "for p in baseline_model.parameters():\n",
        "     p.requires_grad_(False)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B6OVE7wD_awp",
      "metadata": {
        "id": "B6OVE7wD_awp"
      },
      "outputs": [],
      "source": [
        "# BASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\n",
        "\n",
        "\n",
        "# _, baseline_model = load_trainable_model(\n",
        "#     TRAINABLE_MODEL_NAME,\n",
        "#     use_4bit=True,           # QLoRA default\n",
        "#     lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "#     device_map=\"auto\",\n",
        "# )\n",
        "\n",
        "\n",
        "# baseline_model = AutoModelForCausalLM.from_pretrained(\n",
        "#         BASELINE_PATH, device_map=\"auto\", quantization_config=bnb_config\n",
        "#     )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "UNCMXFLbCJ6F",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UNCMXFLbCJ6F",
        "outputId": "a201a4dd-3a24-4571-c794-d470a9f5533a"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\nBASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\\nif BASELINE_PATH:\\n    # Now load the baseline adapter weights into this PEFT model\\n    ADAPTER_DIR = BASELINE_PATH # folder containing adapter_config.json\\n    try:\\n        # Preferred: load adapter into the existing PeftModel\\n        baseline_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\\n        baseline_model.set_adapter(\"default\")\\n    except AttributeError:\\n        # Fallback for older PEFT versions: load the adapter state dict directly\\n        from peft.utils import set_peft_model_state_dict\\n        import torch, os\\n        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\\n                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\\n        set_peft_model_state_dict(baseline_model, state)\\n        # make sure only LoRA params are trainable if you plan to keep training\\n        for n, p in baseline_model.named_parameters():\\n            p.requires_grad = (\"lora_\" in n)\\n\\n    baseline_model.eval()\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 15
        }
      ],
      "source": [
        "'''\n",
        "BASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\n",
        "if BASELINE_PATH:\n",
        "    # Now load the baseline adapter weights into this PEFT model\n",
        "    ADAPTER_DIR = BASELINE_PATH # folder containing adapter_config.json\n",
        "    try:\n",
        "        # Preferred: load adapter into the existing PeftModel\n",
        "        baseline_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\n",
        "        baseline_model.set_adapter(\"default\")\n",
        "    except AttributeError:\n",
        "        # Fallback for older PEFT versions: load the adapter state dict directly\n",
        "        from peft.utils import set_peft_model_state_dict\n",
        "        import torch, os\n",
        "        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\n",
        "                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\n",
        "        set_peft_model_state_dict(baseline_model, state)\n",
        "        # make sure only LoRA params are trainable if you plan to keep training\n",
        "        for n, p in baseline_model.named_parameters():\n",
        "            p.requires_grad = (\"lora_\" in n)\n",
        "\n",
        "    baseline_model.eval()\n",
        "'''"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ChcYKgQMmNjI",
      "metadata": {
        "id": "ChcYKgQMmNjI"
      },
      "source": [
        "## 4) Load GSM8K and build paraphrase classes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "CKweB9VZLnN2",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "referenced_widgets": [
            "a4518227f933454e83361482e7f5ecc1",
            "0cd442287f6f4520b5b5200c2f9758e3",
            "fe0c6587aee94e05a7bec607117b40b3",
            "859155a87335429c90010ff0d313c6bc",
            "858ab81b0e8742418817e79af380d4b3",
            "875250b80ba64b72a98c07e39a39a474",
            "5d81c35f053a43e8bf9c2fb40ab139eb",
            "9c51453c5d604047aec2aa5402aadf95",
            "5face42daebf44cbb6492e2657f034f0",
            "6eca88023c744926a3be2f7a28d77eca",
            "c5890ad41ae041c68ee0d150217e4245",
            "6ca09a109c37482fad60132dacf7f627",
            "1f53a10fd5244fa1b7ee5cf085b1c0db",
            "cb419fb7a1f84e598581226a5558c9a5",
            "c96bd872ae1b4b0e93a94a975e3add57",
            "76170c900810451381851319bd759780",
            "a717000b640b4351aede226cc9e6bcbc",
            "500b78f96bc74f1f95a22f526aad88e5",
            "a5ee3532f896434db2e7d7c7e00035fb",
            "705462646414442bbb77892f1c3aba30",
            "fd9ed38e06914c92a86023d7c7660304",
            "e89fa9ee7c6f4e94b8732c79e32c32b3",
            "9a21e4b6dee74fb4830a0c1103b3656b",
            "80644f0914fa4f10b831463170a2a3d8",
            "d1bcd0983a7f44298171ac6379bec20c",
            "094fc444f90949d193cfb80ff5810374",
            "54f317390144492a9b5991e5e6bec878",
            "0084018313a44300bcadddb278a45cab",
            "d0fc20b0afdf4abf919252ffb947c7a6",
            "a7d8dc630cac47909571bc3518f0af91",
            "b1c05bc0cbec4aecb1c47daff4780786",
            "0fa15707dd78402e94311d212cd4055c",
            "4eb99c2017ac46a88e68d0188997fdd6",
            "bf3c2eb9670640db8b482673a5828d2d",
            "bc7495bab4c449abb4491ffdfa0eb8ab",
            "6d855ebb31314cc88a702c53ba3e9062",
            "42c7e5f00cff4f39ac1c93aece0fd68c",
            "029c3e9269ce4a218a0f40867c4d6a03",
            "1b8551e2775447cbb01e95dbc687a8cc",
            "9ed581154c1244b2a2a26a974ca49631",
            "28c5646e07eb4bc99ae1c7f5626700fd",
            "02786c3cab8841f098e39e3ad6b8cc2b",
            "8276467f67c745588905fcb23411ac0f",
            "11bc6eb594db45109e2b4b31b901178f",
            "cf819f8b33384f9da580a5ba1b7984ae",
            "640fd24e01e54a3382480567db216dd8",
            "646e9ed927ad44ca8d012482319f33b5",
            "addca7f2ee384ceca6f112a599a645fe",
            "9f15c60d32f14a62b18056a1ddab3f43",
            "204bdd9a7cbb44eeb96023a43c950203",
            "87dcd3131094476aa0820406d2dcb45c",
            "e26f0d7f97074887a0fd2eb73adb0cfc",
            "2a9feac629564491905da452e00a627f",
            "24571193f2e74ff0af863b54ef390554",
            "c3f69e194e9e4383b62396bc70162936"
          ]
        },
        "id": "CKweB9VZLnN2",
        "outputId": "9391ab32-9d7f-4c2b-934f-b02468f5fb12"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "README.md: 0.00B [00:00, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "a4518227f933454e83361482e7f5ecc1"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "6ca09a109c37482fad60132dacf7f627"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "9a21e4b6dee74fb4830a0c1103b3656b"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "bf3c2eb9670640db8b482673a5828d2d"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "cf819f8b33384f9da580a5ba1b7984ae"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "ds = load_dataset(\"gsm8k\", \"main\")\n",
        "train_ds = ds[\"train\"]\n",
        "# split = ds[\"train\"].train_test_split(test_size=0.05, seed=42)\n",
        "# train_ds = split[\"train\"]\n",
        "# val_ds   = split[\"test\"]   # your “validation”\n",
        "test_ds  = ds[\"test\"]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "yWg7PjfnRTx7",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yWg7PjfnRTx7",
        "outputId": "79743745-d5e6-426a-fe3a-6f880ee583f8"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "DatasetDict({\n",
            "    train: Dataset({\n",
            "        features: ['question', 'answer'],\n",
            "        num_rows: 7473\n",
            "    })\n",
            "    test: Dataset({\n",
            "        features: ['question', 'answer'],\n",
            "        num_rows: 1319\n",
            "    })\n",
            "})\n"
          ]
        }
      ],
      "source": [
        "print (ds)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZyDraTzXfx9S",
      "metadata": {
        "id": "ZyDraTzXfx9S"
      },
      "outputs": [],
      "source": [
        "import unicodedata\n",
        "\n",
        "def to_ascii_numbers(text: str) -> str:\n",
        "    \"\"\"\n",
        "    Convert any Unicode numeric characters in `text` to ASCII digits.\n",
        "    - Handles full-width digits (e.g., １９３０ → 1930)\n",
        "    - Handles other numeric forms with integer values (e.g., ①, Ⅻ → 1, 12)\n",
        "    - Leaves everything else unchanged (no extra preprocessing)\n",
        "    \"\"\"\n",
        "    out_chars = []\n",
        "    for ch in text:\n",
        "        # 1) Decimal digits in any script (category 'Nd')\n",
        "        try:\n",
        "            d = unicodedata.digit(ch)    # 0..9\n",
        "            out_chars.append(str(d))\n",
        "            continue\n",
        "        except (TypeError, ValueError):\n",
        "            pass\n",
        "\n",
        "        # 2) Other Unicode numbers with an integer numeric value (e.g., circled digits, Roman numerals)\n",
        "        try:\n",
        "            n = unicodedata.numeric(ch)\n",
        "            if float(n).is_integer():\n",
        "                out_chars.append(str(int(n)))\n",
        "                continue\n",
        "        except (TypeError, ValueError):\n",
        "            pass\n",
        "\n",
        "        # 3) Otherwise, keep original character\n",
        "        out_chars.append(ch)\n",
        "    s = \"\".join(out_chars)\n",
        "    return s.removeprefix(\"Paraphrase: \")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "321VfgHiTxzV",
      "metadata": {
        "id": "321VfgHiTxzV"
      },
      "outputs": [],
      "source": [
        "# Speed knobs you can tune safely on a T4:\n",
        "PARA_PER_QUESTION = max(0, PARAPHRASES_PER_QUESTION - 1)  # e.g., 2\n",
        "BATCH_QUESTIONS = 64   # try 32–128 on T4; lower if you OOM\n",
        "MAX_NEW_TOKENS_PARAPHRASE = 128  # shorter is faster; you can try 16–32\n",
        "USE_GREEDY = True      # greedy is faster & cleaner for paraphrasing\n",
        "\n",
        "from tqdm.auto import tqdm\n",
        "import math, torch, numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZBryMmXepKoN",
      "metadata": {
        "id": "ZBryMmXepKoN"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import torch\n",
        "from transformers import StoppingCriteria, StoppingCriteriaList\n",
        "\n",
        "# no-op fallback if you don't have this util\n",
        "try:\n",
        "    to_ascii_numbers\n",
        "except NameError:\n",
        "    def to_ascii_numbers(s: str) -> str:\n",
        "        return s\n",
        "\n",
        "class StopOnTokenSeq(StoppingCriteria):\n",
        "    \"\"\"Stop when any stop string (tokenized) appears at the end of a row.\"\"\"\n",
        "    def __init__(self, tok, stop_strings):\n",
        "        self.tok = tok\n",
        "        self.stop_ids = [tok.encode(s, add_special_tokens=False) for s in stop_strings]\n",
        "        self.maxlen = max((len(x) for x in self.stop_ids), default=1)\n",
        "\n",
        "    def __call__(self, input_ids, scores, **kwargs):\n",
        "        for row in input_ids.tolist():\n",
        "            tail = row[-min(len(row), 6 * self.maxlen):]\n",
        "            for seq in self.stop_ids:\n",
        "                if len(tail) >= len(seq) and tail[-len(seq):] == seq:\n",
        "                    return True\n",
        "        return False\n",
        "\n",
        "\n",
        "def paraphrase_batch_chat(qs, k=1):\n",
        "    \"\"\"\n",
        "    Qwen2/2.5 chat-friendly batch paraphraser.\n",
        "    - Uses LEFT padding (required for decoder-only batched generation).\n",
        "    - Stops when the model starts the next user header.\n",
        "    - Decodes with specials skipped to avoid <|endoftext|> noise.\n",
        "    Returns: List[List[str]] of shape [len(qs)][k]\n",
        "    \"\"\"\n",
        "    assert k >= 1\n",
        "    sys = (\n",
        "      \"You paraphrase GSM8K math word problems.\\n\"\n",
        "      \"Rephrase without changing meaning.\\n\"\n",
        "      \"Preserve all names, quantities, units, and numbers exactly.\\n\"\n",
        "      \"Do NOT solve the problem.\\n\"\n",
        "      \"Output ONLY the paraphrased question, one line, no quotes, no prefix.\"\n",
        "    )\n",
        "\n",
        "    # --- Build prompts (once per question) ---\n",
        "    rendered = []\n",
        "    for q in qs:\n",
        "        msgs = [\n",
        "            {\"role\": \"system\", \"content\": sys},\n",
        "            {\"role\": \"user\", \"content\": f\"Paraphrase the following question.\\n{q}\"},\n",
        "        ]\n",
        "        rendered.append(tokenizer.apply_chat_template(\n",
        "            msgs, tokenize=False, add_generation_prompt=True\n",
        "        ))\n",
        "\n",
        "    # --- Encode with LEFT padding for decoder-only models ---\n",
        "    old_side = tokenizer.padding_side\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    try:\n",
        "        enc = tokenizer(rendered, return_tensors=\"pt\", padding=True, truncation=True).to(baseline_model.device)\n",
        "    finally:\n",
        "        tokenizer.padding_side = old_side\n",
        "\n",
        "    B = len(qs)\n",
        "    cut_idx = enc[\"input_ids\"].size(1)          # fixed padded width (same for all rows)\n",
        "    # For regrouping after num_return_sequences\n",
        "    # We don't need per-row prompt lengths when left-padding.\n",
        "\n",
        "    # --- Generation config (Qwen) ---\n",
        "    EOT_ID = tokenizer.eos_token_id\n",
        "    assert isinstance(EOT_ID, int) and EOT_ID >= 0, \"Could not resolve <|im_end|> id\"\n",
        "\n",
        "    # Respect external bad list, but make sure it doesn't block EOS\n",
        "    try:\n",
        "        bad_words_ids = bad if (isinstance(bad, list) and bad) else None\n",
        "        if bad_words_ids:\n",
        "            assert all(EOT_ID not in seq for seq in bad_words_ids), \"bad_words_ids blocks <|im_end|>\"\n",
        "    except NameError:\n",
        "        bad_words_ids = None\n",
        "\n",
        "    # # Stop if the model begins the next user message (Qwen header)\n",
        "    # stopping = StoppingCriteriaList([\n",
        "    #     StopOnTokenSeq(tokenizer, [\"<|im_start|>user\", \"\\nuser\"])\n",
        "    # ])\n",
        "\n",
        "    with torch.no_grad():\n",
        "        out = baseline_model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=MAX_NEW_TOKENS_PARAPHRASE,\n",
        "            min_new_tokens=0,                    # don't force past EOS\n",
        "            do_sample=False,\n",
        "            temperature=0.3,\n",
        "            top_p=0.3,\n",
        "            num_beams=1,\n",
        "            num_return_sequences=k,\n",
        "            bad_words_ids=bad_words_ids,\n",
        "            eos_token_id=EOT_ID,                 # Qwen chat EOT\n",
        "            pad_token_id=(tokenizer.pad_token_id or tokenizer.eos_token_id or EOT_ID),\n",
        "            # stopping_criteria=stopping,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "\n",
        "    # --- Slice off continuations (after full padded input) ---\n",
        "    seqs = out.sequences                                # [B*k, cut_idx + new_len]\n",
        "    cont_tokens = [seqs[row, cut_idx:] for row in range(B * k)]\n",
        "\n",
        "    # --- Decode: skip specials to hide <|endoftext|>/<|im_end|> ---\n",
        "    outs = tokenizer.batch_decode(cont_tokens, skip_special_tokens=True)\n",
        "\n",
        "    # --- Cleanup: drop a naked role header if it slipped through ---\n",
        "    ROLE_ONLY = re.compile(r'^(?:assistant|user)\\s*:?\\s*$', re.I)\n",
        "    def clean(s: str) -> str:\n",
        "        s = s.split(\"\\n\", 1)[0].strip()\n",
        "        if ROLE_ONLY.match(s):\n",
        "            s = \"\"\n",
        "        return to_ascii_numbers(s)\n",
        "\n",
        "    outs = [clean(o) for o in outs]                     # len = B*k\n",
        "\n",
        "    # Optional fallback: if empty, use original question text\n",
        "    per_q, ptr = [], 0\n",
        "    for q in qs:\n",
        "        group = outs[ptr:ptr + k]\n",
        "        group = [g if g else q for g in group]\n",
        "        per_q.append(group)\n",
        "        ptr += k\n",
        "\n",
        "    return per_q"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bPEoICWTndw_",
      "metadata": {
        "id": "bPEoICWTndw_"
      },
      "outputs": [],
      "source": [
        "# --- Answer generation for Qwen-style chat, batched & left-padded ---\n",
        "\n",
        "import re, math, torch\n",
        "from typing import List, Tuple\n",
        "\n",
        "# Compile once\n",
        "_ROLE_ONLY = re.compile(r'^(?:assistant|user)\\s*:?\\s*$', re.I)\n",
        "SYSTEM_PROMPT = (\n",
        "                  \"You solve grade-school math word problems.\\n\"\n",
        "                  \"Think step by step. Keep the reasoning brief.\\n\"\n",
        "                  \"Then repeat the final answer on a new line in exactly this format:\\n\"\n",
        "                  \"#### <number>\"\n",
        "                )\n",
        "\n",
        "def _clean_one_line(s: str) -> str:\n",
        "    s = s.split(\"\\n\", 1)[0].strip()\n",
        "    return \"\" if _ROLE_ONLY.match(s) else s\n",
        "\n",
        "def gsm8k_messages(question: str,\n",
        "                   *,\n",
        "                   system_prompt = SYSTEM_PROMPT):\n",
        "    return [\n",
        "        {\"role\": \"system\", \"content\": system_prompt},\n",
        "        {\n",
        "            \"role\": \"user\",\n",
        "            \"content\": f\"{question}\\n\\nSolve step by step, then give the final answer.\",\n",
        "        },\n",
        "    ]\n",
        "\n",
        "\n",
        "def format_prompt_instruct(q: str,\n",
        "                           *,\n",
        "                           system_prompt = SYSTEM_PROMPT,\n",
        "                           tokenizer=None) -> str:\n",
        "    \"\"\"\n",
        "    Render a GSM8K-style chat prompt string using gsm8k_messages().\n",
        "    \"\"\"\n",
        "    if tokenizer is None:\n",
        "        tokenizer = globals()[\"tokenizer\"]\n",
        "\n",
        "    messages = gsm8k_messages(q, system_prompt=system_prompt)\n",
        "\n",
        "    return tokenizer.apply_chat_template(\n",
        "        messages,\n",
        "        tokenize=False,\n",
        "        add_generation_prompt=True\n",
        "    )\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8jHm5ZuAHSaM",
      "metadata": {
        "id": "8jHm5ZuAHSaM"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "from typing import List, Tuple, Optional\n",
        "\n",
        "import torch\n",
        "\n",
        "\n",
        "def _extract_gsm8k_final(text: str) -> str:\n",
        "    \"\"\"\n",
        "    Find the last line that begins with ### (allow leading spaces),\n",
        "    and return the number immediately following ###.\n",
        "    If not found or no number after ###, return \"NA\".\n",
        "    \"\"\"\n",
        "    for line in reversed(text.splitlines()):\n",
        "        # number must immediately follow ### (allow spaces), then end or whitespace\n",
        "        m = re.match(r\"^\\s*####\\s*([-+]?\\d+(?:\\.\\d+)?)(?:\\s|$)\", line)\n",
        "        if m:\n",
        "            return m.group(1)\n",
        "        # if we see a ### line with no number, treat as missing and stop\n",
        "        if re.match(r\"^\\s*####(?:\\s|$)\", line):\n",
        "            return \"NA\"\n",
        "    return \"NA\"\n",
        "\n",
        "\n",
        "def answer_batch_chat(\n",
        "    qs: List[str],\n",
        "    *,\n",
        "    model=None,\n",
        "    tokenizer=None,\n",
        "    max_new_tokens: int = MAX_NEW_TOKENS,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        "    use_greedy: bool = False,\n",
        ") -> List[str]:\n",
        "    \"\"\"\n",
        "    Generate one answer per input question, in a single batched call.\n",
        "    Uses format_prompt_instruct() to render the chat prompt string.\n",
        "    \"\"\"\n",
        "    # Fallback to globals if not provided\n",
        "    if model is None:\n",
        "        model = baseline_model\n",
        "    if tokenizer is None:\n",
        "        tokenizer = globals()[\"tokenizer\"]\n",
        "\n",
        "    # --- Build rendered prompts using format_prompt_instruct ---\n",
        "    rendered = [\n",
        "        format_prompt_instruct(q, tokenizer=tokenizer)\n",
        "        for q in qs\n",
        "    ]\n",
        "\n",
        "    # --- Encode with LEFT padding ---\n",
        "    old_side = tokenizer.padding_side\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    try:\n",
        "        enc = tokenizer(\n",
        "            rendered,\n",
        "            return_tensors=\"pt\",\n",
        "            padding=True,\n",
        "            truncation=True,\n",
        "        ).to(next(model.parameters()).device)\n",
        "    finally:\n",
        "        tokenizer.padding_side = old_side\n",
        "\n",
        "    cut_idx = enc[\"input_ids\"].size(1)\n",
        "\n",
        "    # --- EOS / PAD handling ---\n",
        "    eos_id = tokenizer.eos_token_id\n",
        "    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id\n",
        "\n",
        "    # --- Generation ---\n",
        "    with torch.inference_mode():\n",
        "        out = model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            min_new_tokens=0,\n",
        "            do_sample=not use_greedy,\n",
        "            temperature=(temperature if not use_greedy else None),\n",
        "            top_p=(top_p if not use_greedy else None),\n",
        "            num_beams=1,\n",
        "            num_return_sequences=1,\n",
        "            eos_token_id=eos_id,\n",
        "            pad_token_id=pad_id,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "\n",
        "    # --- Slice off the prompt and decode just the continuation ---\n",
        "    seqs = out.sequences\n",
        "    cont = [seqs[i, cut_idx:] for i in range(len(qs))]\n",
        "    outs = tokenizer.batch_decode(cont, skip_special_tokens=True)\n",
        "    # --- Cleanup: prefer GSM8K final answer if present, else one-line fallback ---\n",
        "    cleaned = []\n",
        "    for o in outs:\n",
        "        s = o.strip()\n",
        "\n",
        "        # If you're using GSM8K prompting, this will pull the final numeric answer.\n",
        "        gsm = _extract_gsm8k_final(s)\n",
        "        if gsm is not None:\n",
        "            cleaned.append(gsm)\n",
        "            continue\n",
        "    return cleaned\n",
        "\n",
        "\n",
        "def answer_originals_and_paraphrases(\n",
        "    qs: List[str],\n",
        "    paras_per_q: List[List[str]],\n",
        "    *,\n",
        "    micro_bs: int = 64,\n",
        "    model=None,\n",
        "    tokenizer=None,\n",
        "    max_new_tokens: int = MAX_NEW_TOKENS,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        "    use_greedy: bool = False,\n",
        ") -> Tuple[List[str], List[List[str]]]:\n",
        "    \"\"\"\n",
        "    Answers each original question and each of its paraphrases.\n",
        "    Returns (answers_for_originals, answers_for_paraphrases)\n",
        "    where answers_for_paraphrases[i] aligns with paras_per_q[i].\n",
        "    Uses micro-batching to avoid OOM.\n",
        "    \"\"\"\n",
        "\n",
        "    def _batched_answer(batch: List[str]) -> List[str]:\n",
        "        return answer_batch_chat(\n",
        "            batch,\n",
        "            model=model,\n",
        "            tokenizer=tokenizer,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            temperature=temperature,\n",
        "            top_p=top_p,\n",
        "            use_greedy=use_greedy,\n",
        "        )\n",
        "\n",
        "    # 1) Answers for originals\n",
        "    ans_orig: List[str] = []\n",
        "    for i in range(0, len(qs), micro_bs):\n",
        "        ans_orig.extend(_batched_answer(qs[i : i + micro_bs]))\n",
        "\n",
        "    # 2) Answers for all paraphrases (flatten, answer, then regroup)\n",
        "    flat_paras = [p for group in paras_per_q for p in group]\n",
        "    ans_paras_flat: List[str] = []\n",
        "    for i in range(0, len(flat_paras), micro_bs):\n",
        "        ans_paras_flat.extend(_batched_answer(flat_paras[i : i + micro_bs]))\n",
        "\n",
        "    # Regroup by question\n",
        "    ans_paras: List[List[str]] = []\n",
        "    ptr = 0\n",
        "    for group in paras_per_q:\n",
        "        n = len(group)\n",
        "        ans_paras.append(ans_paras_flat[ptr : ptr + n])\n",
        "        ptr += n\n",
        "\n",
        "    return ans_orig, ans_paras\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6hWe785vn5Uw",
      "metadata": {
        "id": "6hWe785vn5Uw"
      },
      "outputs": [],
      "source": [
        "def _to_text(x):\n",
        "    # Accept str, list[str], dict-like {\"text\": \"...\"} / {\"answer\": \"...\"} / [{\"text\": ...}]\n",
        "    if x is None:\n",
        "        return \"\"\n",
        "    if isinstance(x, str):\n",
        "        return x\n",
        "    if isinstance(x, (int, float)):\n",
        "        return str(x)\n",
        "    if isinstance(x, list):\n",
        "        # join multiple references; tweak as you like\n",
        "        return \"; \".join(_to_text(xx) for xx in x if xx is not None)\n",
        "    if isinstance(x, dict):\n",
        "        for k in (\"text\", \"answer\", \"final\", \"final_answer\", \"label\", \"gold\", \"target\"):\n",
        "            if k in x:\n",
        "                return _to_text(x[k])\n",
        "        # fallback: stringify dict\n",
        "        return str(x)\n",
        "    return str(x)\n",
        "\n",
        "def fetch_ground_truths(train_ds, idxs):\n",
        "    gts = []\n",
        "    for i in idxs:\n",
        "        rec = train_ds[i]\n",
        "        # Try common fields in order\n",
        "        val = None\n",
        "        for key in (\"answer\", \"answers\", \"final_answer\", \"final\", \"label\", \"target\"):\n",
        "            if key in rec:\n",
        "                val = rec[key]; break\n",
        "        gts.append(_to_text(val))\n",
        "    return gts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "t_pao7CWT7Kg",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "t_pao7CWT7Kg",
        "outputId": "fc791bfb-0601-489a-ad66-d84e29fc7e25"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\nPARAPHRASES_PER_QUESTION_tmp = 2\\nk_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\\n\\n# Keep the indices so we can fetch ground truths\\nqs = []\\nidxs = [0]\\nfor idx in idxs:\\n  qs.append(train_ds[idx][\"question\"])\\n\\nparas_per_q = paraphrase_batch_chat(qs, k=k_tmp)\\n\\n# --- NEW: get model answers ---\\nanswers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\\n\\n# --- NEW: fetch ground-truths for the originals ---\\ngt_for_qs = fetch_ground_truths(train_ds, idxs)\\n\\n# ---- Print nicely ----\\nprint(\"=== Test inputs ===\")\\nfor i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\\n    print(f\"[Q{i} @ idx={idx}] {q}\")\\n    print(f\"    Ground truth: {gt}\")\\n\\nprint(\"\\n=== Paraphrases + Answers ===\")\\nfor i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\\n    print(f\"[Q{i}]\")\\n    print(f\"  Original answer: {a_q}\")\\n    print(f\"  Ground truth   : {gt}\")\\n    if not paras:\\n        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\\n    else:\\n        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\\n            print(f\"  {j}. {p}\")\\n            print(f\"     ↳ Answer: {ap}\")\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 24
        }
      ],
      "source": [
        "# --- your snippet (kept) ---\n",
        "'''\n",
        "PARAPHRASES_PER_QUESTION_tmp = 2\n",
        "k_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\n",
        "\n",
        "# Keep the indices so we can fetch ground truths\n",
        "qs = []\n",
        "idxs = [0]\n",
        "for idx in idxs:\n",
        "  qs.append(train_ds[idx][\"question\"])\n",
        "\n",
        "paras_per_q = paraphrase_batch_chat(qs, k=k_tmp)\n",
        "\n",
        "# --- NEW: get model answers ---\n",
        "answers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\n",
        "\n",
        "# --- NEW: fetch ground-truths for the originals ---\n",
        "gt_for_qs = fetch_ground_truths(train_ds, idxs)\n",
        "\n",
        "# ---- Print nicely ----\n",
        "print(\"=== Test inputs ===\")\n",
        "for i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\n",
        "    print(f\"[Q{i} @ idx={idx}] {q}\")\n",
        "    print(f\"    Ground truth: {gt}\")\n",
        "\n",
        "print(\"\\n=== Paraphrases + Answers ===\")\n",
        "for i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\n",
        "    print(f\"[Q{i}]\")\n",
        "    print(f\"  Original answer: {a_q}\")\n",
        "    print(f\"  Ground truth   : {gt}\")\n",
        "    if not paras:\n",
        "        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\n",
        "    else:\n",
        "        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\n",
        "            print(f\"  {j}. {p}\")\n",
        "            print(f\"     ↳ Answer: {ap}\")\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "lejUKL25Sf3U",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lejUKL25Sf3U",
        "outputId": "78917c78-86f2-4b8c-b56d-9d0e9c4715cb"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CACHE_PATH:  /content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl\n"
          ]
        }
      ],
      "source": [
        "import os, json, time, hashlib, pathlib\n",
        "\n",
        "\n",
        "print(\"CACHE_PATH: \", CACHE_PATH)\n",
        "\n",
        "import os, time\n",
        "\n",
        "from pathlib import Path\n",
        "def _ensure_parent_dir(path: str | Path) -> Path:\n",
        "    p = Path(path)\n",
        "    p.parent.mkdir(parents=True, exist_ok=True)\n",
        "    print (\"Making directory: \", path)\n",
        "    return p\n",
        "\n",
        "\n",
        "def save_paraphrase_cache(path, classes_all, classes_usable, meta: dict):\n",
        "    path = _ensure_parent_dir(path)                    # NEW: ensure folder exists\n",
        "    meta = dict(meta)\n",
        "    meta[\"created_utc\"] = int(time.time())\n",
        "    meta[\"hash\"] = _meta_hash(**meta)\n",
        "\n",
        "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta}) + \"\\n\")\n",
        "        for i, cls in enumerate(classes_all):\n",
        "            rec = {\"idx\": i, \"size\": len(cls), \"prompts\": cls, \"usable\": int(len(cls) >= 2)}\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    print(f\"Saved cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "\n",
        "\n",
        "import uuid, os\n",
        "\n",
        "def save_paraphrase_cache_atomic(path, classes_all, classes_usable, meta):\n",
        "    p = _ensure_parent_dir(path)                       # NEW\n",
        "    tmp = p.with_name(p.name + f\".tmp-{uuid.uuid4().hex}\")  # same dir → safe rename\n",
        "    save_paraphrase_cache(tmp, classes_all, classes_usable, meta)\n",
        "    os.replace(tmp, p)                                 # atomic on POSIX\n",
        "\n",
        "\n",
        "def meta_compatible(current_meta: dict, cached_meta: dict | None) -> bool:\n",
        "    if not cached_meta:\n",
        "        return False\n",
        "    # Be strict about the knobs that affect outputs\n",
        "    keys = [\n",
        "        \"train_len\", \"per_class\", \"paraphraser_model\",\n",
        "        \"top_p\", \"top_k\", \"temperature\", \"max_new_tokens\",\n",
        "        \"dedup_thresh\", \"embedder\",\n",
        "    ]\n",
        "    return all(current_meta.get(k) == cached_meta.get(k) for k in keys)\n",
        "\n",
        "\n",
        "def _meta_hash(**kw):\n",
        "    # tie cache to settings that affect paraphrases\n",
        "    key = json.dumps(kw, sort_keys=True).encode()\n",
        "    return hashlib.sha1(key).hexdigest()[:10]\n",
        "\n",
        "def save_paraphrase_cache(path, classes_all, classes_usable, meta: dict):\n",
        "    path = pathlib.Path(path)\n",
        "    meta = dict(meta)\n",
        "    meta[\"created_utc\"] = int(time.time())\n",
        "    meta[\"hash\"] = _meta_hash(**meta)\n",
        "\n",
        "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta}) + \"\\n\")\n",
        "        # classes_all is a list[list[str]]\n",
        "        for i, cls in enumerate(classes_all):\n",
        "            rec = {\n",
        "                \"idx\": i,\n",
        "                \"size\": len(cls),\n",
        "                \"prompts\": cls,\n",
        "                \"usable\": int(len(cls) >= 2)\n",
        "            }\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    print(f\"Saved cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "\n",
        "def load_paraphrase_cache(path, require_hash: str | None = None):\n",
        "    path = pathlib.Path(path)\n",
        "    if not path.exists():\n",
        "        return None, None, None\n",
        "    classes_all, classes_usable = [], []\n",
        "    meta = None\n",
        "    with path.open(\"r\", encoding=\"utf-8\") as f:\n",
        "        for line in f:\n",
        "            obj = json.loads(line)\n",
        "            if \"__meta__\" in obj:\n",
        "                meta = obj[\"__meta__\"]\n",
        "                continue\n",
        "            cls = obj[\"prompts\"]\n",
        "            classes_all.append(cls)\n",
        "            if obj.get(\"usable\") and len(cls) >= 2:\n",
        "                classes_usable.append(cls)\n",
        "    if require_hash and meta and meta.get(\"hash\") != require_hash:\n",
        "        print(\"⚠️ Cache meta hash mismatch; rebuilding may be safer.\")\n",
        "    print(f\"Loaded cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "    return classes_all, classes_usable, meta"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B07v0ZOKuI3F",
      "metadata": {
        "id": "B07v0ZOKuI3F"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "def build_classes_full_batched(\n",
        "    per_class=PARAPHRASES_PER_QUESTION,\n",
        "    dedup_thresh=0.98,\n",
        "    do_dedup=False,   # <- new flag\n",
        "    cache_path: str | None = None,\n",
        "    flush_every: int = 1000,\n",
        "    generate_fn=paraphrase_batch_chat,\n",
        "):\n",
        "    \"\"\"\n",
        "    Process the ENTIRE train_ds using big, batched generation calls.\n",
        "    Periodically checkpoints to `cache_path` and resumes if compatible.\n",
        "    Returns (classes_all, classes_usable).\n",
        "    \"\"\"\n",
        "    k = max(0, per_class - 1)\n",
        "    n_total = len(train_ds)\n",
        "\n",
        "    # --- try to resume from a compatible cache ---\n",
        "    start_at = 0\n",
        "    classes_all, classes_usable = [], []\n",
        "    cached = (None, None, None)\n",
        "\n",
        "    if cache_path:\n",
        "        try:\n",
        "            cached = load_paraphrase_cache(cache_path)\n",
        "        except Exception:\n",
        "            cached = (None, None, None)\n",
        "\n",
        "    cached_all, cached_usable, cached_meta = cached\n",
        "    if (\n",
        "        cache_path\n",
        "        and cached_all is not None\n",
        "        and 0 < len(cached_all) <= n_total\n",
        "        and meta_compatible(META, cached_meta)\n",
        "    ):\n",
        "        classes_all = list(cached_all)\n",
        "        classes_usable = list(cached_usable or [])\n",
        "        start_at = len(classes_all)\n",
        "        print(f\"[Resume] Compatible cache found: {start_at}/{n_total} rows; resuming…\")\n",
        "    else:\n",
        "        if cache_path:\n",
        "            print(\"[Resume] No compatible partial cache; starting fresh.\")\n",
        "\n",
        "    # --- Put embedder on GPU if possible ---\n",
        "    try:\n",
        "        embedder.to(\"cuda\")\n",
        "    except Exception:\n",
        "        pass\n",
        "\n",
        "    pbar = tqdm(total=n_total, desc=\"Building (batched) classes\", leave=True)\n",
        "    pbar.update(start_at)\n",
        "    since_last_flush = 0\n",
        "\n",
        "    try:\n",
        "        for start in range(start_at, n_total, BATCH_QUESTIONS):\n",
        "            end = min(start + BATCH_QUESTIONS, n_total)\n",
        "            batch = train_ds[start:end]\n",
        "            qs = batch[\"question\"] # <-- list of questions\n",
        "            # Fast column slice (HF datasets return list for a column slice)\n",
        "            # qs = train_ds[\"question\"][start:end]\n",
        "\n",
        "            # ---- fast batched paraphrase generation ----\n",
        "            if k > 0:\n",
        "                paras_per_q = generate_fn(qs, k=k)\n",
        "            else:\n",
        "                paras_per_q = [[] for _ in qs]\n",
        "\n",
        "            # ---- dedup per question (batched embeddings for speed) ----\n",
        "            for i, q in enumerate(qs):\n",
        "                paras = [q] + paras_per_q[i]\n",
        "\n",
        "                if do_dedup:\n",
        "                    # evaluate paraphrases by cosine value in embedding space. Discard those that are not novel enough.\n",
        "                    embs = embedder.encode(\n",
        "                        paras, convert_to_tensor=True,\n",
        "                        normalize_embeddings=True, batch_size=64\n",
        "                    )\n",
        "\n",
        "\n",
        "                    # evaluate paraphrases by cosine value in embedding space. Discard those that are not novel enough.\n",
        "                    keep = [0]\n",
        "                    for j in range(1, len(paras)):\n",
        "                        sims = sbert_util.cos_sim(embs[j], embs[keep]).cpu().numpy().flatten()\n",
        "                        if float(np.max(sims)) < dedup_thresh:\n",
        "                            keep.append(j)\n",
        "\n",
        "                    cls = [paras[j] for j in keep]\n",
        "                else:\n",
        "                    cls = paras[:]  # original + all paraphrases, in order, no discard\n",
        "\n",
        "                classes_all.append(cls)\n",
        "                if len(cls) >= 2:\n",
        "                    classes_usable.append(cls)\n",
        "\n",
        "            processed = end - start\n",
        "            since_last_flush += processed\n",
        "            pbar.update(processed)\n",
        "            pbar.set_postfix({\"usable\": len(classes_usable)})\n",
        "\n",
        "            # ---- periodic checkpoint ----\n",
        "            if cache_path and since_last_flush >= flush_every:\n",
        "                save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "                since_last_flush = 0\n",
        "\n",
        "        # final save\n",
        "        if cache_path:\n",
        "            save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "\n",
        "    except KeyboardInterrupt:\n",
        "        if cache_path:\n",
        "            print(\"\\n[Interrupt] Saving partial progress…\")\n",
        "            save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "        raise\n",
        "    finally:\n",
        "        pbar.close()\n",
        "\n",
        "    return classes_all, classes_usable"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QM6pwSm3TAyW",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QM6pwSm3TAyW",
        "outputId": "008c8c98-b045-47ad-fab8-44bc07781734"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loaded cache: /content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl  (7473 classes, usable=7473)\n",
            "Total classes: 7473; usable: 7473\n"
          ]
        }
      ],
      "source": [
        "# Describe the paraphrase-generation settings so caches are consistent\n",
        "META = {\n",
        "    \"train_len\": len(train_ds),\n",
        "    \"per_class\": PARAPHRASES_PER_QUESTION,\n",
        "    \"paraphraser_model\": BASELINE_MODEL_NAME,\n",
        "    \"top_p\": TOP_P, \"top_k\": TOP_K, \"temperature\": GEN_TEMPERATURE,\n",
        "    \"max_new_tokens\": 48,\n",
        "    \"dedup_thresh\": 0.98,\n",
        "    \"embedder\": \"sentence-transformers/all-MiniLM-L6-v2\",\n",
        "}\n",
        "\n",
        "# Try to load cache first\n",
        "paraphrase_classes_all, paraphrase_classes, meta = load_paraphrase_cache(CACHE_PATH)\n",
        "need_build = paraphrase_classes_all is None or len(paraphrase_classes_all) != len(train_ds)\n",
        "\n",
        "if need_build:\n",
        "    print(\"No valid cache found — building now…\")\n",
        "    _ensure_parent_dir(CACHE_PATH)\n",
        "    paraphrase_classes_all, paraphrase_classes = build_classes_full_batched(\n",
        "        per_class=PARAPHRASES_PER_QUESTION,\n",
        "        dedup_thresh=0.98,\n",
        "        cache_path=CACHE_PATH,      # enables resume & periodic saves\n",
        "        flush_every=1000,           # save every 1000 rows (tweak as you like)\n",
        "        generate_fn=paraphrase_batch_chat,\n",
        "    )\n",
        "else:\n",
        "    # Optional: check compatibility (hash) if you want strict matching\n",
        "    pass\n",
        "\n",
        "print(f\"Total classes: {len(paraphrase_classes_all)}; usable: {len(paraphrase_classes)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GSdAyKG47uP8",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GSdAyKG47uP8",
        "outputId": "7c83bd5f-287e-478b-81a8-61d6376cc8da"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\nPARAPHRASES_PER_QUESTION_tmp = 2\\nk_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\\n\\n# Keep the indices so we can fetch ground truths\\nqs = []\\nidxs = [1000, 1300]\\nparas_per_q = []\\nfor idx in idxs:\\n  qs.append(train_ds[idx][\"question\"])\\n  paras_per_q.append([paraphrase_classes[idx][1]])\\nprint(qs)\\nprint(paras_per_q)\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 28
        }
      ],
      "source": [
        "# --- your snippet (kept) ---\n",
        "'''\n",
        "PARAPHRASES_PER_QUESTION_tmp = 2\n",
        "k_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\n",
        "\n",
        "# Keep the indices so we can fetch ground truths\n",
        "qs = []\n",
        "idxs = [1000, 1300]\n",
        "paras_per_q = []\n",
        "for idx in idxs:\n",
        "  qs.append(train_ds[idx][\"question\"])\n",
        "  paras_per_q.append([paraphrase_classes[idx][1]])\n",
        "print(qs)\n",
        "print(paras_per_q)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "gdXynjcC9BGw",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gdXynjcC9BGw",
        "outputId": "b2d8c1d1-dadd-4b85-ae8d-e854c1881a5b"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\n# --- NEW: get model answers ---\\nanswers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\\n\\n# --- NEW: fetch ground-truths for the originals ---\\ngt_for_qs = fetch_ground_truths(train_ds, idxs)\\n\\n# ---- Print nicely ----\\nprint(\"=== Test inputs ===\")\\nfor i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\\n    print(f\"[Q{i} @ idx={idx}] {q}\")\\n    print(f\"    Ground truth: {gt}\")\\n\\nprint(\"\\n=== Paraphrases + Answers ===\")\\nfor i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\\n    print(f\"[Q{i}]\")\\n    print(f\"  Original answer: {a_q}\")\\n    print(f\"  Ground truth   : {gt}\")\\n    if not paras:\\n        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\\n    else:\\n        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\\n            print(f\"  {j}. {p}\")\\n            print(f\"     ↳ Answer: {ap}\")\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 29
        }
      ],
      "source": [
        "'''\n",
        "# --- NEW: get model answers ---\n",
        "answers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\n",
        "\n",
        "# --- NEW: fetch ground-truths for the originals ---\n",
        "gt_for_qs = fetch_ground_truths(train_ds, idxs)\n",
        "\n",
        "# ---- Print nicely ----\n",
        "print(\"=== Test inputs ===\")\n",
        "for i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\n",
        "    print(f\"[Q{i} @ idx={idx}] {q}\")\n",
        "    print(f\"    Ground truth: {gt}\")\n",
        "\n",
        "print(\"\\n=== Paraphrases + Answers ===\")\n",
        "for i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\n",
        "    print(f\"[Q{i}]\")\n",
        "    print(f\"  Original answer: {a_q}\")\n",
        "    print(f\"  Ground truth   : {gt}\")\n",
        "    if not paras:\n",
        "        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\n",
        "    else:\n",
        "        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\n",
        "            print(f\"  {j}. {p}\")\n",
        "            print(f\"     ↳ Answer: {ap}\")\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "SfSfLvyWQ9aB",
      "metadata": {
        "id": "SfSfLvyWQ9aB"
      },
      "outputs": [],
      "source": [
        "FLAG_DEBUG = True"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ywbhUshonDfJ",
      "metadata": {
        "id": "ywbhUshonDfJ"
      },
      "source": [
        "## Evaluations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B5VVKuJrDVKo",
      "metadata": {
        "id": "B5VVKuJrDVKo"
      },
      "outputs": [],
      "source": [
        "import re, math\n",
        "import torch\n",
        "from typing import Dict, Any\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "\n",
        "def _wilson_ci(k: int, n: int, level: float = 0.95):\n",
        "    if n == 0:\n",
        "        return (0.0, 0.0)\n",
        "    z = 1.959963984540054  # 95%\n",
        "    phat = k / n\n",
        "    denom = 1 + (z*z)/n\n",
        "    center = (phat + (z*z)/(2*n)) / denom\n",
        "    half = (z / denom) * math.sqrt((phat*(1-phat) + (z*z)/(4*n)) / n)\n",
        "    return max(0.0, center - half), min(1.0, center + half)\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_gsm8k(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"eval\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,   # number of sample Q/pred/gold to print total\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Evaluate GSM8K with a tqdm progress bar.\n",
        "\n",
        "    Prints:\n",
        "      - tqdm progress bar over batches\n",
        "      - optional few sample outputs (total across run)\n",
        "      - final ACC + 95% CI\n",
        "\n",
        "    Returns result dict with preds + sample + ACC/CI.\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    N_total = len(dataset)\n",
        "    N = min(int(limit), N_total) if limit is not None else N_total\n",
        "\n",
        "    sampled_indices = random.sample(range(N_total), k=N)\n",
        "    qs, gold_list, preds = [], [], []\n",
        "\n",
        "    # Collect questions + golds\n",
        "    for i in sampled_indices:\n",
        "        ex = dataset[i]\n",
        "        qs.append(ex[\"question\"])\n",
        "        gold_list.append(_extract_gsm8k_final(ex.get(\"answer\", \"\")))\n",
        "\n",
        "    num_batches = (N + batch_size - 1) // batch_size\n",
        "    pbar = tqdm(total=num_batches, desc=f\"Evaluating {name}\", dynamic_ncols=True)\n",
        "\n",
        "    printed = 0\n",
        "    for b, start in enumerate(range(0, N, batch_size), start=1):\n",
        "        end = min(start + batch_size, N)\n",
        "        batch_qs = qs[start:end]\n",
        "\n",
        "        batch_preds = answer_batch_chat(\n",
        "            batch_qs,\n",
        "            model=model,\n",
        "            tokenizer=tokenizer,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            use_greedy=use_greedy,\n",
        "        )\n",
        "        preds.extend(batch_preds)\n",
        "\n",
        "        # Update running accuracy for the bar\n",
        "        # (Compute correctness only for this batch)\n",
        "        batch_correct = 0\n",
        "        for j, p in enumerate(batch_preds):\n",
        "            g = gold_list[start + j]\n",
        "            p = (p or \"\").strip()\n",
        "            g = (g or \"\").strip()\n",
        "            if p == g and p != \"NA\" and g != \"NA\":\n",
        "                batch_correct += 1\n",
        "\n",
        "        # Running totals (avoid O(N^2) by tracking separately)\n",
        "        # We'll compute final ACC at the end accurately anyway.\n",
        "        # Use displayed running accuracy as a convenience.\n",
        "        # Derive running correct by re-counting cheaply (still fine at GSM8K scale),\n",
        "        # but better track it incrementally:\n",
        "        if b == 1:\n",
        "            running_correct = batch_correct\n",
        "        else:\n",
        "            running_correct += batch_correct\n",
        "        running_n = len(preds)\n",
        "        running_acc = running_correct / running_n if running_n else 0.0\n",
        "\n",
        "        pbar.set_postfix({\n",
        "            \"N\": running_n,\n",
        "            \"ACC%\": f\"{running_acc*100:.2f}\",\n",
        "            \"greedy\": use_greedy,\n",
        "        })\n",
        "        pbar.update(1)\n",
        "\n",
        "        # Optional: print a few samples total across run\n",
        "        if show_samples > 0 and printed < show_samples:\n",
        "            k = min(show_samples - printed, len(batch_qs))\n",
        "            for j in range(k):\n",
        "                idx = start + j\n",
        "                q_preview = batch_qs[j].replace(\"\\n\", \" \")\n",
        "                if len(q_preview) > 160:\n",
        "                    q_preview = q_preview[:160] + \"...\"\n",
        "                pbar.write(f\"[{name}] idx={idx} Q: {q_preview}\")\n",
        "                pbar.write(f\"[{name}]        pred: {batch_preds[j]}\")\n",
        "                pbar.write(f\"[{name}]        gold: {gold_list[idx]}\")\n",
        "                printed += 1\n",
        "\n",
        "    pbar.close()\n",
        "\n",
        "    # Final exact-match accuracy\n",
        "    correct = 0\n",
        "    pred_na = 0\n",
        "\n",
        "    for p, g in zip(preds, gold_list):\n",
        "        p = (p or \"\").strip()\n",
        "        g = (g or \"\").strip()\n",
        "\n",
        "        if p == \"NA\":\n",
        "            pred_na += 1\n",
        "\n",
        "        if p == g and p != \"NA\" and g != \"NA\":\n",
        "            correct += 1\n",
        "\n",
        "    n = len(preds)\n",
        "    acc_mean = correct / n if n else 0.0\n",
        "    acc_lo, acc_hi = _wilson_ci(correct, n, level=0.95)\n",
        "\n",
        "    pred_na_frac = (pred_na / n) if n else 0.0\n",
        "\n",
        "    print(\n",
        "        f\"[{name}] N={n}  \"\n",
        "        f\"ACC: {acc_mean*100:.2f} and ACC confidence \"\n",
        "        f\"[{acc_lo*100:.2f}, {acc_hi*100:.2f}]  |  \"\n",
        "        f\"pred NA: {pred_na}/{n} ({pred_na_frac*100:.2f}%)  \"\n",
        "    )\n",
        "\n",
        "    model.train(was_training)\n",
        "    return {\n",
        "        \"preds\": preds,\n",
        "        \"sample\": {\n",
        "            \"indices\": sampled_indices,\n",
        "            \"questions\": qs,\n",
        "            \"golds\": gold_list,\n",
        "        },\n",
        "        \"N\": len(preds),\n",
        "        \"ACC\": acc_mean,\n",
        "        \"ACC_CI\": (acc_lo, acc_hi),\n",
        "        \"ACC_CI_Level\": 0.95,\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "h9Juy1GI6_9U",
      "metadata": {
        "colab": {
          "background_save": true,
          "base_uri": "https://localhost:8080/",
          "height": 85,
          "referenced_widgets": [
            "dcdbcdeb464c4196857ff0f15736d8a3",
            "55a6d733b99c453eae69dee18d012599",
            "882ce08f63c24d5b9c537defb111268e",
            "ee29db4bfc0d42b2967ca9e672b9d3b7",
            "8136b4093a9b4d4eaf23bfb192fc9c90",
            "8221da52e41e4c5e8bc1f58179bc64a5",
            "c699c242e1ba4895850f422391292ed2",
            "32b63310775b4f598fd9aae36d7177d8",
            "e8b12a9c33ac4411b00e002e4a1c37ad",
            "57c61c01579a49659201a22d55ed73fa",
            "8a3672d2b9bb41af88d8cfc7a7333b7a"
          ]
        },
        "id": "h9Juy1GI6_9U",
        "outputId": "8120bc23-af37-40e3-dc56-25d0ff6497fd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Evaluation limit: 1319\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "dcdbcdeb464c4196857ff0f15736d8a3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating llama greedy baseline:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[llama greedy baseline] N=1319  ACC: 66.19 and ACC confidence [63.59, 68.69]  |  pred NA: 68/1319 (5.16%)  \n"
          ]
        }
      ],
      "source": [
        "\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"llama greedy baseline\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ypXVQeE_zA1V",
      "metadata": {
        "id": "ypXVQeE_zA1V"
      },
      "outputs": [],
      "source": [
        "# Source - https://stackoverflow.com/a\n",
        "# Posted by Ayo Reis\n",
        "# Retrieved 2025-12-30, License - CC BY-SA 4.0\n",
        "\n",
        "from google.colab import runtime\n",
        "runtime.unassign()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "BbHwE2vv1Xu5",
      "metadata": {
        "id": "BbHwE2vv1Xu5"
      },
      "source": [
        "### Best of N"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "C-BdL7ax1WXq",
      "metadata": {
        "id": "C-BdL7ax1WXq"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "from typing import Tuple\n",
        "\n",
        "def _mean_logprob_of_generated(\n",
        "    model, input_ids: torch.Tensor, attention_mask: torch.Tensor, input_len: int\n",
        ") -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Compute mean log-prob over generated tokens for each sequence in the batch.\n",
        "    input_ids: [B, T] full prompt+generation\n",
        "    attention_mask: [B, T]\n",
        "    input_len: scalar length of the prompt (assumed same across the batch)\n",
        "    Returns: [B] mean log-probability per sequence (higher is better).\n",
        "    \"\"\"\n",
        "    # Forward once to score tokens\n",
        "    out = model(input_ids=input_ids, attention_mask=attention_mask)\n",
        "    logits = out.logits  # [B, T, V]\n",
        "    logprobs = torch.log_softmax(logits, dim=-1)\n",
        "\n",
        "    # We want log p(y_t | x, y_<t>) for generated tokens.\n",
        "    # Token at position t (0-index) is predicted by logits at position t-1.\n",
        "    # So we gather from logprobs[:, :-1] using input_ids[:, 1:].\n",
        "    pred_logprobs = logprobs[:, :-1].gather(dim=-1, index=input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)  # [B, T-1]\n",
        "\n",
        "    B, T = input_ids.shape\n",
        "    gen_start = input_len  # first generated token index in input_ids\n",
        "    # Sequence lengths may differ due to early EOS; rely on attention_mask to know actual used tokens.\n",
        "    # We'll mask everything before gen_start and after the last real token.\n",
        "    valid_mask = attention_mask[:, 1:]  # [B, T-1] tokens that exist (shifted)\n",
        "    gen_mask = torch.zeros_like(valid_mask)\n",
        "    gen_mask[:, gen_start:] = 1  # mark generated region (shifted space)\n",
        "\n",
        "    mask = (valid_mask * gen_mask).to(pred_logprobs.dtype)  # [B, T-1]\n",
        "    # Sum and average over generated tokens actually present\n",
        "    token_counts = mask.sum(dim=-1).clamp_min(1.0)  # avoid /0\n",
        "    seq_logprob_sum = (pred_logprobs * mask).sum(dim=-1)  # [B]\n",
        "    mean_logprob = seq_logprob_sum / token_counts  # [B]\n",
        "    return mean_logprob  # higher is better (less negative)\n",
        "\n",
        "\n",
        "def _pick_first_line(text: str) -> str:\n",
        "    return text.strip().split(\"\\n\")[0].strip()\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def batched_generate_bestof(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    questions: List[str],\n",
        "    best_of: int = 8,\n",
        "    batch_size: int = 4,\n",
        "    max_new_tokens: int = 32,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        ") -> List[str]:\n",
        "    \"\"\"\n",
        "    Best-of-N sampling: for each question, draw `best_of` samples, score each by mean token log-prob,\n",
        "    and return the single highest-prob answer (first line) per question.\n",
        "    \"\"\"\n",
        "    assert best_of >= 1\n",
        "    model_device = next(model.parameters()).device\n",
        "    eos_id = tokenizer.eos_token_id\n",
        "    results = []\n",
        "\n",
        "    for i in range(0, len(questions), batch_size):\n",
        "        batch_q = questions[i:i+batch_size]\n",
        "        prompts = [format_prompt(q) for q in batch_q]\n",
        "        enc = tokenizer(prompts, padding=True, return_tensors=\"pt\").to(model_device)\n",
        "\n",
        "        input_len = enc[\"input_ids\"].shape[1]\n",
        "        # Sample best_of sequences per prompt in a single call\n",
        "        gen_out = model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            do_sample=True,\n",
        "            temperature=temperature,\n",
        "            top_p=top_p,\n",
        "            num_return_sequences=best_of,\n",
        "            eos_token_id=eos_id,\n",
        "            pad_token_id=eos_id,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "        # sequences: [B*best_of, input_len + gen_len<=]\n",
        "        seqs = gen_out.sequences  # token ids\n",
        "        # Build attention mask for the combined sequences\n",
        "        attn = (seqs != eos_id).long()\n",
        "        # But ensure prompt portion is always attended even if eos == pad\n",
        "        attn[:, :input_len] = 1\n",
        "\n",
        "        # Score each sample by mean log-prob over the generated continuation\n",
        "        mean_lp = _mean_logprob_of_generated(model, seqs, attn, input_len)  # [B*best_of]\n",
        "\n",
        "        # Group back per original example\n",
        "        B = len(batch_q)\n",
        "        seqs_per_q = seqs.view(B, best_of, -1)\n",
        "        mean_lp_per_q = mean_lp.view(B, best_of)\n",
        "\n",
        "        # Choose argmax per question\n",
        "        best_idx = torch.argmax(mean_lp_per_q, dim=1)  # [B]\n",
        "        # Decode just the generated part (then take first line)\n",
        "        for b in range(B):\n",
        "            chosen = seqs_per_q[b, best_idx[b].item(), :]\n",
        "            gen_only = chosen[input_len:]  # tokens after the prompt\n",
        "            text = tokenizer.decode(gen_only, skip_special_tokens=True)\n",
        "            results.append(_pick_first_line(text))\n",
        "\n",
        "    return results\n",
        "\n",
        "\n",
        "def evaluate_model_bestof(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 1000,\n",
        "    batch_size: int = 4,\n",
        "    max_new_tokens: int = 32,\n",
        "    best_of: int = 8,\n",
        "    name: str = \"model-bestofN\",\n",
        ") -> Dict:\n",
        "    \"\"\"\n",
        "    Evaluate with Best-of-N sampling (highest mean token log-prob per question).\n",
        "    Returns the same dict schema as `evaluate_model`.\n",
        "    \"\"\"\n",
        "    # Build eval slice\n",
        "    qs, gold_list = [], []\n",
        "    for ex in dataset:\n",
        "        qs.append(ex[\"question\"])\n",
        "        gold_list.append(canonical_answers(ex[\"answer\"]))\n",
        "        if len(qs) >= limit:\n",
        "            break\n",
        "\n",
        "    preds = batched_generate_bestof(\n",
        "        model,\n",
        "        tokenizer,\n",
        "        qs,\n",
        "        best_of=best_of,\n",
        "        batch_size=batch_size,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "    )\n",
        "    assert len(preds) == len(gold_list)\n",
        "\n",
        "    ems, f1s = [], []\n",
        "    for p, g in zip(preds, gold_list):\n",
        "        em, f1 = em_and_f1(p, g)\n",
        "        ems.append(em); f1s.append(f1)\n",
        "\n",
        "    em_mean = float(np.mean(ems)) * 100.0\n",
        "    f1_mean_raw = float(np.mean(f1s))\n",
        "    f1_mean = f1_mean_raw * 100.0\n",
        "\n",
        "    # 95% normal-approx CI for F1\n",
        "    n = len(f1s)\n",
        "    if n > 1:\n",
        "        z = 1.96\n",
        "        std = float(np.std(f1s, ddof=1))\n",
        "        se = std / np.sqrt(n)\n",
        "        f1_lo = max(0.0, (f1_mean_raw - z * se) * 100.0)\n",
        "        f1_hi = min(100.0, (f1_mean_raw + z * se) * 100.0)\n",
        "    else:\n",
        "        f1_lo = f1_mean\n",
        "        f1_hi = f1_mean\n",
        "\n",
        "    print(f\"[{name}] N={len(preds)}  Best-of={best_of}  EM: {em_mean:.2f}  F1: {f1_mean:.2f}  \"\n",
        "          f\"F1 95% CI: [{f1_lo:.2f}, {f1_hi:.2f}]\")\n",
        "\n",
        "    return {\n",
        "        \"N\": len(preds),\n",
        "        \"EM\": em_mean,\n",
        "        \"F1\": f1_mean,\n",
        "        \"F1_CI\": (f1_lo, f1_hi),\n",
        "        \"F1_CI_Level\": 0.95,\n",
        "        \"preds\": preds,\n",
        "        \"golds\": gold_list,\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GHtbvCx3rK87",
      "metadata": {
        "id": "GHtbvCx3rK87"
      },
      "outputs": [],
      "source": [
        "# Greedy baseline\n",
        "USE_GREEDY = True\n",
        "BATCH_SIZE = 64\n",
        "MAX_NEW_TOKENS = 24\n",
        "EVAL_LIMIT = len(val_ds)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "el9lZDDl1hSH",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "el9lZDDl1hSH"
      },
      "outputs": [],
      "source": [
        "\n",
        "res_base_greedy = evaluate_model(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_greedy\")\n",
        "\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline Greedy:\", res_base_greedy[\"preds\"][i])\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "RQTjmepQ-jiM",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "RQTjmepQ-jiM"
      },
      "outputs": [],
      "source": [
        "# Best-of-N (e.g., N=32)\n",
        "BATCH_SIZE = 2\n",
        "res_base_bestof = evaluate_model_bestof(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, best_of=24, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_bestof24\")\n",
        "\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline Best-of-N:\", res_base_bestof[\"preds\"][i])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "rWY0OHvH1xLP",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "rWY0OHvH1xLP"
      },
      "outputs": [],
      "source": [
        "# Greedy trainable\n",
        "# _ = evaluate_model(trainable_model, tokenizer, val_ds, limit=500, batch_size=8, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_greedy\")\n",
        "# trainable_model.eval()  # keep in eval for decoding\n",
        "\n",
        "# Choose evaluation budget (increase if you have time/compute)\n",
        "BATCH_SIZE = 64\n",
        "MAX_NEW_TOKENS = 24\n",
        "\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "trainable_model.eval()\n",
        "res_trained_greedy = evaluate_model(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_greedy\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained Greedy:\", res_trained_greedy[\"preds\"][i])\n",
        "\n",
        "\n",
        "# Best-of-N (e.g., N=8)\n",
        "# _ = evaluate_model_bestof(trainable_model, tokenizer, val_ds, limit=500, batch_size=4, best_of=8, name=\"trainable_bestof8\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VnX-WgxQ-BXZ",
      "metadata": {
        "id": "VnX-WgxQ-BXZ"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 2\n",
        "res_trained_bestof = evaluate_model_bestof(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, best_of=24, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_bestof24\")\n",
        "for i in range(5):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained Best-of-N:\", res_trained_bestof[\"preds\"][i])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Pb94eMijnT2q",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "Pb94eMijnT2q"
      },
      "outputs": [],
      "source": [
        "# baseline\n",
        "USE_GREEDY = False\n",
        "BATCH_SIZE = 64\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "MAX_NEW_TOKENS = 24\n",
        "\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "trainable_model.eval()\n",
        "res_trained_greedy = evaluate_model(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_sampling\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained sampling:\", res_trained_greedy[\"preds\"][i])\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Kxghp_lE_1sp",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "Kxghp_lE_1sp"
      },
      "outputs": [],
      "source": [
        "res_trained_greedy = evaluate_model(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_sampling\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline sampling:\", res_trained_greedy[\"preds\"][i])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "uO0cn0IUP6UH",
      "metadata": {
        "id": "uO0cn0IUP6UH"
      },
      "source": [
        "## Check Coherence"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5DGKZyvoRTeJ",
      "metadata": {
        "id": "5DGKZyvoRTeJ"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "import torch\n",
        "from typing import Dict, Any, List\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "\n",
        "\n",
        "def _class_agrees(preds_for_class):\n",
        "    # If every prediction is NA, count as agreement (consistent failure)\n",
        "    ps = [(p or \"\").strip() for p in preds_for_class]\n",
        "    if all(p == \"NA\" for p in ps):\n",
        "        return True\n",
        "\n",
        "    vals = [p for p in ps if p != \"NA\"]\n",
        "    if len(vals) < 2:\n",
        "        return False\n",
        "    return len(set(vals)) == 1\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_gsm8k_paraphrase_agreement(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"agree\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,   # number of class samples to print total\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Evaluate agreement across paraphrases.\n",
        "\n",
        "    Args:\n",
        "      dataset: list-like of paraphrase classes, each is List[str]\n",
        "               e.g., classes_usable where each element is [orig, para1, para2, ...]\n",
        "      limit: number of classes to evaluate\n",
        "      batch_size: generation batch size (over paraphrase questions)\n",
        "      max_new_tokens: passed to answer_batch_chat\n",
        "      use_greedy: greedy vs sampling\n",
        "\n",
        "    Prints:\n",
        "      [{name}] N={num_classes}  AGREE: {mean:.2f} and AGREE confidence [lo, hi]\n",
        "\n",
        "    Returns:\n",
        "      {\n",
        "        \"preds\": per_class_preds,   # List[List[str]] answers for each class item\n",
        "        \"sample\": {\"indices\": sampled_indices, \"questions\": qs, \"golds\": None},\n",
        "        \"N\": num_classes,\n",
        "        \"AGREE\": agree_mean,\n",
        "        \"AGREE_CI\": (lo, hi),\n",
        "        \"AGREE_CI_Level\": 0.95,\n",
        "      }\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    try:\n",
        "        n_total = len(dataset)\n",
        "        n_classes = min(int(limit), n_total) if limit is not None else n_total\n",
        "        sampled_indices = list(range(n_classes))\n",
        "\n",
        "        per_class_preds: List[List[str]] = []\n",
        "        qs_sample: List[List[str]] = []  # store the paraphrases per class for debugging\n",
        "\n",
        "        num_batches = (n_classes + batch_size - 1) // batch_size\n",
        "        pbar = tqdm(total=n_classes, desc=f\"Paraphrase agreement {name}\", dynamic_ncols=True)\n",
        "\n",
        "        agree_yes = 0\n",
        "        printed = 0\n",
        "\n",
        "        for idx in sampled_indices:\n",
        "            cls = dataset[idx]  # List[str] paraphrases (includes original)\n",
        "            qs_sample.append(cls)\n",
        "\n",
        "            # Generate in chunks if class is large (so we respect batch_size)\n",
        "            cls_preds: List[str] = []\n",
        "            for start in range(0, len(cls), batch_size):\n",
        "                sub_qs = cls[start:start + batch_size]\n",
        "                sub_preds = answer_batch_chat(\n",
        "                    sub_qs,\n",
        "                    model=model,\n",
        "                    tokenizer=tokenizer,\n",
        "                    max_new_tokens=max_new_tokens,\n",
        "                    use_greedy=use_greedy,\n",
        "                )\n",
        "                cls_preds.extend(sub_preds)\n",
        "\n",
        "            per_class_preds.append(cls_preds)\n",
        "\n",
        "            agrees = _class_agrees(cls_preds)\n",
        "            if agrees:\n",
        "                agree_yes += 1\n",
        "\n",
        "            done = len(per_class_preds)\n",
        "            agree_mean_running = agree_yes / done if done else 0.0\n",
        "\n",
        "            pbar.set_postfix({\n",
        "                \"classes\": done,\n",
        "                \"AGREE%\": f\"{agree_mean_running*100:.2f}\",\n",
        "                \"greedy\": use_greedy,\n",
        "            })\n",
        "            pbar.update(1)\n",
        "\n",
        "            # Optional: print a few class samples total\n",
        "            if show_samples > 0 and printed < show_samples:\n",
        "                pbar.write(f\"[{name}] class_idx={idx} size={len(cls)} agrees={agrees}\")\n",
        "                # show first 2 questions and all preds\n",
        "                for j in range(min(2, len(cls))):\n",
        "                    q_preview = cls[j].replace(\"\\n\", \" \")\n",
        "                    if len(q_preview) > 160:\n",
        "                        q_preview = q_preview[:160] + \"...\"\n",
        "                    pbar.write(f\"  q{j}: {q_preview}\")\n",
        "                pbar.write(f\"  preds: {cls_preds}\")\n",
        "                printed += 1\n",
        "\n",
        "        pbar.close()\n",
        "\n",
        "        agree_mean = agree_yes / n_classes if n_classes else 0.0\n",
        "        lo, hi = _wilson_ci(agree_yes, n_classes, level=0.95)\n",
        "\n",
        "        print(\n",
        "            f\"[{name}] N={n_classes}  \"\n",
        "            f\"AGREE: {agree_mean*100:.2f} and AGREE confidence \"\n",
        "            f\"[{lo*100:.2f}, {hi*100:.2f}]\"\n",
        "        )\n",
        "\n",
        "        return {\n",
        "            \"preds\": per_class_preds,\n",
        "            \"sample\": {\n",
        "                \"indices\": sampled_indices,\n",
        "                \"questions\": qs_sample,\n",
        "                \"golds\": None,\n",
        "            },\n",
        "            \"N\": n_classes,\n",
        "            \"AGREE\": agree_mean,\n",
        "            \"AGREE_CI\": (lo, hi),\n",
        "            \"AGREE_CI_Level\": 0.95,\n",
        "        }\n",
        "\n",
        "    finally:\n",
        "        model.train(was_training)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sa-QOGDcEJhd",
      "metadata": {
        "id": "sa-QOGDcEJhd"
      },
      "outputs": [],
      "source": [
        "from typing import Dict, Any, List\n",
        "from tqdm.auto import tqdm\n",
        "import torch\n",
        "\n",
        "def evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,          # number of CLASSES per batch (not prompts)\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"agree\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Batched agreement eval when each class has ~2 items (orig + 1 paraphrase).\n",
        "\n",
        "    batch_size here means \"number of classes per batch\".\n",
        "    Each batch does 1 generation call over ~2*batch_size prompts.\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    try:\n",
        "        n_total = len(dataset)\n",
        "        n_classes = min(int(limit), n_total) if limit is not None else n_total\n",
        "        sampled_indices = list(range(n_classes))\n",
        "\n",
        "        per_class_preds: List[List[str]] = []\n",
        "        qs_sample: List[List[str]] = []\n",
        "\n",
        "        pbar = tqdm(total=n_classes, desc=f\"Paraphrase agreement {name}\", dynamic_ncols=True)\n",
        "        agree_yes = 0\n",
        "        printed = 0\n",
        "\n",
        "        for start in range(0, n_classes, batch_size):\n",
        "            end = min(start + batch_size, n_classes)\n",
        "            cls_batch = [dataset[i] for i in range(start, end)]  # list of classes\n",
        "\n",
        "            # Record questions for the sample output\n",
        "            qs_sample.extend(cls_batch)\n",
        "\n",
        "            # Flatten prompts: [c0_q0, c0_q1, c1_q0, c1_q1, ...]\n",
        "            flat_prompts: List[str] = []\n",
        "            lengths: List[int] = []\n",
        "            for cls in cls_batch:\n",
        "                # cls is [orig, para] (or sometimes only [orig])\n",
        "                lengths.append(len(cls))\n",
        "                flat_prompts.extend(cls)\n",
        "\n",
        "            # One big batched generation call\n",
        "            flat_preds = answer_batch_chat(\n",
        "                flat_prompts,\n",
        "                model=model,\n",
        "                tokenizer=tokenizer,\n",
        "                max_new_tokens=max_new_tokens,\n",
        "                use_greedy=use_greedy,\n",
        "            )\n",
        "\n",
        "            # Unflatten back into per-class preds\n",
        "            offset = 0\n",
        "            for bi, L in enumerate(lengths):\n",
        "                cls_preds = flat_preds[offset:offset + L]\n",
        "                offset += L\n",
        "\n",
        "                per_class_preds.append(cls_preds)\n",
        "\n",
        "                agrees = _class_agrees(cls_preds)\n",
        "                if agrees:\n",
        "                    agree_yes += 1\n",
        "\n",
        "                # Optional sample prints (total across run)\n",
        "                if show_samples > 0 and printed < show_samples:\n",
        "                    class_idx = start + bi\n",
        "                    pbar.write(f\"[{name}] class_idx={class_idx} size={L} agrees={agrees}\")\n",
        "                    for j in range(min(2, L)):\n",
        "                        q_preview = cls_batch[bi][j].replace(\"\\n\", \" \")\n",
        "                        if len(q_preview) > 160:\n",
        "                            q_preview = q_preview[:160] + \"...\"\n",
        "                        pbar.write(f\"  q{j}: {q_preview}\")\n",
        "                    pbar.write(f\"  preds: {cls_preds}\")\n",
        "                    printed += 1\n",
        "\n",
        "            done = len(per_class_preds)\n",
        "            agree_mean_running = agree_yes / done if done else 0.0\n",
        "            pbar.set_postfix({\n",
        "                \"classes\": done,\n",
        "                \"AGREE%\": f\"{agree_mean_running*100:.2f}\",\n",
        "                \"greedy\": use_greedy,\n",
        "            })\n",
        "            pbar.update(end - start)\n",
        "\n",
        "        pbar.close()\n",
        "\n",
        "        agree_mean = agree_yes / n_classes if n_classes else 0.0\n",
        "        lo, hi = _wilson_ci(agree_yes, n_classes, level=0.95)\n",
        "\n",
        "        print(\n",
        "            f\"[{name}] N={n_classes}  \"\n",
        "            f\"AGREE: {agree_mean*100:.2f} and AGREE confidence \"\n",
        "            f\"[{lo*100:.2f}, {hi*100:.2f}]\"\n",
        "        )\n",
        "\n",
        "        return {\n",
        "            \"preds\": per_class_preds,\n",
        "            \"sample\": {\n",
        "                \"indices\": sampled_indices,\n",
        "                \"questions\": qs_sample,\n",
        "                \"golds\": None,\n",
        "            },\n",
        "            \"N\": n_classes,\n",
        "            \"AGREE\": agree_mean,\n",
        "            \"AGREE_CI\": (lo, hi),\n",
        "            \"AGREE_CI_Level\": 0.95,\n",
        "        }\n",
        "\n",
        "    finally:\n",
        "        model.train(was_training)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "pnyq4rJrRjAe",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 243,
          "referenced_widgets": [
            "b6c0617396d44837ac2bc00893b812d1",
            "39b9ffa02eb84f85a76156cbc1c8a88b",
            "7f8b51a3fe4343f39fa6ddc435410a3a",
            "3a3884b6b2c045c2ae33bba4c0c462b4",
            "51096091a9594b3cbba9e1f06c439f65",
            "346936dae8d0460b88f574168fea2209",
            "a7531cf6b513470da02333b1420e5c9c",
            "da8588237dec4ae9877dc1ebb1236e48",
            "59f2130e9c564172b17f10efa43f72d6",
            "3345258ab7fd4dc085c81ca3eacb35dd",
            "441473c91daf4aeea3a91dc460512b2a"
          ]
        },
        "id": "pnyq4rJrRjAe",
        "outputId": "bb932ce5-2731-471c-d6c8-3f2fd568ee2e"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b6c0617396d44837ac2bc00893b812d1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Paraphrase agreement qwen3b:   0%|          | 0/7473 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[qwen3b] class_idx=0 size=2 agrees=True\n",
            "  q0: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n",
            "  q1: In April, Natalia sold clips to 48 of her friends, and in May she sold half as many. How many clips did Natalia sell in total between April and May?\n",
            "  preds: ['72', '72']\n",
            "[qwen3b] class_idx=1 size=2 agrees=True\n",
            "  q0: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\n",
            "  q1: Weng earns $12 per hour from babysitting. She worked for 50 minutes yesterday. How much did she earn?\n",
            "  preds: ['10', '10']\n",
            "[qwen3b] N=7473  AGREE: 67.14 and AGREE confidence [66.06, 68.19]\n"
          ]
        }
      ],
      "source": [
        "coherence_eval_limit = len(paraphrase_classes)\n",
        "\n",
        "# classes_usable is output of build_classes_full_batched(...)\n",
        "agree_res = evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=paraphrase_classes,\n",
        "    limit=coherence_eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=256,\n",
        "    name=\"qwen3b\",\n",
        "    use_greedy=True,\n",
        "    show_samples=2,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "XUChbdyqVsLQ",
      "metadata": {
        "id": "XUChbdyqVsLQ"
      },
      "source": [
        "## GRPO Implementation"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3HnZFHkuHZVx",
      "metadata": {
        "id": "3HnZFHkuHZVx"
      },
      "source": [
        "### GRPO Main"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dx7osq9fVvm1",
      "metadata": {
        "id": "dx7osq9fVvm1"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "from torch.utils.data import IterableDataset\n",
        "\n",
        "class ParaphrasePairIterableFormatted(IterableDataset):\n",
        "    \"\"\"\n",
        "    Infinite stream. Each \"pair\" yields exactly two consecutive items:\n",
        "      (q0, is_original=1) then (q1, is_original=0)\n",
        "    \"\"\"\n",
        "    def __init__(self, classes_usable, tokenizer, format_prompt_instruct, system_prompt, seed=0):\n",
        "        self.classes = classes_usable          # list[list[str]]; cls[0] is original\n",
        "        self.tok = tokenizer\n",
        "        self.format_prompt_instruct = format_prompt_instruct\n",
        "        self.system_prompt = system_prompt\n",
        "        self.rng = random.Random(seed)\n",
        "\n",
        "    def __iter__(self):\n",
        "        while True:\n",
        "            cls = self.rng.choice(self.classes)\n",
        "            q0 = cls[0]\n",
        "            q1 = self.rng.choice(cls[1:]) if len(cls) > 1 else q0\n",
        "            pair_id = self.rng.randrange(2**63)\n",
        "\n",
        "            p0 = self.format_prompt_instruct(q0, tokenizer=self.tok, system_prompt=self.system_prompt)\n",
        "            p1 = self.format_prompt_instruct(q1, tokenizer=self.tok, system_prompt=self.system_prompt)\n",
        "\n",
        "            yield {\"prompt\": p0, \"is_original\": 1, \"pair_id\": pair_id}\n",
        "            yield {\"prompt\": p1, \"is_original\": 0, \"pair_id\": pair_id}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1snZI84RW7ml",
      "metadata": {
        "id": "1snZI84RW7ml"
      },
      "outputs": [],
      "source": [
        "from collections import Counter\n",
        "\n",
        "\n",
        "def agreement_reward_fn(prompts, completions, **kwargs):\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    # majority vote among all answers (including \"NA\", per your current behavior)\n",
        "    if not answers:\n",
        "        return []\n",
        "\n",
        "    counts = Counter(answers)\n",
        "    max_count = max(counts.values())\n",
        "    majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "    # reward 1 to (tied) majority answer(s), 0 otherwise\n",
        "    return [1.0 if a in majority_set else 0.0 for a in answers]\n",
        "\n",
        "\n",
        "\n",
        "from collections import Counter\n",
        "\n",
        "def agreement_reward_(prompts, completions, **kwargs):\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    # majority vote among all answers (including \"NA\", per your current behavior)\n",
        "    if not answers:\n",
        "        return []\n",
        "\n",
        "    counts = Counter(answers)\n",
        "    max_count = max(counts.values())\n",
        "    majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "    # reward 1 to (tied) majority answer(s), 0 otherwise\n",
        "    return [1.0 if a in majority_set else 0.0 for a in answers]\n",
        "\n",
        "\n",
        "\n",
        "from collections import Counter\n",
        "\n",
        "def pair_agreement_reward_fn(prompts, completions, pair_id=None, is_original=None, **kwargs):\n",
        "    # pair_id and is_original should be aligned with (prompts, completions)\n",
        "    # TRL typically repeats prompt-level columns across generations.\n",
        "\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    rewards = [0.0] * len(answers)\n",
        "    if pair_id is None or is_original is None:\n",
        "        raise ValueError(\"Need pair_id and is_original passed into reward fn (from dataset columns).\")\n",
        "\n",
        "    # group indices by pair_id\n",
        "    by_pid = {}\n",
        "    for i, pid in enumerate(pair_id):\n",
        "        by_pid.setdefault(pid, []).append(i)\n",
        "\n",
        "    for pid, idxs in by_pid.items():\n",
        "        idx0 = [i for i in idxs if int(is_original[i]) == 1]  # original prompt (q0)\n",
        "        idx1 = [i for i in idxs if int(is_original[i]) == 0]  # paraphrase (q1)\n",
        "\n",
        "        a0 = [answers[i] for i in idx0 if answers[i] != \"NA\"]\n",
        "        a1 = [answers[i] for i in idx1 if answers[i] != \"NA\"]\n",
        "\n",
        "        # Only reward if BOTH sides have at least one non-NA answer\n",
        "        if len(a0) == 0 or len(a1) == 0:\n",
        "            continue\n",
        "\n",
        "        c0, c1 = Counter(a0), Counter(a1)\n",
        "        d0, d1 = len(a0), len(a1)\n",
        "\n",
        "        # Reward each completion by \"how often the other paraphrase produced the same answer\"\n",
        "        for i in idx0:\n",
        "            a = answers[i]\n",
        "            if a != \"NA\":\n",
        "                rewards[i] = c1.get(a, 0) / d1\n",
        "        for i in idx1:\n",
        "            a = answers[i]\n",
        "            if a != \"NA\":\n",
        "                rewards[i] = c0.get(a, 0) / d0\n",
        "\n",
        "    return rewards"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from collections import Counter\n",
        "\n",
        "def pair_majority_reward_fn(prompts, completions, pair_id=None, is_original=None, **kwargs):\n",
        "    \"\"\"\n",
        "    Reward = 1.0 iff the completion's extracted answer is in the (tied) majority vote\n",
        "    within its side (original vs paraphrase) for the same pair_id. Otherwise 0.0.\n",
        "\n",
        "    Separation between pairs: majority is computed independently per pair_id,\n",
        "    and also independently per side (is_original==1 vs ==0) within each pair.\n",
        "    \"\"\"\n",
        "    if pair_id is None or is_original is None:\n",
        "        raise ValueError(\"Need pair_id and is_original passed into reward fn (from dataset columns).\")\n",
        "\n",
        "    # Extract answers (same as before)\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    rewards = [0.0] * len(answers)\n",
        "\n",
        "    # Group indices by pair_id\n",
        "    by_pid = {}\n",
        "    for i, pid in enumerate(pair_id):\n",
        "        by_pid.setdefault(pid, []).append(i)\n",
        "\n",
        "    for pid, idxs in by_pid.items():\n",
        "        # Split indices by side\n",
        "        idx0 = [i for i in idxs if int(is_original[i]) == 1]  # original prompt (q0)\n",
        "        idx1 = [i for i in idxs if int(is_original[i]) == 0]  # paraphrase (q1)\n",
        "\n",
        "        # Majority vote within each side (including \"NA\", matching your current behavior)\n",
        "        for side_idxs in (idx0, idx1):\n",
        "            if not side_idxs:\n",
        "                continue\n",
        "\n",
        "            side_answers = [answers[i] for i in side_idxs]\n",
        "            counts = Counter(side_answers)\n",
        "            max_count = max(counts.values())\n",
        "            majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "            for i in side_idxs:\n",
        "                rewards[i] = 1.0 if answers[i] in majority_set else 0.0\n",
        "\n",
        "            # ensure NA stays 0\n",
        "            for i in side_idxs:\n",
        "                if answers[i] == \"NA\":\n",
        "                    rewards[i] = 0.0\n",
        "\n",
        "    return rewards\n"
      ],
      "metadata": {
        "id": "78-viI_4KVov"
      },
      "id": "78-viI_4KVov",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "xkJgG7HfW9Vn",
      "metadata": {
        "id": "xkJgG7HfW9Vn"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from trl import GRPOTrainer\n",
        "from collections.abc import Mapping\n",
        "import torch\n",
        "from trl import GRPOTrainer\n",
        "\n",
        "class OrigOnlyKLGRPOTrainer(GRPOTrainer):\n",
        "    def _prepare_inputs(self, inputs):\n",
        "        # inputs is list[dict]\n",
        "        is_orig_list = [ex[\"is_original\"] for ex in inputs]\n",
        "        pair_id_list = [ex[\"pair_id\"] for ex in inputs]\n",
        "\n",
        "        prepared = super()._prepare_inputs(inputs)\n",
        "\n",
        "        device = prepared[\"prompt_ids\"].device\n",
        "        B_prompts = prepared[\"prompt_ids\"].shape[0]   # prompt-level count (usually 2*PAIRS_PER_STEP)\n",
        "\n",
        "        def compress_expanded_to_prompt_level(t, B_prompts):\n",
        "            # t has length B_prompts*K (expanded); compress to length B_prompts\n",
        "            K = t.numel() // B_prompts\n",
        "            first_block = t[:B_prompts]\n",
        "\n",
        "            # If the first block already contains mixed values, it's interleaved by prompt:\n",
        "            # [q0,q1,q0,q1,...] per generation => first_block is exactly prompt-level.\n",
        "            if torch.unique(first_block).numel() > 1:\n",
        "                return first_block\n",
        "\n",
        "            # Otherwise grouped layout: [q0...q0, q1...q1]\n",
        "            idx = torch.arange(0, t.numel(), K, device=t.device)  # 0, K, 2K, ...\n",
        "            return t[idx]\n",
        "\n",
        "        def align_to_B(vec):\n",
        "            \"\"\"\n",
        "            Align vec to length B_prompts. vec may be:\n",
        "              - prompt-level (len == B_prompts)\n",
        "              - expanded (len == B_prompts*K)\n",
        "            \"\"\"\n",
        "            t = torch.as_tensor(vec, device=device).view(-1).long()\n",
        "\n",
        "            if t.numel() == B_prompts:\n",
        "                return t\n",
        "\n",
        "            if t.numel() > B_prompts:\n",
        "                # ✅ THIS is where \"already-expanded → compress\" goes\n",
        "                if t.numel() % B_prompts != 0:\n",
        "                    raise ValueError(f\"Cannot compress metadata: got {t.numel()}, need multiple of {B_prompts}\")\n",
        "                return compress_expanded_to_prompt_level(t, B_prompts)\n",
        "\n",
        "            # (rare) scalar -> broadcast\n",
        "            if t.numel() == 1:\n",
        "                return t.repeat(B_prompts)\n",
        "\n",
        "            raise ValueError(f\"Cannot align metadata: got {t.numel()}, need {B_prompts}\")\n",
        "\n",
        "        prepared[\"is_original_prompt\"] = align_to_B(is_orig_list)\n",
        "        prepared[\"pair_id_prompt\"] = align_to_B(pair_id_list)\n",
        "\n",
        "        # keep prompt-level values here; expand later in compute_loss to match rollouts\n",
        "        return prepared\n",
        "\n",
        "    def _get_logps(self, model, input_ids, attention_mask, logits_to_keep):\n",
        "        # TRL versions differ: some have _get_per_token_logps, newer use _get_per_token_logps_and_entropies\n",
        "        if hasattr(self, \"_get_per_token_logps\"):\n",
        "            return self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)\n",
        "\n",
        "        out = self._get_per_token_logps_and_entropies(\n",
        "            model,\n",
        "            input_ids,\n",
        "            attention_mask,\n",
        "            logits_to_keep,\n",
        "            compute_entropy=False,\n",
        "            pixel_values=None,\n",
        "            image_grid_thw=None,\n",
        "            num_images=None,\n",
        "        )\n",
        "        # out can be (logps, entropies) or {\"logps\":..., \"entropies\":...}\n",
        "        if isinstance(out, tuple):\n",
        "            return out[0]\n",
        "        return out[\"logps\"]\n",
        "\n",
        "    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
        "        import torch\n",
        "\n",
        "        if return_outputs:\n",
        "            raise ValueError(\"GRPOTrainer does not support return_outputs=True\")\n",
        "\n",
        "        prompt_ids, prompt_mask = inputs[\"prompt_ids\"], inputs[\"prompt_mask\"]\n",
        "        completion_ids, completion_mask = inputs[\"completion_ids\"], inputs[\"completion_mask\"]\n",
        "\n",
        "        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n",
        "        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)\n",
        "\n",
        "        logits_to_keep = completion_ids.size(1)\n",
        "        per_token_logps = self._get_logps(model, input_ids, attention_mask, logits_to_keep)  # [N, T]\n",
        "\n",
        "        # ---- dimensions ----\n",
        "        N = per_token_logps.shape[0]  # rollout-level rows\n",
        "        denom = completion_mask.sum(dim=1).clamp_min(1)  # [N]\n",
        "\n",
        "        # ---- advantages: align to N ----\n",
        "        advantages = inputs[\"advantages\"].view(-1)  # could be [N] or [B]\n",
        "        if advantages.numel() != N:\n",
        "            if N % advantages.numel() == 0:\n",
        "                advantages = advantages.repeat_interleave(N // advantages.numel())\n",
        "            else:\n",
        "                raise ValueError(f\"advantages has {advantages.numel()} elems but need {N}\")\n",
        "        advantages = advantages.to(per_token_logps.device)\n",
        "\n",
        "        # ---- PPO/GRPO ratio term ----\n",
        "        if \"old_per_token_logps\" in inputs:\n",
        "            ratio = torch.exp(per_token_logps - inputs[\"old_per_token_logps\"])\n",
        "        else:\n",
        "            ratio = torch.exp(per_token_logps - per_token_logps.detach())\n",
        "\n",
        "        per_token_pg = ratio * advantages.unsqueeze(1)  # [N, T]\n",
        "\n",
        "        beta = float(getattr(self, \"beta\", getattr(self.args, \"beta\", 0.0)))\n",
        "\n",
        "        # ---- KL term (masked to originals only) ----\n",
        "        is_orig = None\n",
        "        per_token_kl = 0.0\n",
        "\n",
        "        if beta != 0.0:\n",
        "            ref_per_token_logps = inputs[\"ref_per_token_logps\"]\n",
        "            per_token_kl_raw = (\n",
        "                torch.exp(ref_per_token_logps - per_token_logps)\n",
        "                - (ref_per_token_logps - per_token_logps)\n",
        "                - 1\n",
        "            )  # [N, T]\n",
        "\n",
        "            # prompt-level mask produced in _prepare_inputs\n",
        "            is_orig_prompt = inputs.get(\"is_original_prompt\", None)\n",
        "            if is_orig_prompt is None:\n",
        "                raise ValueError(\"Missing inputs['is_original_prompt'] from _prepare_inputs().\")\n",
        "\n",
        "            is_orig_prompt = torch.as_tensor(is_orig_prompt, device=per_token_logps.device).float().view(-1)  # [B]\n",
        "            B = is_orig_prompt.numel()\n",
        "\n",
        "            if N % B != 0:\n",
        "                raise ValueError(f\"Cannot expand is_original_prompt: N={N}, B={B}\")\n",
        "            K = N // B\n",
        "\n",
        "            # Detect whether rollout rows are interleaved-by-prompt (first B rows are all different prompts)\n",
        "            def _is_interleaved_by_prompt(prompt_ids, B):\n",
        "                reps = []\n",
        "                for i in range(min(B, prompt_ids.size(0))):\n",
        "                    row = prompt_ids[i]\n",
        "                    if not any(torch.equal(row, r) for r in reps):\n",
        "                        reps.append(row)\n",
        "                return len(reps) == B\n",
        "\n",
        "            interleaved = _is_interleaved_by_prompt(prompt_ids, B)\n",
        "\n",
        "            if interleaved:\n",
        "                # [p0,p1,...,pB-1, p0,p1,...] repeated K times\n",
        "                is_orig = is_orig_prompt.repeat(K)  # [N]\n",
        "            else:\n",
        "                # [p0...p0, p1...p1, ...] each repeated K times\n",
        "                is_orig = is_orig_prompt.repeat_interleave(K)  # [N]\n",
        "\n",
        "            # Mask KL to originals only\n",
        "            per_token_kl = per_token_kl_raw * is_orig.unsqueeze(1)  # [N, T]\n",
        "\n",
        "            # ---- logging: masked KL only (as requested) ----\n",
        "            with torch.no_grad():\n",
        "                # prompt-level counts\n",
        "                is_p = is_orig_prompt.float().view(-1)\n",
        "                p_n1 = int((is_p == 1).sum().item())\n",
        "                p_n0 = int((is_p == 0).sum().item())\n",
        "\n",
        "                # rollout-level counts (after expansion)\n",
        "                r_n1 = int((is_orig == 1).sum().item())\n",
        "                r_n0 = int((is_orig == 0).sum().item())\n",
        "\n",
        "            self.log({\n",
        "                \"dbg/prompt_n1\": p_n1,\n",
        "                \"dbg/prompt_n0\": p_n0,\n",
        "                \"dbg/rollout_n1\": r_n1,\n",
        "                \"dbg/rollout_n0\": r_n0,\n",
        "            })\n",
        "            with torch.no_grad():\n",
        "                # mean KL over completion tokens for each sample\n",
        "                kl_per_sample = (per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)\n",
        "\n",
        "                # raw mean KL (all samples)\n",
        "                kl_all = kl_per_sample.mean()\n",
        "\n",
        "                # orig-only mean KL (only samples with is_orig==1)\n",
        "                orig_mask = is_orig.bool()\n",
        "                kl_orig = kl_per_sample[orig_mask].mean() if orig_mask.any() else torch.tensor(0.0, device=kl_all.device)\n",
        "\n",
        "            self.log({\n",
        "                \"kl_all\": kl_all.detach().float().cpu().item(),\n",
        "                \"kl_orig_only\": kl_orig.detach().float().cpu().item(),\n",
        "            })\n",
        "\n",
        "        # ---- final loss ----\n",
        "        per_token_loss = -(per_token_pg - beta * per_token_kl)  # [N, T]\n",
        "        loss = ((per_token_loss * completion_mask).sum(dim=1) / denom).mean()\n",
        "\n",
        "        return loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "i5QbPheK3EXX",
      "metadata": {
        "id": "i5QbPheK3EXX"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainerCallback\n",
        "import torch\n",
        "\n",
        "class GSM8KEvalCallback(TrainerCallback):\n",
        "    def __init__(self, trainer, *, eval_fn, log_every: int, eval_kwargs: dict,\n",
        "                 reward_key=\"reward\", kl_key=\"kl_all_raw\"):\n",
        "        self.trainer = trainer\n",
        "        self.eval_fn = eval_fn\n",
        "        self.log_every = log_every\n",
        "        self.eval_kwargs = eval_kwargs\n",
        "        self.reward_key = reward_key\n",
        "        self.kl_key = kl_key\n",
        "        self._last_reward = None\n",
        "        self._last_kl = None\n",
        "\n",
        "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
        "        logs = logs or {}\n",
        "        # cache latest values when they appear\n",
        "        if self.reward_key in logs:\n",
        "            self._last_reward = logs[self.reward_key]\n",
        "        if self.kl_key in logs:\n",
        "            self._last_kl = logs[self.kl_key]\n",
        "\n",
        "        step = int(state.global_step)\n",
        "        if step == 0 or (step % self.log_every) != 0:\n",
        "            return control\n",
        "\n",
        "        # main process only\n",
        "        if hasattr(self.trainer, \"is_world_process_zero\") and not self.trainer.is_world_process_zero():\n",
        "            return control\n",
        "\n",
        "        def fmt(x):\n",
        "            return \"NA\" if x is None else f\"{x:.6g}\"\n",
        "\n",
        "        print(f\"[step {step}] {self.reward_key}={fmt(self._last_reward)}  {self.kl_key}={fmt(self._last_kl)}\")\n",
        "\n",
        "        # run eval\n",
        "        model = self.trainer.model\n",
        "        was_training = model.training\n",
        "        model.eval()\n",
        "\n",
        "\n",
        "        with torch.no_grad():\n",
        "            metrics = self.eval_fn(model=model, **self.eval_kwargs)\n",
        "        if was_training:\n",
        "            model.train()\n",
        "\n",
        "        self.trainer.log({f\"gsm8k/{k}\": v for k, v in metrics.items()})\n",
        "\n",
        "        if torch.cuda.is_available():\n",
        "            torch.cuda.empty_cache()\n",
        "        return control"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "51XzkxBxBJNa",
      "metadata": {
        "id": "51XzkxBxBJNa"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainerCallback\n",
        "import torch\n",
        "\n",
        "class PrintAndEvalOnSave(TrainerCallback):\n",
        "    def __init__(self, trainer, *, eval_fn, eval_kwargs: dict,\n",
        "                 reward_key=\"reward\", kl_key=\"kl_orig_only\",\n",
        "                 n1_key=\"dbg/rollout_n1\", n0_key=\"dbg/rollout_n0\"):\n",
        "        self.trainer = trainer\n",
        "        self.eval_fn = eval_fn\n",
        "        self.eval_kwargs = eval_kwargs\n",
        "        self.reward_key = reward_key\n",
        "        self.kl_key = kl_key\n",
        "        self.n1_key = n1_key\n",
        "        self.n0_key = n0_key\n",
        "\n",
        "        self._last_reward = None\n",
        "        self._last_kl = None\n",
        "        self._last_n1 = None\n",
        "        self._last_n0 = None\n",
        "\n",
        "        self._last_eval_step = -1\n",
        "        self._logging_eval = False\n",
        "\n",
        "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
        "        # Cache train stats whenever they are logged\n",
        "        if self._logging_eval:\n",
        "            return control\n",
        "        logs = logs or {}\n",
        "\n",
        "        if self.reward_key in logs:\n",
        "            self._last_reward = logs[self.reward_key]\n",
        "        if self.kl_key in logs:\n",
        "            self._last_kl = logs[self.kl_key]\n",
        "        if self.n1_key in logs:\n",
        "            self._last_n1 = logs[self.n1_key]\n",
        "        if self.n0_key in logs:\n",
        "            self._last_n0 = logs[self.n0_key]\n",
        "\n",
        "        return control\n",
        "\n",
        "    def on_save(self, args, state, control, **kwargs):\n",
        "        step = int(state.global_step)\n",
        "        if step == 0 or step == self._last_eval_step:\n",
        "            return control\n",
        "        self._last_eval_step = step\n",
        "\n",
        "        if hasattr(self.trainer, \"is_world_process_zero\") and not self.trainer.is_world_process_zero():\n",
        "            return control\n",
        "\n",
        "        def fmt(x):\n",
        "            return \"NA\" if x is None else f\"{x:.6g}\" if isinstance(x, (float, int)) else str(x)\n",
        "\n",
        "        print(\n",
        "            f\"[save step {step}] \"\n",
        "            f\"{self.reward_key}={fmt(self._last_reward)}  \"\n",
        "            f\"{self.kl_key}={fmt(self._last_kl)}  \"\n",
        "            f\"n1={fmt(self._last_n1)}  n0={fmt(self._last_n0)}\"\n",
        "        )\n",
        "\n",
        "        # Run eval once per checkpoint\n",
        "        model = self.trainer.model\n",
        "        was_training = model.training\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            metrics = self.eval_fn(model=model, **self.eval_kwargs)\n",
        "        if was_training:\n",
        "            model.train()\n",
        "\n",
        "        # Log eval metrics, but avoid triggering our own caching logic\n",
        "        self._logging_eval = True\n",
        "        try:\n",
        "            self.trainer.log({f\"gsm8k/{k}\": v for k, v in metrics.items()})\n",
        "        finally:\n",
        "            self._logging_eval = False\n",
        "\n",
        "        if torch.cuda.is_available():\n",
        "            torch.cuda.empty_cache()\n",
        "\n",
        "        return control"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "P8qP661M7LAo",
      "metadata": {
        "id": "P8qP661M7LAo"
      },
      "outputs": [],
      "source": [
        "### Llama and qwen params\n",
        "\n",
        "from trl import GRPOConfig\n",
        "\n",
        "K = 32               # rollouts per prompt\n",
        "PAIRS_PER_STEP = 4  # per device\n",
        "KL_WEIGHT = 0.05\n",
        "\n",
        "LOG_EVERY = 8\n",
        "SAVE_EVERY = 20 * LOG_EVERY   #\n",
        "EVAL_LIMIT = len(test_ds)\n",
        "\n",
        "TOTAL_STEPS = 200 * LOG_EVERY  # pick whatever\n",
        "\n",
        "MAX_NEW_TOKENS = 512\n",
        "\n",
        "train_dataset = ParaphrasePairIterableFormatted(\n",
        "    classes_usable=paraphrase_classes,\n",
        "    tokenizer=tokenizer,\n",
        "    format_prompt_instruct=format_prompt_instruct,\n",
        "    system_prompt=SYSTEM_PROMPT,\n",
        "    seed=0,\n",
        ")\n",
        "\n",
        "grpo_args = GRPOConfig(\n",
        "    output_dir=WORK_PATH,\n",
        "    per_device_train_batch_size=2 * PAIRS_PER_STEP,   # 2 prompts = (q0,q1)\n",
        "    num_generations=K,\n",
        "    generation_batch_size = (2 * PAIRS_PER_STEP) * K,     # MUST be divisible by K\n",
        "\n",
        "    max_steps=TOTAL_STEPS,           # required for IterableDataset (no __len__)\n",
        "    accelerator_config={\n",
        "        \"dispatch_batches\": False,   # <-- fixes your str concatenation crash\n",
        "        \"split_batches\": False,\n",
        "    },\n",
        "\n",
        "    # ✅ checkpointing\n",
        "    save_strategy=\"steps\",\n",
        "    save_steps=SAVE_EVERY,\n",
        "    save_total_limit=30,        # keep last 30 checkpoints (optional)\n",
        "    save_safetensors=True,     # optional\n",
        "\n",
        "    beta=KL_WEIGHT,\n",
        "    logging_steps=SAVE_EVERY, #LOG_EVERY,\n",
        "    max_completion_length=MAX_NEW_TOKENS,\n",
        "    learning_rate=1e-5,\n",
        "    dataloader_num_workers=0,\n",
        "    report_to=[]\n",
        ")\n",
        "\n",
        "\n",
        "trainer = OrigOnlyKLGRPOTrainer(\n",
        "    model=trainable_model,\n",
        "    args=grpo_args,\n",
        "    train_dataset=train_dataset,\n",
        "    processing_class=tokenizer,\n",
        "    reward_funcs=pair_majority_reward_fn,\n",
        ")\n",
        "\n",
        "eval_kwargs = dict(\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,          # or test split\n",
        "    limit=EVAL_LIMIT,              # whatever you use\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"trained_grpo\",\n",
        ")\n",
        "\n",
        "\n",
        "trainer.add_callback(\n",
        "    PrintAndEvalOnSave(\n",
        "        trainer,\n",
        "        eval_fn=evaluate_gsm8k,\n",
        "        eval_kwargs=eval_kwargs,\n",
        "        reward_key=\"reward\",\n",
        "        kl_key=\"kl_orig_only\",\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "                                     # llama majority vote\n",
        "if CKPT_LOAD_PATH:\n",
        "    trainer.train(resume_from_checkpoint=CKPT_LOAD_PATH)\n",
        "else:\n",
        "    trainer.train()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "c36ef80f6a5844c0996be20b99174001",
            "77173f4595d04b13835901a44459ca31",
            "251b3affacbb4960ab8bf9ed32846eac",
            "82b3538b27644ec1af7590511ff0e3df",
            "b2157dee918f43df82051cc5d732a454",
            "ec3be9c493074bc59a92e54ea73ff993",
            "2f296d2be6d245f0920fefd1fe1a6238",
            "83c0fddf60994a3d87fead56bdfe0ba0",
            "346ca0f9fb1d4b419cab0e2ba1613667",
            "b9dac3e0f2334d519994ae4a48ab44a3",
            "674f280a741241fe8d45f66c79854f66",
            "a83653628c70449192d8cb57ee28cbe9",
            "917281a4a2c34057b0def2278151ea93",
            "137b3f056d7f4e129e33ec050d48414f",
            "2d0c265acd9a4d71819e4185ccb0b6b9",
            "7af8dbf17624480b836be0e3cae58a6f",
            "0e35c0b8260748e3b9bf03a849bf1eea",
            "b1be8d7f613a4c7eaf179132d2561bcf",
            "0d213a47024a4e2094f008d607c8106d",
            "38faafa9b2a546ae9aeb23157c0ff021",
            "5ec946e4076c4bf8bec3da6caf68efab",
            "9b49048bc4004edcaa12dbd2a54144ec",
            "04071bf126594d6da8321b8955a1c65b",
            "fd5c75ff480b40bb8e36b5fd361de605",
            "927a6b39006048a6bb3a01ebdcccf2c5",
            "2f1650a54c8b4dac83c40e8d2673ae55",
            "895145a6853f4fce831e669ae9d2143a",
            "58d3f262baf34b2e8a93080b3a84078a",
            "8443464cc9ba4989b03885103110cf30",
            "755bcecd4dff430d93075086589764e9",
            "b402e3ccb16145ebbca9f84156ca44ab",
            "39a9791d71974f04881824b5d8329c0c",
            "5b88f6b250c74efda6b90af708b45c35",
            "41f75be0a84d4ee3b77fb035d606caaf",
            "caa8046611d84c989b19c49e9e5b3d08",
            "b4aeacbf641f4592ae18ebff7b8f975b",
            "a31fc33114a446f0be8be4f9a6471b10",
            "5de6f0fe0efe4be6b37b037215c9eef3",
            "0265a92c8dcb42d5a05f8ebc4979b477",
            "c7018c73765c44a6907811dfaea104fd",
            "91dbf3c0ccfe48258efdb90d4ae1d394",
            "61b22b9892fa40d3800bc83b170bcb00",
            "5843f563475e4fdd9443ff40b8271073",
            "62398a6c238c49c99982728f2ab8ea01",
            "53a250f5a9324300be750347202be76f",
            "0a493f8239634e0db75665d2914cae72",
            "96f8b20d52a34c209c494258d88cf3c8",
            "b1a12eddc2f44ddfb35bde51e96fc199",
            "d72d6bbd54af4ecbb9e13093fa21694d",
            "51ef546b4a9a43e280e0cf2feb1508aa",
            "665dbeb11ec3448f85852bde435c0da1",
            "273402276286463092c324b101c59601",
            "34cc57832ac24e0984ca5949813f0227",
            "e98a94ca7ea04254987dc8f055f88be6",
            "449196000f8d4871b6ba705e66d4f290",
            "7844cc9f801a46708d7dacc663e15f7a",
            "e5fa83622140415083f19a970d259dc4",
            "001146de57824784b70aed7f4a593f9f",
            "7d2d530a77a54099891746da666bb427",
            "f6f88c0d6fa049aa8250903f99ea2d3b",
            "0c5f203e7e9142eca7645a4929ebb129",
            "2a3f7af9b71241e58bfecb7f8a944e03",
            "9771ad7ebb694d7aaa0d8f1d34a38428",
            "b06bff990ed048ff98d38bbdfce3114d",
            "844b67e63dfc4992ac29ec24df6909b2",
            "ba95a62c687848b39dd84f4e643a4bbd",
            "eb61ff4eb86c4ca88a8fc6b6da801f5d",
            "d9413a542f64452082d6ec5fadf6fe30",
            "2ee552354cf2402bade39bd72e7efa2c",
            "325b042469a8410fb8fb2b0516f5d37d",
            "e9bf02ef4b124945a32c94901243412b",
            "7c12b48e75fc454f9d1b8f019e17876e",
            "25beb888633b4cd8a6fdb1f0d0d0ced1",
            "273021b6b8f34940900f6b99f555bc09",
            "c260bcf9957f42e68d07c54404d2dccc",
            "e25edd81ca9a43dca0047583ebab8ade",
            "35a222a0c74446bbbfedc30bd14394a8",
            "d4d031715a544f2e8b26c18d59d2c28d",
            "9bbd6ccdbe354198bd088bb7ee82a16f",
            "92f80130de624cb6aa572dad9b9a579d",
            "40ff3a6b33df450ebb0a1ca99af2100d",
            "02c1f6fce2c240209db6e4614e34f7de",
            "fb79b6da757a4817ab21b38ed543a5c2",
            "371f700d9add466caf67cfb2ad4795ce",
            "988e311f2a2a4f2d9548cee3c4a1cebd",
            "030fa296e29c4b5aa98e0e7b6b5c52fd",
            "56fab98371434fafb5bd5d67936045ff",
            "7df0ef87c19b46e598f837340953447a",
            "7b739b99a0f24798859e0046effb3335",
            "952bce6c047344f59ca5f392aba9fabf",
            "b37437cf9e3e43b588957a87a828848c",
            "006060c03f8f4c83904ae5b1e0ca581d",
            "d186476b39034e0c83f0c2b05eeff4a5",
            "5e3d047f381a40cdb27dbbdd374583e6",
            "208af50d0a4d4012affcd65fd686ef1e",
            "6894850ac2484010a1c0d2a6ac24eed5",
            "2c78341748fb43ccbf807d42fa7a7957",
            "6ff0bfc005c44c6c8815b36f150fa073",
            "2c3690523cd64c15835e51f610d81f1d",
            "310cb44652814df388dbcb586d454db0",
            "b76f2ab7adbd457bb44503230d7a05a1",
            "1d06caf9760d4c138d796d21929d8837",
            "e73c88572ad94a59ac5bed92cde2eb3d",
            "64673f99ade0414484cf703b8c8f42e9",
            "9a84a30201ad4e57b767aef60e9e5c07",
            "fb75d6c155a045968eb02c818f6484aa",
            "d99826be5e7243d986ed1b16471cdc7a",
            "5f762efb20ed44bb8ec9a213f26fa85e",
            "59748a4746354e1b8539c9e41236621a",
            "cc643c706e854c2fae91e7ddb3d3feb3"
          ]
        },
        "id": "iRXPcEI_Kiu9",
        "outputId": "feccbed4-c6d7-407a-ba81-285bf001444e"
      },
      "id": "iRXPcEI_Kiu9",
      "execution_count": null,
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.\n",
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1441' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1441/1600 5:20:42 < 35:26, 0.07 it/s, Epoch 0.90/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>160</td>\n",
              "      <td>-0.000200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>480</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1120</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[save step 160] reward=0.945312  kl_orig_only=0.00133161  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c36ef80f6a5844c0996be20b99174001",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 68.39 and ACC confidence [65.83, 70.84]  |  pred NA: 19/1319 (1.44%)  \n",
            "[save step 320] reward=0.976562  kl_orig_only=0.00160498  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a83653628c70449192d8cb57ee28cbe9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 69.14 and ACC confidence [66.60, 71.58]  |  pred NA: 15/1319 (1.14%)  \n",
            "[save step 480] reward=0.964844  kl_orig_only=0.00179966  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "04071bf126594d6da8321b8955a1c65b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 69.37 and ACC confidence [66.83, 71.80]  |  pred NA: 9/1319 (0.68%)  \n",
            "[save step 640] reward=0.957031  kl_orig_only=0.00190489  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "41f75be0a84d4ee3b77fb035d606caaf",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 68.84 and ACC confidence [66.29, 71.28]  |  pred NA: 7/1319 (0.53%)  \n",
            "[save step 800] reward=0.96875  kl_orig_only=0.00529411  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "53a250f5a9324300be750347202be76f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 69.07 and ACC confidence [66.52, 71.50]  |  pred NA: 8/1319 (0.61%)  \n",
            "[save step 960] reward=0.988281  kl_orig_only=0.00211212  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7844cc9f801a46708d7dacc663e15f7a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 69.29 and ACC confidence [66.75, 71.73]  |  pred NA: 7/1319 (0.53%)  \n",
            "[save step 1120] reward=1  kl_orig_only=0.00196523  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "eb61ff4eb86c4ca88a8fc6b6da801f5d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 69.14 and ACC confidence [66.60, 71.58]  |  pred NA: 3/1319 (0.23%)  \n",
            "[save step 1280] reward=0.988281  kl_orig_only=0.00163574  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d4d031715a544f2e8b26c18d59d2c28d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 68.61 and ACC confidence [66.06, 71.06]  |  pred NA: 3/1319 (0.23%)  \n",
            "[save step 1440] reward=0.996094  kl_orig_only=0.00829473  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7b739b99a0f24798859e0046effb3335",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 68.76 and ACC confidence [66.21, 71.21]  |  pred NA: 5/1319 (0.38%)  \n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1600' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1600/1600 6:20:37, Epoch 1/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>160</td>\n",
              "      <td>-0.000200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>480</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1120</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1440</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1600</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[save step 1600] reward=0.976562  kl_orig_only=0.005312  n1=4  n0=4\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "310cb44652814df388dbcb586d454db0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 69.37 and ACC confidence [66.83, 71.80]  |  pred NA: 3/1319 (0.23%)  \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "C_MXdDcl3Y70",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "64c565858ba245259d61e8cc7db85336",
            "b1345bf5ab304f66b1decc6b75723c49",
            "8aaa4aba009447b6bf7aacc0217320e7",
            "38576a9201d5449887c56eefc9e40782",
            "38ed94ffb06c4471ad562918944cc6d9",
            "0805508fb0924ad5b22caaf40cde4a27",
            "829bd5bc714c49b1bb4727091aafef98",
            "e081f3fa40b140928d72e293b5f183fc",
            "47cec929d24f43ba9964903f6dc60c71",
            "da9e73ebd00a4159a7df6e1269b6e1ea",
            "2609916bcc1c49aa8d18dbe31d3ff3f5",
            "1e3aa372aec74831943f083c02a7a14d",
            "7bee36f381ba410fbccc2de5e9528db5",
            "8e51a0ab5fe845469102dc0b9a10c4fa",
            "23e8803e888a4b549b7d075c7001aa39",
            "74d92c73431d4e8884319c1968105011",
            "c7b261a1890f492cab7542dcfdd9b59d",
            "8e4c530251a4417c9efd6704391e8ac4",
            "40a4643e69124288b23392eac5c14146",
            "08a4431774f24c75b8d444644ab51bbc",
            "7f0c70ff207a45af8d5e751d2ba6bf9c",
            "9175b0f83eee429d8c2b642de2a70e5e",
            "f5aa689e09ed440fb9f56c179bb265c9",
            "34e172988b824c1582cfd819e046c826",
            "e50ee262aedb465887063371888af73d",
            "83ecf75f478d4163b4137bf522cfc795",
            "54d2f6ff9c854e628af3f22e47cc9e44",
            "e17fd94fc6b7497fba6034b12ee10743",
            "7e2160fe500846438c02356d7c61305f",
            "cd1f0ef537164384a1c203978fb919cb",
            "2903ef393db64f0f94723d15854423a3",
            "712bf82474a24741b286830d8c5f09ff",
            "7583ecda0d00433e919347428773928b",
            "18ce7a5dbdb44472b3b84d81a13bdf7e",
            "4ca8064308b143e28cd73bd304b9ed83",
            "4575c0e40c0a4d9fbb9f27ce93520d72",
            "5b058c0f308e4f67b638c431716bd2ed",
            "5be74c54e1cd45ef8032e1ad2cfcadf4",
            "c700dee1151c4e419670fcb687230099",
            "254dfa7e95d1438eb3d638db4a531553",
            "c4bd00cd657d45c1b0550d0a882e4380",
            "67e18328058e446a9d6993a65ee9c419",
            "e0ea0f451a8a41b1a021b608977acbf1",
            "a4f8662c51c745f68b1ff72aa76f620e",
            "9acc7381d5b1426797cb66a9112a6911",
            "e90aaeabd9dd403c9337936ee63fdb25",
            "c484432bc7eb4735b6e573b8f3edc87d",
            "6c08e1c01a5b41c9aac3a1dcf746561c",
            "87a429738ad54504a5e9d85f3b62d274",
            "a68dd47670994fc8a73ba87df3002df6",
            "43aefa5dffd4481ca6452f9a191c7104",
            "991dd007c09a422191c56fde75f43e55",
            "6dc3745efacd4785ab7575f5a26bab4c",
            "ad350ba861c5477cb7d62b889e5b1aaa",
            "03deaf4188004bc78a45b4841b6be913",
            "ca20be70e6b34871901f6bedb7686672",
            "6efe992a340b41a493f3dc7d14b01dc0",
            "ee0e3b45bf3248dfa9f73971420ec62b",
            "6ac4140030014c768e8472e1a9a6f797",
            "ec0fb1046ff340ffa9a050e44db46d80",
            "e201cd53f07f4b0199e174765eec3188",
            "01463617e99b4ccd8c701adb2f28e4a4",
            "60e15841baab47e8a13245d0e1987861",
            "f8e1ff19275843bfb873010a36b843b1",
            "2d754b238b1343c4b42da6949a5dbff8",
            "bb133ad736554433b86d272da51db4bc",
            "c25fd1b2a7614cd8a257fa8b1ad2dfbc",
            "af578cd9cf5f44d9b8d921978d087687",
            "d18b9c5932384d059eeab9570c967a8e",
            "ae80b229d8cc47508613bac414c7b156",
            "e7eb013ea79f49f8990627838cecc7c0",
            "707c688d9b3d4ce3b143be6e0cc6186f",
            "95610e97f00747169d18a2d6e10d2c98",
            "0939ae8111eb479ba46b5d1d71dd2330",
            "e5582fe402584486885e6d1c2e64b4b9",
            "e4be724d889f48be9b9057d0631d6b62",
            "aef031448deb4717ac103c393099468e",
            "206d7699fc774334b83ad60b811b0b4e",
            "62dce4f6682b47f296485b8b3a9f1545",
            "2ad1efe489834087a1006d23af12d145",
            "5b7f08fd54d54974952cba793f4d8fba",
            "37f3d66df88d4ad9bc18290123d1bbe9",
            "024c5af233af4a86923d8042cec8d2f5",
            "761264599bdf466ea7389e3a4610925a",
            "d04a5fca29214b56b65160119dc50bd1",
            "1715c5e4363941849dbb18acfe3ec42b",
            "eff5a0f4044e496495f304c15e518b75",
            "ddd0c46883884014b665bee2518ca526",
            "d5a909a3d64a4a3ba4e48a8ce95a1464",
            "9fb1bc2a522e4638abd7bb91f216d73e",
            "f0f6e68424fe485caf3b9920c44d7db2",
            "19f850405ba1467881d8480c5f1939bc",
            "231e3eeca3504f91a75816d52b32856d",
            "48a543e859f541a79e83d7261d903d25",
            "3f5f6fce1b2e4230b4bc8fac6119836e",
            "7904d96d5c594e8487aa5f4f6a337a18",
            "ecfb15a27d594bedb493c33ab4b0b169",
            "1dfb7425bb8442a7b5003c2bb65a2574",
            "75360684e1394731aed53f617034c15f",
            "7a9b23a7f7874922be0e31e5055a00b9",
            "7ebf37faa9af4c4890cf16a7247d0e8f",
            "c39b6023a96f49789b8eb81331fee15e",
            "cc3d7bbfc91b4b84b853508b24fd0aaf",
            "00e1ce49ec70470f88b97d209b704656",
            "8b56adcad9ea446d96912ca6210b0d1b",
            "2d1fff1b0d8b4bb29a2effda2ee30723",
            "300545c787334b50b5cb1cc0fe14e92f",
            "5d74b673faf1401cbf517013ccfd4e17",
            "c9687ec0add942f095fb63d00547513c",
            "965e99c1182e4bb3810e32603d688d31"
          ]
        },
        "id": "C_MXdDcl3Y70",
        "outputId": "5aa3cadf-bfa9-4d5d-eed4-cf9fad27791f"
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.\n",
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='801' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [ 801/1600 2:46:01 < 2:46:01, 0.08 it/s, Epoch 0.50/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>160</td>\n",
              "      <td>-0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>480</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[save step 160] reward=0.539062  kl_orig_only=0.00124143  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "64c565858ba245259d61e8cc7db85336",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 69.14 and ACC confidence [66.60, 71.58]  |  pred NA: 27/1319 (2.05%)  \n",
            "[save step 320] reward=0.460938  kl_orig_only=0.0031046  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1e3aa372aec74831943f083c02a7a14d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 70.89 and ACC confidence [68.38, 73.28]  |  pred NA: 29/1319 (2.20%)  \n",
            "[save step 480] reward=0.53125  kl_orig_only=0.00391685  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f5aa689e09ed440fb9f56c179bb265c9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 71.04 and ACC confidence [68.53, 73.42]  |  pred NA: 29/1319 (2.20%)  \n",
            "[save step 640] reward=0.523438  kl_orig_only=0.00259704  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "18ce7a5dbdb44472b3b84d81a13bdf7e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 71.04 and ACC confidence [68.53, 73.42]  |  pred NA: 29/1319 (2.20%)  \n",
            "[save step 800] reward=0.5  kl_orig_only=0.00832482  n1=4  n0=4\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9acc7381d5b1426797cb66a9112a6911",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 71.72 and ACC confidence [69.23, 74.09]  |  pred NA: 29/1319 (2.20%)  \n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1600' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1600/1600 6:24:50, Epoch 1/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>160</td>\n",
              "      <td>-0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>480</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1120</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1440</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1600</td>\n",
              "      <td>0.000100</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[save step 960] reward=0.640625  kl_orig_only=0.00306134  n1=4  n0=4\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ca20be70e6b34871901f6bedb7686672"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 72.02 and ACC confidence [69.54, 74.38]  |  pred NA: 26/1319 (1.97%)  \n",
            "[save step 1120] reward=0.523438  kl_orig_only=0.00614992  n1=4  n0=4\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "c25fd1b2a7614cd8a257fa8b1ad2dfbc"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 72.48 and ACC confidence [70.01, 74.82]  |  pred NA: 25/1319 (1.90%)  \n",
            "[save step 1280] reward=0.554688  kl_orig_only=0.00496465  n1=4  n0=4\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "206d7699fc774334b83ad60b811b0b4e"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 72.93 and ACC confidence [70.47, 75.26]  |  pred NA: 26/1319 (1.97%)  \n",
            "[save step 1440] reward=0.585938  kl_orig_only=0.0114772  n1=4  n0=4\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d5a909a3d64a4a3ba4e48a8ce95a1464"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 72.63 and ACC confidence [70.16, 74.97]  |  pred NA: 27/1319 (2.05%)  \n",
            "[save step 1600] reward=0.351562  kl_orig_only=0.00569508  n1=4  n0=4\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/21 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7a9b23a7f7874922be0e31e5055a00b9"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 72.71 and ACC confidence [70.24, 75.04]  |  pred NA: 27/1319 (2.05%)  \n"
          ]
        }
      ],
      "source": [
        "                                     # llama\n",
        "if CKPT_LOAD_PATH:\n",
        "    trainer.train(resume_from_checkpoint=CKPT_LOAD_PATH)\n",
        "else:\n",
        "    trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Ge7S4ERp7oma",
      "metadata": {
        "id": "Ge7S4ERp7oma"
      },
      "outputs": [],
      "source": [
        "import gc, torch\n",
        "\n",
        "# if you still have a reference:\n",
        "del trainable_model          # and del any other refs like ref_model, baseline_model, trainer, etc.\n",
        "gc.collect()\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.empty_cache()\n",
        "    torch.cuda.ipc_collect()   # optional, sometimes helps in notebooks\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "k5M1jnQdXDxC",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "k5M1jnQdXDxC",
        "outputId": "8f0fe89a-4491-40c2-a7da-d5b6344bf349"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using masked KL penalty.\n"
          ]
        }
      ],
      "source": [
        "# Option B: Train in chunks with TRL GRPOTrainer, run your eval() between chunks.\n",
        "# Assumes you already have:\n",
        "#   - model, tokenizer\n",
        "#   - classes_usable (paraphrase classes)\n",
        "#   - val_ds (gsm8k validation split you made)\n",
        "#   - your evaluate_gsm8k(...) and (optional) evaluate_gsm8k_paraphrase_agreement_batched(...)\n",
        "#   - your agreement_reward_fn(...) (reward over completions)\n",
        "#\n",
        "# Notes:\n",
        "# - GRPOTrainer supports resuming training; we use max_steps to control total steps.\n",
        "# - For IterableDataset, max_steps is the right way to bound training.\n",
        "import os\n",
        "os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
        "\n",
        "import math\n",
        "from trl import GRPOConfig\n",
        "try:\n",
        "    # If you use the masked KL subclass we discussed earlier:\n",
        "    TrainerCls = MaskedKLGRPOTrainer\n",
        "    print(\"Using masked KL penalty.\")\n",
        "except NameError:\n",
        "    from trl import GRPOTrainer\n",
        "    TrainerCls = GRPOTrainer\n",
        "\n",
        "# --- dataset for TRL training (paired stream) ---\n",
        "train_dataset = ParaphrasePairIterableFormatted(paraphrase_classes, tokenizer, seed=42)\n",
        "\n",
        "# --- GRPO config ---\n",
        "K = SAMPLES_PER_PAIR\n",
        "\n",
        "# ----------------------------\n",
        "# Chunked training + evaluation\n",
        "# ----------------------------\n",
        "TOTAL_STEPS = LOG_EVERY * 200\n",
        "CHUNK_STEPS = LOG_EVERY * 20\n",
        "\n",
        "EVAL_LIMIT = 1000\n",
        "EVAL_BS = 64\n",
        "\n",
        "DO_PARAPHRASE_EVAL = False\n",
        "PARA_EVAL_LIMIT = 200\n",
        "PARA_EVAL_CLASS_BS = 16   # classes per batch (prompts ~ 2*bs if each class has 2 items)\n",
        "\n",
        "\n",
        "grpo_args = GRPOConfig(\n",
        "    output_dir=WORK_PATH,\n",
        "    per_device_train_batch_size=2,       # (q0, q1)\n",
        "    gradient_accumulation_steps=1,\n",
        "    generation_batch_size=K * 2,   # <-- must be divisible by num_generations\n",
        "    num_generations=K,                   # K rollouts per prompt\n",
        "    max_completion_length=MAX_NEW_TOKENS,\n",
        "    learning_rate=1e-5,\n",
        "    beta=KL_WEIGHT,                        # adjust if your version uses beta instead\n",
        "    logging_steps=LOG_EVERY,\n",
        "    save_steps= CHUNK_STEPS,\n",
        "    max_steps= TOTAL_STEPS,                         # we'll set this per chunk below\n",
        ")\n",
        "\n",
        "trainer = TrainerCls(\n",
        "    model=trainable_model,\n",
        "    args=grpo_args,\n",
        "    train_dataset=train_dataset,\n",
        "    processing_class=tokenizer,\n",
        "    reward_funcs=agreement_reward_fn,\n",
        "    )\n",
        "\n",
        "# Set collator AFTER init (works with your GRPOTrainer signature)\n",
        "trainer.data_collator = lambda feats: grpo_tokenizing_collator(\n",
        "    feats, tokenizer, max_prompt_length=2048\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "m2ssljPjv5gT",
      "metadata": {
        "id": "m2ssljPjv5gT"
      },
      "source": [
        "### Debug"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "oBWStm0RxM1n",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oBWStm0RxM1n",
        "outputId": "c628f178-7254-48d3-b019-6a5ff5bc3540"
      },
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "Can only concatenate tensors but got <class 'str'>",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-761637024.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# After building trainer, grab one batch from its dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_train_dataloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    864\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stop_iteration\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    865\u001b[0m         \u001b[0mfirst_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 866\u001b[0;31m         \u001b[0mnext_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_batch_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    867\u001b[0m         \u001b[0mbatch_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    868\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstop_iteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m_fetch_batches\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    820\u001b[0m                             \u001b[0mbatches\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    821\u001b[0m                     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 822\u001b[0;31m                         \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    823\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mRuntimeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    824\u001b[0m                         raise RuntimeError(\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    615\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    619\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    620\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mTypeError\u001b[0m: Can only concatenate tensors but got <class 'str'>"
          ]
        }
      ],
      "source": [
        "# After building trainer, grab one batch from its dataloader\n",
        "dl = trainer.get_train_dataloader()\n",
        "batch = next(iter(dl))\n",
        "\n",
        "loss, out = trainer.compute_loss(trainer.model, batch, return_outputs=True)\n",
        "\n",
        "print(\"beta:\", float(out[\"beta\"]))\n",
        "print(\"kl_mean_all:\", float(out[\"kl_mean_all\"]))\n",
        "print(\"kl_orig_mean:\", float(out[\"kl_orig_mean\"]))\n",
        "print(\"kl_para_mean:\", float(out[\"kl_para_mean\"]))\n",
        "print(\"masked_kl:\", float(out[\"masked_kl\"]))\n",
        "print(\"loss:\", float(loss.detach().cpu()))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "q-p8g9fkv7gq",
      "metadata": {
        "id": "q-p8g9fkv7gq"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "def debug_kl_mask_once(trainer, train_dataset, tokenizer, K: int):\n",
        "    \"\"\"\n",
        "    Runs one forward pass and prints KL stats split by is_original.\n",
        "    \"\"\"\n",
        "    # Build a single batch: exactly 2 prompts (q0,q1)\n",
        "    it = iter(train_dataset)\n",
        "    batch = [next(it), next(it)]  # two records from ParaphrasePairIterable\n",
        "    # TRL will usually handle collation; easiest is to use trainer's dataloader.\n",
        "    # But we can mimic minimal inputs expected by compute_loss:\n",
        "    prompts = [b[\"prompt\"] for b in batch]\n",
        "    is_orig = torch.tensor([b[\"is_original\"] for b in batch])\n",
        "\n",
        "    # Let TRL tokenize the prompts the same way it does internally:\n",
        "    # Many TRL versions expect \"prompt\" in inputs and will process it.\n",
        "    inputs = {\"prompt\": prompts, \"is_original\": is_orig}\n",
        "\n",
        "    # Call compute_loss with outputs\n",
        "    loss, outputs = trainer.compute_loss(trainer.model, inputs, return_outputs=True)\n",
        "\n",
        "    # Find KL vector key\n",
        "    kl_key = None\n",
        "    for k in [\"kl\", \"kl_div\", \"per_sample_kl\", \"kl_per_sample\"]:\n",
        "        if isinstance(outputs, dict) and k in outputs:\n",
        "            kl_key = k\n",
        "            break\n",
        "    if kl_key is None:\n",
        "        raise KeyError(f\"Couldn't find per-sample KL in outputs keys: {list(outputs.keys())}\")\n",
        "\n",
        "    kl = outputs[kl_key].detach().cpu()         # expected shape [B*G] = [2*K]\n",
        "    mask = is_orig.float().repeat_interleave(K).cpu()\n",
        "\n",
        "    # Split KL by prompt type\n",
        "    kl_orig = kl[mask == 1]\n",
        "    kl_para = kl[mask == 0]\n",
        "\n",
        "    print(\"KL key:\", kl_key, \"shape:\", tuple(kl.shape))\n",
        "    print(\"KL mean (orig):\", float(kl_orig.mean()) if len(kl_orig) else None)\n",
        "    print(\"KL mean (para):\", float(kl_para.mean()) if len(kl_para) else None)\n",
        "\n",
        "    if \"masked_kl\" in outputs:\n",
        "        print(\"masked_kl (trainer):\", float(outputs[\"masked_kl\"].detach().cpu()))\n",
        "    else:\n",
        "        # recompute masked KL ourselves\n",
        "        masked_kl = (kl * mask).sum() / mask.sum().clamp_min(1.0)\n",
        "        print(\"masked_kl (recomputed):\", float(masked_kl))\n",
        "\n",
        "    print(\"loss:\", float(loss.detach().cpu()))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "fr6HpoQjv9QY",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fr6HpoQjv9QY",
        "outputId": "573da719-a9b0-48d7-cdc3-c38e51ca5446"
      },
      "outputs": [
        {
          "ename": "ValueError",
          "evalue": "The GRPOTrainer does not support returning outputs",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-1024889168.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdebug_kl_mask_once\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mSAMPLES_PER_PAIR\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/tmp/ipython-input-1970023020.py\u001b[0m in \u001b[0;36mdebug_kl_mask_once\u001b[0;34m(trainer, train_dataset, tokenizer, K)\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0;31m# Call compute_loss with outputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m     \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m     \u001b[0;31m# Find KL vector key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/tmp/ipython-input-276940471.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, model, inputs, return_outputs, **kwargs)\u001b[0m\n\u001b[1;32m     10\u001b[0m         \u001b[0;31m# Run the normal GRPO forward to get outputs that include per-sample stats.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0;31m# We rely on GRPOTrainer producing a dict-like outputs containing KL.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0;31m# inputs[\"is_original\"]: shape [batch_prompts] (here batch_prompts should be 2)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/trl/extras/profiling.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     96\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mprofiling_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/trl/trainer/grpo_trainer.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, model, inputs, return_outputs, num_items_in_batch)\u001b[0m\n\u001b[1;32m   2104\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2105\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2106\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The GRPOTrainer does not support returning outputs\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2107\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_liger_kernel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2108\u001b[0m             \u001b[0;31m# Compute the loss using the liger grpo loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mValueError\u001b[0m: The GRPOTrainer does not support returning outputs"
          ]
        }
      ],
      "source": [
        "debug_kl_mask_once(trainer, train_dataset, tokenizer, K=SAMPLES_PER_PAIR)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sx5zM-iWwAJR",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "sx5zM-iWwAJR",
        "outputId": "34036967-7614-447c-9d61-5bc4162034de"
      },
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "Can only concatenate tensors but got <class 'str'>",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-3887887193.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0;31m# Train up to cur_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m     \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresume\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m     \u001b[0mresume\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   2323\u001b[0m                 \u001b[0mhf_hub_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_progress_bars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2324\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2325\u001b[0;31m             return inner_training_loop(\n\u001b[0m\u001b[1;32m   2326\u001b[0m                 \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2327\u001b[0m                 \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2616\u001b[0m                 \u001b[0mupdate_step\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2617\u001b[0m                 \u001b[0mnum_batches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgradient_accumulation_steps\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mupdate_step\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtotal_updates\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mremainder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2618\u001b[0;31m                 \u001b[0mbatch_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_batch_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_batches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2619\u001b[0m                 \u001b[0;31m# Store the number of batches for current gradient accumulation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2620\u001b[0m                 \u001b[0;31m# This is used to correctly scale the loss when the last accumulation step has fewer batches\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mget_batch_samples\u001b[0;34m(self, epoch_iterator, num_batches, device)\u001b[0m\n\u001b[1;32m   5652\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5653\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5654\u001b[0;31m                 \u001b[0mbatch_samples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   5655\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5656\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    864\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stop_iteration\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    865\u001b[0m         \u001b[0mfirst_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 866\u001b[0;31m         \u001b[0mnext_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_batch_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    867\u001b[0m         \u001b[0mbatch_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    868\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstop_iteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m_fetch_batches\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    820\u001b[0m                             \u001b[0mbatches\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    821\u001b[0m                     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 822\u001b[0;31m                         \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    823\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mRuntimeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    824\u001b[0m                         raise RuntimeError(\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    615\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    619\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    620\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mTypeError\u001b[0m: Can only concatenate tensors but got <class 'str'>"
          ]
        }
      ],
      "source": [
        "\n",
        "resume = None\n",
        "cur_target = 0\n",
        "while cur_target < TOTAL_STEPS:\n",
        "    cur_target = min(cur_target + CHUNK_STEPS, TOTAL_STEPS)\n",
        "\n",
        "    # Update trainer's total step target.\n",
        "    # HF Trainer uses args.max_steps as the global target; calling train() again will continue.\n",
        "    trainer.args.max_steps = cur_target\n",
        "\n",
        "    # Train up to cur_target\n",
        "    trainer.train(resume_from_checkpoint = resume)\n",
        "    resume = True\n",
        "\n",
        "    # --- Evaluation (your functions) ---\n",
        "    # 1) GSM8K ACC\n",
        "    _ = evaluate_gsm8k(\n",
        "        model=trainer.model,\n",
        "        tokenizer=tokenizer,\n",
        "        dataset=test_ds,\n",
        "        limit=EVAL_LIMIT,\n",
        "        batch_size=EVAL_BS,\n",
        "        max_new_tokens=MAX_NEW_TOKENS,\n",
        "        name=f\"grpo_step{cur_target}\",\n",
        "        use_greedy=True,\n",
        "    )\n",
        "\n",
        "    # 2) Optional paraphrase agreement (fast batched version)\n",
        "    if DO_PARAPHRASE_EVAL:\n",
        "        _ = evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "            model=trainer.model,\n",
        "            tokenizer=tokenizer,\n",
        "            dataset=paraphrase_classes,\n",
        "            limit=PARA_EVAL_LIMIT,\n",
        "            batch_size=PARA_EVAL_CLASS_BS,\n",
        "            max_new_tokens=MAX_NEW_TOKENS,\n",
        "            name=f\"agree_step{cur_target}\",\n",
        "            use_greedy=True,\n",
        "            show_samples=0,\n",
        "        )\n",
        "\n",
        "print(\"Done.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "SD17xtwHtI6d",
      "metadata": {
        "id": "SD17xtwHtI6d"
      },
      "source": [
        "## Load/Generated Answer-Level Partitions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "V5ZfIEeatNHc",
      "metadata": {
        "id": "V5ZfIEeatNHc"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from typing import List\n",
        "from collections import defaultdict\n",
        "\n",
        "# Assume something like:\n",
        "# from sentence_transformers import util as sbert_util\n",
        "# and 'embedder' is a global SentenceTransformer\n",
        "def cluster_embeddings_online(\n",
        "    embs: torch.Tensor,           # [K, D]\n",
        "    sim_threshold: float = 0.95,\n",
        ") -> List[int]:\n",
        "    \"\"\"\n",
        "    Simple online clustering:\n",
        "      - Start first cluster with embs[0]\n",
        "      - For each next embedding:\n",
        "            if cos_sim_to_any_center >= threshold -> join that cluster\n",
        "            else create a new cluster\n",
        "    Returns a list cluster_ids[k] in {0, 1, ...}\n",
        "    \"\"\"\n",
        "    K, D = embs.shape\n",
        "    centers = []          # list of [D] tensors\n",
        "    counts = []           # num points in each cluster\n",
        "    cluster_ids = []\n",
        "\n",
        "    for k in range(K):\n",
        "        e = embs[k : k + 1]          # [1, D] for util.cos_sim\n",
        "\n",
        "        best_cid = None\n",
        "        best_sim = -1.0\n",
        "\n",
        "        for cid, center in enumerate(centers):\n",
        "            sim = torch.nn.functional.cosine_similarity(\n",
        "                e, center.unsqueeze(0), dim=1\n",
        "            ).item()\n",
        "            if sim > best_sim:\n",
        "                best_sim = sim\n",
        "                best_cid = cid\n",
        "\n",
        "        if best_cid is not None and best_sim >= sim_threshold:\n",
        "            # join existing cluster\n",
        "            cluster_ids.append(best_cid)\n",
        "            # update running mean center\n",
        "            n = counts[best_cid]\n",
        "            centers[best_cid] = (centers[best_cid] * n + e.squeeze(0)) / (n + 1)\n",
        "            counts[best_cid] = n + 1\n",
        "        else:\n",
        "            # start new cluster\n",
        "            cid = len(centers)\n",
        "            centers.append(e.squeeze(0))\n",
        "            counts.append(1)\n",
        "            cluster_ids.append(cid)\n",
        "\n",
        "    return cluster_ids\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def cluster_answers_for_question(\n",
        "    answers: List[str],\n",
        "    *,\n",
        "    question: str,\n",
        "    embedder,\n",
        "    sim_threshold: float = 0.95,\n",
        "):\n",
        "    \"\"\"\n",
        "    Returns:\n",
        "      cluster_ids: [K] int\n",
        "      centers:     [C, D] tensor of cluster centers\n",
        "    \"\"\"\n",
        "    texts = [f\"Q: {question}\\nA: {a}\" for a in answers]\n",
        "    embs = embedder.encode(texts, convert_to_tensor=True, show_progress_bar=False)  # [K, D]\n",
        "\n",
        "    K, D = embs.shape\n",
        "    centers = []\n",
        "    counts = []\n",
        "    cluster_ids = []\n",
        "\n",
        "    for k in range(K):\n",
        "        e = embs[k]  # [D]\n",
        "        best_cid = None\n",
        "        best_sim = -1.0\n",
        "\n",
        "        for cid, c in enumerate(centers):\n",
        "            sim = torch.nn.functional.cosine_similarity(\n",
        "                e.unsqueeze(0), c.unsqueeze(0), dim=1\n",
        "            ).item()\n",
        "            if sim > best_sim:\n",
        "                best_sim = sim\n",
        "                best_cid = cid\n",
        "\n",
        "        if best_cid is not None and best_sim >= sim_threshold:\n",
        "            # join existing cluster\n",
        "            cluster_ids.append(best_cid)\n",
        "            n = counts[best_cid]\n",
        "            centers[best_cid] = (centers[best_cid] * n + e) / (n + 1)\n",
        "            counts[best_cid] = n + 1\n",
        "        else:\n",
        "            # create new cluster\n",
        "            cid = len(centers)\n",
        "            centers.append(e)\n",
        "            counts.append(1)\n",
        "            cluster_ids.append(cid)\n",
        "\n",
        "    centers_tensor = torch.stack(centers, dim=0)  # [C, D]\n",
        "    return cluster_ids, centers_tensor\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2izcDe6yyIxP",
      "metadata": {
        "id": "2izcDe6yyIxP"
      },
      "outputs": [],
      "source": [
        "import json, time, hashlib, random, pathlib\n",
        "from collections import Counter\n",
        "from typing import List, Dict, Any\n",
        "\n",
        "\n",
        "\n",
        "COH_CACHE_PATH = f\"cache/triviaqa_cluster.jsonl\"\n",
        "COH_CACHE_PATH = os.path.join(COHERENCE_DATA_PATH, COH_CACHE_PATH)\n",
        "COH_CACHE_PATH\n",
        "\n",
        "def _coh_meta_hash(meta: dict) -> str:\n",
        "    key = json.dumps(meta, sort_keys=True).encode(\"utf-8\")\n",
        "    return hashlib.sha1(key).hexdigest()[:10]\n",
        "\n",
        "\n",
        "def _coh_meta_compatible(current: dict, cached: dict) -> bool:\n",
        "    keys = [\n",
        "        \"train_size\",\n",
        "        \"n_questions\",\n",
        "        \"samples_per_x\",\n",
        "        \"baseline_model_id\",\n",
        "        \"clustering_version\",\n",
        "        \"seed\",\n",
        "    ]\n",
        "    return all(current.get(k) == cached.get(k) for k in keys)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def _sample_baseline_answers_for_x(\n",
        "    x: str,\n",
        "    *,\n",
        "    model,\n",
        "    tokenizer,\n",
        "    K: int,\n",
        "    max_new_tokens: int = 32,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        ") -> List[str]:\n",
        "    ys = answer_batch_chat(\n",
        "        [x] * K,\n",
        "        model=model,\n",
        "        tokenizer=tokenizer,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        temperature=temperature,\n",
        "        top_p=top_p,\n",
        "    )\n",
        "    return ys\n",
        "\n",
        "\n",
        "def build_or_load_coherence_set_cluster(\n",
        "    train_ds,\n",
        "    *,\n",
        "    baseline_model,\n",
        "    tokenizer,\n",
        "    embedder,\n",
        "    paraphrase_classes_all = paraphrase_classes_all,              # <--- NEW: list per train_idx\n",
        "    cache_path: str = COH_CACHE_PATH,\n",
        "    n_questions: int = 600,\n",
        "    samples_per_x: int = 32,\n",
        "    sim_threshold: float = 0.95,\n",
        "    seed: int = 0,\n",
        ") -> List[Dict[str, Any]]:\n",
        "    \"\"\"\n",
        "    Coherence via paraphrases:\n",
        "      For each chosen train_idx:\n",
        "        - Get original question q0 and paraphrases q1..qM-1 from paraphrase_classes_all[train_idx]\n",
        "        - Generate 'samples_per_x' answers for EACH of these questions.\n",
        "          Total rollouts = samples_per_x * M.\n",
        "        - Cluster ALL answers together into semantic classes z.\n",
        "        - nu0(z | x): empirical distribution over clusters using only answers to q0.\n",
        "        - nu_coh(z | x): empirical distribution over clusters using all answers (q0 + paraphrases).\n",
        "        - ratio(z) = nu_coh(z) / nu0(z) for z with nu0(z) > 0.\n",
        "    \"\"\"\n",
        "    path = pathlib.Path(cache_path)\n",
        "    meta_current = {\n",
        "        \"train_size\": len(train_ds),\n",
        "        \"n_questions\": n_questions,\n",
        "        \"samples_per_x\": samples_per_x,\n",
        "        \"baseline_model_id\": getattr(baseline_model, \"name_or_path\", \"unknown\"),\n",
        "        \"clustering_version\": f\"cluster_sim_{sim_threshold}_with_paraphrases\",\n",
        "        \"seed\": seed,\n",
        "    }\n",
        "\n",
        "    # ---------- Try loading existing cache ----------\n",
        "    if path.exists():\n",
        "        records: List[Dict[str, Any]] = []\n",
        "        cached_meta = None\n",
        "        with path.open(\"r\", encoding=\"utf-8\") as f:\n",
        "            for line in f:\n",
        "                obj = json.loads(line)\n",
        "                if \"__meta__\" in obj:\n",
        "                    cached_meta = obj[\"__meta__\"]\n",
        "                else:\n",
        "                    records.append(obj)\n",
        "\n",
        "        if cached_meta is not None and _coh_meta_compatible(meta_current, cached_meta):\n",
        "            print(f\"[coh] Loaded clustered coherence set from {path}\")\n",
        "            print(f\"[coh] {len(records)} questions, \"\n",
        "                  f\"{cached_meta['samples_per_x']} samples each (per question variant)\")\n",
        "            return records\n",
        "        else:\n",
        "            print(\"[coh] Cache exists but metadata mismatch; rebuilding.\")\n",
        "\n",
        "    # ---------- Need to build a new coherence set ----------\n",
        "    print(\"[coh] Building new clustered coherence set (with paraphrases)...\")\n",
        "    path.parent.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "    rng = random.Random(seed)\n",
        "    all_idxs = list(range(len(train_ds)))\n",
        "    chosen_idxs = rng.sample(all_idxs, n_questions)\n",
        "\n",
        "    records: List[Dict[str, Any]] = []\n",
        "\n",
        "    for i, train_idx in enumerate(chosen_idxs):\n",
        "        ex = train_ds[train_idx]\n",
        "\n",
        "        # 1) Original question + paraphrases\n",
        "        q_variants = paraphrase_classes_all[train_idx]\n",
        "        if not q_variants:\n",
        "            # Fallback: if somehow no paraphrases exist, just use the original question\n",
        "            q0 = ex[\"question\"]\n",
        "            q_variants = [q0]\n",
        "        else:\n",
        "            # Ensure index 0 is the original as you described\n",
        "            q0 = q_variants[0]\n",
        "\n",
        "        # 2) Generate answers for each question variant\n",
        "        all_answers: List[str] = []\n",
        "        answer_src: List[int] = []  # which variant index produced each answer (0 = original)\n",
        "\n",
        "        for j, q_text in enumerate(q_variants):\n",
        "            ys = _sample_baseline_answers_for_x(\n",
        "                q_text,\n",
        "                model=baseline_model,\n",
        "                tokenizer=tokenizer,\n",
        "                K=samples_per_x,\n",
        "            )\n",
        "            all_answers.extend(ys)\n",
        "            answer_src.extend([j] * samples_per_x)\n",
        "\n",
        "        # 3) CLUSTERING-BASED SEMANTICS over ALL answers\n",
        "        #    (cluster IDs shared between original+paraphrases)\n",
        "        cluster_ids, centers = cluster_answers_for_question(\n",
        "            all_answers,\n",
        "            question=q0,             # cluster function can ignore this if it's answer-only\n",
        "            embedder=embedder,\n",
        "            sim_threshold=sim_threshold,\n",
        "        )\n",
        "\n",
        "        z_all = [f\"c{cid}\" for cid in cluster_ids]  # semantics E(y, x) for all rollouts\n",
        "\n",
        "        # 3a) nu_coh(z | x): empirical distribution over ALL rollouts (original + paraphrases)\n",
        "        counts_all = Counter(z_all)\n",
        "        total_all = float(len(z_all))\n",
        "        nu_coh = {z: c / total_all for z, c in counts_all.items()}\n",
        "\n",
        "        # 3b) nu0(z | x): distribution using ONLY answers to the original question (variant index 0)\n",
        "        z_orig = [z for z, src in zip(z_all, answer_src) if src == 0]\n",
        "        counts_orig = Counter(z_orig)\n",
        "        total_orig = float(len(z_orig))  # should be samples_per_x\n",
        "        nu0 = {z: c / total_orig for z, c in counts_orig.items()}\n",
        "\n",
        "        # 3c) ratio(z) = nu_coh(z) / nu0(z) where nu0(z) > 0\n",
        "        ratio = {}\n",
        "        for z, p0 in nu0.items():\n",
        "            if p0 > 0.0:\n",
        "                p_coh = nu_coh.get(z, 0.0)\n",
        "                ratio[z] = p_coh / p0\n",
        "\n",
        "        rec = {\n",
        "            \"train_idx\": train_idx,\n",
        "            \"question\": q0,                # original question\n",
        "            \"questions\": q_variants,       # original + paraphrases\n",
        "            \"answers\": all_answers,        # flattened, length = samples_per_x * len(q_variants)\n",
        "            \"answer_src\": answer_src,      # variant index for each answer\n",
        "            \"z_list\": z_all,               # cluster ID per answer\n",
        "            \"nu0\": nu0,                    # distribution from original-only rollouts\n",
        "            \"nu_coh\": nu_coh,              # distribution from all (original + paraphrases)\n",
        "            \"ratio\": ratio,                # nu_coh(z) / nu0(z) for z with nu0(z) > 0\n",
        "            \"centers\": centers.cpu().tolist(),\n",
        "        }\n",
        "        records.append(rec)\n",
        "\n",
        "        if (i + 1) % 50 == 0:\n",
        "            print(f\"[coh] processed {i+1}/{n_questions} base questions\")\n",
        "\n",
        "    # ---------- Save cache ----------\n",
        "    meta_to_save = dict(meta_current)\n",
        "    meta_to_save[\"created_utc\"] = int(time.time())\n",
        "    meta_to_save[\"hash\"] = _coh_meta_hash(meta_to_save)\n",
        "\n",
        "    # Uncomment this block when you're ready to persist to disk\n",
        "\n",
        "    tmp = path.with_suffix(path.suffix + \".tmp\")\n",
        "    with tmp.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta_to_save}) + \"\\n\")\n",
        "        for rec in records:\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    tmp.replace(path)\n",
        "\n",
        "    print(f\"[coh] Saved clustered coherence set to {path}\")\n",
        "\n",
        "\n",
        "    return records\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "DrkEk66cDZ4i",
      "metadata": {
        "id": "DrkEk66cDZ4i"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from typing import List, Union, Iterable, Optional\n",
        "\n",
        "class LMEmbedder:\n",
        "    \"\"\"\n",
        "    Generic wrapper to use any HF LM (Qwen, LLaMA, Phi-3, etc.)\n",
        "    as a sentence embedder with a SentenceTransformer-like interface.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        model,\n",
        "        tokenizer,\n",
        "        max_length: int = 128,\n",
        "        normalize: bool = True,\n",
        "        device: Optional[Union[str, torch.device]] = None,\n",
        "        batch_size: int = 32,\n",
        "    ):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model: HF model (e.g. AutoModelForCausalLM / LlamaForCausalLM / Phi3ForCausalLM, possibly PEFT-wrapped)\n",
        "            tokenizer: matching tokenizer\n",
        "            max_length: truncation length for inputs\n",
        "            normalize: whether to L2-normalize embeddings\n",
        "            device: override device; if None, use model.device\n",
        "            batch_size: batching size for .encode()\n",
        "        \"\"\"\n",
        "        self.model = model\n",
        "        self.tokenizer = tokenizer\n",
        "        self.max_length = max_length\n",
        "        self.normalize = normalize\n",
        "        self.batch_size = batch_size\n",
        "\n",
        "        # Ensure we can access hidden states\n",
        "        if hasattr(self.model.config, \"output_hidden_states\"):\n",
        "            self.model.config.output_hidden_states = True\n",
        "\n",
        "        self.model.eval()\n",
        "\n",
        "        # Figure out device\n",
        "        if device is None:\n",
        "            # works for most HF models, including PEFT-wrapped\n",
        "            try:\n",
        "                self.device = next(self.model.parameters()).device\n",
        "            except StopIteration:\n",
        "                self.device = torch.device(\"cpu\")\n",
        "        else:\n",
        "            self.device = torch.device(device)\n",
        "\n",
        "        # Make sure we have a pad_token_id\n",
        "        if self.tokenizer.pad_token_id is None:\n",
        "            # common choice: use EOS as pad\n",
        "            if self.tokenizer.eos_token_id is not None:\n",
        "                self.tokenizer.pad_token = self.tokenizer.eos_token\n",
        "                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def encode(\n",
        "        self,\n",
        "        texts: Union[str, List[str], Iterable[str]],\n",
        "        convert_to_tensor: bool = True,\n",
        "        show_progress_bar: bool = False,\n",
        "    ):\n",
        "        \"\"\"\n",
        "        SentenceTransformer-style API.\n",
        "\n",
        "        Args:\n",
        "            texts: single string or list/iterable of strings\n",
        "            convert_to_tensor: if True, returns a torch.Tensor [N, D]\n",
        "                               else returns a list of numpy arrays\n",
        "            show_progress_bar: ignored here, but kept for API compatibility\n",
        "\n",
        "        Returns:\n",
        "            embeddings: [N, D] tensor or list of np.array\n",
        "        \"\"\"\n",
        "        if isinstance(texts, str):\n",
        "            texts = [texts]\n",
        "        else:\n",
        "            texts = list(texts)\n",
        "\n",
        "        all_embs = []\n",
        "        bs = self.batch_size\n",
        "\n",
        "        for i in range(0, len(texts), bs):\n",
        "            batch_texts = texts[i : i + bs]\n",
        "\n",
        "            enc = self.tokenizer(\n",
        "                batch_texts,\n",
        "                padding=True,\n",
        "                truncation=True,\n",
        "                max_length=self.max_length,\n",
        "                return_tensors=\"pt\",\n",
        "            )\n",
        "            input_ids = enc[\"input_ids\"].to(self.device)\n",
        "            attention_mask = enc[\"attention_mask\"].to(self.device)\n",
        "\n",
        "            out = self.model(\n",
        "                input_ids=input_ids,\n",
        "                attention_mask=attention_mask,\n",
        "                output_hidden_states=True,\n",
        "            )\n",
        "\n",
        "            # last hidden state: [B, T, D]\n",
        "            # some models return .last_hidden_state, some in hidden_states[-1]\n",
        "            if hasattr(out, \"hidden_states\") and out.hidden_states is not None:\n",
        "                hidden = out.hidden_states[-1]\n",
        "            else:\n",
        "                # fallback to last_hidden_state if hidden_states is not present\n",
        "                hidden = out.last_hidden_state\n",
        "\n",
        "            # mean-pool over non-pad tokens\n",
        "            mask = attention_mask.unsqueeze(-1)  # [B, T, 1]\n",
        "            masked_hidden = hidden * mask\n",
        "            lengths = mask.sum(dim=1).clamp(min=1)  # [B, 1]\n",
        "            emb = masked_hidden.sum(dim=1) / lengths  # [B, D]\n",
        "\n",
        "            if self.normalize:\n",
        "                emb = F.normalize(emb, p=2, dim=1)\n",
        "\n",
        "            all_embs.append(emb.cpu())\n",
        "\n",
        "        embs = torch.cat(all_embs, dim=0)  # [N, D]\n",
        "\n",
        "        if convert_to_tensor:\n",
        "            return embs\n",
        "        else:\n",
        "            return [e.numpy() for e in embs]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Zp-xhGcSyXRq",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 404
        },
        "id": "Zp-xhGcSyXRq",
        "outputId": "503eaed4-b348-42d1-ce84-873811d4939b"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "28501ff67bde42a6a9410c55c8977bb9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "994a0db0c63545118bce1b33958c1614",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "cc3f6cd9667a49edafbab2449e02b4e5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0d5d447849144930962682f67ed4b1cb",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4288dfb5897345768b2378e8d159f539",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f1e6da22975748af9609388a2c5d5798",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ce24fe0613244bc6887926753798d897",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a593af02a7cf489f98067e067d0c5a38",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.txt: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e3507cdb84464004b6add70103ccfe09",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c05e536a07e44073895cc90f77f7655d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a6837fefcfe9491787b137373a41995e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[coh] Loaded clustered coherence set from /content/drive/MyDrive/coherence-answer/Qwen/Qwen2.5-3B-Instruct/instruct/cache/triviaqa_cluster.jsonl\n",
            "[coh] 6400 questions, 16 samples each (per question variant)\n"
          ]
        }
      ],
      "source": [
        "from sentence_transformers import SentenceTransformer\n",
        "\n",
        "embedder = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\")\n",
        "\n",
        "# embedder = LMEmbedder(\n",
        "#     model=baseline_model,\n",
        "#     tokenizer=tokenizer,\n",
        "#     max_length=128,\n",
        "#     normalize=True,\n",
        "#     batch_size=32,\n",
        "# )\n",
        "\n",
        "coh_records = build_or_load_coherence_set_cluster(\n",
        "    train_ds,\n",
        "    baseline_model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    embedder = embedder,\n",
        "    n_questions=32 * 200,\n",
        "    samples_per_x=16,\n",
        "    cache_path=COH_CACHE_PATH,\n",
        "    seed = 42\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ikGAzn8nFtkS",
      "metadata": {
        "id": "ikGAzn8nFtkS"
      },
      "outputs": [],
      "source": [
        "\n",
        "import json\n",
        "\n",
        "for rec in coh_records[:1]:\n",
        "    temp_rec = {k: v for k, v in rec.items() if k != 'centers'}\n",
        "    print(json.dumps(temp_rec, indent=2))\n",
        "\n",
        "    # Fetch and print ground truth\n",
        "    question_idx = rec[\"train_idx\"]\n",
        "    ground_truth_obj = train_ds[question_idx][\"answer\"]\n",
        "    print(f\"  Ground Truth Answers (train_idx={question_idx}): {ground_truth_obj}\")\n",
        "    print(\"\\n\" + \"=\"*80 + \"\\n\") # Separator for readability\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "aJqwV-W114QE",
      "metadata": {
        "id": "aJqwV-W114QE"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def E_cluster_for_known_x(\n",
        "    y: str,\n",
        "    x: str,\n",
        "    rec: dict,        # one coherence record for this x\n",
        "    embedder,\n",
        "    min_sim: float = 0.0,  # or e.g. 0.7 if you want a floor\n",
        ") -> str:\n",
        "    \"\"\"\n",
        "    Map a new completion y for a *known* question x\n",
        "    onto one of the existing clusters in rec.\n",
        "    Returns a cluster label like 'c0', 'c1', ...\n",
        "    \"\"\"\n",
        "    # 1) embed the (x, y) pair the same way as during clustering\n",
        "    text = f\"Q: {x}\\nA: {y}\"\n",
        "    emb = embedder.encode([text], convert_to_tensor=True, show_progress_bar=False)[0]  # [D]\n",
        "\n",
        "    # 2) load centers\n",
        "    centers = torch.tensor(rec[\"centers\"], dtype=emb.dtype, device=emb.device)  # [C, D]\n",
        "\n",
        "    # 3) cosine similarity to each center\n",
        "    sims = torch.nn.functional.cosine_similarity(\n",
        "        emb.unsqueeze(0), centers, dim=1\n",
        "    )  # [C]\n",
        "\n",
        "    best_cid = int(torch.argmax(sims).item())\n",
        "    best_sim = float(sims[best_cid].item())\n",
        "\n",
        "    if best_sim < min_sim:\n",
        "        # Optional: treat as 'other' cluster; depends on your design.\n",
        "        # For many WSFT setups it’s simpler to just force assignment.\n",
        "        pass\n",
        "\n",
        "    return f\"c{best_cid}\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0jXMFGTV1607",
      "metadata": {
        "id": "0jXMFGTV1607"
      },
      "outputs": [],
      "source": [
        "# Build a lookup once after loading coh_records\n",
        "coh_by_idx = {rec[\"train_idx\"]: rec for rec in coh_records}\n",
        "\n",
        "def E_for_training_example(train_idx: int, x: str, y: str) -> str:\n",
        "    rec = coh_by_idx[train_idx]\n",
        "    return E_cluster_for_known_x(y, x, rec, embedder)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5_3wH5NkftWo",
      "metadata": {
        "id": "5_3wH5NkftWo"
      },
      "outputs": [],
      "source": [
        "from collections import Counter, defaultdict\n",
        "from typing import List, Dict, Any, Iterable, Tuple, Optional\n",
        "\n",
        "\n",
        "from dataclasses import dataclass\n",
        "\n",
        "@dataclass\n",
        "class WSFTExample:\n",
        "    x: str\n",
        "    y: str\n",
        "    w: float\n",
        "    train_idx: int\n",
        "    sample_idx: int   # 0..K-1 within that question\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def make_wsft_examples_from_coh_records(\n",
        "    coh_records: List[Dict[str, Any]],\n",
        "    alpha: float,\n",
        "    *,\n",
        "    # By default, \"coh\" means use x0 and x1 only (original + 1 paraphrase).\n",
        "    # If you want to include more paraphrases, pass e.g. use_variants=None to use all.\n",
        "    use_variants: Optional[Tuple[int, ...]] = (0, 1),\n",
        "    eps: float = 1e-8,\n",
        "    weight_clip: Optional[float] = None,\n",
        ") -> List[WSFTExample]:\n",
        "    \"\"\"\n",
        "    Build WSFT training tuples (x, y, w) from coherence cluster records.\n",
        "\n",
        "    E(y) is implemented by using the stored cluster id z in rec[\"z_list\"].\n",
        "\n",
        "    v0(z|x): empirical distribution of cluster labels among answers generated for *that* x (variant-specific).\n",
        "    v_coh(z|x): empirical distribution of cluster labels among answers generated for x0 and x1 (by default).\n",
        "\n",
        "    w = (1-alpha) + alpha * v_coh(z|x) / v0(z|x)\n",
        "    \"\"\"\n",
        "    assert 0.0 <= alpha <= 1.0, \"alpha must be in [0, 1].\"\n",
        "\n",
        "    wsft_examples: List[WSFTExample] = []\n",
        "\n",
        "    for rec in coh_records:\n",
        "        train_idx = rec[\"train_idx\"]\n",
        "        questions: List[str] = rec[\"questions\"]          # [x0, x1, ...]\n",
        "        answers: List[str] = rec[\"answers\"]              # flattened\n",
        "        answer_src: List[int] = rec[\"answer_src\"]        # which x_j produced each answer\n",
        "        z_list: List[str] = rec[\"z_list\"]                # E(y): cluster id per answer\n",
        "\n",
        "        if not questions or not answers:\n",
        "            continue\n",
        "\n",
        "        # Decide which variants define \"coherence\" set (default: only x0 and x1).\n",
        "        if use_variants is None:\n",
        "            coh_vars: Iterable[int] = range(len(questions))\n",
        "        else:\n",
        "            coh_vars = tuple(v for v in use_variants if v < len(questions))\n",
        "            if len(coh_vars) == 0:\n",
        "                coh_vars = (0,)  # fallback\n",
        "\n",
        "        coh_var_set = set(coh_vars)\n",
        "\n",
        "        # ---- v_coh over answers from selected variants (default: x0+x1) ----\n",
        "        coh_z = [z for z, src in zip(z_list, answer_src) if src in coh_var_set]\n",
        "        if len(coh_z) == 0:\n",
        "            continue\n",
        "        coh_counts = Counter(coh_z)\n",
        "        coh_total = float(len(coh_z))\n",
        "        v_coh = {z: c / coh_total for z, c in coh_counts.items()}\n",
        "\n",
        "        # ---- v0 per variant x_j (only for variants we include) ----\n",
        "        v0_by_var: Dict[int, Dict[str, float]] = {}\n",
        "        for j in coh_var_set:\n",
        "            z_j = [z for z, src in zip(z_list, answer_src) if src == j]\n",
        "            if len(z_j) == 0:\n",
        "                continue\n",
        "            cnt = Counter(z_j)\n",
        "            tot = float(len(z_j))\n",
        "            v0_by_var[j] = {z: c / tot for z, c in cnt.items()}\n",
        "\n",
        "        # ---- build WSFT examples ----\n",
        "        per_var_k = defaultdict(int)  # sample_idx within that x_j\n",
        "\n",
        "        for y, z, src in zip(answers, z_list, answer_src):\n",
        "            if src not in coh_var_set:\n",
        "                continue\n",
        "            if src not in v0_by_var:\n",
        "                continue  # no stats for this variant\n",
        "\n",
        "            v0_z = max(v0_by_var[src].get(z, 0.0), eps)\n",
        "            vcoh_z = v_coh.get(z, 0.0)\n",
        "\n",
        "            w = (1.0 - alpha) + alpha * (vcoh_z / v0_z)\n",
        "            if weight_clip is not None:\n",
        "                w = min(w, weight_clip)\n",
        "\n",
        "            wsft_examples.append(\n",
        "                WSFTExample(\n",
        "                    x=questions[src],           # x is either x0 or x1 (or more if you allow)\n",
        "                    y=y,\n",
        "                    w=float(w),\n",
        "                    train_idx=train_idx,\n",
        "                    sample_idx=int(per_var_k[src]),\n",
        "                )\n",
        "            )\n",
        "            per_var_k[src] += 1\n",
        "\n",
        "    print(f\"[wsft] Built {len(wsft_examples)} WSFT examples (alpha={alpha}, use_variants={use_variants})\")\n",
        "    return wsft_examples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "_hU9FbRClOjZ",
      "metadata": {
        "id": "_hU9FbRClOjZ"
      },
      "outputs": [],
      "source": [
        "from collections import Counter, defaultdict\n",
        "from typing import List, Dict, Any, Optional, Tuple, Iterable\n",
        "\n",
        "def make_wsft_examples_from_coh_records_majority(\n",
        "    coh_records: List[Dict[str, Any]],\n",
        "    alpha: float,\n",
        "    *,\n",
        "    use_variants: Optional[Tuple[int, ...]] = (0, 1),\n",
        "    eps: float = 1e-8,\n",
        "    weight_clip: Optional[float] = None,\n",
        ") -> List[WSFTExample]:\n",
        "    \"\"\"\n",
        "    Build WSFT training tuples (x, y, w) from coherence cluster records.\n",
        "\n",
        "    E(y) is implemented by using the stored cluster id z in rec[\"z_list\"].\n",
        "\n",
        "    v0(z|x): empirical distribution of cluster labels among answers generated for *that* x (variant-specific).\n",
        "    v_coh(z|x): MAJORITY VOTE over cluster labels among answers from the selected variants\n",
        "               (default: variants 0 and 1). Implemented as a degenerate distribution:\n",
        "                 v_coh(z|x) = 1 if z == z*, else 0,\n",
        "               where z* is the most frequent cluster among those answers.\n",
        "\n",
        "    w = (1-alpha) + alpha * v_coh(z|x) / v0(z|x)\n",
        "      = (1-alpha) + alpha * (1 / v0(z|x)) if z==z*, else (1-alpha)\n",
        "    \"\"\"\n",
        "    assert 0.0 <= alpha <= 1.0, \"alpha must be in [0, 1].\"\n",
        "\n",
        "    def _z_numeric(z: str) -> int:\n",
        "        # deterministic tie-break (prefer smaller cluster id)\n",
        "        try:\n",
        "            if isinstance(z, str) and z.startswith(\"c\"):\n",
        "                return int(z[1:])\n",
        "            return int(z)\n",
        "        except Exception:\n",
        "            return 0\n",
        "\n",
        "    wsft_examples: List[WSFTExample] = []\n",
        "\n",
        "    for rec in coh_records:\n",
        "        train_idx = rec[\"train_idx\"]\n",
        "        questions: List[str] = rec[\"questions\"]\n",
        "        answers: List[str] = rec[\"answers\"]\n",
        "        answer_src: List[int] = rec[\"answer_src\"]\n",
        "        z_list: List[str] = rec[\"z_list\"]\n",
        "\n",
        "        if not questions or not answers:\n",
        "            continue\n",
        "\n",
        "        # Decide which variants define the \"coherence\" majority vote set.\n",
        "        if use_variants is None:\n",
        "            coh_vars: Iterable[int] = range(len(questions))\n",
        "        else:\n",
        "            coh_vars = tuple(v for v in use_variants if v < len(questions))\n",
        "            if len(coh_vars) == 0:\n",
        "                coh_vars = (0,)\n",
        "\n",
        "        coh_var_set = set(coh_vars)\n",
        "\n",
        "        # ---- majority vote z* over answers from selected variants ----\n",
        "        coh_z = [z for z, src in zip(z_list, answer_src) if src in coh_var_set]\n",
        "        if len(coh_z) == 0:\n",
        "            continue\n",
        "\n",
        "        coh_counts = Counter(coh_z)\n",
        "        # pick max count; tie-break by smaller numeric cluster id for determinism\n",
        "        z_star = max(coh_counts.items(), key=lambda kv: (kv[1], -_z_numeric(kv[0])))[0]\n",
        "\n",
        "        # ---- v0 per variant x_j ----\n",
        "        v0_by_var: Dict[int, Dict[str, float]] = {}\n",
        "        for j in coh_var_set:\n",
        "            z_j = [z for z, src in zip(z_list, answer_src) if src == j]\n",
        "            if len(z_j) == 0:\n",
        "                continue\n",
        "            cnt = Counter(z_j)\n",
        "            tot = float(len(z_j))\n",
        "            v0_by_var[j] = {z: c / tot for z, c in cnt.items()}\n",
        "\n",
        "        # ---- build WSFT examples ----\n",
        "        per_var_k = defaultdict(int)\n",
        "\n",
        "        for y, z, src in zip(answers, z_list, answer_src):\n",
        "            if src not in coh_var_set:\n",
        "                continue\n",
        "            if src not in v0_by_var:\n",
        "                continue\n",
        "\n",
        "            v0_z = max(v0_by_var[src].get(z, 0.0), eps)\n",
        "            vcoh_z = 1.0 if z == z_star else 0.0  # majority vote distribution\n",
        "\n",
        "            w = (1.0 - alpha) + alpha * (vcoh_z / v0_z)\n",
        "            if weight_clip is not None:\n",
        "                w = min(w, weight_clip)\n",
        "\n",
        "            wsft_examples.append(\n",
        "                WSFTExample(\n",
        "                    x=questions[src],\n",
        "                    y=y,\n",
        "                    w=float(w),\n",
        "                    train_idx=train_idx,\n",
        "                    sample_idx=int(per_var_k[src]),\n",
        "                )\n",
        "            )\n",
        "            per_var_k[src] += 1\n",
        "\n",
        "    print(f\"[wsft] Built {len(wsft_examples)} WSFT examples (alpha={alpha}, use_variants={use_variants})\")\n",
        "    return wsft_examples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4P4XzkJNf6oG",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4P4XzkJNf6oG",
        "outputId": "59e001c4-6903-49ff-80dd-9c05dc3c958b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[wsft] Built 204800 WSFT examples (alpha=0.5, use_variants=(0, 1))\n"
          ]
        }
      ],
      "source": [
        "alpha = 0.5\n",
        "wsft_examples = make_wsft_examples_from_coh_records_majority(coh_records, alpha=alpha)  # uses x0+x1 only\n",
        "# or include all paraphrases in v_coh:\n",
        "# wsft_examples_all = make_wsft_examples_from_coh_records(coh_records, alpha=alpha, use_variants=None)\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Q3RRjK20mNjF",
        "u7T3ljRumNjG",
        "Ul00h3tomNjH",
        "ChcYKgQMmNjI",
        "SD17xtwHtI6d"
      ],
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "32b63310775b4f598fd9aae36d7177d8": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3345258ab7fd4dc085c81ca3eacb35dd": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "346936dae8d0460b88f574168fea2209": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "39b9ffa02eb84f85a76156cbc1c8a88b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_346936dae8d0460b88f574168fea2209",
            "placeholder": "​",
            "style": "IPY_MODEL_a7531cf6b513470da02333b1420e5c9c",
            "value": "Paraphrase agreement qwen3b: 100%"
          }
        },
        "3a3884b6b2c045c2ae33bba4c0c462b4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3345258ab7fd4dc085c81ca3eacb35dd",
            "placeholder": "​",
            "style": "IPY_MODEL_441473c91daf4aeea3a91dc460512b2a",
            "value": " 7473/7473 [2:12:43&lt;00:00,  1.12s/it, classes=7473, AGREE%=67.14, greedy=1]"
          }
        },
        "441473c91daf4aeea3a91dc460512b2a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "51096091a9594b3cbba9e1f06c439f65": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "55a6d733b99c453eae69dee18d012599": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8221da52e41e4c5e8bc1f58179bc64a5",
            "placeholder": "​",
            "style": "IPY_MODEL_c699c242e1ba4895850f422391292ed2",
            "value": "Evaluating llama greedy baseline:   0%"
          }
        },
        "57c61c01579a49659201a22d55ed73fa": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "59f2130e9c564172b17f10efa43f72d6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7f8b51a3fe4343f39fa6ddc435410a3a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_da8588237dec4ae9877dc1ebb1236e48",
            "max": 7473,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_59f2130e9c564172b17f10efa43f72d6",
            "value": 7473
          }
        },
        "8136b4093a9b4d4eaf23bfb192fc9c90": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "8221da52e41e4c5e8bc1f58179bc64a5": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "882ce08f63c24d5b9c537defb111268e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_32b63310775b4f598fd9aae36d7177d8",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_e8b12a9c33ac4411b00e002e4a1c37ad",
            "value": 0
          }
        },
        "8a3672d2b9bb41af88d8cfc7a7333b7a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a7531cf6b513470da02333b1420e5c9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b6c0617396d44837ac2bc00893b812d1": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_39b9ffa02eb84f85a76156cbc1c8a88b",
              "IPY_MODEL_7f8b51a3fe4343f39fa6ddc435410a3a",
              "IPY_MODEL_3a3884b6b2c045c2ae33bba4c0c462b4"
            ],
            "layout": "IPY_MODEL_51096091a9594b3cbba9e1f06c439f65"
          }
        },
        "c699c242e1ba4895850f422391292ed2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "da8588237dec4ae9877dc1ebb1236e48": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "dcdbcdeb464c4196857ff0f15736d8a3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_55a6d733b99c453eae69dee18d012599",
              "IPY_MODEL_882ce08f63c24d5b9c537defb111268e",
              "IPY_MODEL_ee29db4bfc0d42b2967ca9e672b9d3b7"
            ],
            "layout": "IPY_MODEL_8136b4093a9b4d4eaf23bfb192fc9c90"
          }
        },
        "e8b12a9c33ac4411b00e002e4a1c37ad": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ee29db4bfc0d42b2967ca9e672b9d3b7": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_57c61c01579a49659201a22d55ed73fa",
            "placeholder": "​",
            "style": "IPY_MODEL_8a3672d2b9bb41af88d8cfc7a7333b7a",
            "value": " 0/21 [00:00&lt;?, ?it/s]"
          }
        },
        "64c565858ba245259d61e8cc7db85336": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b1345bf5ab304f66b1decc6b75723c49",
              "IPY_MODEL_8aaa4aba009447b6bf7aacc0217320e7",
              "IPY_MODEL_38576a9201d5449887c56eefc9e40782"
            ],
            "layout": "IPY_MODEL_38ed94ffb06c4471ad562918944cc6d9"
          }
        },
        "b1345bf5ab304f66b1decc6b75723c49": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0805508fb0924ad5b22caaf40cde4a27",
            "placeholder": "​",
            "style": "IPY_MODEL_829bd5bc714c49b1bb4727091aafef98",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "8aaa4aba009447b6bf7aacc0217320e7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e081f3fa40b140928d72e293b5f183fc",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_47cec929d24f43ba9964903f6dc60c71",
            "value": 21
          }
        },
        "38576a9201d5449887c56eefc9e40782": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_da9e73ebd00a4159a7df6e1269b6e1ea",
            "placeholder": "​",
            "style": "IPY_MODEL_2609916bcc1c49aa8d18dbe31d3ff3f5",
            "value": " 21/21 [22:15&lt;00:00, 64.08s/it, N=1319, ACC%=69.14, greedy=1]"
          }
        },
        "38ed94ffb06c4471ad562918944cc6d9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "0805508fb0924ad5b22caaf40cde4a27": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "829bd5bc714c49b1bb4727091aafef98": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e081f3fa40b140928d72e293b5f183fc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "47cec929d24f43ba9964903f6dc60c71": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "da9e73ebd00a4159a7df6e1269b6e1ea": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2609916bcc1c49aa8d18dbe31d3ff3f5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1e3aa372aec74831943f083c02a7a14d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7bee36f381ba410fbccc2de5e9528db5",
              "IPY_MODEL_8e51a0ab5fe845469102dc0b9a10c4fa",
              "IPY_MODEL_23e8803e888a4b549b7d075c7001aa39"
            ],
            "layout": "IPY_MODEL_74d92c73431d4e8884319c1968105011"
          }
        },
        "7bee36f381ba410fbccc2de5e9528db5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c7b261a1890f492cab7542dcfdd9b59d",
            "placeholder": "​",
            "style": "IPY_MODEL_8e4c530251a4417c9efd6704391e8ac4",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "8e51a0ab5fe845469102dc0b9a10c4fa": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_40a4643e69124288b23392eac5c14146",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_08a4431774f24c75b8d444644ab51bbc",
            "value": 21
          }
        },
        "23e8803e888a4b549b7d075c7001aa39": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7f0c70ff207a45af8d5e751d2ba6bf9c",
            "placeholder": "​",
            "style": "IPY_MODEL_9175b0f83eee429d8c2b642de2a70e5e",
            "value": " 21/21 [22:02&lt;00:00, 70.87s/it, N=1319, ACC%=70.89, greedy=1]"
          }
        },
        "74d92c73431d4e8884319c1968105011": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "c7b261a1890f492cab7542dcfdd9b59d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8e4c530251a4417c9efd6704391e8ac4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "40a4643e69124288b23392eac5c14146": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "08a4431774f24c75b8d444644ab51bbc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7f0c70ff207a45af8d5e751d2ba6bf9c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9175b0f83eee429d8c2b642de2a70e5e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f5aa689e09ed440fb9f56c179bb265c9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_34e172988b824c1582cfd819e046c826",
              "IPY_MODEL_e50ee262aedb465887063371888af73d",
              "IPY_MODEL_83ecf75f478d4163b4137bf522cfc795"
            ],
            "layout": "IPY_MODEL_54d2f6ff9c854e628af3f22e47cc9e44"
          }
        },
        "34e172988b824c1582cfd819e046c826": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e17fd94fc6b7497fba6034b12ee10743",
            "placeholder": "​",
            "style": "IPY_MODEL_7e2160fe500846438c02356d7c61305f",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "e50ee262aedb465887063371888af73d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_cd1f0ef537164384a1c203978fb919cb",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2903ef393db64f0f94723d15854423a3",
            "value": 21
          }
        },
        "83ecf75f478d4163b4137bf522cfc795": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_712bf82474a24741b286830d8c5f09ff",
            "placeholder": "​",
            "style": "IPY_MODEL_7583ecda0d00433e919347428773928b",
            "value": " 21/21 [21:09&lt;00:00, 50.44s/it, N=1319, ACC%=71.04, greedy=1]"
          }
        },
        "54d2f6ff9c854e628af3f22e47cc9e44": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "e17fd94fc6b7497fba6034b12ee10743": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7e2160fe500846438c02356d7c61305f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "cd1f0ef537164384a1c203978fb919cb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2903ef393db64f0f94723d15854423a3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "712bf82474a24741b286830d8c5f09ff": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7583ecda0d00433e919347428773928b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "18ce7a5dbdb44472b3b84d81a13bdf7e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_4ca8064308b143e28cd73bd304b9ed83",
              "IPY_MODEL_4575c0e40c0a4d9fbb9f27ce93520d72",
              "IPY_MODEL_5b058c0f308e4f67b638c431716bd2ed"
            ],
            "layout": "IPY_MODEL_5be74c54e1cd45ef8032e1ad2cfcadf4"
          }
        },
        "4ca8064308b143e28cd73bd304b9ed83": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c700dee1151c4e419670fcb687230099",
            "placeholder": "​",
            "style": "IPY_MODEL_254dfa7e95d1438eb3d638db4a531553",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "4575c0e40c0a4d9fbb9f27ce93520d72": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c4bd00cd657d45c1b0550d0a882e4380",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_67e18328058e446a9d6993a65ee9c419",
            "value": 21
          }
        },
        "5b058c0f308e4f67b638c431716bd2ed": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e0ea0f451a8a41b1a021b608977acbf1",
            "placeholder": "​",
            "style": "IPY_MODEL_a4f8662c51c745f68b1ff72aa76f620e",
            "value": " 21/21 [21:41&lt;00:00, 52.67s/it, N=1319, ACC%=71.04, greedy=1]"
          }
        },
        "5be74c54e1cd45ef8032e1ad2cfcadf4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "c700dee1151c4e419670fcb687230099": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "254dfa7e95d1438eb3d638db4a531553": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c4bd00cd657d45c1b0550d0a882e4380": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "67e18328058e446a9d6993a65ee9c419": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e0ea0f451a8a41b1a021b608977acbf1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a4f8662c51c745f68b1ff72aa76f620e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9acc7381d5b1426797cb66a9112a6911": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_e90aaeabd9dd403c9337936ee63fdb25",
              "IPY_MODEL_c484432bc7eb4735b6e573b8f3edc87d",
              "IPY_MODEL_6c08e1c01a5b41c9aac3a1dcf746561c"
            ],
            "layout": "IPY_MODEL_87a429738ad54504a5e9d85f3b62d274"
          }
        },
        "e90aaeabd9dd403c9337936ee63fdb25": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a68dd47670994fc8a73ba87df3002df6",
            "placeholder": "​",
            "style": "IPY_MODEL_43aefa5dffd4481ca6452f9a191c7104",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "c484432bc7eb4735b6e573b8f3edc87d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_991dd007c09a422191c56fde75f43e55",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_6dc3745efacd4785ab7575f5a26bab4c",
            "value": 21
          }
        },
        "6c08e1c01a5b41c9aac3a1dcf746561c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ad350ba861c5477cb7d62b889e5b1aaa",
            "placeholder": "​",
            "style": "IPY_MODEL_03deaf4188004bc78a45b4841b6be913",
            "value": " 21/21 [22:19&lt;00:00, 62.33s/it, N=1319, ACC%=71.72, greedy=1]"
          }
        },
        "87a429738ad54504a5e9d85f3b62d274": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "a68dd47670994fc8a73ba87df3002df6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "43aefa5dffd4481ca6452f9a191c7104": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "991dd007c09a422191c56fde75f43e55": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6dc3745efacd4785ab7575f5a26bab4c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ad350ba861c5477cb7d62b889e5b1aaa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "03deaf4188004bc78a45b4841b6be913": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ca20be70e6b34871901f6bedb7686672": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_6efe992a340b41a493f3dc7d14b01dc0",
              "IPY_MODEL_ee0e3b45bf3248dfa9f73971420ec62b",
              "IPY_MODEL_6ac4140030014c768e8472e1a9a6f797"
            ],
            "layout": "IPY_MODEL_ec0fb1046ff340ffa9a050e44db46d80"
          }
        },
        "6efe992a340b41a493f3dc7d14b01dc0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e201cd53f07f4b0199e174765eec3188",
            "placeholder": "​",
            "style": "IPY_MODEL_01463617e99b4ccd8c701adb2f28e4a4",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "ee0e3b45bf3248dfa9f73971420ec62b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_60e15841baab47e8a13245d0e1987861",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f8e1ff19275843bfb873010a36b843b1",
            "value": 21
          }
        },
        "6ac4140030014c768e8472e1a9a6f797": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2d754b238b1343c4b42da6949a5dbff8",
            "placeholder": "​",
            "style": "IPY_MODEL_bb133ad736554433b86d272da51db4bc",
            "value": " 21/21 [23:29&lt;00:00, 73.38s/it, N=1319, ACC%=72.02, greedy=1]"
          }
        },
        "ec0fb1046ff340ffa9a050e44db46d80": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "e201cd53f07f4b0199e174765eec3188": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "01463617e99b4ccd8c701adb2f28e4a4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "60e15841baab47e8a13245d0e1987861": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f8e1ff19275843bfb873010a36b843b1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "2d754b238b1343c4b42da6949a5dbff8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bb133ad736554433b86d272da51db4bc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c25fd1b2a7614cd8a257fa8b1ad2dfbc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_af578cd9cf5f44d9b8d921978d087687",
              "IPY_MODEL_d18b9c5932384d059eeab9570c967a8e",
              "IPY_MODEL_ae80b229d8cc47508613bac414c7b156"
            ],
            "layout": "IPY_MODEL_e7eb013ea79f49f8990627838cecc7c0"
          }
        },
        "af578cd9cf5f44d9b8d921978d087687": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_707c688d9b3d4ce3b143be6e0cc6186f",
            "placeholder": "​",
            "style": "IPY_MODEL_95610e97f00747169d18a2d6e10d2c98",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "d18b9c5932384d059eeab9570c967a8e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0939ae8111eb479ba46b5d1d71dd2330",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_e5582fe402584486885e6d1c2e64b4b9",
            "value": 21
          }
        },
        "ae80b229d8cc47508613bac414c7b156": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e4be724d889f48be9b9057d0631d6b62",
            "placeholder": "​",
            "style": "IPY_MODEL_aef031448deb4717ac103c393099468e",
            "value": " 21/21 [22:35&lt;00:00, 54.32s/it, N=1319, ACC%=72.48, greedy=1]"
          }
        },
        "e7eb013ea79f49f8990627838cecc7c0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "707c688d9b3d4ce3b143be6e0cc6186f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "95610e97f00747169d18a2d6e10d2c98": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0939ae8111eb479ba46b5d1d71dd2330": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e5582fe402584486885e6d1c2e64b4b9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e4be724d889f48be9b9057d0631d6b62": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "aef031448deb4717ac103c393099468e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "206d7699fc774334b83ad60b811b0b4e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_62dce4f6682b47f296485b8b3a9f1545",
              "IPY_MODEL_2ad1efe489834087a1006d23af12d145",
              "IPY_MODEL_5b7f08fd54d54974952cba793f4d8fba"
            ],
            "layout": "IPY_MODEL_37f3d66df88d4ad9bc18290123d1bbe9"
          }
        },
        "62dce4f6682b47f296485b8b3a9f1545": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_024c5af233af4a86923d8042cec8d2f5",
            "placeholder": "​",
            "style": "IPY_MODEL_761264599bdf466ea7389e3a4610925a",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "2ad1efe489834087a1006d23af12d145": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d04a5fca29214b56b65160119dc50bd1",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_1715c5e4363941849dbb18acfe3ec42b",
            "value": 21
          }
        },
        "5b7f08fd54d54974952cba793f4d8fba": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_eff5a0f4044e496495f304c15e518b75",
            "placeholder": "​",
            "style": "IPY_MODEL_ddd0c46883884014b665bee2518ca526",
            "value": " 21/21 [22:38&lt;00:00, 62.87s/it, N=1319, ACC%=72.93, greedy=1]"
          }
        },
        "37f3d66df88d4ad9bc18290123d1bbe9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "024c5af233af4a86923d8042cec8d2f5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "761264599bdf466ea7389e3a4610925a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d04a5fca29214b56b65160119dc50bd1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1715c5e4363941849dbb18acfe3ec42b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "eff5a0f4044e496495f304c15e518b75": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ddd0c46883884014b665bee2518ca526": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d5a909a3d64a4a3ba4e48a8ce95a1464": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9fb1bc2a522e4638abd7bb91f216d73e",
              "IPY_MODEL_f0f6e68424fe485caf3b9920c44d7db2",
              "IPY_MODEL_19f850405ba1467881d8480c5f1939bc"
            ],
            "layout": "IPY_MODEL_231e3eeca3504f91a75816d52b32856d"
          }
        },
        "9fb1bc2a522e4638abd7bb91f216d73e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_48a543e859f541a79e83d7261d903d25",
            "placeholder": "​",
            "style": "IPY_MODEL_3f5f6fce1b2e4230b4bc8fac6119836e",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "f0f6e68424fe485caf3b9920c44d7db2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7904d96d5c594e8487aa5f4f6a337a18",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_ecfb15a27d594bedb493c33ab4b0b169",
            "value": 21
          }
        },
        "19f850405ba1467881d8480c5f1939bc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1dfb7425bb8442a7b5003c2bb65a2574",
            "placeholder": "​",
            "style": "IPY_MODEL_75360684e1394731aed53f617034c15f",
            "value": " 21/21 [24:14&lt;00:00, 60.94s/it, N=1319, ACC%=72.63, greedy=1]"
          }
        },
        "231e3eeca3504f91a75816d52b32856d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "48a543e859f541a79e83d7261d903d25": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3f5f6fce1b2e4230b4bc8fac6119836e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7904d96d5c594e8487aa5f4f6a337a18": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ecfb15a27d594bedb493c33ab4b0b169": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "1dfb7425bb8442a7b5003c2bb65a2574": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "75360684e1394731aed53f617034c15f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7a9b23a7f7874922be0e31e5055a00b9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7ebf37faa9af4c4890cf16a7247d0e8f",
              "IPY_MODEL_c39b6023a96f49789b8eb81331fee15e",
              "IPY_MODEL_cc3d7bbfc91b4b84b853508b24fd0aaf"
            ],
            "layout": "IPY_MODEL_00e1ce49ec70470f88b97d209b704656"
          }
        },
        "7ebf37faa9af4c4890cf16a7247d0e8f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8b56adcad9ea446d96912ca6210b0d1b",
            "placeholder": "​",
            "style": "IPY_MODEL_2d1fff1b0d8b4bb29a2effda2ee30723",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "c39b6023a96f49789b8eb81331fee15e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_300545c787334b50b5cb1cc0fe14e92f",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5d74b673faf1401cbf517013ccfd4e17",
            "value": 21
          }
        },
        "cc3d7bbfc91b4b84b853508b24fd0aaf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c9687ec0add942f095fb63d00547513c",
            "placeholder": "​",
            "style": "IPY_MODEL_965e99c1182e4bb3810e32603d688d31",
            "value": " 21/21 [22:19&lt;00:00, 64.03s/it, N=1319, ACC%=72.71, greedy=1]"
          }
        },
        "00e1ce49ec70470f88b97d209b704656": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "8b56adcad9ea446d96912ca6210b0d1b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2d1fff1b0d8b4bb29a2effda2ee30723": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "300545c787334b50b5cb1cc0fe14e92f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5d74b673faf1401cbf517013ccfd4e17": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c9687ec0add942f095fb63d00547513c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "965e99c1182e4bb3810e32603d688d31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "5c2a8e737abb41cd861753909ecc4f1b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_2a7970cb9ff44c3b8f9e370223327d49",
              "IPY_MODEL_feca0d3b91d241c0a81824c57a28d6df",
              "IPY_MODEL_ff754a2a323741b585a73b55b63cd219"
            ],
            "layout": "IPY_MODEL_9e1fae8fcf3545fcb8dfcf9fd14d5890"
          }
        },
        "2a7970cb9ff44c3b8f9e370223327d49": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e2f37aea6afd4ede81a240588f1d88dd",
            "placeholder": "​",
            "style": "IPY_MODEL_4068211b5c1b48f59253923e3237ba25",
            "value": "tokenizer_config.json: 100%"
          }
        },
        "feca0d3b91d241c0a81824c57a28d6df": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b118a18d02b24044853c086278044d88",
            "max": 54528,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_4410990ea8f346c39703ff02ca7996af",
            "value": 54528
          }
        },
        "ff754a2a323741b585a73b55b63cd219": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_65597fbb2ccc43d282a1f8366e94f72f",
            "placeholder": "​",
            "style": "IPY_MODEL_d02a4ead5c7b434f913ec8e70550187d",
            "value": " 54.5k/54.5k [00:00&lt;00:00, 3.80MB/s]"
          }
        },
        "9e1fae8fcf3545fcb8dfcf9fd14d5890": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e2f37aea6afd4ede81a240588f1d88dd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4068211b5c1b48f59253923e3237ba25": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b118a18d02b24044853c086278044d88": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4410990ea8f346c39703ff02ca7996af": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "65597fbb2ccc43d282a1f8366e94f72f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d02a4ead5c7b434f913ec8e70550187d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "66263c8eadfd474c8a67a11b30055ddb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_22eed0d1eef643c4b9f5d5f02a3a5bd9",
              "IPY_MODEL_d14709117ec147ea9547b72fd3bc3dbb",
              "IPY_MODEL_feacc7ad63334712a7f7e08b8d45624c"
            ],
            "layout": "IPY_MODEL_9a38e5b3eb314071bed4f955632605ec"
          }
        },
        "22eed0d1eef643c4b9f5d5f02a3a5bd9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e879a2550d514432bccd7f8f76169a56",
            "placeholder": "​",
            "style": "IPY_MODEL_b226dc2cdc844e37852fb2c98950812e",
            "value": "tokenizer.json: 100%"
          }
        },
        "d14709117ec147ea9547b72fd3bc3dbb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f97c79e9765847a49b5b5c987ed7e835",
            "max": 9085657,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_848ec43f0e5148bbb452f89d29870b30",
            "value": 9085657
          }
        },
        "feacc7ad63334712a7f7e08b8d45624c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_af63a7c7c8dc4a68a6436e9b02c7b428",
            "placeholder": "​",
            "style": "IPY_MODEL_5e2d2cc4bcc6407ea39a4fb70d10155c",
            "value": " 9.09M/9.09M [00:00&lt;00:00, 14.3MB/s]"
          }
        },
        "9a38e5b3eb314071bed4f955632605ec": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e879a2550d514432bccd7f8f76169a56": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b226dc2cdc844e37852fb2c98950812e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f97c79e9765847a49b5b5c987ed7e835": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "848ec43f0e5148bbb452f89d29870b30": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "af63a7c7c8dc4a68a6436e9b02c7b428": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5e2d2cc4bcc6407ea39a4fb70d10155c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2f32d607fbcf4eb39ea3962a8d9bc520": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_92f7ecd8d8fa4f429a78e1499cc7f400",
              "IPY_MODEL_ac2c082b53ac4e1fbb5d78b9ef80b81b",
              "IPY_MODEL_52a3413726ce4ae79ec9bad9b972f01d"
            ],
            "layout": "IPY_MODEL_2d01798b5c3d49a5b056f2ef6bfb6c25"
          }
        },
        "92f7ecd8d8fa4f429a78e1499cc7f400": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9ca9bdc58f3a42e395905139385b9e9b",
            "placeholder": "​",
            "style": "IPY_MODEL_7134e35138e64efea8abff67100ae40e",
            "value": "special_tokens_map.json: 100%"
          }
        },
        "ac2c082b53ac4e1fbb5d78b9ef80b81b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8ffee91d96b34a13bf448178a109ce79",
            "max": 296,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_ea7672d0767f480696aa09f075d4772d",
            "value": 296
          }
        },
        "52a3413726ce4ae79ec9bad9b972f01d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_164133cae7374a258b3d217b115bc379",
            "placeholder": "​",
            "style": "IPY_MODEL_6b0436f2c9ac4e78b8b456bc4c73afae",
            "value": " 296/296 [00:00&lt;00:00, 20.6kB/s]"
          }
        },
        "2d01798b5c3d49a5b056f2ef6bfb6c25": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9ca9bdc58f3a42e395905139385b9e9b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7134e35138e64efea8abff67100ae40e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8ffee91d96b34a13bf448178a109ce79": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ea7672d0767f480696aa09f075d4772d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "164133cae7374a258b3d217b115bc379": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6b0436f2c9ac4e78b8b456bc4c73afae": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "345807be67e54d2ebd9bdadebf8992b2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a79a2eb11be6486a9e42993f1e8ba2d4",
              "IPY_MODEL_279630ffc6a544e196f8d0fd4e5ead06",
              "IPY_MODEL_b1aed9b47c5842b6afa5d91461728259"
            ],
            "layout": "IPY_MODEL_9f014d810d5e4be892ed0fff1cd9183c"
          }
        },
        "a79a2eb11be6486a9e42993f1e8ba2d4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d58e83c0ede843a6bfca39fb18753489",
            "placeholder": "​",
            "style": "IPY_MODEL_e01eb59e7a9f4c059cd08ba8d2c8957a",
            "value": "config.json: 100%"
          }
        },
        "279630ffc6a544e196f8d0fd4e5ead06": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c3e13b43249e49b9b270a1f9aecb15d8",
            "max": 878,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f94d24a0eee24f8bb99afd296a6fb884",
            "value": 878
          }
        },
        "b1aed9b47c5842b6afa5d91461728259": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_248768b4966d45318f0feaf8da35411b",
            "placeholder": "​",
            "style": "IPY_MODEL_95d87cd4cde34ba4a449a98eb27cf169",
            "value": " 878/878 [00:00&lt;00:00, 70.4kB/s]"
          }
        },
        "9f014d810d5e4be892ed0fff1cd9183c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d58e83c0ede843a6bfca39fb18753489": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e01eb59e7a9f4c059cd08ba8d2c8957a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c3e13b43249e49b9b270a1f9aecb15d8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f94d24a0eee24f8bb99afd296a6fb884": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "248768b4966d45318f0feaf8da35411b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "95d87cd4cde34ba4a449a98eb27cf169": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "253871439f644e97815714dcf14a55ae": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d4a34bbaf6264706a6a79953649fd31e",
              "IPY_MODEL_e48f6ef7549d4cc9bc3ea6acff8ba375",
              "IPY_MODEL_bca11efb3a754b39b055445f470ae89f"
            ],
            "layout": "IPY_MODEL_686d32adbffc4600b4ea5e9ed755f620"
          }
        },
        "d4a34bbaf6264706a6a79953649fd31e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d2dfb8df5c27478eaa1d2c674a62faa5",
            "placeholder": "​",
            "style": "IPY_MODEL_1d5500bb6f31454293e8d922a1ca1101",
            "value": "model.safetensors.index.json: 100%"
          }
        },
        "e48f6ef7549d4cc9bc3ea6acff8ba375": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d18c1c5e9c884514bb794373d78caa2e",
            "max": 20919,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b828683ae02c44aaab389538a41e67ea",
            "value": 20919
          }
        },
        "bca11efb3a754b39b055445f470ae89f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3cb48f1eabd543afb0cefd44b004c348",
            "placeholder": "​",
            "style": "IPY_MODEL_5419d0335507421a9a7a8ae28697a679",
            "value": " 20.9k/20.9k [00:00&lt;00:00, 2.33MB/s]"
          }
        },
        "686d32adbffc4600b4ea5e9ed755f620": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d2dfb8df5c27478eaa1d2c674a62faa5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1d5500bb6f31454293e8d922a1ca1101": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d18c1c5e9c884514bb794373d78caa2e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b828683ae02c44aaab389538a41e67ea": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "3cb48f1eabd543afb0cefd44b004c348": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5419d0335507421a9a7a8ae28697a679": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6885f9a52efd4f9b9e65d43b062f3103": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_aceb980633344130a3086cbd0b45fa9b",
              "IPY_MODEL_ae4d8432117d4928bdc3f16392749120",
              "IPY_MODEL_4ef3316dded5415589b49e2f2cf5988a"
            ],
            "layout": "IPY_MODEL_5649585455c74d7a88664b6c6dfbe575"
          }
        },
        "aceb980633344130a3086cbd0b45fa9b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_04f3bcff5fd84c988a1c67154e3d3460",
            "placeholder": "​",
            "style": "IPY_MODEL_53f36048a2764289a49ea26692000911",
            "value": "Fetching 2 files: 100%"
          }
        },
        "ae4d8432117d4928bdc3f16392749120": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_637ca7d0597942a0ab719e2920334174",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_8c19165f390e44f284e696ad9c5eecca",
            "value": 2
          }
        },
        "4ef3316dded5415589b49e2f2cf5988a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_eb5c01168ef64d85873e5726e45be03b",
            "placeholder": "​",
            "style": "IPY_MODEL_34a4977199cb4d05ab2fbc8c3076100d",
            "value": " 2/2 [00:13&lt;00:00, 13.63s/it]"
          }
        },
        "5649585455c74d7a88664b6c6dfbe575": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "04f3bcff5fd84c988a1c67154e3d3460": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "53f36048a2764289a49ea26692000911": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "637ca7d0597942a0ab719e2920334174": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8c19165f390e44f284e696ad9c5eecca": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "eb5c01168ef64d85873e5726e45be03b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "34a4977199cb4d05ab2fbc8c3076100d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d1fd251b122f4b7eb9d0887adb94d8c4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7a1088b57ed4412283e917b2312f3799",
              "IPY_MODEL_f10ef23bdd034cbea4c5151f16977351",
              "IPY_MODEL_2405965d92fd41dc96360500b2aa72dd"
            ],
            "layout": "IPY_MODEL_20e493de23184f508bfed4832b3ae345"
          }
        },
        "7a1088b57ed4412283e917b2312f3799": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4e98e8d97cf34aaa9142e96c70bf3bfa",
            "placeholder": "​",
            "style": "IPY_MODEL_87a4daada34f4cc493b83fa9cba200f0",
            "value": "model-00001-of-00002.safetensors: 100%"
          }
        },
        "f10ef23bdd034cbea4c5151f16977351": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9a0b62601c9b4658a68d67f365fc21b5",
            "max": 4965799096,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_6bddf26b59eb4a948457b1cc7d5a4044",
            "value": 4965799096
          }
        },
        "2405965d92fd41dc96360500b2aa72dd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ae6c402cf7c1407bad8c8d0e31572bf5",
            "placeholder": "​",
            "style": "IPY_MODEL_bb359cf0ce3a4981a53cff4535022f69",
            "value": " 4.97G/4.97G [00:13&lt;00:00, 1.19GB/s]"
          }
        },
        "20e493de23184f508bfed4832b3ae345": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4e98e8d97cf34aaa9142e96c70bf3bfa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "87a4daada34f4cc493b83fa9cba200f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9a0b62601c9b4658a68d67f365fc21b5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6bddf26b59eb4a948457b1cc7d5a4044": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ae6c402cf7c1407bad8c8d0e31572bf5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bb359cf0ce3a4981a53cff4535022f69": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1d2ca94bc22e499cba0ef83b22541e6d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_e691a50ec860483f964bd98319327d1b",
              "IPY_MODEL_5abe002f92ed43b38fccb3c02eeb1d3b",
              "IPY_MODEL_b1b744c870a54f26ba7161beda026379"
            ],
            "layout": "IPY_MODEL_4e0a7ad1a6454b57b1aaffa91bd7fe50"
          }
        },
        "e691a50ec860483f964bd98319327d1b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8b9c5f0a2b1d4a539de9c7bafac96146",
            "placeholder": "​",
            "style": "IPY_MODEL_f79777359f684bddbdf5172dd200d737",
            "value": "model-00002-of-00002.safetensors: 100%"
          }
        },
        "5abe002f92ed43b38fccb3c02eeb1d3b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6dbf455b1bb441b1bf4eede9bec60111",
            "max": 1459729952,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_cbb8f7741cec45f9aad53b04913313fa",
            "value": 1459729952
          }
        },
        "b1b744c870a54f26ba7161beda026379": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5c159ca87657449aa0a49d5fcd056b2c",
            "placeholder": "​",
            "style": "IPY_MODEL_b5f4573a6b314e879b2e315f64a543dc",
            "value": " 1.46G/1.46G [00:05&lt;00:00, 414MB/s]"
          }
        },
        "4e0a7ad1a6454b57b1aaffa91bd7fe50": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8b9c5f0a2b1d4a539de9c7bafac96146": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f79777359f684bddbdf5172dd200d737": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6dbf455b1bb441b1bf4eede9bec60111": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cbb8f7741cec45f9aad53b04913313fa": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5c159ca87657449aa0a49d5fcd056b2c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b5f4573a6b314e879b2e315f64a543dc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7750c11e1e194ff0890a57e56e27e263": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_5a8b0227548d43bcaec232d96adf734e",
              "IPY_MODEL_1e1871367f7d4094af05db9d37f04607",
              "IPY_MODEL_195a11df03dd4d70ba54464f8976b073"
            ],
            "layout": "IPY_MODEL_dc410ad99a5a44cdb6b9762bd713afff"
          }
        },
        "5a8b0227548d43bcaec232d96adf734e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d3e708079ef64dc6b7b3f466f834b7be",
            "placeholder": "​",
            "style": "IPY_MODEL_786cdc0611a34b8c854f68b8b2e09e9c",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "1e1871367f7d4094af05db9d37f04607": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_be406211e41b44cf83e862e30e81ab75",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_1c7ce4adb8da4386b26a63396da0a371",
            "value": 2
          }
        },
        "195a11df03dd4d70ba54464f8976b073": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0aee1ef79f444de2b9f339171e632de0",
            "placeholder": "​",
            "style": "IPY_MODEL_0039028878bb4068950bc57dbefbdf58",
            "value": " 2/2 [00:05&lt;00:00,  2.56s/it]"
          }
        },
        "dc410ad99a5a44cdb6b9762bd713afff": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d3e708079ef64dc6b7b3f466f834b7be": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "786cdc0611a34b8c854f68b8b2e09e9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "be406211e41b44cf83e862e30e81ab75": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1c7ce4adb8da4386b26a63396da0a371": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "0aee1ef79f444de2b9f339171e632de0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0039028878bb4068950bc57dbefbdf58": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8fa409bfc803409b8fd72b6139524acf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a05dfcf3316640c59b68fab1d32f44b7",
              "IPY_MODEL_288f07b9303a4673a3f47abf3a61c216",
              "IPY_MODEL_5c358655cd494ee8bce41d6cbe5b0634"
            ],
            "layout": "IPY_MODEL_d3b8ff18c54c43fd976997a87a90dd57"
          }
        },
        "a05dfcf3316640c59b68fab1d32f44b7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6ad1432aadfb42329a77a17954efeda5",
            "placeholder": "​",
            "style": "IPY_MODEL_2eed566299d74471a737b836dff92cf0",
            "value": "generation_config.json: 100%"
          }
        },
        "288f07b9303a4673a3f47abf3a61c216": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_53c53f1199644f7f9cd06ac35f96d83e",
            "max": 189,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_8c1ced63f1bc49ada2525d473fc9a29f",
            "value": 189
          }
        },
        "5c358655cd494ee8bce41d6cbe5b0634": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7829a146b0ff44f9ac880cd6af96107c",
            "placeholder": "​",
            "style": "IPY_MODEL_b62990369a7847daa3d23cf2ee3b6f53",
            "value": " 189/189 [00:00&lt;00:00, 21.1kB/s]"
          }
        },
        "d3b8ff18c54c43fd976997a87a90dd57": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6ad1432aadfb42329a77a17954efeda5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2eed566299d74471a737b836dff92cf0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "53c53f1199644f7f9cd06ac35f96d83e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8c1ced63f1bc49ada2525d473fc9a29f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7829a146b0ff44f9ac880cd6af96107c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b62990369a7847daa3d23cf2ee3b6f53": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a4518227f933454e83361482e7f5ecc1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_0cd442287f6f4520b5b5200c2f9758e3",
              "IPY_MODEL_fe0c6587aee94e05a7bec607117b40b3",
              "IPY_MODEL_859155a87335429c90010ff0d313c6bc"
            ],
            "layout": "IPY_MODEL_858ab81b0e8742418817e79af380d4b3"
          }
        },
        "0cd442287f6f4520b5b5200c2f9758e3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_875250b80ba64b72a98c07e39a39a474",
            "placeholder": "​",
            "style": "IPY_MODEL_5d81c35f053a43e8bf9c2fb40ab139eb",
            "value": "README.md: "
          }
        },
        "fe0c6587aee94e05a7bec607117b40b3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9c51453c5d604047aec2aa5402aadf95",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5face42daebf44cbb6492e2657f034f0",
            "value": 1
          }
        },
        "859155a87335429c90010ff0d313c6bc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6eca88023c744926a3be2f7a28d77eca",
            "placeholder": "​",
            "style": "IPY_MODEL_c5890ad41ae041c68ee0d150217e4245",
            "value": " 7.93k/? [00:00&lt;00:00, 561kB/s]"
          }
        },
        "858ab81b0e8742418817e79af380d4b3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "875250b80ba64b72a98c07e39a39a474": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5d81c35f053a43e8bf9c2fb40ab139eb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9c51453c5d604047aec2aa5402aadf95": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "5face42daebf44cbb6492e2657f034f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "6eca88023c744926a3be2f7a28d77eca": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c5890ad41ae041c68ee0d150217e4245": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6ca09a109c37482fad60132dacf7f627": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_1f53a10fd5244fa1b7ee5cf085b1c0db",
              "IPY_MODEL_cb419fb7a1f84e598581226a5558c9a5",
              "IPY_MODEL_c96bd872ae1b4b0e93a94a975e3add57"
            ],
            "layout": "IPY_MODEL_76170c900810451381851319bd759780"
          }
        },
        "1f53a10fd5244fa1b7ee5cf085b1c0db": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a717000b640b4351aede226cc9e6bcbc",
            "placeholder": "​",
            "style": "IPY_MODEL_500b78f96bc74f1f95a22f526aad88e5",
            "value": "main/train-00000-of-00001.parquet: 100%"
          }
        },
        "cb419fb7a1f84e598581226a5558c9a5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a5ee3532f896434db2e7d7c7e00035fb",
            "max": 2306545,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_705462646414442bbb77892f1c3aba30",
            "value": 2306545
          }
        },
        "c96bd872ae1b4b0e93a94a975e3add57": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fd9ed38e06914c92a86023d7c7660304",
            "placeholder": "​",
            "style": "IPY_MODEL_e89fa9ee7c6f4e94b8732c79e32c32b3",
            "value": " 2.31M/2.31M [00:00&lt;00:00, 77.4kB/s]"
          }
        },
        "76170c900810451381851319bd759780": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a717000b640b4351aede226cc9e6bcbc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "500b78f96bc74f1f95a22f526aad88e5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a5ee3532f896434db2e7d7c7e00035fb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "705462646414442bbb77892f1c3aba30": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "fd9ed38e06914c92a86023d7c7660304": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e89fa9ee7c6f4e94b8732c79e32c32b3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9a21e4b6dee74fb4830a0c1103b3656b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_80644f0914fa4f10b831463170a2a3d8",
              "IPY_MODEL_d1bcd0983a7f44298171ac6379bec20c",
              "IPY_MODEL_094fc444f90949d193cfb80ff5810374"
            ],
            "layout": "IPY_MODEL_54f317390144492a9b5991e5e6bec878"
          }
        },
        "80644f0914fa4f10b831463170a2a3d8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0084018313a44300bcadddb278a45cab",
            "placeholder": "​",
            "style": "IPY_MODEL_d0fc20b0afdf4abf919252ffb947c7a6",
            "value": "main/test-00000-of-00001.parquet: 100%"
          }
        },
        "d1bcd0983a7f44298171ac6379bec20c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a7d8dc630cac47909571bc3518f0af91",
            "max": 419088,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b1c05bc0cbec4aecb1c47daff4780786",
            "value": 419088
          }
        },
        "094fc444f90949d193cfb80ff5810374": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0fa15707dd78402e94311d212cd4055c",
            "placeholder": "​",
            "style": "IPY_MODEL_4eb99c2017ac46a88e68d0188997fdd6",
            "value": " 419k/419k [00:00&lt;00:00, 813kB/s]"
          }
        },
        "54f317390144492a9b5991e5e6bec878": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0084018313a44300bcadddb278a45cab": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d0fc20b0afdf4abf919252ffb947c7a6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a7d8dc630cac47909571bc3518f0af91": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b1c05bc0cbec4aecb1c47daff4780786": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "0fa15707dd78402e94311d212cd4055c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4eb99c2017ac46a88e68d0188997fdd6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "bf3c2eb9670640db8b482673a5828d2d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_bc7495bab4c449abb4491ffdfa0eb8ab",
              "IPY_MODEL_6d855ebb31314cc88a702c53ba3e9062",
              "IPY_MODEL_42c7e5f00cff4f39ac1c93aece0fd68c"
            ],
            "layout": "IPY_MODEL_029c3e9269ce4a218a0f40867c4d6a03"
          }
        },
        "bc7495bab4c449abb4491ffdfa0eb8ab": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1b8551e2775447cbb01e95dbc687a8cc",
            "placeholder": "​",
            "style": "IPY_MODEL_9ed581154c1244b2a2a26a974ca49631",
            "value": "Generating train split: 100%"
          }
        },
        "6d855ebb31314cc88a702c53ba3e9062": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_28c5646e07eb4bc99ae1c7f5626700fd",
            "max": 7473,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_02786c3cab8841f098e39e3ad6b8cc2b",
            "value": 7473
          }
        },
        "42c7e5f00cff4f39ac1c93aece0fd68c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8276467f67c745588905fcb23411ac0f",
            "placeholder": "​",
            "style": "IPY_MODEL_11bc6eb594db45109e2b4b31b901178f",
            "value": " 7473/7473 [00:00&lt;00:00, 264432.97 examples/s]"
          }
        },
        "029c3e9269ce4a218a0f40867c4d6a03": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1b8551e2775447cbb01e95dbc687a8cc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9ed581154c1244b2a2a26a974ca49631": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "28c5646e07eb4bc99ae1c7f5626700fd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "02786c3cab8841f098e39e3ad6b8cc2b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "8276467f67c745588905fcb23411ac0f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "11bc6eb594db45109e2b4b31b901178f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "cf819f8b33384f9da580a5ba1b7984ae": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_640fd24e01e54a3382480567db216dd8",
              "IPY_MODEL_646e9ed927ad44ca8d012482319f33b5",
              "IPY_MODEL_addca7f2ee384ceca6f112a599a645fe"
            ],
            "layout": "IPY_MODEL_9f15c60d32f14a62b18056a1ddab3f43"
          }
        },
        "640fd24e01e54a3382480567db216dd8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_204bdd9a7cbb44eeb96023a43c950203",
            "placeholder": "​",
            "style": "IPY_MODEL_87dcd3131094476aa0820406d2dcb45c",
            "value": "Generating test split: 100%"
          }
        },
        "646e9ed927ad44ca8d012482319f33b5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e26f0d7f97074887a0fd2eb73adb0cfc",
            "max": 1319,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2a9feac629564491905da452e00a627f",
            "value": 1319
          }
        },
        "addca7f2ee384ceca6f112a599a645fe": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_24571193f2e74ff0af863b54ef390554",
            "placeholder": "​",
            "style": "IPY_MODEL_c3f69e194e9e4383b62396bc70162936",
            "value": " 1319/1319 [00:00&lt;00:00, 97576.36 examples/s]"
          }
        },
        "9f15c60d32f14a62b18056a1ddab3f43": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "204bdd9a7cbb44eeb96023a43c950203": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "87dcd3131094476aa0820406d2dcb45c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e26f0d7f97074887a0fd2eb73adb0cfc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2a9feac629564491905da452e00a627f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "24571193f2e74ff0af863b54ef390554": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c3f69e194e9e4383b62396bc70162936": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c36ef80f6a5844c0996be20b99174001": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_77173f4595d04b13835901a44459ca31",
              "IPY_MODEL_251b3affacbb4960ab8bf9ed32846eac",
              "IPY_MODEL_82b3538b27644ec1af7590511ff0e3df"
            ],
            "layout": "IPY_MODEL_b2157dee918f43df82051cc5d732a454"
          }
        },
        "77173f4595d04b13835901a44459ca31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ec3be9c493074bc59a92e54ea73ff993",
            "placeholder": "​",
            "style": "IPY_MODEL_2f296d2be6d245f0920fefd1fe1a6238",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "251b3affacbb4960ab8bf9ed32846eac": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_83c0fddf60994a3d87fead56bdfe0ba0",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_346ca0f9fb1d4b419cab0e2ba1613667",
            "value": 21
          }
        },
        "82b3538b27644ec1af7590511ff0e3df": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b9dac3e0f2334d519994ae4a48ab44a3",
            "placeholder": "​",
            "style": "IPY_MODEL_674f280a741241fe8d45f66c79854f66",
            "value": " 21/21 [20:47&lt;00:00, 58.35s/it, N=1319, ACC%=68.39, greedy=1]"
          }
        },
        "b2157dee918f43df82051cc5d732a454": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "ec3be9c493074bc59a92e54ea73ff993": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2f296d2be6d245f0920fefd1fe1a6238": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "83c0fddf60994a3d87fead56bdfe0ba0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "346ca0f9fb1d4b419cab0e2ba1613667": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "b9dac3e0f2334d519994ae4a48ab44a3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "674f280a741241fe8d45f66c79854f66": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a83653628c70449192d8cb57ee28cbe9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_917281a4a2c34057b0def2278151ea93",
              "IPY_MODEL_137b3f056d7f4e129e33ec050d48414f",
              "IPY_MODEL_2d0c265acd9a4d71819e4185ccb0b6b9"
            ],
            "layout": "IPY_MODEL_7af8dbf17624480b836be0e3cae58a6f"
          }
        },
        "917281a4a2c34057b0def2278151ea93": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0e35c0b8260748e3b9bf03a849bf1eea",
            "placeholder": "​",
            "style": "IPY_MODEL_b1be8d7f613a4c7eaf179132d2561bcf",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "137b3f056d7f4e129e33ec050d48414f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0d213a47024a4e2094f008d607c8106d",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_38faafa9b2a546ae9aeb23157c0ff021",
            "value": 21
          }
        },
        "2d0c265acd9a4d71819e4185ccb0b6b9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5ec946e4076c4bf8bec3da6caf68efab",
            "placeholder": "​",
            "style": "IPY_MODEL_9b49048bc4004edcaa12dbd2a54144ec",
            "value": " 21/21 [22:10&lt;00:00, 72.18s/it, N=1319, ACC%=69.14, greedy=1]"
          }
        },
        "7af8dbf17624480b836be0e3cae58a6f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "0e35c0b8260748e3b9bf03a849bf1eea": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b1be8d7f613a4c7eaf179132d2561bcf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0d213a47024a4e2094f008d607c8106d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "38faafa9b2a546ae9aeb23157c0ff021": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5ec946e4076c4bf8bec3da6caf68efab": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9b49048bc4004edcaa12dbd2a54144ec": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "04071bf126594d6da8321b8955a1c65b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_fd5c75ff480b40bb8e36b5fd361de605",
              "IPY_MODEL_927a6b39006048a6bb3a01ebdcccf2c5",
              "IPY_MODEL_2f1650a54c8b4dac83c40e8d2673ae55"
            ],
            "layout": "IPY_MODEL_895145a6853f4fce831e669ae9d2143a"
          }
        },
        "fd5c75ff480b40bb8e36b5fd361de605": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_58d3f262baf34b2e8a93080b3a84078a",
            "placeholder": "​",
            "style": "IPY_MODEL_8443464cc9ba4989b03885103110cf30",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "927a6b39006048a6bb3a01ebdcccf2c5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_755bcecd4dff430d93075086589764e9",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b402e3ccb16145ebbca9f84156ca44ab",
            "value": 21
          }
        },
        "2f1650a54c8b4dac83c40e8d2673ae55": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_39a9791d71974f04881824b5d8329c0c",
            "placeholder": "​",
            "style": "IPY_MODEL_5b88f6b250c74efda6b90af708b45c35",
            "value": " 21/21 [22:43&lt;00:00, 66.68s/it, N=1319, ACC%=69.37, greedy=1]"
          }
        },
        "895145a6853f4fce831e669ae9d2143a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "58d3f262baf34b2e8a93080b3a84078a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8443464cc9ba4989b03885103110cf30": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "755bcecd4dff430d93075086589764e9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b402e3ccb16145ebbca9f84156ca44ab": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "39a9791d71974f04881824b5d8329c0c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5b88f6b250c74efda6b90af708b45c35": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "41f75be0a84d4ee3b77fb035d606caaf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_caa8046611d84c989b19c49e9e5b3d08",
              "IPY_MODEL_b4aeacbf641f4592ae18ebff7b8f975b",
              "IPY_MODEL_a31fc33114a446f0be8be4f9a6471b10"
            ],
            "layout": "IPY_MODEL_5de6f0fe0efe4be6b37b037215c9eef3"
          }
        },
        "caa8046611d84c989b19c49e9e5b3d08": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0265a92c8dcb42d5a05f8ebc4979b477",
            "placeholder": "​",
            "style": "IPY_MODEL_c7018c73765c44a6907811dfaea104fd",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "b4aeacbf641f4592ae18ebff7b8f975b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_91dbf3c0ccfe48258efdb90d4ae1d394",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_61b22b9892fa40d3800bc83b170bcb00",
            "value": 21
          }
        },
        "a31fc33114a446f0be8be4f9a6471b10": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5843f563475e4fdd9443ff40b8271073",
            "placeholder": "​",
            "style": "IPY_MODEL_62398a6c238c49c99982728f2ab8ea01",
            "value": " 21/21 [21:29&lt;00:00, 54.48s/it, N=1319, ACC%=68.84, greedy=1]"
          }
        },
        "5de6f0fe0efe4be6b37b037215c9eef3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "0265a92c8dcb42d5a05f8ebc4979b477": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c7018c73765c44a6907811dfaea104fd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "91dbf3c0ccfe48258efdb90d4ae1d394": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "61b22b9892fa40d3800bc83b170bcb00": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5843f563475e4fdd9443ff40b8271073": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "62398a6c238c49c99982728f2ab8ea01": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "53a250f5a9324300be750347202be76f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_0a493f8239634e0db75665d2914cae72",
              "IPY_MODEL_96f8b20d52a34c209c494258d88cf3c8",
              "IPY_MODEL_b1a12eddc2f44ddfb35bde51e96fc199"
            ],
            "layout": "IPY_MODEL_d72d6bbd54af4ecbb9e13093fa21694d"
          }
        },
        "0a493f8239634e0db75665d2914cae72": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_51ef546b4a9a43e280e0cf2feb1508aa",
            "placeholder": "​",
            "style": "IPY_MODEL_665dbeb11ec3448f85852bde435c0da1",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "96f8b20d52a34c209c494258d88cf3c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_273402276286463092c324b101c59601",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_34cc57832ac24e0984ca5949813f0227",
            "value": 21
          }
        },
        "b1a12eddc2f44ddfb35bde51e96fc199": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e98a94ca7ea04254987dc8f055f88be6",
            "placeholder": "​",
            "style": "IPY_MODEL_449196000f8d4871b6ba705e66d4f290",
            "value": " 21/21 [20:52&lt;00:00, 47.79s/it, N=1319, ACC%=69.07, greedy=1]"
          }
        },
        "d72d6bbd54af4ecbb9e13093fa21694d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "51ef546b4a9a43e280e0cf2feb1508aa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "665dbeb11ec3448f85852bde435c0da1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "273402276286463092c324b101c59601": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "34cc57832ac24e0984ca5949813f0227": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e98a94ca7ea04254987dc8f055f88be6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "449196000f8d4871b6ba705e66d4f290": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7844cc9f801a46708d7dacc663e15f7a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_e5fa83622140415083f19a970d259dc4",
              "IPY_MODEL_001146de57824784b70aed7f4a593f9f",
              "IPY_MODEL_7d2d530a77a54099891746da666bb427"
            ],
            "layout": "IPY_MODEL_f6f88c0d6fa049aa8250903f99ea2d3b"
          }
        },
        "e5fa83622140415083f19a970d259dc4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0c5f203e7e9142eca7645a4929ebb129",
            "placeholder": "​",
            "style": "IPY_MODEL_2a3f7af9b71241e58bfecb7f8a944e03",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "001146de57824784b70aed7f4a593f9f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9771ad7ebb694d7aaa0d8f1d34a38428",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b06bff990ed048ff98d38bbdfce3114d",
            "value": 21
          }
        },
        "7d2d530a77a54099891746da666bb427": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_844b67e63dfc4992ac29ec24df6909b2",
            "placeholder": "​",
            "style": "IPY_MODEL_ba95a62c687848b39dd84f4e643a4bbd",
            "value": " 21/21 [21:32&lt;00:00, 65.61s/it, N=1319, ACC%=69.29, greedy=1]"
          }
        },
        "f6f88c0d6fa049aa8250903f99ea2d3b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "0c5f203e7e9142eca7645a4929ebb129": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2a3f7af9b71241e58bfecb7f8a944e03": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9771ad7ebb694d7aaa0d8f1d34a38428": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b06bff990ed048ff98d38bbdfce3114d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "844b67e63dfc4992ac29ec24df6909b2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ba95a62c687848b39dd84f4e643a4bbd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "eb61ff4eb86c4ca88a8fc6b6da801f5d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d9413a542f64452082d6ec5fadf6fe30",
              "IPY_MODEL_2ee552354cf2402bade39bd72e7efa2c",
              "IPY_MODEL_325b042469a8410fb8fb2b0516f5d37d"
            ],
            "layout": "IPY_MODEL_e9bf02ef4b124945a32c94901243412b"
          }
        },
        "d9413a542f64452082d6ec5fadf6fe30": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7c12b48e75fc454f9d1b8f019e17876e",
            "placeholder": "​",
            "style": "IPY_MODEL_25beb888633b4cd8a6fdb1f0d0d0ced1",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "2ee552354cf2402bade39bd72e7efa2c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_273021b6b8f34940900f6b99f555bc09",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c260bcf9957f42e68d07c54404d2dccc",
            "value": 21
          }
        },
        "325b042469a8410fb8fb2b0516f5d37d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e25edd81ca9a43dca0047583ebab8ade",
            "placeholder": "​",
            "style": "IPY_MODEL_35a222a0c74446bbbfedc30bd14394a8",
            "value": " 21/21 [21:44&lt;00:00, 55.74s/it, N=1319, ACC%=69.14, greedy=1]"
          }
        },
        "e9bf02ef4b124945a32c94901243412b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "7c12b48e75fc454f9d1b8f019e17876e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "25beb888633b4cd8a6fdb1f0d0d0ced1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "273021b6b8f34940900f6b99f555bc09": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c260bcf9957f42e68d07c54404d2dccc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e25edd81ca9a43dca0047583ebab8ade": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "35a222a0c74446bbbfedc30bd14394a8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d4d031715a544f2e8b26c18d59d2c28d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9bbd6ccdbe354198bd088bb7ee82a16f",
              "IPY_MODEL_92f80130de624cb6aa572dad9b9a579d",
              "IPY_MODEL_40ff3a6b33df450ebb0a1ca99af2100d"
            ],
            "layout": "IPY_MODEL_02c1f6fce2c240209db6e4614e34f7de"
          }
        },
        "9bbd6ccdbe354198bd088bb7ee82a16f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fb79b6da757a4817ab21b38ed543a5c2",
            "placeholder": "​",
            "style": "IPY_MODEL_371f700d9add466caf67cfb2ad4795ce",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "92f80130de624cb6aa572dad9b9a579d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_988e311f2a2a4f2d9548cee3c4a1cebd",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_030fa296e29c4b5aa98e0e7b6b5c52fd",
            "value": 21
          }
        },
        "40ff3a6b33df450ebb0a1ca99af2100d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_56fab98371434fafb5bd5d67936045ff",
            "placeholder": "​",
            "style": "IPY_MODEL_7df0ef87c19b46e598f837340953447a",
            "value": " 21/21 [20:23&lt;00:00, 59.05s/it, N=1319, ACC%=68.61, greedy=1]"
          }
        },
        "02c1f6fce2c240209db6e4614e34f7de": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "fb79b6da757a4817ab21b38ed543a5c2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "371f700d9add466caf67cfb2ad4795ce": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "988e311f2a2a4f2d9548cee3c4a1cebd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "030fa296e29c4b5aa98e0e7b6b5c52fd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "56fab98371434fafb5bd5d67936045ff": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7df0ef87c19b46e598f837340953447a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7b739b99a0f24798859e0046effb3335": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_952bce6c047344f59ca5f392aba9fabf",
              "IPY_MODEL_b37437cf9e3e43b588957a87a828848c",
              "IPY_MODEL_006060c03f8f4c83904ae5b1e0ca581d"
            ],
            "layout": "IPY_MODEL_d186476b39034e0c83f0c2b05eeff4a5"
          }
        },
        "952bce6c047344f59ca5f392aba9fabf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5e3d047f381a40cdb27dbbdd374583e6",
            "placeholder": "​",
            "style": "IPY_MODEL_208af50d0a4d4012affcd65fd686ef1e",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "b37437cf9e3e43b588957a87a828848c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6894850ac2484010a1c0d2a6ac24eed5",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2c78341748fb43ccbf807d42fa7a7957",
            "value": 21
          }
        },
        "006060c03f8f4c83904ae5b1e0ca581d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6ff0bfc005c44c6c8815b36f150fa073",
            "placeholder": "​",
            "style": "IPY_MODEL_2c3690523cd64c15835e51f610d81f1d",
            "value": " 21/21 [21:43&lt;00:00, 55.59s/it, N=1319, ACC%=68.76, greedy=1]"
          }
        },
        "d186476b39034e0c83f0c2b05eeff4a5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "5e3d047f381a40cdb27dbbdd374583e6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "208af50d0a4d4012affcd65fd686ef1e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6894850ac2484010a1c0d2a6ac24eed5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2c78341748fb43ccbf807d42fa7a7957": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "6ff0bfc005c44c6c8815b36f150fa073": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2c3690523cd64c15835e51f610d81f1d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "310cb44652814df388dbcb586d454db0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b76f2ab7adbd457bb44503230d7a05a1",
              "IPY_MODEL_1d06caf9760d4c138d796d21929d8837",
              "IPY_MODEL_e73c88572ad94a59ac5bed92cde2eb3d"
            ],
            "layout": "IPY_MODEL_64673f99ade0414484cf703b8c8f42e9"
          }
        },
        "b76f2ab7adbd457bb44503230d7a05a1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9a84a30201ad4e57b767aef60e9e5c07",
            "placeholder": "​",
            "style": "IPY_MODEL_fb75d6c155a045968eb02c818f6484aa",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "1d06caf9760d4c138d796d21929d8837": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d99826be5e7243d986ed1b16471cdc7a",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5f762efb20ed44bb8ec9a213f26fa85e",
            "value": 21
          }
        },
        "e73c88572ad94a59ac5bed92cde2eb3d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_59748a4746354e1b8539c9e41236621a",
            "placeholder": "​",
            "style": "IPY_MODEL_cc643c706e854c2fae91e7ddb3d3feb3",
            "value": " 21/21 [21:46&lt;00:00, 67.12s/it, N=1319, ACC%=69.37, greedy=1]"
          }
        },
        "64673f99ade0414484cf703b8c8f42e9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "9a84a30201ad4e57b767aef60e9e5c07": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fb75d6c155a045968eb02c818f6484aa": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d99826be5e7243d986ed1b16471cdc7a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5f762efb20ed44bb8ec9a213f26fa85e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "59748a4746354e1b8539c9e41236621a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cc643c706e854c2fae91e7ddb3d3feb3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}