{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "fU6KXDe-mNjD",
      "metadata": {
        "id": "fU6KXDe-mNjD"
      },
      "source": [
        "# Qwen: Direct Projection Training with Paraphrase Coherence (TriviaQA)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cva0_WZcmNjF",
      "metadata": {
        "id": "cva0_WZcmNjF"
      },
      "source": [
        "This Colab trains a **Qwen causal LM** with a **direct projection** objective:\n",
        "\n",
        "- **KL to frozen baseline** keeps the model near \\(\\omega_0\\).\n",
        "- **Pairwise JS** improves **coherence across paraphrases** of the same question.\n",
        "- Uses **REINFORCE** estimators with baselines, multi-sample averaging, and log-sum-exp numerics."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Q3RRjK20mNjF",
      "metadata": {
        "id": "Q3RRjK20mNjF"
      },
      "source": [
        "## 0) Runtime & GPU check"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "AIguI7ArakvF",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AIguI7ArakvF",
        "outputId": "eb9f6d00-b2c3-417b-a8ba-5cd1667bc6e2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]\n",
            "OS: Linux-6.6.105+-x86_64-with-glibc2.35\n",
            "Arch: x86_64\n",
            "Torch: 2.9.0+cu126\n",
            "CUDA (toolkit from Torch): 12.6\n",
            "CUDA available: True\n",
            "GPU: NVIDIA A100-SXM4-80GB\n",
            "SM capability: (8, 0)\n",
            "VRAM (GB): 79.32\n"
          ]
        }
      ],
      "source": [
        "import platform, torch, subprocess, sys\n",
        "import os, torch\n",
        "\n",
        "print(\"Python:\", sys.version)\n",
        "print(\"OS:\", platform.platform())\n",
        "print(\"Arch:\", platform.machine())\n",
        "print(\"Torch:\", torch.__version__)\n",
        "print(\"CUDA (toolkit from Torch):\", torch.version.cuda)\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "    print(\"SM capability:\", torch.cuda.get_device_capability(0))\n",
        "    print(\"VRAM (GB):\", round(torch.cuda.get_device_properties(0).total_memory/2**30,2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "zEE6ahLV5LZX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zEE6ahLV5LZX",
        "outputId": "d40f09b2-2a11-4ab2-af1a-bda1a289d117"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading: https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.15/flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n",
            "Installing: flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n",
            "Installed flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import subprocess\n",
        "\n",
        "# 1) Force flash-attn to use a prebuilt wheel or fail fast\n",
        "os.environ[\"FLASH_ATTENTION_SKIP_CUDA_BUILD\"] = \"TRUE\"\n",
        "\n",
        "# 2) Download the wheel for:\n",
        "#    - flash_attn 2.8.3\n",
        "#    - CUDA 12.6 (cu126)\n",
        "#    - torch 2.9\n",
        "#    - Python 3.12 (cp312)\n",
        "wheel_url = \"https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.15/flash_attn-2.8.3+cu126torch2.9-cp312-cp312-linux_x86_64.whl\"\n",
        "wheel_name = wheel_url.split(\"/\")[-1]\n",
        "\n",
        "print(\"Downloading:\", wheel_url)\n",
        "subprocess.run([\"wget\", \"-q\", wheel_url, \"-O\", wheel_name], check=True)\n",
        "\n",
        "print(\"Installing:\", wheel_name)\n",
        "subprocess.run(\n",
        "    [\"pip\", \"install\", \"--no-dependencies\", \"--upgrade\", wheel_name],\n",
        "    check=True,\n",
        ")\n",
        "\n",
        "print(\"Installed\", wheel_name)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "rgVu_ySz5ReL",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rgVu_ySz5ReL",
        "outputId": "aedaada6-5353-4ad8-b8f3-4218f292b437"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "flash_attn version: 2.8.3\n"
          ]
        }
      ],
      "source": [
        "import flash_attn\n",
        "\n",
        "print(\"flash_attn version:\", getattr(flash_attn, \"__version__\", \"unknown\"))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "qQ3UlHbC5dLW",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qQ3UlHbC5dLW",
        "outputId": "668c354d-3d29-4801-9b69-cd00f8dc65d0"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Torch: 2.9.0+cu126\n",
            "CUDA available: True\n",
            "Device: NVIDIA A100-SXM4-80GB\n",
            "Compute capability: (8, 0)\n",
            "FlashAttention forward pass OK.\n",
            "Output shape: torch.Size([2, 128, 8, 64])\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from flash_attn import flash_attn_func\n",
        "\n",
        "device = \"cuda\"\n",
        "\n",
        "print(\"Torch:\", torch.__version__)\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "print(\"Device:\", torch.cuda.get_device_name(0))\n",
        "print(\"Compute capability:\", torch.cuda.get_device_capability(0))\n",
        "\n",
        "# Small test config\n",
        "B = 2      # batch size\n",
        "L = 128    # sequence length\n",
        "H = 8      # number of heads\n",
        "D = 64     # head dimension (must be <= 256 for FA2)\n",
        "\n",
        "# Q, K, V shapes: (batch, seqlen, nheads, headdim)\n",
        "q = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "k = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "v = torch.randn(B, L, H, D, device=device, dtype=torch.float16)\n",
        "\n",
        "try:\n",
        "    out = flash_attn_func(\n",
        "        q, k, v,\n",
        "        dropout_p=0.0,\n",
        "        softmax_scale=None,\n",
        "        causal=False,\n",
        "    )\n",
        "    print(\"FlashAttention forward pass OK.\")\n",
        "    print(\"Output shape:\", out.shape)\n",
        "except Exception as e:\n",
        "    print(\"FlashAttention FAILED on GPU:\")\n",
        "    print(type(e).__name__, \":\", e)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "uUeRWvsdQv1C",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uUeRWvsdQv1C",
        "outputId": "c5acb48c-e609-4174-aba9-be1780f3e065"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "u7T3ljRumNjG",
      "metadata": {
        "id": "u7T3ljRumNjG"
      },
      "source": [
        "## 1) Install packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d2snW-PemNjG",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "d2snW-PemNjG",
        "outputId": "b69ec5c8-62d4-4243-d1ee-37bd12e1732f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.8 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m78.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.1)\n"
          ]
        }
      ],
      "source": [
        "!pip -q install --upgrade pip\n",
        "!pip -q install -U \"transformers>=4.43.0\" \"accelerate>=0.33.0\" \"datasets>=2.19.0\" \"bitsandbytes>=0.43.0\" \"peft>=0.11.0\" sentence-transformers einops\n",
        "# Optional: flash-attn (if wheel exists for your GPU)\n",
        "# !pip -q install flash-attn --no-build-isolation\n",
        "!pip install tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bBwLq3tAZBfd",
      "metadata": {
        "id": "bBwLq3tAZBfd"
      },
      "outputs": [],
      "source": [
        "!pip -q install -U \"trl>=0.8.0\""
      ]
    },
    {
      "cell_type": "markdown",
      "id": "IviYIdL0mNjG",
      "metadata": {
        "id": "IviYIdL0mNjG"
      },
      "source": [
        "## 2) Imports & global config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QGO8iNsFmNjG",
      "metadata": {
        "id": "QGO8iNsFmNjG"
      },
      "outputs": [],
      "source": [
        "from dataclasses import dataclass\n",
        "from typing import List, Tuple\n",
        "import torch, math, random, numpy as np\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
        "from datasets import load_dataset\n",
        "from sentence_transformers import SentenceTransformer, util as sbert_util\n",
        "\n",
        "SEED = 42\n",
        "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "# Choose model (smaller first for Colab T4/L4)\n",
        "TRAINABLE_MODEL_NAME = \"microsoft/Phi-3-mini-4k-instruct\"#\"meta-llama/Llama-3.2-3B-Instruct\"#\"Qwen/Qwen2.5-3B-Instruct\"#\n",
        "BASELINE_MODEL_NAME  = TRAINABLE_MODEL_NAME\n",
        "\n",
        "USE_4BIT = True\n",
        "BNB_CONFIG = BitsAndBytesConfig(\n",
        "    load_in_4bit=USE_4BIT,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
        ")\n",
        "\n",
        "BASELINE_DEVICE_MAP = \"auto\"#\"cpu\"   # \"cpu\" to save VRAM; set to \"auto\" if you have >=24GB\n",
        "\n",
        "MAX_NEW_TOKENS = 512\n",
        "GEN_TEMPERATURE = 0.7\n",
        "TOP_P = 0.9\n",
        "TOP_K = 50\n",
        "\n",
        "PARAPHRASES_PER_QUESTION = 2     # total per class incl. original\n",
        "NUM_CLASSES_FOR_TRAIN = 10000\n",
        "SAMPLES_PER_PAIR = 16\n",
        "JS_WEIGHT = 1.0\n",
        "KL_WEIGHT = 0.02\n",
        "\n",
        "LR = 2e-5\n",
        "WEIGHT_DECAY = 0.01\n",
        "GRAD_CLIP = 1.0\n",
        "STEPS = 50\n",
        "BATCH_SIZE_CLASSES = 32\n",
        "LOG_EVERY = 20\n",
        "\n",
        "#\n",
        "WORK_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-mode-jan16/weight-5e-2\"\n",
        "CKPT_LOAD_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-mode-jan16/weight-5e-2/checkpoint-960\"#f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct-majority-dec29/weight-1/checkpoint-1600\"\n",
        "COHERENCE_DATA_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/answerlevel/{TRAINABLE_MODEL_NAME}/instruct/\"\n",
        "os.makedirs(WORK_PATH, exist_ok=True)\n",
        "CACHE_PATH = f\"/content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZMmHw__UNCjX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZMmHw__UNCjX",
        "outputId": "8ab2d6d6-6398-4ecf-f6ab-f323239de84d"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "adapter_config.json\t   README.md\t\t    tokenizer_config.json\n",
            "adapter_model.safetensors  ref\t\t\t    tokenizer.json\n",
            "added_tokens.json\t   rng_state.pth\t    tokenizer.model\n",
            "chat_template.jinja\t   scheduler.pt\t\t    trainer_state.json\n",
            "optimizer.pt\t\t   special_tokens_map.json  training_args.bin\n"
          ]
        }
      ],
      "source": [
        "!ls $CKPT_LOAD_PATH"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c7DuQugEeSaD",
      "metadata": {
        "id": "c7DuQugEeSaD"
      },
      "source": [
        "## Utils"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Ul00h3tomNjH",
      "metadata": {
        "id": "Ul00h3tomNjH"
      },
      "source": [
        "## 3) Load tokenizer, trainable QLoRA model, and frozen baseline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "T5xYHUDKds80",
      "metadata": {
        "id": "T5xYHUDKds80"
      },
      "outputs": [],
      "source": [
        "# @title\n",
        "# === utils_loaders.py ===\n",
        "import os, gc, torch, platform\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
        "\n",
        "# ---------- Helpers ----------\n",
        "\n",
        "def gpu_supports_bf16() -> bool:\n",
        "    return torch.cuda.is_available() and torch.cuda.is_bf16_supported()\n",
        "\n",
        "def pick_dtype():\n",
        "    return torch.bfloat16 if gpu_supports_bf16() else (torch.float16 if torch.cuda.is_available() else torch.float32)\n",
        "\n",
        "def pick_attn_impl() -> str:\n",
        "    \"\"\"\n",
        "    FlashAttention-2 only on Ampere+ (SM >= 80). Otherwise use SDPA.\n",
        "    \"\"\"\n",
        "    if not torch.cuda.is_available():\n",
        "        return \"sdpa\"\n",
        "    sm_major, _ = torch.cuda.get_device_capability(0)\n",
        "    print (\"using flash_attention_2\") if sm_major >= 8 else print (\"using sdpa\")\n",
        "    return \"flash_attention_2\" if sm_major >= 8 else \"sdpa\"\n",
        "\n",
        "def hard_cleanup(*names_to_del, also_empty_cache=True):\n",
        "    \"\"\"\n",
        "    Delete large globals by name (if present), run gc, and empty CUDA cache.\n",
        "    \"\"\"\n",
        "    g = globals()\n",
        "    for n in names_to_del:\n",
        "        if n in g:\n",
        "            try: del g[n]\n",
        "            except: pass\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available() and also_empty_cache:\n",
        "        torch.cuda.empty_cache()\n",
        "        torch.cuda.synchronize()\n",
        "\n",
        "def make_bnb_config(load_in_4bit=True):\n",
        "    # Lazy import to avoid dependency unless used\n",
        "    from transformers import BitsAndBytesConfig\n",
        "    return BitsAndBytesConfig(\n",
        "        load_in_4bit=load_in_4bit,\n",
        "        load_in_8bit=not load_in_4bit and True or False,   # if you ever want 8-bit, flip flags accordingly\n",
        "        bnb_4bit_compute_dtype=pick_dtype(),\n",
        "        bnb_4bit_use_double_quant=True,\n",
        "        bnb_4bit_quant_type=\"nf4\",\n",
        "    )\n",
        "\n",
        "def setup_tokenizer(model_id: str, use_fast=True):\n",
        "    tok = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, trust_remote_code=True)\n",
        "    # padding & truncation for decoder-only batched gen\n",
        "    if tok.pad_token is None:\n",
        "        # Prefer a dedicated pad token if it exists, else fall back to eos\n",
        "        tok.pad_token = tok.eos_token\n",
        "    tok.padding_side = \"left\"\n",
        "    tok.truncation_side = \"left\"\n",
        "    return tok\n",
        "\n",
        "# ---------- Trainable (QLoRA) ----------\n",
        "\n",
        "def load_trainable_model(\n",
        "    model_id: str,\n",
        "    use_4bit: bool = True,\n",
        "    lora_r: int = 8,\n",
        "    lora_alpha: int = 16,\n",
        "    lora_dropout: float = 0.05,\n",
        "    extra_target_modules: list[str] | None = None,\n",
        "    device_map: str | dict = \"auto\",\n",
        "):\n",
        "    \"\"\"\n",
        "    Load a k-bit QLoRA-ready model + tokenizer and wrap with PEFT LoRA.\n",
        "    Returns: (tokenizer, peft_model)\n",
        "\n",
        "    Special case:\n",
        "      - For microsoft/Phi-3* models, we force trust_remote_code=False\n",
        "        to avoid the DynamicCache.seen_tokens AttributeError.\n",
        "    \"\"\"\n",
        "\n",
        "    # --- tokenizer ---\n",
        "    tok = setup_tokenizer(model_id)\n",
        "\n",
        "    # --- quantization config ---\n",
        "    quant_cfg = make_bnb_config(load_in_4bit=use_4bit) if use_4bit else None\n",
        "\n",
        "    # --- pick trust_remote_code depending on model ---\n",
        "    # Phi-3 has outdated remote modeling code that breaks with recent transformers.\n",
        "    is_phi3 = \"phi-3\" in model_id.lower()\n",
        "    trust_remote = False if is_phi3 else True\n",
        "\n",
        "    # --- load base model ---\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        model_id,\n",
        "        quantization_config=quant_cfg,\n",
        "        torch_dtype=pick_dtype(),\n",
        "        device_map=device_map,\n",
        "        low_cpu_mem_usage=True,\n",
        "        trust_remote_code=trust_remote,\n",
        "        attn_implementation=pick_attn_impl(),  # HF will choose SDPA if needed\n",
        "    )\n",
        "\n",
        "    # Prepare for k-bit training\n",
        "    model = prepare_model_for_kbit_training(model)\n",
        "\n",
        "    # Auto-detect common projection module names if not provided\n",
        "    targets = set(extra_target_modules or [])\n",
        "    for name, _ in model.named_modules():\n",
        "        low = name.lower().split(\".\")[-1]\n",
        "        if low in {\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"}:\n",
        "            targets.add(low)\n",
        "    if not targets:\n",
        "        targets = {\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\"}\n",
        "\n",
        "    lora_cfg = LoraConfig(\n",
        "        r=lora_r,\n",
        "        lora_alpha=lora_alpha,\n",
        "        target_modules=sorted(targets),\n",
        "        lora_dropout=lora_dropout,\n",
        "        bias=\"none\",\n",
        "        task_type=\"CAUSAL_LM\",\n",
        "    )\n",
        "    model = get_peft_model(model, lora_cfg)\n",
        "\n",
        "    # Required for gradient checkpointing with PEFT\n",
        "    model.config.use_cache = False\n",
        "\n",
        "    model.enable_input_require_grads()\n",
        "    model.gradient_checkpointing_enable()\n",
        "\n",
        "    model.print_trainable_parameters()\n",
        "\n",
        "    return tok, model\n",
        "\n",
        "\n",
        "# ---------- Frozen Baseline (inference) ----------\n",
        "\n",
        "def load_baseline_model(\n",
        "    model_id: str,\n",
        "    force_cpu_embedder: bool = True,\n",
        "    use_4bit: bool = False,\n",
        "    device_map: str | dict = \"auto\",\n",
        "    max_memory: dict | None = None,\n",
        "    allow_offload_folder: str | None = None,\n",
        "):\n",
        "    \"\"\"\n",
        "    Load a frozen inference model with the fastest attention available for your GPU.\n",
        "    Returns: (tokenizer, baseline_model)\n",
        "    \"\"\"\n",
        "    tok = setup_tokenizer(model_id)\n",
        "\n",
        "    kwargs = dict(\n",
        "        device_map=device_map,\n",
        "        low_cpu_mem_usage=True,\n",
        "        trust_remote_code=True,\n",
        "        attn_implementation=pick_attn_impl(),\n",
        "    )\n",
        "\n",
        "    if use_4bit:\n",
        "        kwargs[\"quantization_config\"] = make_bnb_config(load_in_4bit=True)\n",
        "    else:\n",
        "        kwargs[\"torch_dtype\"] = pick_dtype()\n",
        "\n",
        "    if max_memory is not None:\n",
        "        kwargs[\"max_memory\"] = max_memory\n",
        "    if allow_offload_folder is not None:\n",
        "        os.makedirs(allow_offload_folder, exist_ok=True)\n",
        "        kwargs[\"offload_folder\"] = allow_offload_folder\n",
        "        kwargs[\"offload_state_dict\"] = True\n",
        "\n",
        "    model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval()\n",
        "\n",
        "    # Freeze to be explicit\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad_(False)\n",
        "\n",
        "    return tok, model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "FswzqYDPdwoH",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "referenced_widgets": [
            "7a5e4a148a8046ce9f9eebb891e0e997",
            "f39003ffa74d483d9915bdbb9a9687a8",
            "ded174fdf0274aefa8d4f607d0cc5d10",
            "5ced36961b7c410eaf907f18e36ec4e9",
            "b4e1b888783d45c2888fa275d07a26f4",
            "797ad36b4a6f499592e36e142c7cf56a",
            "ffe990e78ecb4baa9b89bcaafd1fcd45",
            "618604fc21de45e796c41a76dc9b377c",
            "3f0627222871458ca034b76ba320d062",
            "6a0da7f66ea44a77a3ec8050d3a2e830",
            "45b4e334a9b94033ae78b96791bdf4e1",
            "ec08eb790dbc4a00bfaf4c9174307d5f",
            "b119d8e2d9c94031840b6d4977d9db07",
            "b892c27f302946d09ecabd65f2bd2b52",
            "ebd53e2c38e7455c859f1e30d8f1a085",
            "799b5bebcd6b4fd89a5872b3f8af1a50",
            "2276b8a3c8c7494f988822c71035ee3f",
            "7a5d280f7f04401dbadc42de3d4b110b",
            "1ddd36facd51491da8c1d34a4d340a79",
            "0d19e1d1729b4e88a7fe046a3fcf907c",
            "37f2190d83b04182a80c83c891b3b445",
            "4efa2599bb644aba81fc2ed14216a602",
            "0d9f778be8994c1f9c0d20bc2d5f1f11",
            "e373c7e122834faab7ff309257d3f3f1",
            "07fe66a591e548d0a4e72dd952d129b4",
            "0e51bea0b0d84a65bead2df6b8f4195b",
            "187704687eec405d953b997968c540ce",
            "41a4945ba4fe4a5290f6caac50f78ead",
            "7b0acbeaa67c4622aec005e0720259d7",
            "ea0b144cce3f46f491d3b3a202efaa84",
            "6a8695dc8b3b45539633cdbae26b2f7a",
            "1f5ec773417c4cbda9fac915ea5dc94a",
            "d45bdcfc6f5d485d8239ddc0e548983b",
            "7456f00a6fea4823a1f9513fd6d07cf2",
            "9c696098fb87434aaf25c504970a2190",
            "bbbdc35b1d4045e696daa8cd20700ce2",
            "0707764eb0ab418493663a026597a4e0",
            "67f9ad5d3221419d94a6be7197ada150",
            "2c453740ebd74bbebe04b2138b58edea",
            "6c4217ad778b4ec698adcd8922aa7881",
            "5ce9ada55af34ab1a3a13b665821d81e",
            "d81d9529b42c4f76865bddb827df56f0",
            "60e76b9f6e0840ada2adf78ac1397d27",
            "2a9521b75b4f4658a4abddfa74706a04",
            "67aa64c60ad24a8889e1f34a44ac4fc9",
            "8789ab53fa334c9c8da9941f1a5c3e51",
            "e47ba5263c9f477db5a126f589ec1333",
            "20f653836cbc4d839b62aac4b1aa904d",
            "26ac80c7d933450e90a8cdb682b48660",
            "93d0bff923cb4aa2af17260db5fc5aa3",
            "9a3da4e17ca24fc89752879eaf0a1f1e",
            "aaeb6d43ea2c4d0a97137ccbd5b6b56e",
            "7b9fe861916b473eadc34d7791dcd8a5",
            "5dd4d8669b294a0ba9fc23fd4d56ddf6",
            "7db97fac4f9e496c947c039f27f9911f",
            "24913f7d4e384e9ba15136053fd6dce2",
            "2cd20eccac484d57ab1df90a2911d322",
            "f5abc702151f43249d9c6205e6a4ab3b",
            "69a9e28b095a40bdb220b3b5c80174f6",
            "ea4e39e1326640fa807513eee76150a8",
            "44e3abdc7ee949bbb352e5d032069563",
            "d03c0cab1f4e4bc882520a07f591dce0",
            "862d908202b947639e1b266a1052b313",
            "7565c74477314ec5aafcbdec86d2fa09",
            "e2a4c231c1c94455b59512a0f3b8ab01",
            "f9f211046f8449749459c61e0257ac16",
            "cdbad6d0d7d74f04bef0a0bef72bc049",
            "0a3a3123bd2b4defb16b8c4bbe82fe8b",
            "96de9938e80f436982115b248107f388",
            "9a392cb3bb2f40a5b554ae29e366405c",
            "afd3fec1be764e66ab6231552f9fcede",
            "19c68f9546764c66b7dc74ee4ffd8c1d",
            "76ac6c1bef1f45a8ad24703744e8c0b2",
            "8c0dad254951417d8fb28165ad139143",
            "38291db55bc746bd8114e0bf5ed213a1",
            "736cf06ab906427ca04087e3180f7930",
            "109f3593f3864fea934c60edbe46b70d",
            "01e5c0e01f384a29a3f98c7fefc0dded",
            "23562b6755dc498ca3423bf68b77b65a",
            "ba863d5535c047b99c7d26d066555f0b",
            "8d5a36bdf6c049eab68fdd2567095d64",
            "d4a968ea181f43688bbf975326484e2a",
            "ffe0a6b2edd8486fa9b0b93dc916d0a1",
            "a39590bba0d0492abcf414490763f559",
            "807c6777747b4ada902120426de8a611",
            "f14f6c98b4884884bbfedb4aa4200123",
            "4694232e149c428aae43a1ec9b79f2ee",
            "36db6fbb53e34ca7bd71bdd22deb8fc3",
            "3f5a11b214724194a8fa39bb355e2bbe",
            "2f43805da2ee43c38f7276896125e455",
            "7f14f4d8df694f30aa71af895b1155af",
            "510b3441845a4b9593229e10ff1ddc26",
            "b8b08af66b2f448b95e3da143758c90b",
            "b00f0afb302a43119aa20912dbbd37b3",
            "48dec881fbc14ed5ba3d1c16a7dd47c0",
            "666ece6cceed4a709fbf2f7fe226e877",
            "4600ae8cc09b4cce8ab03d6fc97eca8f",
            "d74d9a726d7249e3b0de95057d3f36cc",
            "bb93c0459aa242f49f3cd7e1ae734ee0",
            "9a4a5417e4cc45ca8fa8f07e76fde0a8",
            "fbc22ca925314cf2b59c285a2db85ca2",
            "e288c2fd0ee24608abffdba553045a7e",
            "43d2252a5e4f4622ac362770258ed7ba",
            "a2a6de3dd62b4a6fbfd291b2ca4e8730",
            "713c5923a6b3410b8672ecfe38e66f29",
            "ea8d320b8213469cb81460d0fbeee3d8",
            "73447fc92e084749a7581a67b2bbc336",
            "e6c4157e89c7479a9148ccb7de462366",
            "7656329c18ea4c649be0a4daa666b4f0",
            "4cccd262bcfa4d72950c433cf579e17f",
            "ae3cc799469b4bb9b65702e3bb81311e",
            "9ed53200c8e049bfa57c037d79ebaf2f",
            "1e2bbf3a5d37489784b68fe9239ceed8",
            "1c39bab6d8354d80a21d4b8e10a34258",
            "230d4ede8bdf436e8b94e62c0ed1f1d3",
            "e6f2ad9ac925407eb4cab5d383e814ab",
            "1dfce580b8ab401398bb1319d8e816ca",
            "15c68305dffb4d0c9a2deb5d1730c04a",
            "3450aea3fc904dbe831388c60c1dffb8",
            "9ef9e3aa3251406498565affe909b1c7",
            "d1db823bd4544a6aa718f8fb413cdf85",
            "17714a5200ad439689748a4b8a9e73a6",
            "07c4803e23194b1abf9d8cccf8b68d0f",
            "a230326aab3e4502afa4027a27a1bf42",
            "0cfa1d8cde204151aad13b1af64925a7",
            "6786981118b24cf88c6954116da51911",
            "6c98c97bfd0a4d97ac65839a3db66c3d",
            "056fe12e14094451a2f9f10745031797",
            "8bfa751712744c409598ee69280bdeb0",
            "26f555af847b4aa89bb4b012d7fb0747",
            "4c36dd3013414f35a66a52f133c5b5d6",
            "9c58a1b03fef4bcd9b00bda9e1f0741b"
          ]
        },
        "collapsed": true,
        "id": "FswzqYDPdwoH",
        "outputId": "9e5679d2-8774-4613-e82f-7d116940724f"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer_config.json: 0.00B [00:00, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7a5e4a148a8046ce9f9eebb891e0e997"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ec08eb790dbc4a00bfaf4c9174307d5f"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer.json: 0.00B [00:00, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "0d9f778be8994c1f9c0d20bc2d5f1f11"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7456f00a6fea4823a1f9513fd6d07cf2"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "67aa64c60ad24a8889e1f34a44ac4fc9"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "using flash_attention_2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "24913f7d4e384e9ba15136053fd6dce2"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "`torch_dtype` is deprecated! Use `dtype` instead!\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model.safetensors.index.json: 0.00B [00:00, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "cdbad6d0d7d74f04bef0a0bef72bc049"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "01e5c0e01f384a29a3f98c7fefc0dded"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3f5a11b214724194a8fa39bb355e2bbe"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "9a4a5417e4cc45ca8fa8f07e76fde0a8"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ae3cc799469b4bb9b65702e3bb81311e"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "17714a5200ad439689748a4b8a9e73a6"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "trainable params: 4,456,448 || all params: 3,825,536,000 || trainable%: 0.1165\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\n# 0) Optional: reduce fragmentation (set *before* importing torch in a fresh runtime)\\nimport os\\nos.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True,max_split_size_mb:256\")\\n\\n# 1) Clean any old models from VRAM\\n\\n# hard_cleanup(\"baseline_model\", \"trainable_model\", \"embedder\")\\n\\n# 3) Load FROZEN baseline model for inference/paraphrasing\\nBASELINE_MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\\n# If VRAM is tight, leave some headroom and allow offload\\nmax_mem = {0: \"13GiB\", \"cpu\": \"48GiB\"} if torch.cuda.is_available() else None\\ntokenizer, baseline_model = load_baseline_model(\\n    BASELINE_MODEL_NAME,\\n    use_4bit=False,               # can set True to save VRAM\\n    device_map=BASELINE_DEVICE_MAP,\\n    max_memory=max_mem,           # optional\\n    allow_offload_folder=\"./offload\"  # optional\\n)\\n\\n# 4) Embedder (keep on CPU unless you really need GPU)\\nfrom sentence_transformers import SentenceTransformer\\nembedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\\n# embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 11
        }
      ],
      "source": [
        "# 2) Load fresh TRAINABLE (QLoRA) model\n",
        "\n",
        "\n",
        "\n",
        "# --- Assumes your helper builds a 4-bit base + attaches LoRA with the same hyperparams ---\n",
        "# (use exactly the same args you trained with)\n",
        "tokenizer, trainable_model = load_trainable_model(\n",
        "    TRAINABLE_MODEL_NAME,\n",
        "    use_4bit=True,           # QLoRA default\n",
        "    lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "\n",
        "\n",
        "'''\n",
        "import copy\n",
        "baseline_model = copy.deepcopy(trainable_model).eval()\n",
        "for p in baseline_model.parameters():\n",
        "    p.requires_grad_(False)\n",
        "'''\n",
        "# from sentence_transformers import SentenceTransformer\n",
        "# embedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\n",
        "# # embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\n",
        "\n",
        "\n",
        "\n",
        "'''\n",
        "# 0) Optional: reduce fragmentation (set *before* importing torch in a fresh runtime)\n",
        "import os\n",
        "os.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"expandable_segments:True,max_split_size_mb:256\")\n",
        "\n",
        "# 1) Clean any old models from VRAM\n",
        "\n",
        "# hard_cleanup(\"baseline_model\", \"trainable_model\", \"embedder\")\n",
        "\n",
        "# 3) Load FROZEN baseline model for inference/paraphrasing\n",
        "BASELINE_MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
        "# If VRAM is tight, leave some headroom and allow offload\n",
        "max_mem = {0: \"13GiB\", \"cpu\": \"48GiB\"} if torch.cuda.is_available() else None\n",
        "tokenizer, baseline_model = load_baseline_model(\n",
        "    BASELINE_MODEL_NAME,\n",
        "    use_4bit=False,               # can set True to save VRAM\n",
        "    device_map=BASELINE_DEVICE_MAP,\n",
        "    max_memory=max_mem,           # optional\n",
        "    allow_offload_folder=\"./offload\"  # optional\n",
        ")\n",
        "\n",
        "# 4) Embedder (keep on CPU unless you really need GPU)\n",
        "from sentence_transformers import SentenceTransformer\n",
        "embedder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")  # CPU is fine\n",
        "# embedder.to(\"cuda\")  # only if you have VRAM headroom and big batches\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "tKvuvb_gBq3k",
      "metadata": {
        "id": "tKvuvb_gBq3k"
      },
      "outputs": [],
      "source": [
        "\n",
        "if CKPT_LOAD_PATH:\n",
        "    # Now load the trained adapter weights into this PEFT model\n",
        "    ADAPTER_DIR = CKPT_LOAD_PATH  # folder containing adapter_config.json\n",
        "    print(\"loading from: \", ADAPTER_DIR)\n",
        "    try:\n",
        "        # Preferred: load adapter into the existing PeftModel\n",
        "        trainable_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\n",
        "        trainable_model.set_adapter(\"default\")\n",
        "    except AttributeError:\n",
        "        # Fallback for older PEFT versions: load the adapter state dict directly\n",
        "        from peft.utils import set_peft_model_state_dict\n",
        "        import torch, os\n",
        "        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\n",
        "                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\n",
        "        set_peft_model_state_dict(trainable_model, state)\n",
        "        # make sure only LoRA params are trainable if you plan to keep training\n",
        "        for n, p in trainable_model.named_parameters():\n",
        "            p.requires_grad = (\"lora_\" in n)\n",
        "\n",
        "    trainable_model.train()  # if you’re resuming training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "glPKBu0Uo3sX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "id": "glPKBu0Uo3sX",
        "outputId": "eabd011e-8af5-4bd7-e074-fce6295bb488"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\n_, baseline_model = load_trainable_model(\\n    TRAINABLE_MODEL_NAME,\\n    use_4bit=True,           # QLoRA default\\n    lora_r=8, lora_alpha=16, lora_dropout=0.05,\\n    device_map=\"auto\",\\n)\\n\\nfor p in baseline_model.parameters():\\n     p.requires_grad_(False)\\n'"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "'''\n",
        "_, baseline_model = load_trainable_model(\n",
        "    TRAINABLE_MODEL_NAME,\n",
        "    use_4bit=True,           # QLoRA default\n",
        "    lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "\n",
        "for p in baseline_model.parameters():\n",
        "     p.requires_grad_(False)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B6OVE7wD_awp",
      "metadata": {
        "id": "B6OVE7wD_awp"
      },
      "outputs": [],
      "source": [
        "# BASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\n",
        "\n",
        "\n",
        "# _, baseline_model = load_trainable_model(\n",
        "#     TRAINABLE_MODEL_NAME,\n",
        "#     use_4bit=True,           # QLoRA default\n",
        "#     lora_r=8, lora_alpha=16, lora_dropout=0.05,\n",
        "#     device_map=\"auto\",\n",
        "# )\n",
        "\n",
        "\n",
        "# baseline_model = AutoModelForCausalLM.from_pretrained(\n",
        "#         BASELINE_PATH, device_map=\"auto\", quantization_config=bnb_config\n",
        "#     )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "UNCMXFLbCJ6F",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 209
        },
        "id": "UNCMXFLbCJ6F",
        "outputId": "f2d7a387-1d54-4972-e7c2-7ceca727bcc2"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'\\nBASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\\nif BASELINE_PATH:\\n    # Now load the baseline adapter weights into this PEFT model\\n    ADAPTER_DIR = BASELINE_PATH # folder containing adapter_config.json\\n    try:\\n        # Preferred: load adapter into the existing PeftModel\\n        baseline_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\\n        baseline_model.set_adapter(\"default\")\\n    except AttributeError:\\n        # Fallback for older PEFT versions: load the adapter state dict directly\\n        from peft.utils import set_peft_model_state_dict\\n        import torch, os\\n        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\\n                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\\n        set_peft_model_state_dict(baseline_model, state)\\n        # make sure only LoRA params are trainable if you plan to keep training\\n        for n, p in baseline_model.named_parameters():\\n            p.requires_grad = (\"lora_\" in n)\\n\\n    baseline_model.eval()\\n'"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "'''\n",
        "BASELINE_PATH = \"/content/drive/MyDrive/coherence/Qwen/Qwen2.5-3B-Instruct/checkpoint600-lambda-1-baseline/\"\n",
        "if BASELINE_PATH:\n",
        "    # Now load the baseline adapter weights into this PEFT model\n",
        "    ADAPTER_DIR = BASELINE_PATH # folder containing adapter_config.json\n",
        "    try:\n",
        "        # Preferred: load adapter into the existing PeftModel\n",
        "        baseline_model.load_adapter(ADAPTER_DIR, adapter_name=\"default\")\n",
        "        baseline_model.set_adapter(\"default\")\n",
        "    except AttributeError:\n",
        "        # Fallback for older PEFT versions: load the adapter state dict directly\n",
        "        from peft.utils import set_peft_model_state_dict\n",
        "        import torch, os\n",
        "        state = torch.load(os.path.join(ADAPTER_DIR, \"adapter_model.safetensors\"),\n",
        "                          map_location=\"cpu\")  # if safetensors, use safetensors.load_file\n",
        "        set_peft_model_state_dict(baseline_model, state)\n",
        "        # make sure only LoRA params are trainable if you plan to keep training\n",
        "        for n, p in baseline_model.named_parameters():\n",
        "            p.requires_grad = (\"lora_\" in n)\n",
        "\n",
        "    baseline_model.eval()\n",
        "'''"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ChcYKgQMmNjI",
      "metadata": {
        "id": "ChcYKgQMmNjI"
      },
      "source": [
        "## 4) Load GSM8K and build paraphrase classes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "CKweB9VZLnN2",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "referenced_widgets": [
            "daa54c05faeb42f8b108905ae30eb5bb",
            "a3530137220b42f4a5b40b12624580ba",
            "a7838f19c8714080b81480dfc3f2d243",
            "f258d048c22944588894c61b61fecc9f",
            "5c58d0fe6fb34b7d8f05f0c1ceadb35c",
            "1014f966098e4c518c830f756c929d65",
            "82f9f1c61ab3419ba812e25c6c0ce6f6",
            "3dfaf9bfe0604969b7a3677f73b753b7",
            "b185c9f868fc419bb8f1dc1c0d44a465",
            "ce4b11394235491f9c6f073011a8ba94",
            "227cf2d26b2f42d8a7acbb6e5547b728",
            "072017f5299d4c85a4f54fbe34ae344f",
            "056cded851da46c686fcf32e109d8b94",
            "4904aa32cf67436e8d54d37dea0f0762",
            "dcc2d22e0aea4d18aac0b60c5d325fc2",
            "55eed6957fec4dce8da7cf8adaf48597",
            "092624ebb29f436894e4396f03953e86",
            "b727e50e2d1b4d21a05164530f7053f1",
            "b980c155b06c49adad63fe83fb36a373",
            "94afaa6eecbd480cb6b149dff4e4f80c",
            "8511a603a53543dab3f567a9c172b652",
            "3485568498df43de8ee14437f61b3cd6",
            "8ab15221c29c4b26a72f791d2f16fa41",
            "2e4112f36f7a4863b34423490f4d9139",
            "3a6b4d6a40644ac086a090d7a046a9b1",
            "712db2e1b43446e29131d92dbc7b1a81",
            "b756871eb62d4f7bbd098ee10a922a39",
            "5c7f29b0d5e842b18fb868a6e25375b6",
            "0fa62fd72be942f9976482d33c127229",
            "f4a02ee3abd8429fafd72fa5cd469e4d",
            "9c2e944601b641d5bbbd8d17818530bb",
            "21465763a3f84881be0d58c5b48db247",
            "b085867817fd451dad0f099c0d99f2b2",
            "cda513e07ffc47ffbe6d39b7db49a833",
            "ae53cb72ec024ebf8d0358bc2cc77fdb",
            "d6de9cd36dea474a9fd61076e797dc50",
            "63de55d0846b4237969336a4f5355712",
            "cc1ae692118b421ab11c7194a95b048f",
            "76903d1a10a249d785a889f0be761858",
            "3d6a5e6a18e14113aa8512420d38045a",
            "f880abbd4444494c9c41120297951364",
            "c4159664e6bd439e8c3a03e80a14f122",
            "de0b9490927c4524b0a544f664779eb8",
            "38b494b59b8a4c459f726c96d0b36390",
            "c2beba96514345648cf43cb638583d52",
            "41501aac60d644a78c75863d3e2fa4c4",
            "138698b26b38420fb43a18c99f2a7a64",
            "05277fc8241f41ffbc57eab74ccc82c7",
            "41adc3e56fc94969a27a87556c279c0f",
            "cf494d2c2fdd46088506136d6bdf677c",
            "dd07772ab98342d5bfb7d770c81a8ebf",
            "0e16c268f4ef47e482ad5fa9ebc9185b",
            "2926d30ad85742e29a6f9d23df909c59",
            "123d4849dffa453b9f16447920072e1e",
            "621fca71bdab4fe9897429a325760829"
          ]
        },
        "id": "CKweB9VZLnN2",
        "outputId": "49694879-2274-4c48-ad2d-9241bdb8204c"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "README.md: 0.00B [00:00, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "daa54c05faeb42f8b108905ae30eb5bb"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "072017f5299d4c85a4f54fbe34ae344f"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "8ab15221c29c4b26a72f791d2f16fa41"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "cda513e07ffc47ffbe6d39b7db49a833"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "c2beba96514345648cf43cb638583d52"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "ds = load_dataset(\"gsm8k\", \"main\")\n",
        "train_ds = ds[\"train\"]\n",
        "# split = ds[\"train\"].train_test_split(test_size=0.05, seed=42)\n",
        "# train_ds = split[\"train\"]\n",
        "# val_ds   = split[\"test\"]   # your “validation”\n",
        "test_ds  = ds[\"test\"]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "yWg7PjfnRTx7",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yWg7PjfnRTx7",
        "outputId": "45b7d4e0-4acb-40bf-a221-ab55cd18854e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "DatasetDict({\n",
            "    train: Dataset({\n",
            "        features: ['question', 'answer'],\n",
            "        num_rows: 7473\n",
            "    })\n",
            "    test: Dataset({\n",
            "        features: ['question', 'answer'],\n",
            "        num_rows: 1319\n",
            "    })\n",
            "})\n"
          ]
        }
      ],
      "source": [
        "print (ds)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZyDraTzXfx9S",
      "metadata": {
        "id": "ZyDraTzXfx9S"
      },
      "outputs": [],
      "source": [
        "import unicodedata\n",
        "\n",
        "def to_ascii_numbers(text: str) -> str:\n",
        "    \"\"\"\n",
        "    Convert any Unicode numeric characters in `text` to ASCII digits.\n",
        "    - Handles full-width digits (e.g., １９３０ → 1930)\n",
        "    - Handles other numeric forms with integer values (e.g., ①, Ⅻ → 1, 12)\n",
        "    - Leaves everything else unchanged (no extra preprocessing)\n",
        "    \"\"\"\n",
        "    out_chars = []\n",
        "    for ch in text:\n",
        "        # 1) Decimal digits in any script (category 'Nd')\n",
        "        try:\n",
        "            d = unicodedata.digit(ch)    # 0..9\n",
        "            out_chars.append(str(d))\n",
        "            continue\n",
        "        except (TypeError, ValueError):\n",
        "            pass\n",
        "\n",
        "        # 2) Other Unicode numbers with an integer numeric value (e.g., circled digits, Roman numerals)\n",
        "        try:\n",
        "            n = unicodedata.numeric(ch)\n",
        "            if float(n).is_integer():\n",
        "                out_chars.append(str(int(n)))\n",
        "                continue\n",
        "        except (TypeError, ValueError):\n",
        "            pass\n",
        "\n",
        "        # 3) Otherwise, keep original character\n",
        "        out_chars.append(ch)\n",
        "    s = \"\".join(out_chars)\n",
        "    return s.removeprefix(\"Paraphrase: \")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "321VfgHiTxzV",
      "metadata": {
        "id": "321VfgHiTxzV"
      },
      "outputs": [],
      "source": [
        "# Speed knobs you can tune safely on a T4:\n",
        "PARA_PER_QUESTION = max(0, PARAPHRASES_PER_QUESTION - 1)  # e.g., 2\n",
        "BATCH_QUESTIONS = 64   # try 32–128 on T4; lower if you OOM\n",
        "MAX_NEW_TOKENS_PARAPHRASE = 128  # shorter is faster; you can try 16–32\n",
        "USE_GREEDY = True      # greedy is faster & cleaner for paraphrasing\n",
        "\n",
        "from tqdm.auto import tqdm\n",
        "import math, torch, numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZBryMmXepKoN",
      "metadata": {
        "id": "ZBryMmXepKoN"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import torch\n",
        "from transformers import StoppingCriteria, StoppingCriteriaList\n",
        "\n",
        "# no-op fallback if you don't have this util\n",
        "try:\n",
        "    to_ascii_numbers\n",
        "except NameError:\n",
        "    def to_ascii_numbers(s: str) -> str:\n",
        "        return s\n",
        "\n",
        "class StopOnTokenSeq(StoppingCriteria):\n",
        "    \"\"\"Stop when any stop string (tokenized) appears at the end of a row.\"\"\"\n",
        "    def __init__(self, tok, stop_strings):\n",
        "        self.tok = tok\n",
        "        self.stop_ids = [tok.encode(s, add_special_tokens=False) for s in stop_strings]\n",
        "        self.maxlen = max((len(x) for x in self.stop_ids), default=1)\n",
        "\n",
        "    def __call__(self, input_ids, scores, **kwargs):\n",
        "        for row in input_ids.tolist():\n",
        "            tail = row[-min(len(row), 6 * self.maxlen):]\n",
        "            for seq in self.stop_ids:\n",
        "                if len(tail) >= len(seq) and tail[-len(seq):] == seq:\n",
        "                    return True\n",
        "        return False\n",
        "\n",
        "\n",
        "def paraphrase_batch_chat(qs, k=1):\n",
        "    \"\"\"\n",
        "    Qwen2/2.5 chat-friendly batch paraphraser.\n",
        "    - Uses LEFT padding (required for decoder-only batched generation).\n",
        "    - Stops when the model starts the next user header.\n",
        "    - Decodes with specials skipped to avoid <|endoftext|> noise.\n",
        "    Returns: List[List[str]] of shape [len(qs)][k]\n",
        "    \"\"\"\n",
        "    assert k >= 1\n",
        "    sys = (\n",
        "      \"You paraphrase GSM8K math word problems.\\n\"\n",
        "      \"Rephrase without changing meaning.\\n\"\n",
        "      \"Preserve all names, quantities, units, and numbers exactly.\\n\"\n",
        "      \"Do NOT solve the problem.\\n\"\n",
        "      \"Output ONLY the paraphrased question, one line, no quotes, no prefix.\"\n",
        "    )\n",
        "\n",
        "    # --- Build prompts (once per question) ---\n",
        "    rendered = []\n",
        "    for q in qs:\n",
        "        msgs = [\n",
        "            {\"role\": \"system\", \"content\": sys},\n",
        "            {\"role\": \"user\", \"content\": f\"Paraphrase the following question.\\n{q}\"},\n",
        "        ]\n",
        "        rendered.append(tokenizer.apply_chat_template(\n",
        "            msgs, tokenize=False, add_generation_prompt=True\n",
        "        ))\n",
        "\n",
        "    # --- Encode with LEFT padding for decoder-only models ---\n",
        "    old_side = tokenizer.padding_side\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    try:\n",
        "        enc = tokenizer(rendered, return_tensors=\"pt\", padding=True, truncation=True).to(baseline_model.device)\n",
        "    finally:\n",
        "        tokenizer.padding_side = old_side\n",
        "\n",
        "    B = len(qs)\n",
        "    cut_idx = enc[\"input_ids\"].size(1)          # fixed padded width (same for all rows)\n",
        "    # For regrouping after num_return_sequences\n",
        "    # We don't need per-row prompt lengths when left-padding.\n",
        "\n",
        "    # --- Generation config (Qwen) ---\n",
        "    EOT_ID = tokenizer.eos_token_id\n",
        "    assert isinstance(EOT_ID, int) and EOT_ID >= 0, \"Could not resolve <|im_end|> id\"\n",
        "\n",
        "    # Respect external bad list, but make sure it doesn't block EOS\n",
        "    try:\n",
        "        bad_words_ids = bad if (isinstance(bad, list) and bad) else None\n",
        "        if bad_words_ids:\n",
        "            assert all(EOT_ID not in seq for seq in bad_words_ids), \"bad_words_ids blocks <|im_end|>\"\n",
        "    except NameError:\n",
        "        bad_words_ids = None\n",
        "\n",
        "    # # Stop if the model begins the next user message (Qwen header)\n",
        "    # stopping = StoppingCriteriaList([\n",
        "    #     StopOnTokenSeq(tokenizer, [\"<|im_start|>user\", \"\\nuser\"])\n",
        "    # ])\n",
        "\n",
        "    with torch.no_grad():\n",
        "        out = baseline_model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=MAX_NEW_TOKENS_PARAPHRASE,\n",
        "            min_new_tokens=0,                    # don't force past EOS\n",
        "            do_sample=False,\n",
        "            temperature=0.3,\n",
        "            top_p=0.3,\n",
        "            num_beams=1,\n",
        "            num_return_sequences=k,\n",
        "            bad_words_ids=bad_words_ids,\n",
        "            eos_token_id=EOT_ID,                 # Qwen chat EOT\n",
        "            pad_token_id=(tokenizer.pad_token_id or tokenizer.eos_token_id or EOT_ID),\n",
        "            # stopping_criteria=stopping,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "\n",
        "    # --- Slice off continuations (after full padded input) ---\n",
        "    seqs = out.sequences                                # [B*k, cut_idx + new_len]\n",
        "    cont_tokens = [seqs[row, cut_idx:] for row in range(B * k)]\n",
        "\n",
        "    # --- Decode: skip specials to hide <|endoftext|>/<|im_end|> ---\n",
        "    outs = tokenizer.batch_decode(cont_tokens, skip_special_tokens=True)\n",
        "\n",
        "    # --- Cleanup: drop a naked role header if it slipped through ---\n",
        "    ROLE_ONLY = re.compile(r'^(?:assistant|user)\\s*:?\\s*$', re.I)\n",
        "    def clean(s: str) -> str:\n",
        "        s = s.split(\"\\n\", 1)[0].strip()\n",
        "        if ROLE_ONLY.match(s):\n",
        "            s = \"\"\n",
        "        return to_ascii_numbers(s)\n",
        "\n",
        "    outs = [clean(o) for o in outs]                     # len = B*k\n",
        "\n",
        "    # Optional fallback: if empty, use original question text\n",
        "    per_q, ptr = [], 0\n",
        "    for q in qs:\n",
        "        group = outs[ptr:ptr + k]\n",
        "        group = [g if g else q for g in group]\n",
        "        per_q.append(group)\n",
        "        ptr += k\n",
        "\n",
        "    return per_q"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bPEoICWTndw_",
      "metadata": {
        "id": "bPEoICWTndw_"
      },
      "outputs": [],
      "source": [
        "# --- Answer generation for Qwen-style chat, batched & left-padded ---\n",
        "\n",
        "import re, math, torch\n",
        "from typing import List, Tuple\n",
        "\n",
        "# Compile once\n",
        "_ROLE_ONLY = re.compile(r'^(?:assistant|user)\\s*:?\\s*$', re.I)\n",
        "SYSTEM_PROMPT = (\n",
        "                  \"You solve grade-school math word problems.\\n\"\n",
        "                  \"Think step by step. Keep the reasoning brief.\\n\"\n",
        "                  \"Then repeat the final answer on a new line in exactly this format:\\n\"\n",
        "                  \"#### <number>\"\n",
        "                )\n",
        "\n",
        "def _clean_one_line(s: str) -> str:\n",
        "    s = s.split(\"\\n\", 1)[0].strip()\n",
        "    return \"\" if _ROLE_ONLY.match(s) else s\n",
        "\n",
        "def gsm8k_messages(question: str,\n",
        "                   *,\n",
        "                   system_prompt = SYSTEM_PROMPT):\n",
        "    return [\n",
        "        {\"role\": \"system\", \"content\": system_prompt},\n",
        "        {\n",
        "            \"role\": \"user\",\n",
        "            \"content\": f\"{question}\\n\\nSolve step by step, then give the final answer.\",\n",
        "        },\n",
        "    ]\n",
        "\n",
        "\n",
        "def format_prompt_instruct(q: str,\n",
        "                           *,\n",
        "                           system_prompt = SYSTEM_PROMPT,\n",
        "                           tokenizer=None) -> str:\n",
        "    \"\"\"\n",
        "    Render a GSM8K-style chat prompt string using gsm8k_messages().\n",
        "    \"\"\"\n",
        "    if tokenizer is None:\n",
        "        tokenizer = globals()[\"tokenizer\"]\n",
        "\n",
        "    messages = gsm8k_messages(q, system_prompt=system_prompt)\n",
        "\n",
        "    return tokenizer.apply_chat_template(\n",
        "        messages,\n",
        "        tokenize=False,\n",
        "        add_generation_prompt=True\n",
        "    )\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8jHm5ZuAHSaM",
      "metadata": {
        "id": "8jHm5ZuAHSaM"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "from typing import List, Tuple, Optional\n",
        "\n",
        "import torch\n",
        "\n",
        "\n",
        "def _extract_gsm8k_final(text: str) -> str:\n",
        "    \"\"\"\n",
        "    Find the last line that begins with ### (allow leading spaces),\n",
        "    and return the number immediately following ###.\n",
        "    If not found or no number after ###, return \"NA\".\n",
        "    \"\"\"\n",
        "    for line in reversed(text.splitlines()):\n",
        "        # number must immediately follow ### (allow spaces), then end or whitespace\n",
        "        m = re.match(r\"^\\s*####\\s*([-+]?\\d+(?:\\.\\d+)?)(?:\\s|$)\", line)\n",
        "        if m:\n",
        "            return m.group(1)\n",
        "        # if we see a ### line with no number, treat as missing and stop\n",
        "        if re.match(r\"^\\s*####(?:\\s|$)\", line):\n",
        "            return \"NA\"\n",
        "    return \"NA\"\n",
        "\n",
        "\n",
        "def answer_batch_chat(\n",
        "    qs: List[str],\n",
        "    *,\n",
        "    model=None,\n",
        "    tokenizer=None,\n",
        "    max_new_tokens: int = MAX_NEW_TOKENS,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        "    use_greedy: bool = False,\n",
        ") -> List[str]:\n",
        "    \"\"\"\n",
        "    Generate one answer per input question, in a single batched call.\n",
        "    Uses format_prompt_instruct() to render the chat prompt string.\n",
        "    \"\"\"\n",
        "    # Fallback to globals if not provided\n",
        "    if model is None:\n",
        "        model = baseline_model\n",
        "    if tokenizer is None:\n",
        "        tokenizer = globals()[\"tokenizer\"]\n",
        "\n",
        "    # --- Build rendered prompts using format_prompt_instruct ---\n",
        "    rendered = [\n",
        "        format_prompt_instruct(q, tokenizer=tokenizer)\n",
        "        for q in qs\n",
        "    ]\n",
        "\n",
        "    # --- Encode with LEFT padding ---\n",
        "    old_side = tokenizer.padding_side\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    try:\n",
        "        enc = tokenizer(\n",
        "            rendered,\n",
        "            return_tensors=\"pt\",\n",
        "            padding=True,\n",
        "            truncation=True,\n",
        "        ).to(next(model.parameters()).device)\n",
        "    finally:\n",
        "        tokenizer.padding_side = old_side\n",
        "\n",
        "    cut_idx = enc[\"input_ids\"].size(1)\n",
        "\n",
        "    # --- EOS / PAD handling ---\n",
        "    eos_id = tokenizer.eos_token_id\n",
        "    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id\n",
        "\n",
        "    # --- Generation ---\n",
        "    with torch.inference_mode():\n",
        "        out = model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            min_new_tokens=0,\n",
        "            do_sample=not use_greedy,\n",
        "            temperature=(temperature if not use_greedy else None),\n",
        "            top_p=(top_p if not use_greedy else None),\n",
        "            num_beams=1,\n",
        "            num_return_sequences=1,\n",
        "            eos_token_id=eos_id,\n",
        "            pad_token_id=pad_id,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "\n",
        "    # --- Slice off the prompt and decode just the continuation ---\n",
        "    seqs = out.sequences\n",
        "    cont = [seqs[i, cut_idx:] for i in range(len(qs))]\n",
        "    outs = tokenizer.batch_decode(cont, skip_special_tokens=True)\n",
        "    # --- Cleanup: prefer GSM8K final answer if present, else one-line fallback ---\n",
        "    cleaned = []\n",
        "    for o in outs:\n",
        "        s = o.strip()\n",
        "\n",
        "        # If you're using GSM8K prompting, this will pull the final numeric answer.\n",
        "        gsm = _extract_gsm8k_final(s)\n",
        "        if gsm is not None:\n",
        "            cleaned.append(gsm)\n",
        "            continue\n",
        "    return cleaned\n",
        "\n",
        "\n",
        "def answer_originals_and_paraphrases(\n",
        "    qs: List[str],\n",
        "    paras_per_q: List[List[str]],\n",
        "    *,\n",
        "    micro_bs: int = 64,\n",
        "    model=None,\n",
        "    tokenizer=None,\n",
        "    max_new_tokens: int = MAX_NEW_TOKENS,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        "    use_greedy: bool = False,\n",
        ") -> Tuple[List[str], List[List[str]]]:\n",
        "    \"\"\"\n",
        "    Answers each original question and each of its paraphrases.\n",
        "    Returns (answers_for_originals, answers_for_paraphrases)\n",
        "    where answers_for_paraphrases[i] aligns with paras_per_q[i].\n",
        "    Uses micro-batching to avoid OOM.\n",
        "    \"\"\"\n",
        "\n",
        "    def _batched_answer(batch: List[str]) -> List[str]:\n",
        "        return answer_batch_chat(\n",
        "            batch,\n",
        "            model=model,\n",
        "            tokenizer=tokenizer,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            temperature=temperature,\n",
        "            top_p=top_p,\n",
        "            use_greedy=use_greedy,\n",
        "        )\n",
        "\n",
        "    # 1) Answers for originals\n",
        "    ans_orig: List[str] = []\n",
        "    for i in range(0, len(qs), micro_bs):\n",
        "        ans_orig.extend(_batched_answer(qs[i : i + micro_bs]))\n",
        "\n",
        "    # 2) Answers for all paraphrases (flatten, answer, then regroup)\n",
        "    flat_paras = [p for group in paras_per_q for p in group]\n",
        "    ans_paras_flat: List[str] = []\n",
        "    for i in range(0, len(flat_paras), micro_bs):\n",
        "        ans_paras_flat.extend(_batched_answer(flat_paras[i : i + micro_bs]))\n",
        "\n",
        "    # Regroup by question\n",
        "    ans_paras: List[List[str]] = []\n",
        "    ptr = 0\n",
        "    for group in paras_per_q:\n",
        "        n = len(group)\n",
        "        ans_paras.append(ans_paras_flat[ptr : ptr + n])\n",
        "        ptr += n\n",
        "\n",
        "    return ans_orig, ans_paras\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6hWe785vn5Uw",
      "metadata": {
        "id": "6hWe785vn5Uw"
      },
      "outputs": [],
      "source": [
        "def _to_text(x):\n",
        "    # Accept str, list[str], dict-like {\"text\": \"...\"} / {\"answer\": \"...\"} / [{\"text\": ...}]\n",
        "    if x is None:\n",
        "        return \"\"\n",
        "    if isinstance(x, str):\n",
        "        return x\n",
        "    if isinstance(x, (int, float)):\n",
        "        return str(x)\n",
        "    if isinstance(x, list):\n",
        "        # join multiple references; tweak as you like\n",
        "        return \"; \".join(_to_text(xx) for xx in x if xx is not None)\n",
        "    if isinstance(x, dict):\n",
        "        for k in (\"text\", \"answer\", \"final\", \"final_answer\", \"label\", \"gold\", \"target\"):\n",
        "            if k in x:\n",
        "                return _to_text(x[k])\n",
        "        # fallback: stringify dict\n",
        "        return str(x)\n",
        "    return str(x)\n",
        "\n",
        "def fetch_ground_truths(train_ds, idxs):\n",
        "    gts = []\n",
        "    for i in idxs:\n",
        "        rec = train_ds[i]\n",
        "        # Try common fields in order\n",
        "        val = None\n",
        "        for key in (\"answer\", \"answers\", \"final_answer\", \"final\", \"label\", \"target\"):\n",
        "            if key in rec:\n",
        "                val = rec[key]; break\n",
        "        gts.append(_to_text(val))\n",
        "    return gts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "t_pao7CWT7Kg",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "t_pao7CWT7Kg",
        "outputId": "874347be-ffae-4aaf-ade1-679f05737e4a"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\nPARAPHRASES_PER_QUESTION_tmp = 2\\nk_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\\n\\n# Keep the indices so we can fetch ground truths\\nqs = []\\nidxs = [0]\\nfor idx in idxs:\\n  qs.append(train_ds[idx][\"question\"])\\n\\nparas_per_q = paraphrase_batch_chat(qs, k=k_tmp)\\n\\n# --- NEW: get model answers ---\\nanswers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\\n\\n# --- NEW: fetch ground-truths for the originals ---\\ngt_for_qs = fetch_ground_truths(train_ds, idxs)\\n\\n# ---- Print nicely ----\\nprint(\"=== Test inputs ===\")\\nfor i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\\n    print(f\"[Q{i} @ idx={idx}] {q}\")\\n    print(f\"    Ground truth: {gt}\")\\n\\nprint(\"\\n=== Paraphrases + Answers ===\")\\nfor i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\\n    print(f\"[Q{i}]\")\\n    print(f\"  Original answer: {a_q}\")\\n    print(f\"  Ground truth   : {gt}\")\\n    if not paras:\\n        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\\n    else:\\n        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\\n            print(f\"  {j}. {p}\")\\n            print(f\"     ↳ Answer: {ap}\")\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 20
        }
      ],
      "source": [
        "# --- your snippet (kept) ---\n",
        "'''\n",
        "PARAPHRASES_PER_QUESTION_tmp = 2\n",
        "k_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\n",
        "\n",
        "# Keep the indices so we can fetch ground truths\n",
        "qs = []\n",
        "idxs = [0]\n",
        "for idx in idxs:\n",
        "  qs.append(train_ds[idx][\"question\"])\n",
        "\n",
        "paras_per_q = paraphrase_batch_chat(qs, k=k_tmp)\n",
        "\n",
        "# --- NEW: get model answers ---\n",
        "answers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\n",
        "\n",
        "# --- NEW: fetch ground-truths for the originals ---\n",
        "gt_for_qs = fetch_ground_truths(train_ds, idxs)\n",
        "\n",
        "# ---- Print nicely ----\n",
        "print(\"=== Test inputs ===\")\n",
        "for i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\n",
        "    print(f\"[Q{i} @ idx={idx}] {q}\")\n",
        "    print(f\"    Ground truth: {gt}\")\n",
        "\n",
        "print(\"\\n=== Paraphrases + Answers ===\")\n",
        "for i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\n",
        "    print(f\"[Q{i}]\")\n",
        "    print(f\"  Original answer: {a_q}\")\n",
        "    print(f\"  Ground truth   : {gt}\")\n",
        "    if not paras:\n",
        "        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\n",
        "    else:\n",
        "        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\n",
        "            print(f\"  {j}. {p}\")\n",
        "            print(f\"     ↳ Answer: {ap}\")\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "lejUKL25Sf3U",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lejUKL25Sf3U",
        "outputId": "68db188a-bc2c-4b52-982e-c4eff2d58a56"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CACHE_PATH:  /content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl\n"
          ]
        }
      ],
      "source": [
        "import os, json, time, hashlib, pathlib\n",
        "\n",
        "\n",
        "print(\"CACHE_PATH: \", CACHE_PATH)\n",
        "\n",
        "import os, time\n",
        "\n",
        "from pathlib import Path\n",
        "def _ensure_parent_dir(path: str | Path) -> Path:\n",
        "    p = Path(path)\n",
        "    p.parent.mkdir(parents=True, exist_ok=True)\n",
        "    print (\"Making directory: \", path)\n",
        "    return p\n",
        "\n",
        "\n",
        "def save_paraphrase_cache(path, classes_all, classes_usable, meta: dict):\n",
        "    path = _ensure_parent_dir(path)                    # NEW: ensure folder exists\n",
        "    meta = dict(meta)\n",
        "    meta[\"created_utc\"] = int(time.time())\n",
        "    meta[\"hash\"] = _meta_hash(**meta)\n",
        "\n",
        "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta}) + \"\\n\")\n",
        "        for i, cls in enumerate(classes_all):\n",
        "            rec = {\"idx\": i, \"size\": len(cls), \"prompts\": cls, \"usable\": int(len(cls) >= 2)}\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    print(f\"Saved cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "\n",
        "\n",
        "import uuid, os\n",
        "\n",
        "def save_paraphrase_cache_atomic(path, classes_all, classes_usable, meta):\n",
        "    p = _ensure_parent_dir(path)                       # NEW\n",
        "    tmp = p.with_name(p.name + f\".tmp-{uuid.uuid4().hex}\")  # same dir → safe rename\n",
        "    save_paraphrase_cache(tmp, classes_all, classes_usable, meta)\n",
        "    os.replace(tmp, p)                                 # atomic on POSIX\n",
        "\n",
        "\n",
        "def meta_compatible(current_meta: dict, cached_meta: dict | None) -> bool:\n",
        "    if not cached_meta:\n",
        "        return False\n",
        "    # Be strict about the knobs that affect outputs\n",
        "    keys = [\n",
        "        \"train_len\", \"per_class\", \"paraphraser_model\",\n",
        "        \"top_p\", \"top_k\", \"temperature\", \"max_new_tokens\",\n",
        "        \"dedup_thresh\", \"embedder\",\n",
        "    ]\n",
        "    return all(current_meta.get(k) == cached_meta.get(k) for k in keys)\n",
        "\n",
        "\n",
        "def _meta_hash(**kw):\n",
        "    # tie cache to settings that affect paraphrases\n",
        "    key = json.dumps(kw, sort_keys=True).encode()\n",
        "    return hashlib.sha1(key).hexdigest()[:10]\n",
        "\n",
        "def save_paraphrase_cache(path, classes_all, classes_usable, meta: dict):\n",
        "    path = pathlib.Path(path)\n",
        "    meta = dict(meta)\n",
        "    meta[\"created_utc\"] = int(time.time())\n",
        "    meta[\"hash\"] = _meta_hash(**meta)\n",
        "\n",
        "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta}) + \"\\n\")\n",
        "        # classes_all is a list[list[str]]\n",
        "        for i, cls in enumerate(classes_all):\n",
        "            rec = {\n",
        "                \"idx\": i,\n",
        "                \"size\": len(cls),\n",
        "                \"prompts\": cls,\n",
        "                \"usable\": int(len(cls) >= 2)\n",
        "            }\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    print(f\"Saved cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "\n",
        "def load_paraphrase_cache(path, require_hash: str | None = None):\n",
        "    path = pathlib.Path(path)\n",
        "    if not path.exists():\n",
        "        return None, None, None\n",
        "    classes_all, classes_usable = [], []\n",
        "    meta = None\n",
        "    with path.open(\"r\", encoding=\"utf-8\") as f:\n",
        "        for line in f:\n",
        "            obj = json.loads(line)\n",
        "            if \"__meta__\" in obj:\n",
        "                meta = obj[\"__meta__\"]\n",
        "                continue\n",
        "            cls = obj[\"prompts\"]\n",
        "            classes_all.append(cls)\n",
        "            if obj.get(\"usable\") and len(cls) >= 2:\n",
        "                classes_usable.append(cls)\n",
        "    if require_hash and meta and meta.get(\"hash\") != require_hash:\n",
        "        print(\"⚠️ Cache meta hash mismatch; rebuilding may be safer.\")\n",
        "    print(f\"Loaded cache: {path}  ({len(classes_all)} classes, usable={len(classes_usable)})\")\n",
        "    return classes_all, classes_usable, meta"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B07v0ZOKuI3F",
      "metadata": {
        "id": "B07v0ZOKuI3F"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "def build_classes_full_batched(\n",
        "    per_class=PARAPHRASES_PER_QUESTION,\n",
        "    dedup_thresh=0.98,\n",
        "    do_dedup=False,   # <- new flag\n",
        "    cache_path: str | None = None,\n",
        "    flush_every: int = 1000,\n",
        "    generate_fn=paraphrase_batch_chat,\n",
        "):\n",
        "    \"\"\"\n",
        "    Process the ENTIRE train_ds using big, batched generation calls.\n",
        "    Periodically checkpoints to `cache_path` and resumes if compatible.\n",
        "    Returns (classes_all, classes_usable).\n",
        "    \"\"\"\n",
        "    k = max(0, per_class - 1)\n",
        "    n_total = len(train_ds)\n",
        "\n",
        "    # --- try to resume from a compatible cache ---\n",
        "    start_at = 0\n",
        "    classes_all, classes_usable = [], []\n",
        "    cached = (None, None, None)\n",
        "\n",
        "    if cache_path:\n",
        "        try:\n",
        "            cached = load_paraphrase_cache(cache_path)\n",
        "        except Exception:\n",
        "            cached = (None, None, None)\n",
        "\n",
        "    cached_all, cached_usable, cached_meta = cached\n",
        "    if (\n",
        "        cache_path\n",
        "        and cached_all is not None\n",
        "        and 0 < len(cached_all) <= n_total\n",
        "        and meta_compatible(META, cached_meta)\n",
        "    ):\n",
        "        classes_all = list(cached_all)\n",
        "        classes_usable = list(cached_usable or [])\n",
        "        start_at = len(classes_all)\n",
        "        print(f\"[Resume] Compatible cache found: {start_at}/{n_total} rows; resuming…\")\n",
        "    else:\n",
        "        if cache_path:\n",
        "            print(\"[Resume] No compatible partial cache; starting fresh.\")\n",
        "\n",
        "    # --- Put embedder on GPU if possible ---\n",
        "    try:\n",
        "        embedder.to(\"cuda\")\n",
        "    except Exception:\n",
        "        pass\n",
        "\n",
        "    pbar = tqdm(total=n_total, desc=\"Building (batched) classes\", leave=True)\n",
        "    pbar.update(start_at)\n",
        "    since_last_flush = 0\n",
        "\n",
        "    try:\n",
        "        for start in range(start_at, n_total, BATCH_QUESTIONS):\n",
        "            end = min(start + BATCH_QUESTIONS, n_total)\n",
        "            batch = train_ds[start:end]\n",
        "            qs = batch[\"question\"] # <-- list of questions\n",
        "            # Fast column slice (HF datasets return list for a column slice)\n",
        "            # qs = train_ds[\"question\"][start:end]\n",
        "\n",
        "            # ---- fast batched paraphrase generation ----\n",
        "            if k > 0:\n",
        "                paras_per_q = generate_fn(qs, k=k)\n",
        "            else:\n",
        "                paras_per_q = [[] for _ in qs]\n",
        "\n",
        "            # ---- dedup per question (batched embeddings for speed) ----\n",
        "            for i, q in enumerate(qs):\n",
        "                paras = [q] + paras_per_q[i]\n",
        "\n",
        "                if do_dedup:\n",
        "                    # evaluate paraphrases by cosine value in embedding space. Discard those that are not novel enough.\n",
        "                    embs = embedder.encode(\n",
        "                        paras, convert_to_tensor=True,\n",
        "                        normalize_embeddings=True, batch_size=64\n",
        "                    )\n",
        "\n",
        "\n",
        "                    # evaluate paraphrases by cosine value in embedding space. Discard those that are not novel enough.\n",
        "                    keep = [0]\n",
        "                    for j in range(1, len(paras)):\n",
        "                        sims = sbert_util.cos_sim(embs[j], embs[keep]).cpu().numpy().flatten()\n",
        "                        if float(np.max(sims)) < dedup_thresh:\n",
        "                            keep.append(j)\n",
        "\n",
        "                    cls = [paras[j] for j in keep]\n",
        "                else:\n",
        "                    cls = paras[:]  # original + all paraphrases, in order, no discard\n",
        "\n",
        "                classes_all.append(cls)\n",
        "                if len(cls) >= 2:\n",
        "                    classes_usable.append(cls)\n",
        "\n",
        "            processed = end - start\n",
        "            since_last_flush += processed\n",
        "            pbar.update(processed)\n",
        "            pbar.set_postfix({\"usable\": len(classes_usable)})\n",
        "\n",
        "            # ---- periodic checkpoint ----\n",
        "            if cache_path and since_last_flush >= flush_every:\n",
        "                save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "                since_last_flush = 0\n",
        "\n",
        "        # final save\n",
        "        if cache_path:\n",
        "            save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "\n",
        "    except KeyboardInterrupt:\n",
        "        if cache_path:\n",
        "            print(\"\\n[Interrupt] Saving partial progress…\")\n",
        "            save_paraphrase_cache_atomic(cache_path, classes_all, classes_usable, META)\n",
        "        raise\n",
        "    finally:\n",
        "        pbar.close()\n",
        "\n",
        "    return classes_all, classes_usable"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QM6pwSm3TAyW",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QM6pwSm3TAyW",
        "outputId": "a572db7f-ff87-45ad-cfe6-bc7e6c4f889a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loaded cache: /content/drive/MyDrive/coherence-GSM8K/Qwen/Qwen2.5-3B-Instruct/paraphrase.jsonl  (7473 classes, usable=7473)\n",
            "Total classes: 7473; usable: 7473\n"
          ]
        }
      ],
      "source": [
        "# Describe the paraphrase-generation settings so caches are consistent\n",
        "META = {\n",
        "    \"train_len\": len(train_ds),\n",
        "    \"per_class\": PARAPHRASES_PER_QUESTION,\n",
        "    \"paraphraser_model\": BASELINE_MODEL_NAME,\n",
        "    \"top_p\": TOP_P, \"top_k\": TOP_K, \"temperature\": GEN_TEMPERATURE,\n",
        "    \"max_new_tokens\": 48,\n",
        "    \"dedup_thresh\": 0.98,\n",
        "    \"embedder\": \"sentence-transformers/all-MiniLM-L6-v2\",\n",
        "}\n",
        "\n",
        "# Try to load cache first\n",
        "paraphrase_classes_all, paraphrase_classes, meta = load_paraphrase_cache(CACHE_PATH)\n",
        "need_build = paraphrase_classes_all is None or len(paraphrase_classes_all) != len(train_ds)\n",
        "\n",
        "if need_build:\n",
        "    print(\"No valid cache found — building now…\")\n",
        "    _ensure_parent_dir(CACHE_PATH)\n",
        "    paraphrase_classes_all, paraphrase_classes = build_classes_full_batched(\n",
        "        per_class=PARAPHRASES_PER_QUESTION,\n",
        "        dedup_thresh=0.98,\n",
        "        cache_path=CACHE_PATH,      # enables resume & periodic saves\n",
        "        flush_every=1000,           # save every 1000 rows (tweak as you like)\n",
        "        generate_fn=paraphrase_batch_chat,\n",
        "    )\n",
        "else:\n",
        "    # Optional: check compatibility (hash) if you want strict matching\n",
        "    pass\n",
        "\n",
        "print(f\"Total classes: {len(paraphrase_classes_all)}; usable: {len(paraphrase_classes)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GSdAyKG47uP8",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GSdAyKG47uP8",
        "outputId": "e3e9eedd-0e22-484b-f7ad-8f8d12dc9e50"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\nPARAPHRASES_PER_QUESTION_tmp = 2\\nk_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\\n\\n# Keep the indices so we can fetch ground truths\\nqs = []\\nidxs = [1000, 1300]\\nparas_per_q = []\\nfor idx in idxs:\\n  qs.append(train_ds[idx][\"question\"])\\n  paras_per_q.append([paraphrase_classes[idx][1]])\\nprint(qs)\\nprint(paras_per_q)\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 24
        }
      ],
      "source": [
        "# --- your snippet (kept) ---\n",
        "'''\n",
        "PARAPHRASES_PER_QUESTION_tmp = 2\n",
        "k_tmp = max(0, PARAPHRASES_PER_QUESTION_tmp - 1)\n",
        "\n",
        "# Keep the indices so we can fetch ground truths\n",
        "qs = []\n",
        "idxs = [1000, 1300]\n",
        "paras_per_q = []\n",
        "for idx in idxs:\n",
        "  qs.append(train_ds[idx][\"question\"])\n",
        "  paras_per_q.append([paraphrase_classes[idx][1]])\n",
        "print(qs)\n",
        "print(paras_per_q)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "gdXynjcC9BGw",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gdXynjcC9BGw",
        "outputId": "66175f30-6d82-4d53-b411-aa42fa6dc5c3"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\n# --- NEW: get model answers ---\\nanswers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\\n\\n# --- NEW: fetch ground-truths for the originals ---\\ngt_for_qs = fetch_ground_truths(train_ds, idxs)\\n\\n# ---- Print nicely ----\\nprint(\"=== Test inputs ===\")\\nfor i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\\n    print(f\"[Q{i} @ idx={idx}] {q}\")\\n    print(f\"    Ground truth: {gt}\")\\n\\nprint(\"\\n=== Paraphrases + Answers ===\")\\nfor i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\\n    print(f\"[Q{i}]\")\\n    print(f\"  Original answer: {a_q}\")\\n    print(f\"  Ground truth   : {gt}\")\\n    if not paras:\\n        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\\n    else:\\n        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\\n            print(f\"  {j}. {p}\")\\n            print(f\"     ↳ Answer: {ap}\")\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 25
        }
      ],
      "source": [
        "'''\n",
        "# --- NEW: get model answers ---\n",
        "answers_for_qs, answers_for_paras = answer_originals_and_paraphrases(qs, paras_per_q, micro_bs=64)\n",
        "\n",
        "# --- NEW: fetch ground-truths for the originals ---\n",
        "gt_for_qs = fetch_ground_truths(train_ds, idxs)\n",
        "\n",
        "# ---- Print nicely ----\n",
        "print(\"=== Test inputs ===\")\n",
        "for i, (idx, q, gt) in enumerate(zip(idxs, qs, gt_for_qs), 1):\n",
        "    print(f\"[Q{i} @ idx={idx}] {q}\")\n",
        "    print(f\"    Ground truth: {gt}\")\n",
        "\n",
        "print(\"\\n=== Paraphrases + Answers ===\")\n",
        "for i, (q, paras, a_q, a_paras, gt) in enumerate(zip(qs, paras_per_q, answers_for_qs, answers_for_paras, gt_for_qs), 1):\n",
        "    print(f\"[Q{i}]\")\n",
        "    print(f\"  Original answer: {a_q}\")\n",
        "    print(f\"  Ground truth   : {gt}\")\n",
        "    if not paras:\n",
        "        print(\"  (no paraphrases requested; set PARAPHRASES_PER_QUESTION >= 2)\")\n",
        "    else:\n",
        "        for j, (p, ap) in enumerate(zip(paras, a_paras), 1):\n",
        "            print(f\"  {j}. {p}\")\n",
        "            print(f\"     ↳ Answer: {ap}\")\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "SfSfLvyWQ9aB",
      "metadata": {
        "id": "SfSfLvyWQ9aB"
      },
      "outputs": [],
      "source": [
        "FLAG_DEBUG = True"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ywbhUshonDfJ",
      "metadata": {
        "id": "ywbhUshonDfJ"
      },
      "source": [
        "## Evaluations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "B5VVKuJrDVKo",
      "metadata": {
        "id": "B5VVKuJrDVKo"
      },
      "outputs": [],
      "source": [
        "import re, math\n",
        "import torch\n",
        "from typing import Dict, Any\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "\n",
        "def _wilson_ci(k: int, n: int, level: float = 0.95):\n",
        "    if n == 0:\n",
        "        return (0.0, 0.0)\n",
        "    z = 1.959963984540054  # 95%\n",
        "    phat = k / n\n",
        "    denom = 1 + (z*z)/n\n",
        "    center = (phat + (z*z)/(2*n)) / denom\n",
        "    half = (z / denom) * math.sqrt((phat*(1-phat) + (z*z)/(4*n)) / n)\n",
        "    return max(0.0, center - half), min(1.0, center + half)\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_gsm8k(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"eval\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,   # number of sample Q/pred/gold to print total\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Evaluate GSM8K with a tqdm progress bar.\n",
        "\n",
        "    Prints:\n",
        "      - tqdm progress bar over batches\n",
        "      - optional few sample outputs (total across run)\n",
        "      - final ACC + 95% CI\n",
        "\n",
        "    Returns result dict with preds + sample + ACC/CI.\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    N_total = len(dataset)\n",
        "    N = min(int(limit), N_total) if limit is not None else N_total\n",
        "\n",
        "    sampled_indices = random.sample(range(N_total), k=N)\n",
        "    qs, gold_list, preds = [], [], []\n",
        "\n",
        "    # Collect questions + golds\n",
        "    for i in sampled_indices:\n",
        "        ex = dataset[i]\n",
        "        qs.append(ex[\"question\"])\n",
        "        gold_list.append(_extract_gsm8k_final(ex.get(\"answer\", \"\")))\n",
        "\n",
        "    num_batches = (N + batch_size - 1) // batch_size\n",
        "    pbar = tqdm(total=num_batches, desc=f\"Evaluating {name}\", dynamic_ncols=True)\n",
        "\n",
        "    printed = 0\n",
        "    for b, start in enumerate(range(0, N, batch_size), start=1):\n",
        "        end = min(start + batch_size, N)\n",
        "        batch_qs = qs[start:end]\n",
        "\n",
        "        batch_preds = answer_batch_chat(\n",
        "            batch_qs,\n",
        "            model=model,\n",
        "            tokenizer=tokenizer,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            use_greedy=use_greedy,\n",
        "        )\n",
        "        preds.extend(batch_preds)\n",
        "\n",
        "        # Update running accuracy for the bar\n",
        "        # (Compute correctness only for this batch)\n",
        "        batch_correct = 0\n",
        "        for j, p in enumerate(batch_preds):\n",
        "            g = gold_list[start + j]\n",
        "            p = (p or \"\").strip()\n",
        "            g = (g or \"\").strip()\n",
        "            if p == g and p != \"NA\" and g != \"NA\":\n",
        "                batch_correct += 1\n",
        "\n",
        "        # Running totals (avoid O(N^2) by tracking separately)\n",
        "        # We'll compute final ACC at the end accurately anyway.\n",
        "        # Use displayed running accuracy as a convenience.\n",
        "        # Derive running correct by re-counting cheaply (still fine at GSM8K scale),\n",
        "        # but better track it incrementally:\n",
        "        if b == 1:\n",
        "            running_correct = batch_correct\n",
        "        else:\n",
        "            running_correct += batch_correct\n",
        "        running_n = len(preds)\n",
        "        running_acc = running_correct / running_n if running_n else 0.0\n",
        "\n",
        "        pbar.set_postfix({\n",
        "            \"N\": running_n,\n",
        "            \"ACC%\": f\"{running_acc*100:.2f}\",\n",
        "            \"greedy\": use_greedy,\n",
        "        })\n",
        "        pbar.update(1)\n",
        "\n",
        "        # Optional: print a few samples total across run\n",
        "        if show_samples > 0 and printed < show_samples:\n",
        "            k = min(show_samples - printed, len(batch_qs))\n",
        "            for j in range(k):\n",
        "                idx = start + j\n",
        "                q_preview = batch_qs[j].replace(\"\\n\", \" \")\n",
        "                if len(q_preview) > 160:\n",
        "                    q_preview = q_preview[:160] + \"...\"\n",
        "                pbar.write(f\"[{name}] idx={idx} Q: {q_preview}\")\n",
        "                pbar.write(f\"[{name}]        pred: {batch_preds[j]}\")\n",
        "                pbar.write(f\"[{name}]        gold: {gold_list[idx]}\")\n",
        "                printed += 1\n",
        "\n",
        "    pbar.close()\n",
        "\n",
        "    # Final exact-match accuracy\n",
        "    correct = 0\n",
        "    pred_na = 0\n",
        "\n",
        "    for p, g in zip(preds, gold_list):\n",
        "        p = (p or \"\").strip()\n",
        "        g = (g or \"\").strip()\n",
        "\n",
        "        if p == \"NA\":\n",
        "            pred_na += 1\n",
        "\n",
        "        if p == g and p != \"NA\" and g != \"NA\":\n",
        "            correct += 1\n",
        "\n",
        "    n = len(preds)\n",
        "    acc_mean = correct / n if n else 0.0\n",
        "    acc_lo, acc_hi = _wilson_ci(correct, n, level=0.95)\n",
        "\n",
        "    pred_na_frac = (pred_na / n) if n else 0.0\n",
        "\n",
        "    print(\n",
        "        f\"[{name}] N={n}  \"\n",
        "        f\"ACC: {acc_mean*100:.2f} and ACC confidence \"\n",
        "        f\"[{acc_lo*100:.2f}, {acc_hi*100:.2f}]  |  \"\n",
        "        f\"pred NA: {pred_na}/{n} ({pred_na_frac*100:.2f}%)  \"\n",
        "    )\n",
        "\n",
        "    model.train(was_training)\n",
        "    return {\n",
        "        \"preds\": preds,\n",
        "        \"sample\": {\n",
        "            \"indices\": sampled_indices,\n",
        "            \"questions\": qs,\n",
        "            \"golds\": gold_list,\n",
        "        },\n",
        "        \"N\": len(preds),\n",
        "        \"ACC\": acc_mean,\n",
        "        \"ACC_CI\": (acc_lo, acc_hi),\n",
        "        \"ACC_CI_Level\": 0.95,\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "O1w08WfY85lE",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 103,
          "referenced_widgets": [
            "630f910e53c848eb837d2b38607fa849",
            "9b19089b3ef840a2b0df283d9ae9abef",
            "9e2cb734212a4abe9ecf419d3e673223",
            "d0869f15627f48d89e1fff45de94aafd",
            "b3fdeedd557a4743beb7038c796ecc0f",
            "e550a8f4bf6e4a258ce3cdd39981e076",
            "82ce5afcd1ff42318c5caf3aaa03f753",
            "fa91884fa4bb45658e3c905c83b5a592",
            "ea6a755fa09d42fe9faaa9d756d3bf7f",
            "a5f645314be841a88da3af297c88f5ee",
            "116dc2b6474e4266a2a3aa71365627d5"
          ]
        },
        "id": "O1w08WfY85lE",
        "outputId": "d540398e-e2db-46b8-9f53-7428143b6140"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Evaluation limit: 1319\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "630f910e53c848eb837d2b38607fa849",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating phi3 greedy:   0%|          | 0/21 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[phi3 greedy] N=1319  ACC: 73.69 and ACC confidence [71.25, 76.00]\n"
          ]
        }
      ],
      "source": [
        "\n",
        "eval_limit = len(test_ds)\n",
        "print(\"Evaluation limit:\", eval_limit)\n",
        "\n",
        "res = evaluate_gsm8k(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,\n",
        "    limit=eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"phi3 greedy\",\n",
        "    use_greedy=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ypXVQeE_zA1V",
      "metadata": {
        "id": "ypXVQeE_zA1V"
      },
      "outputs": [],
      "source": [
        "# Source - https://stackoverflow.com/a\n",
        "# Posted by Ayo Reis\n",
        "# Retrieved 2025-12-30, License - CC BY-SA 4.0\n",
        "\n",
        "from google.colab import runtime\n",
        "runtime.unassign()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "BbHwE2vv1Xu5",
      "metadata": {
        "id": "BbHwE2vv1Xu5"
      },
      "source": [
        "### Best of N"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "C-BdL7ax1WXq",
      "metadata": {
        "id": "C-BdL7ax1WXq"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "from typing import Tuple\n",
        "\n",
        "def _mean_logprob_of_generated(\n",
        "    model, input_ids: torch.Tensor, attention_mask: torch.Tensor, input_len: int\n",
        ") -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Compute mean log-prob over generated tokens for each sequence in the batch.\n",
        "    input_ids: [B, T] full prompt+generation\n",
        "    attention_mask: [B, T]\n",
        "    input_len: scalar length of the prompt (assumed same across the batch)\n",
        "    Returns: [B] mean log-probability per sequence (higher is better).\n",
        "    \"\"\"\n",
        "    # Forward once to score tokens\n",
        "    out = model(input_ids=input_ids, attention_mask=attention_mask)\n",
        "    logits = out.logits  # [B, T, V]\n",
        "    logprobs = torch.log_softmax(logits, dim=-1)\n",
        "\n",
        "    # We want log p(y_t | x, y_<t>) for generated tokens.\n",
        "    # Token at position t (0-index) is predicted by logits at position t-1.\n",
        "    # So we gather from logprobs[:, :-1] using input_ids[:, 1:].\n",
        "    pred_logprobs = logprobs[:, :-1].gather(dim=-1, index=input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)  # [B, T-1]\n",
        "\n",
        "    B, T = input_ids.shape\n",
        "    gen_start = input_len  # first generated token index in input_ids\n",
        "    # Sequence lengths may differ due to early EOS; rely on attention_mask to know actual used tokens.\n",
        "    # We'll mask everything before gen_start and after the last real token.\n",
        "    valid_mask = attention_mask[:, 1:]  # [B, T-1] tokens that exist (shifted)\n",
        "    gen_mask = torch.zeros_like(valid_mask)\n",
        "    gen_mask[:, gen_start:] = 1  # mark generated region (shifted space)\n",
        "\n",
        "    mask = (valid_mask * gen_mask).to(pred_logprobs.dtype)  # [B, T-1]\n",
        "    # Sum and average over generated tokens actually present\n",
        "    token_counts = mask.sum(dim=-1).clamp_min(1.0)  # avoid /0\n",
        "    seq_logprob_sum = (pred_logprobs * mask).sum(dim=-1)  # [B]\n",
        "    mean_logprob = seq_logprob_sum / token_counts  # [B]\n",
        "    return mean_logprob  # higher is better (less negative)\n",
        "\n",
        "\n",
        "def _pick_first_line(text: str) -> str:\n",
        "    return text.strip().split(\"\\n\")[0].strip()\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def batched_generate_bestof(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    questions: List[str],\n",
        "    best_of: int = 8,\n",
        "    batch_size: int = 4,\n",
        "    max_new_tokens: int = 32,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        ") -> List[str]:\n",
        "    \"\"\"\n",
        "    Best-of-N sampling: for each question, draw `best_of` samples, score each by mean token log-prob,\n",
        "    and return the single highest-prob answer (first line) per question.\n",
        "    \"\"\"\n",
        "    assert best_of >= 1\n",
        "    model_device = next(model.parameters()).device\n",
        "    eos_id = tokenizer.eos_token_id\n",
        "    results = []\n",
        "\n",
        "    for i in range(0, len(questions), batch_size):\n",
        "        batch_q = questions[i:i+batch_size]\n",
        "        prompts = [format_prompt(q) for q in batch_q]\n",
        "        enc = tokenizer(prompts, padding=True, return_tensors=\"pt\").to(model_device)\n",
        "\n",
        "        input_len = enc[\"input_ids\"].shape[1]\n",
        "        # Sample best_of sequences per prompt in a single call\n",
        "        gen_out = model.generate(\n",
        "            **enc,\n",
        "            max_new_tokens=max_new_tokens,\n",
        "            do_sample=True,\n",
        "            temperature=temperature,\n",
        "            top_p=top_p,\n",
        "            num_return_sequences=best_of,\n",
        "            eos_token_id=eos_id,\n",
        "            pad_token_id=eos_id,\n",
        "            return_dict_in_generate=True,\n",
        "        )\n",
        "        # sequences: [B*best_of, input_len + gen_len<=]\n",
        "        seqs = gen_out.sequences  # token ids\n",
        "        # Build attention mask for the combined sequences\n",
        "        attn = (seqs != eos_id).long()\n",
        "        # But ensure prompt portion is always attended even if eos == pad\n",
        "        attn[:, :input_len] = 1\n",
        "\n",
        "        # Score each sample by mean log-prob over the generated continuation\n",
        "        mean_lp = _mean_logprob_of_generated(model, seqs, attn, input_len)  # [B*best_of]\n",
        "\n",
        "        # Group back per original example\n",
        "        B = len(batch_q)\n",
        "        seqs_per_q = seqs.view(B, best_of, -1)\n",
        "        mean_lp_per_q = mean_lp.view(B, best_of)\n",
        "\n",
        "        # Choose argmax per question\n",
        "        best_idx = torch.argmax(mean_lp_per_q, dim=1)  # [B]\n",
        "        # Decode just the generated part (then take first line)\n",
        "        for b in range(B):\n",
        "            chosen = seqs_per_q[b, best_idx[b].item(), :]\n",
        "            gen_only = chosen[input_len:]  # tokens after the prompt\n",
        "            text = tokenizer.decode(gen_only, skip_special_tokens=True)\n",
        "            results.append(_pick_first_line(text))\n",
        "\n",
        "    return results\n",
        "\n",
        "\n",
        "def evaluate_model_bestof(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 1000,\n",
        "    batch_size: int = 4,\n",
        "    max_new_tokens: int = 32,\n",
        "    best_of: int = 8,\n",
        "    name: str = \"model-bestofN\",\n",
        ") -> Dict:\n",
        "    \"\"\"\n",
        "    Evaluate with Best-of-N sampling (highest mean token log-prob per question).\n",
        "    Returns the same dict schema as `evaluate_model`.\n",
        "    \"\"\"\n",
        "    # Build eval slice\n",
        "    qs, gold_list = [], []\n",
        "    for ex in dataset:\n",
        "        qs.append(ex[\"question\"])\n",
        "        gold_list.append(canonical_answers(ex[\"answer\"]))\n",
        "        if len(qs) >= limit:\n",
        "            break\n",
        "\n",
        "    preds = batched_generate_bestof(\n",
        "        model,\n",
        "        tokenizer,\n",
        "        qs,\n",
        "        best_of=best_of,\n",
        "        batch_size=batch_size,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "    )\n",
        "    assert len(preds) == len(gold_list)\n",
        "\n",
        "    ems, f1s = [], []\n",
        "    for p, g in zip(preds, gold_list):\n",
        "        em, f1 = em_and_f1(p, g)\n",
        "        ems.append(em); f1s.append(f1)\n",
        "\n",
        "    em_mean = float(np.mean(ems)) * 100.0\n",
        "    f1_mean_raw = float(np.mean(f1s))\n",
        "    f1_mean = f1_mean_raw * 100.0\n",
        "\n",
        "    # 95% normal-approx CI for F1\n",
        "    n = len(f1s)\n",
        "    if n > 1:\n",
        "        z = 1.96\n",
        "        std = float(np.std(f1s, ddof=1))\n",
        "        se = std / np.sqrt(n)\n",
        "        f1_lo = max(0.0, (f1_mean_raw - z * se) * 100.0)\n",
        "        f1_hi = min(100.0, (f1_mean_raw + z * se) * 100.0)\n",
        "    else:\n",
        "        f1_lo = f1_mean\n",
        "        f1_hi = f1_mean\n",
        "\n",
        "    print(f\"[{name}] N={len(preds)}  Best-of={best_of}  EM: {em_mean:.2f}  F1: {f1_mean:.2f}  \"\n",
        "          f\"F1 95% CI: [{f1_lo:.2f}, {f1_hi:.2f}]\")\n",
        "\n",
        "    return {\n",
        "        \"N\": len(preds),\n",
        "        \"EM\": em_mean,\n",
        "        \"F1\": f1_mean,\n",
        "        \"F1_CI\": (f1_lo, f1_hi),\n",
        "        \"F1_CI_Level\": 0.95,\n",
        "        \"preds\": preds,\n",
        "        \"golds\": gold_list,\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GHtbvCx3rK87",
      "metadata": {
        "id": "GHtbvCx3rK87"
      },
      "outputs": [],
      "source": [
        "# Greedy baseline\n",
        "USE_GREEDY = True\n",
        "BATCH_SIZE = 64\n",
        "MAX_NEW_TOKENS = 24\n",
        "EVAL_LIMIT = len(val_ds)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "el9lZDDl1hSH",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "el9lZDDl1hSH"
      },
      "outputs": [],
      "source": [
        "\n",
        "res_base_greedy = evaluate_model(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_greedy\")\n",
        "\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline Greedy:\", res_base_greedy[\"preds\"][i])\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "RQTjmepQ-jiM",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "RQTjmepQ-jiM"
      },
      "outputs": [],
      "source": [
        "# Best-of-N (e.g., N=32)\n",
        "BATCH_SIZE = 2\n",
        "res_base_bestof = evaluate_model_bestof(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, best_of=24, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_bestof24\")\n",
        "\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline Best-of-N:\", res_base_bestof[\"preds\"][i])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "rWY0OHvH1xLP",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "rWY0OHvH1xLP"
      },
      "outputs": [],
      "source": [
        "# Greedy trainable\n",
        "# _ = evaluate_model(trainable_model, tokenizer, val_ds, limit=500, batch_size=8, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_greedy\")\n",
        "# trainable_model.eval()  # keep in eval for decoding\n",
        "\n",
        "# Choose evaluation budget (increase if you have time/compute)\n",
        "BATCH_SIZE = 64\n",
        "MAX_NEW_TOKENS = 24\n",
        "\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "trainable_model.eval()\n",
        "res_trained_greedy = evaluate_model(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_greedy\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained Greedy:\", res_trained_greedy[\"preds\"][i])\n",
        "\n",
        "\n",
        "# Best-of-N (e.g., N=8)\n",
        "# _ = evaluate_model_bestof(trainable_model, tokenizer, val_ds, limit=500, batch_size=4, best_of=8, name=\"trainable_bestof8\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VnX-WgxQ-BXZ",
      "metadata": {
        "id": "VnX-WgxQ-BXZ"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 2\n",
        "res_trained_bestof = evaluate_model_bestof(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, best_of=24, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_bestof24\")\n",
        "for i in range(5):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained Best-of-N:\", res_trained_bestof[\"preds\"][i])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Pb94eMijnT2q",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "Pb94eMijnT2q"
      },
      "outputs": [],
      "source": [
        "# baseline\n",
        "USE_GREEDY = False\n",
        "BATCH_SIZE = 64\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "MAX_NEW_TOKENS = 24\n",
        "\n",
        "EVAL_LIMIT = len(val_ds)\n",
        "trainable_model.eval()\n",
        "res_trained_greedy = evaluate_model(trainable_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"trainable_sampling\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Trained sampling:\", res_trained_greedy[\"preds\"][i])\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Kxghp_lE_1sp",
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "Kxghp_lE_1sp"
      },
      "outputs": [],
      "source": [
        "res_trained_greedy = evaluate_model(baseline_model, tokenizer, val_ds, limit=EVAL_LIMIT, batch_size=BATCH_SIZE, max_new_tokens=MAX_NEW_TOKENS, name=\"baseline_sampling\")\n",
        "for i in range(3):\n",
        "    print(\"\\nQ:\", val_ds[i][\"question\"])\n",
        "    print(\"Golds:\", canonical_answers(val_ds[i][\"answer\"])[:5])\n",
        "    print(\"Baseline sampling:\", res_trained_greedy[\"preds\"][i])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "uO0cn0IUP6UH",
      "metadata": {
        "id": "uO0cn0IUP6UH"
      },
      "source": [
        "## Check Coherence"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5DGKZyvoRTeJ",
      "metadata": {
        "id": "5DGKZyvoRTeJ"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "import torch\n",
        "from typing import Dict, Any, List\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "\n",
        "\n",
        "def _class_agrees(preds_for_class):\n",
        "    # If every prediction is NA, count as agreement (consistent failure)\n",
        "    ps = [(p or \"\").strip() for p in preds_for_class]\n",
        "    if all(p == \"NA\" for p in ps):\n",
        "        return True\n",
        "\n",
        "    vals = [p for p in ps if p != \"NA\"]\n",
        "    if len(vals) < 2:\n",
        "        return False\n",
        "    return len(set(vals)) == 1\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_gsm8k_paraphrase_agreement(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"agree\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,   # number of class samples to print total\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Evaluate agreement across paraphrases.\n",
        "\n",
        "    Args:\n",
        "      dataset: list-like of paraphrase classes, each is List[str]\n",
        "               e.g., classes_usable where each element is [orig, para1, para2, ...]\n",
        "      limit: number of classes to evaluate\n",
        "      batch_size: generation batch size (over paraphrase questions)\n",
        "      max_new_tokens: passed to answer_batch_chat\n",
        "      use_greedy: greedy vs sampling\n",
        "\n",
        "    Prints:\n",
        "      [{name}] N={num_classes}  AGREE: {mean:.2f} and AGREE confidence [lo, hi]\n",
        "\n",
        "    Returns:\n",
        "      {\n",
        "        \"preds\": per_class_preds,   # List[List[str]] answers for each class item\n",
        "        \"sample\": {\"indices\": sampled_indices, \"questions\": qs, \"golds\": None},\n",
        "        \"N\": num_classes,\n",
        "        \"AGREE\": agree_mean,\n",
        "        \"AGREE_CI\": (lo, hi),\n",
        "        \"AGREE_CI_Level\": 0.95,\n",
        "      }\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    try:\n",
        "        n_total = len(dataset)\n",
        "        n_classes = min(int(limit), n_total) if limit is not None else n_total\n",
        "        sampled_indices = list(range(n_classes))\n",
        "\n",
        "        per_class_preds: List[List[str]] = []\n",
        "        qs_sample: List[List[str]] = []  # store the paraphrases per class for debugging\n",
        "\n",
        "        num_batches = (n_classes + batch_size - 1) // batch_size\n",
        "        pbar = tqdm(total=n_classes, desc=f\"Paraphrase agreement {name}\", dynamic_ncols=True)\n",
        "\n",
        "        agree_yes = 0\n",
        "        printed = 0\n",
        "\n",
        "        for idx in sampled_indices:\n",
        "            cls = dataset[idx]  # List[str] paraphrases (includes original)\n",
        "            qs_sample.append(cls)\n",
        "\n",
        "            # Generate in chunks if class is large (so we respect batch_size)\n",
        "            cls_preds: List[str] = []\n",
        "            for start in range(0, len(cls), batch_size):\n",
        "                sub_qs = cls[start:start + batch_size]\n",
        "                sub_preds = answer_batch_chat(\n",
        "                    sub_qs,\n",
        "                    model=model,\n",
        "                    tokenizer=tokenizer,\n",
        "                    max_new_tokens=max_new_tokens,\n",
        "                    use_greedy=use_greedy,\n",
        "                )\n",
        "                cls_preds.extend(sub_preds)\n",
        "\n",
        "            per_class_preds.append(cls_preds)\n",
        "\n",
        "            agrees = _class_agrees(cls_preds)\n",
        "            if agrees:\n",
        "                agree_yes += 1\n",
        "\n",
        "            done = len(per_class_preds)\n",
        "            agree_mean_running = agree_yes / done if done else 0.0\n",
        "\n",
        "            pbar.set_postfix({\n",
        "                \"classes\": done,\n",
        "                \"AGREE%\": f\"{agree_mean_running*100:.2f}\",\n",
        "                \"greedy\": use_greedy,\n",
        "            })\n",
        "            pbar.update(1)\n",
        "\n",
        "            # Optional: print a few class samples total\n",
        "            if show_samples > 0 and printed < show_samples:\n",
        "                pbar.write(f\"[{name}] class_idx={idx} size={len(cls)} agrees={agrees}\")\n",
        "                # show first 2 questions and all preds\n",
        "                for j in range(min(2, len(cls))):\n",
        "                    q_preview = cls[j].replace(\"\\n\", \" \")\n",
        "                    if len(q_preview) > 160:\n",
        "                        q_preview = q_preview[:160] + \"...\"\n",
        "                    pbar.write(f\"  q{j}: {q_preview}\")\n",
        "                pbar.write(f\"  preds: {cls_preds}\")\n",
        "                printed += 1\n",
        "\n",
        "        pbar.close()\n",
        "\n",
        "        agree_mean = agree_yes / n_classes if n_classes else 0.0\n",
        "        lo, hi = _wilson_ci(agree_yes, n_classes, level=0.95)\n",
        "\n",
        "        print(\n",
        "            f\"[{name}] N={n_classes}  \"\n",
        "            f\"AGREE: {agree_mean*100:.2f} and AGREE confidence \"\n",
        "            f\"[{lo*100:.2f}, {hi*100:.2f}]\"\n",
        "        )\n",
        "\n",
        "        return {\n",
        "            \"preds\": per_class_preds,\n",
        "            \"sample\": {\n",
        "                \"indices\": sampled_indices,\n",
        "                \"questions\": qs_sample,\n",
        "                \"golds\": None,\n",
        "            },\n",
        "            \"N\": n_classes,\n",
        "            \"AGREE\": agree_mean,\n",
        "            \"AGREE_CI\": (lo, hi),\n",
        "            \"AGREE_CI_Level\": 0.95,\n",
        "        }\n",
        "\n",
        "    finally:\n",
        "        model.train(was_training)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sa-QOGDcEJhd",
      "metadata": {
        "id": "sa-QOGDcEJhd"
      },
      "outputs": [],
      "source": [
        "from typing import Dict, Any, List\n",
        "from tqdm.auto import tqdm\n",
        "import torch\n",
        "\n",
        "def evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    dataset,\n",
        "    limit: int = 100,\n",
        "    batch_size: int = 8,          # number of CLASSES per batch (not prompts)\n",
        "    max_new_tokens: int = 256,\n",
        "    name: str = \"agree\",\n",
        "    use_greedy: bool = True,\n",
        "    *,\n",
        "    show_samples: int = 0,\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Batched agreement eval when each class has ~2 items (orig + 1 paraphrase).\n",
        "\n",
        "    batch_size here means \"number of classes per batch\".\n",
        "    Each batch does 1 generation call over ~2*batch_size prompts.\n",
        "    \"\"\"\n",
        "    was_training = model.training\n",
        "    model.eval()\n",
        "\n",
        "    try:\n",
        "        n_total = len(dataset)\n",
        "        n_classes = min(int(limit), n_total) if limit is not None else n_total\n",
        "        sampled_indices = list(range(n_classes))\n",
        "\n",
        "        per_class_preds: List[List[str]] = []\n",
        "        qs_sample: List[List[str]] = []\n",
        "\n",
        "        pbar = tqdm(total=n_classes, desc=f\"Paraphrase agreement {name}\", dynamic_ncols=True)\n",
        "        agree_yes = 0\n",
        "        printed = 0\n",
        "\n",
        "        for start in range(0, n_classes, batch_size):\n",
        "            end = min(start + batch_size, n_classes)\n",
        "            cls_batch = [dataset[i] for i in range(start, end)]  # list of classes\n",
        "\n",
        "            # Record questions for the sample output\n",
        "            qs_sample.extend(cls_batch)\n",
        "\n",
        "            # Flatten prompts: [c0_q0, c0_q1, c1_q0, c1_q1, ...]\n",
        "            flat_prompts: List[str] = []\n",
        "            lengths: List[int] = []\n",
        "            for cls in cls_batch:\n",
        "                # cls is [orig, para] (or sometimes only [orig])\n",
        "                lengths.append(len(cls))\n",
        "                flat_prompts.extend(cls)\n",
        "\n",
        "            # One big batched generation call\n",
        "            flat_preds = answer_batch_chat(\n",
        "                flat_prompts,\n",
        "                model=model,\n",
        "                tokenizer=tokenizer,\n",
        "                max_new_tokens=max_new_tokens,\n",
        "                use_greedy=use_greedy,\n",
        "            )\n",
        "\n",
        "            # Unflatten back into per-class preds\n",
        "            offset = 0\n",
        "            for bi, L in enumerate(lengths):\n",
        "                cls_preds = flat_preds[offset:offset + L]\n",
        "                offset += L\n",
        "\n",
        "                per_class_preds.append(cls_preds)\n",
        "\n",
        "                agrees = _class_agrees(cls_preds)\n",
        "                if agrees:\n",
        "                    agree_yes += 1\n",
        "\n",
        "                # Optional sample prints (total across run)\n",
        "                if show_samples > 0 and printed < show_samples:\n",
        "                    class_idx = start + bi\n",
        "                    pbar.write(f\"[{name}] class_idx={class_idx} size={L} agrees={agrees}\")\n",
        "                    for j in range(min(2, L)):\n",
        "                        q_preview = cls_batch[bi][j].replace(\"\\n\", \" \")\n",
        "                        if len(q_preview) > 160:\n",
        "                            q_preview = q_preview[:160] + \"...\"\n",
        "                        pbar.write(f\"  q{j}: {q_preview}\")\n",
        "                    pbar.write(f\"  preds: {cls_preds}\")\n",
        "                    printed += 1\n",
        "\n",
        "            done = len(per_class_preds)\n",
        "            agree_mean_running = agree_yes / done if done else 0.0\n",
        "            pbar.set_postfix({\n",
        "                \"classes\": done,\n",
        "                \"AGREE%\": f\"{agree_mean_running*100:.2f}\",\n",
        "                \"greedy\": use_greedy,\n",
        "            })\n",
        "            pbar.update(end - start)\n",
        "\n",
        "        pbar.close()\n",
        "\n",
        "        agree_mean = agree_yes / n_classes if n_classes else 0.0\n",
        "        lo, hi = _wilson_ci(agree_yes, n_classes, level=0.95)\n",
        "\n",
        "        print(\n",
        "            f\"[{name}] N={n_classes}  \"\n",
        "            f\"AGREE: {agree_mean*100:.2f} and AGREE confidence \"\n",
        "            f\"[{lo*100:.2f}, {hi*100:.2f}]\"\n",
        "        )\n",
        "\n",
        "        return {\n",
        "            \"preds\": per_class_preds,\n",
        "            \"sample\": {\n",
        "                \"indices\": sampled_indices,\n",
        "                \"questions\": qs_sample,\n",
        "                \"golds\": None,\n",
        "            },\n",
        "            \"N\": n_classes,\n",
        "            \"AGREE\": agree_mean,\n",
        "            \"AGREE_CI\": (lo, hi),\n",
        "            \"AGREE_CI_Level\": 0.95,\n",
        "        }\n",
        "\n",
        "    finally:\n",
        "        model.train(was_training)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "pnyq4rJrRjAe",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 243,
          "referenced_widgets": [
            "b6c0617396d44837ac2bc00893b812d1",
            "39b9ffa02eb84f85a76156cbc1c8a88b",
            "7f8b51a3fe4343f39fa6ddc435410a3a",
            "3a3884b6b2c045c2ae33bba4c0c462b4",
            "51096091a9594b3cbba9e1f06c439f65",
            "346936dae8d0460b88f574168fea2209",
            "a7531cf6b513470da02333b1420e5c9c",
            "da8588237dec4ae9877dc1ebb1236e48",
            "59f2130e9c564172b17f10efa43f72d6",
            "3345258ab7fd4dc085c81ca3eacb35dd",
            "441473c91daf4aeea3a91dc460512b2a"
          ]
        },
        "id": "pnyq4rJrRjAe",
        "outputId": "bb932ce5-2731-471c-d6c8-3f2fd568ee2e"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b6c0617396d44837ac2bc00893b812d1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Paraphrase agreement qwen3b:   0%|          | 0/7473 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[qwen3b] class_idx=0 size=2 agrees=True\n",
            "  q0: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n",
            "  q1: In April, Natalia sold clips to 48 of her friends, and in May she sold half as many. How many clips did Natalia sell in total between April and May?\n",
            "  preds: ['72', '72']\n",
            "[qwen3b] class_idx=1 size=2 agrees=True\n",
            "  q0: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\n",
            "  q1: Weng earns $12 per hour from babysitting. She worked for 50 minutes yesterday. How much did she earn?\n",
            "  preds: ['10', '10']\n",
            "[qwen3b] N=7473  AGREE: 67.14 and AGREE confidence [66.06, 68.19]\n"
          ]
        }
      ],
      "source": [
        "coherence_eval_limit = len(paraphrase_classes)\n",
        "\n",
        "# classes_usable is output of build_classes_full_batched(...)\n",
        "agree_res = evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "    model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=paraphrase_classes,\n",
        "    limit=coherence_eval_limit,\n",
        "    batch_size=64,\n",
        "    max_new_tokens=256,\n",
        "    name=\"qwen3b\",\n",
        "    use_greedy=True,\n",
        "    show_samples=2,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "XUChbdyqVsLQ",
      "metadata": {
        "id": "XUChbdyqVsLQ"
      },
      "source": [
        "## GRPO Implementation"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3HnZFHkuHZVx",
      "metadata": {
        "id": "3HnZFHkuHZVx"
      },
      "source": [
        "### GRPO Main"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dx7osq9fVvm1",
      "metadata": {
        "id": "dx7osq9fVvm1"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "from torch.utils.data import IterableDataset\n",
        "\n",
        "class ParaphrasePairIterableFormatted(IterableDataset):\n",
        "    \"\"\"\n",
        "    Infinite stream. Each \"pair\" yields exactly two consecutive items:\n",
        "      (q0, is_original=1) then (q1, is_original=0)\n",
        "    \"\"\"\n",
        "    def __init__(self, classes_usable, tokenizer, format_prompt_instruct, system_prompt, seed=0):\n",
        "        self.classes = classes_usable          # list[list[str]]; cls[0] is original\n",
        "        self.tok = tokenizer\n",
        "        self.format_prompt_instruct = format_prompt_instruct\n",
        "        self.system_prompt = system_prompt\n",
        "        self.rng = random.Random(seed)\n",
        "\n",
        "    def __iter__(self):\n",
        "        while True:\n",
        "            cls = self.rng.choice(self.classes)\n",
        "            q0 = cls[0]\n",
        "            q1 = self.rng.choice(cls[1:]) if len(cls) > 1 else q0\n",
        "            pair_id = self.rng.randrange(2**63)\n",
        "\n",
        "            p0 = self.format_prompt_instruct(q0, tokenizer=self.tok, system_prompt=self.system_prompt)\n",
        "            p1 = self.format_prompt_instruct(q1, tokenizer=self.tok, system_prompt=self.system_prompt)\n",
        "\n",
        "            yield {\"prompt\": p0, \"is_original\": 1, \"pair_id\": pair_id}\n",
        "            yield {\"prompt\": p1, \"is_original\": 0, \"pair_id\": pair_id}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1snZI84RW7ml",
      "metadata": {
        "id": "1snZI84RW7ml"
      },
      "outputs": [],
      "source": [
        "from collections import Counter\n",
        "\n",
        "\n",
        "def agreement_reward_fn(prompts, completions, **kwargs):\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    # majority vote among all answers (including \"NA\", per your current behavior)\n",
        "    if not answers:\n",
        "        return []\n",
        "\n",
        "    counts = Counter(answers)\n",
        "    max_count = max(counts.values())\n",
        "    majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "    # reward 1 to (tied) majority answer(s), 0 otherwise\n",
        "    return [1.0 if a in majority_set else 0.0 for a in answers]\n",
        "\n",
        "\n",
        "\n",
        "from collections import Counter\n",
        "\n",
        "def agreement_reward_(prompts, completions, **kwargs):\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    # majority vote among all answers (including \"NA\", per your current behavior)\n",
        "    if not answers:\n",
        "        return []\n",
        "\n",
        "    counts = Counter(answers)\n",
        "    max_count = max(counts.values())\n",
        "    majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "    # reward 1 to (tied) majority answer(s), 0 otherwise\n",
        "    return [1.0 if a in majority_set else 0.0 for a in answers]\n",
        "\n",
        "\n",
        "\n",
        "from collections import Counter\n",
        "\n",
        "def pair_agreement_reward_fn(prompts, completions, pair_id=None, is_original=None, **kwargs):\n",
        "    # pair_id and is_original should be aligned with (prompts, completions)\n",
        "    # TRL typically repeats prompt-level columns across generations.\n",
        "\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    rewards = [0.0] * len(answers)\n",
        "    if pair_id is None or is_original is None:\n",
        "        raise ValueError(\"Need pair_id and is_original passed into reward fn (from dataset columns).\")\n",
        "\n",
        "    # group indices by pair_id\n",
        "    by_pid = {}\n",
        "    for i, pid in enumerate(pair_id):\n",
        "        by_pid.setdefault(pid, []).append(i)\n",
        "\n",
        "    for pid, idxs in by_pid.items():\n",
        "        idx0 = [i for i in idxs if int(is_original[i]) == 1]  # original prompt (q0)\n",
        "        idx1 = [i for i in idxs if int(is_original[i]) == 0]  # paraphrase (q1)\n",
        "\n",
        "        a0 = [answers[i] for i in idx0 if answers[i] != \"NA\"]\n",
        "        a1 = [answers[i] for i in idx1 if answers[i] != \"NA\"]\n",
        "\n",
        "        # Only reward if BOTH sides have at least one non-NA answer\n",
        "        if len(a0) == 0 or len(a1) == 0:\n",
        "            continue\n",
        "\n",
        "        c0, c1 = Counter(a0), Counter(a1)\n",
        "        d0, d1 = len(a0), len(a1)\n",
        "\n",
        "        # Reward each completion by \"how often the other paraphrase produced the same answer\"\n",
        "        for i in idx0:\n",
        "            a = answers[i]\n",
        "            if a != \"NA\":\n",
        "                rewards[i] = c1.get(a, 0) / d1\n",
        "        for i in idx1:\n",
        "            a = answers[i]\n",
        "            if a != \"NA\":\n",
        "                rewards[i] = c0.get(a, 0) / d0\n",
        "\n",
        "    return rewards"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "fyyil8TixJyR",
      "metadata": {
        "id": "fyyil8TixJyR"
      },
      "outputs": [],
      "source": [
        "from collections import Counter\n",
        "\n",
        "def pair_majority_reward_fn(prompts, completions, pair_id=None, is_original=None, **kwargs):\n",
        "    \"\"\"\n",
        "    Reward = 1.0 iff the completion's extracted answer is in the (tied) majority vote\n",
        "    within its side (original vs paraphrase) for the same pair_id. Otherwise 0.0.\n",
        "\n",
        "    Separation between pairs: majority is computed independently per pair_id,\n",
        "    and also independently per side (is_original==1 vs ==0) within each pair.\n",
        "    \"\"\"\n",
        "    if pair_id is None or is_original is None:\n",
        "        raise ValueError(\"Need pair_id and is_original passed into reward fn (from dataset columns).\")\n",
        "\n",
        "    # Extract answers (same as before)\n",
        "    answers = []\n",
        "    for c in completions:\n",
        "        a = _extract_gsm8k_final(c or \"\")\n",
        "        answers.append(a if a else \"NA\")\n",
        "\n",
        "    rewards = [0.0] * len(answers)\n",
        "\n",
        "    # Group indices by pair_id\n",
        "    by_pid = {}\n",
        "    for i, pid in enumerate(pair_id):\n",
        "        by_pid.setdefault(pid, []).append(i)\n",
        "\n",
        "    for pid, idxs in by_pid.items():\n",
        "        # Split indices by side\n",
        "        idx0 = [i for i in idxs if int(is_original[i]) == 1]  # original prompt (q0)\n",
        "        idx1 = [i for i in idxs if int(is_original[i]) == 0]  # paraphrase (q1)\n",
        "\n",
        "        # Majority vote within each side (including \"NA\", matching your current behavior)\n",
        "        for side_idxs in (idx0, idx1):\n",
        "            if not side_idxs:\n",
        "                continue\n",
        "\n",
        "            side_answers = [answers[i] for i in side_idxs]\n",
        "            counts = Counter(side_answers)\n",
        "            max_count = max(counts.values())\n",
        "            majority_set = {a for a, cnt in counts.items() if cnt == max_count}\n",
        "\n",
        "            for i in side_idxs:\n",
        "                rewards[i] = 1.0 if answers[i] in majority_set else 0.0\n",
        "\n",
        "            # ensure NA stays 0\n",
        "            for i in side_idxs:\n",
        "                if answers[i] == \"NA\":\n",
        "                    rewards[i] = 0.0\n",
        "\n",
        "    return rewards\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "xkJgG7HfW9Vn",
      "metadata": {
        "id": "xkJgG7HfW9Vn"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from trl import GRPOTrainer\n",
        "from collections.abc import Mapping\n",
        "import torch\n",
        "from trl import GRPOTrainer\n",
        "\n",
        "class OrigOnlyKLGRPOTrainer(GRPOTrainer):\n",
        "    def _prepare_inputs(self, inputs):\n",
        "        # inputs is list[dict]\n",
        "        is_orig_list = [ex[\"is_original\"] for ex in inputs]\n",
        "        pair_id_list = [ex[\"pair_id\"] for ex in inputs]\n",
        "\n",
        "        prepared = super()._prepare_inputs(inputs)\n",
        "\n",
        "        device = prepared[\"prompt_ids\"].device\n",
        "        B_prompts = prepared[\"prompt_ids\"].shape[0]   # prompt-level count (usually 2*PAIRS_PER_STEP)\n",
        "\n",
        "        def compress_expanded_to_prompt_level(t, B_prompts):\n",
        "            # t has length B_prompts*K (expanded); compress to length B_prompts\n",
        "            K = t.numel() // B_prompts\n",
        "            first_block = t[:B_prompts]\n",
        "\n",
        "            # If the first block already contains mixed values, it's interleaved by prompt:\n",
        "            # [q0,q1,q0,q1,...] per generation => first_block is exactly prompt-level.\n",
        "            if torch.unique(first_block).numel() > 1:\n",
        "                return first_block\n",
        "\n",
        "            # Otherwise grouped layout: [q0...q0, q1...q1]\n",
        "            idx = torch.arange(0, t.numel(), K, device=t.device)  # 0, K, 2K, ...\n",
        "            return t[idx]\n",
        "\n",
        "        def align_to_B(vec):\n",
        "            \"\"\"\n",
        "            Align vec to length B_prompts. vec may be:\n",
        "              - prompt-level (len == B_prompts)\n",
        "              - expanded (len == B_prompts*K)\n",
        "            \"\"\"\n",
        "            t = torch.as_tensor(vec, device=device).view(-1).long()\n",
        "\n",
        "            if t.numel() == B_prompts:\n",
        "                return t\n",
        "\n",
        "            if t.numel() > B_prompts:\n",
        "                # ✅ THIS is where \"already-expanded → compress\" goes\n",
        "                if t.numel() % B_prompts != 0:\n",
        "                    raise ValueError(f\"Cannot compress metadata: got {t.numel()}, need multiple of {B_prompts}\")\n",
        "                return compress_expanded_to_prompt_level(t, B_prompts)\n",
        "\n",
        "            # (rare) scalar -> broadcast\n",
        "            if t.numel() == 1:\n",
        "                return t.repeat(B_prompts)\n",
        "\n",
        "            raise ValueError(f\"Cannot align metadata: got {t.numel()}, need {B_prompts}\")\n",
        "\n",
        "        prepared[\"is_original_prompt\"] = align_to_B(is_orig_list)\n",
        "        prepared[\"pair_id_prompt\"] = align_to_B(pair_id_list)\n",
        "\n",
        "        # keep prompt-level values here; expand later in compute_loss to match rollouts\n",
        "        return prepared\n",
        "\n",
        "    def _get_logps(self, model, input_ids, attention_mask, logits_to_keep):\n",
        "        # TRL versions differ: some have _get_per_token_logps, newer use _get_per_token_logps_and_entropies\n",
        "        if hasattr(self, \"_get_per_token_logps\"):\n",
        "            return self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)\n",
        "\n",
        "        out = self._get_per_token_logps_and_entropies(\n",
        "            model,\n",
        "            input_ids,\n",
        "            attention_mask,\n",
        "            logits_to_keep,\n",
        "            compute_entropy=False,\n",
        "            pixel_values=None,\n",
        "            image_grid_thw=None,\n",
        "            num_images=None,\n",
        "        )\n",
        "        # out can be (logps, entropies) or {\"logps\":..., \"entropies\":...}\n",
        "        if isinstance(out, tuple):\n",
        "            return out[0]\n",
        "        return out[\"logps\"]\n",
        "\n",
        "    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
        "        import torch\n",
        "\n",
        "        if return_outputs:\n",
        "            raise ValueError(\"GRPOTrainer does not support return_outputs=True\")\n",
        "\n",
        "        prompt_ids, prompt_mask = inputs[\"prompt_ids\"], inputs[\"prompt_mask\"]\n",
        "        completion_ids, completion_mask = inputs[\"completion_ids\"], inputs[\"completion_mask\"]\n",
        "\n",
        "        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n",
        "        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)\n",
        "\n",
        "        logits_to_keep = completion_ids.size(1)\n",
        "        per_token_logps = self._get_logps(model, input_ids, attention_mask, logits_to_keep)  # [N, T]\n",
        "\n",
        "        # ---- dimensions ----\n",
        "        N = per_token_logps.shape[0]  # rollout-level rows\n",
        "        denom = completion_mask.sum(dim=1).clamp_min(1)  # [N]\n",
        "\n",
        "        # ---- advantages: align to N ----\n",
        "        advantages = inputs[\"advantages\"].view(-1)  # could be [N] or [B]\n",
        "        if advantages.numel() != N:\n",
        "            if N % advantages.numel() == 0:\n",
        "                advantages = advantages.repeat_interleave(N // advantages.numel())\n",
        "            else:\n",
        "                raise ValueError(f\"advantages has {advantages.numel()} elems but need {N}\")\n",
        "        advantages = advantages.to(per_token_logps.device)\n",
        "\n",
        "        # ---- PPO/GRPO ratio term ----\n",
        "        if \"old_per_token_logps\" in inputs:\n",
        "            ratio = torch.exp(per_token_logps - inputs[\"old_per_token_logps\"])\n",
        "        else:\n",
        "            ratio = torch.exp(per_token_logps - per_token_logps.detach())\n",
        "\n",
        "        per_token_pg = ratio * advantages.unsqueeze(1)  # [N, T]\n",
        "\n",
        "        beta = float(getattr(self, \"beta\", getattr(self.args, \"beta\", 0.0)))\n",
        "\n",
        "        # ---- KL term (masked to originals only) ----\n",
        "        is_orig = None\n",
        "        per_token_kl = 0.0\n",
        "\n",
        "        if beta != 0.0:\n",
        "            ref_per_token_logps = inputs[\"ref_per_token_logps\"]\n",
        "            per_token_kl_raw = (\n",
        "                torch.exp(ref_per_token_logps - per_token_logps)\n",
        "                - (ref_per_token_logps - per_token_logps)\n",
        "                - 1\n",
        "            )  # [N, T]\n",
        "\n",
        "            # prompt-level mask produced in _prepare_inputs\n",
        "            is_orig_prompt = inputs.get(\"is_original_prompt\", None)\n",
        "            if is_orig_prompt is None:\n",
        "                raise ValueError(\"Missing inputs['is_original_prompt'] from _prepare_inputs().\")\n",
        "\n",
        "            is_orig_prompt = torch.as_tensor(is_orig_prompt, device=per_token_logps.device).float().view(-1)  # [B]\n",
        "            B = is_orig_prompt.numel()\n",
        "\n",
        "            if N % B != 0:\n",
        "                raise ValueError(f\"Cannot expand is_original_prompt: N={N}, B={B}\")\n",
        "            K = N // B\n",
        "\n",
        "            # Detect whether rollout rows are interleaved-by-prompt (first B rows are all different prompts)\n",
        "            def _is_interleaved_by_prompt(prompt_ids, B):\n",
        "                reps = []\n",
        "                for i in range(min(B, prompt_ids.size(0))):\n",
        "                    row = prompt_ids[i]\n",
        "                    if not any(torch.equal(row, r) for r in reps):\n",
        "                        reps.append(row)\n",
        "                return len(reps) == B\n",
        "\n",
        "            interleaved = _is_interleaved_by_prompt(prompt_ids, B)\n",
        "\n",
        "            if interleaved:\n",
        "                # [p0,p1,...,pB-1, p0,p1,...] repeated K times\n",
        "                is_orig = is_orig_prompt.repeat(K)  # [N]\n",
        "            else:\n",
        "                # [p0...p0, p1...p1, ...] each repeated K times\n",
        "                is_orig = is_orig_prompt.repeat_interleave(K)  # [N]\n",
        "\n",
        "            # Mask KL to originals only\n",
        "            per_token_kl = per_token_kl_raw * is_orig.unsqueeze(1)  # [N, T]\n",
        "\n",
        "            # ---- logging: masked KL only (as requested) ----\n",
        "            with torch.no_grad():\n",
        "                # prompt-level counts\n",
        "                is_p = is_orig_prompt.float().view(-1)\n",
        "                p_n1 = int((is_p == 1).sum().item())\n",
        "                p_n0 = int((is_p == 0).sum().item())\n",
        "\n",
        "                # rollout-level counts (after expansion)\n",
        "                r_n1 = int((is_orig == 1).sum().item())\n",
        "                r_n0 = int((is_orig == 0).sum().item())\n",
        "\n",
        "            self.log({\n",
        "                \"dbg/prompt_n1\": p_n1,\n",
        "                \"dbg/prompt_n0\": p_n0,\n",
        "                \"dbg/rollout_n1\": r_n1,\n",
        "                \"dbg/rollout_n0\": r_n0,\n",
        "            })\n",
        "            with torch.no_grad():\n",
        "                # mean KL over completion tokens for each sample\n",
        "                kl_per_sample = (per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)\n",
        "\n",
        "                # raw mean KL (all samples)\n",
        "                kl_all = kl_per_sample.mean()\n",
        "\n",
        "                # orig-only mean KL (only samples with is_orig==1)\n",
        "                orig_mask = is_orig.bool()\n",
        "                kl_orig = kl_per_sample[orig_mask].mean() if orig_mask.any() else torch.tensor(0.0, device=kl_all.device)\n",
        "\n",
        "            self.log({\n",
        "                \"kl_all\": kl_all.detach().float().cpu().item(),\n",
        "                \"kl_orig_only\": kl_orig.detach().float().cpu().item(),\n",
        "            })\n",
        "\n",
        "        # ---- final loss ----\n",
        "        per_token_loss = -(per_token_pg - beta * per_token_kl)  # [N, T]\n",
        "        loss = ((per_token_loss * completion_mask).sum(dim=1) / denom).mean()\n",
        "\n",
        "        return loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "i5QbPheK3EXX",
      "metadata": {
        "id": "i5QbPheK3EXX"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainerCallback\n",
        "import torch\n",
        "\n",
        "class GSM8KEvalCallback(TrainerCallback):\n",
        "    def __init__(self, trainer, *, eval_fn, log_every: int, eval_kwargs: dict,\n",
        "                 reward_key=\"reward\", kl_key=\"kl_all_raw\"):\n",
        "        self.trainer = trainer\n",
        "        self.eval_fn = eval_fn\n",
        "        self.log_every = log_every\n",
        "        self.eval_kwargs = eval_kwargs\n",
        "        self.reward_key = reward_key\n",
        "        self.kl_key = kl_key\n",
        "        self._last_reward = None\n",
        "        self._last_kl = None\n",
        "\n",
        "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
        "        logs = logs or {}\n",
        "        # cache latest values when they appear\n",
        "        if self.reward_key in logs:\n",
        "            self._last_reward = logs[self.reward_key]\n",
        "        if self.kl_key in logs:\n",
        "            self._last_kl = logs[self.kl_key]\n",
        "\n",
        "        step = int(state.global_step)\n",
        "        if step == 0 or (step % self.log_every) != 0:\n",
        "            return control\n",
        "\n",
        "        # main process only\n",
        "        if hasattr(self.trainer, \"is_world_process_zero\") and not self.trainer.is_world_process_zero():\n",
        "            return control\n",
        "\n",
        "        def fmt(x):\n",
        "            return \"NA\" if x is None else f\"{x:.6g}\"\n",
        "\n",
        "        print(f\"[step {step}] {self.reward_key}={fmt(self._last_reward)}  {self.kl_key}={fmt(self._last_kl)}\")\n",
        "\n",
        "        # run eval\n",
        "        model = self.trainer.model\n",
        "        was_training = model.training\n",
        "        model.eval()\n",
        "\n",
        "\n",
        "        with torch.no_grad():\n",
        "            metrics = self.eval_fn(model=model, **self.eval_kwargs)\n",
        "        if was_training:\n",
        "            model.train()\n",
        "\n",
        "        self.trainer.log({f\"gsm8k/{k}\": v for k, v in metrics.items()})\n",
        "\n",
        "        if torch.cuda.is_available():\n",
        "            torch.cuda.empty_cache()\n",
        "        return control"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "51XzkxBxBJNa",
      "metadata": {
        "id": "51XzkxBxBJNa"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainerCallback\n",
        "import torch\n",
        "\n",
        "class PrintAndEvalOnSave(TrainerCallback):\n",
        "    def __init__(self, trainer, *, eval_fn, eval_kwargs: dict,\n",
        "                 reward_key=\"reward\", kl_key=\"kl_orig_only\",\n",
        "                 n1_key=\"dbg/rollout_n1\", n0_key=\"dbg/rollout_n0\"):\n",
        "        self.trainer = trainer\n",
        "        self.eval_fn = eval_fn\n",
        "        self.eval_kwargs = eval_kwargs\n",
        "        self.reward_key = reward_key\n",
        "        self.kl_key = kl_key\n",
        "        self.n1_key = n1_key\n",
        "        self.n0_key = n0_key\n",
        "\n",
        "        self._last_reward = None\n",
        "        self._last_kl = None\n",
        "        self._last_n1 = None\n",
        "        self._last_n0 = None\n",
        "\n",
        "        self._last_eval_step = -1\n",
        "        self._logging_eval = False\n",
        "\n",
        "    def on_log(self, args, state, control, logs=None, **kwargs):\n",
        "        # Cache train stats whenever they are logged\n",
        "        if self._logging_eval:\n",
        "            return control\n",
        "        logs = logs or {}\n",
        "\n",
        "        if self.reward_key in logs:\n",
        "            self._last_reward = logs[self.reward_key]\n",
        "        if self.kl_key in logs:\n",
        "            self._last_kl = logs[self.kl_key]\n",
        "        if self.n1_key in logs:\n",
        "            self._last_n1 = logs[self.n1_key]\n",
        "        if self.n0_key in logs:\n",
        "            self._last_n0 = logs[self.n0_key]\n",
        "\n",
        "        return control\n",
        "\n",
        "    def on_save(self, args, state, control, **kwargs):\n",
        "        step = int(state.global_step)\n",
        "        if step == 0 or step == self._last_eval_step:\n",
        "            return control\n",
        "        self._last_eval_step = step\n",
        "\n",
        "        if hasattr(self.trainer, \"is_world_process_zero\") and not self.trainer.is_world_process_zero():\n",
        "            return control\n",
        "\n",
        "        def fmt(x):\n",
        "            return \"NA\" if x is None else f\"{x:.6g}\" if isinstance(x, (float, int)) else str(x)\n",
        "\n",
        "        print(\n",
        "            f\"[save step {step}] \"\n",
        "            f\"{self.reward_key}={fmt(self._last_reward)}  \"\n",
        "            f\"{self.kl_key}={fmt(self._last_kl)}  \"\n",
        "            f\"n1={fmt(self._last_n1)}  n0={fmt(self._last_n0)}\"\n",
        "        )\n",
        "\n",
        "        # Run eval once per checkpoint\n",
        "        model = self.trainer.model\n",
        "        was_training = model.training\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            metrics = self.eval_fn(model=model, **self.eval_kwargs)\n",
        "        if was_training:\n",
        "            model.train()\n",
        "\n",
        "        # Log eval metrics, but avoid triggering our own caching logic\n",
        "        self._logging_eval = True\n",
        "        try:\n",
        "            self.trainer.log({f\"gsm8k/{k}\": v for k, v in metrics.items()})\n",
        "        finally:\n",
        "            self._logging_eval = False\n",
        "\n",
        "        if torch.cuda.is_available():\n",
        "            torch.cuda.empty_cache()\n",
        "\n",
        "        return control"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "CnoU-MPWDw1A",
      "metadata": {
        "id": "CnoU-MPWDw1A"
      },
      "outputs": [],
      "source": [
        "### Phi 3 params\n",
        "\n",
        "from trl import GRPOConfig\n",
        "\n",
        "K = 24               # rollouts per prompt\n",
        "PAIRS_PER_STEP = 2  # per device\n",
        "KL_WEIGHT = 0.05\n",
        "\n",
        "LOG_EVERY = 16\n",
        "SAVE_EVERY = 20 * LOG_EVERY   #\n",
        "EVAL_LIMIT = len(test_ds)\n",
        "\n",
        "TOTAL_STEPS = 200 * LOG_EVERY  # pick whatever\n",
        "\n",
        "train_dataset = ParaphrasePairIterableFormatted(\n",
        "    classes_usable=paraphrase_classes,\n",
        "    tokenizer=tokenizer,\n",
        "    format_prompt_instruct=format_prompt_instruct,\n",
        "    system_prompt=SYSTEM_PROMPT,\n",
        "    seed=0,\n",
        ")\n",
        "\n",
        "grpo_args = GRPOConfig(\n",
        "    output_dir=WORK_PATH,\n",
        "    per_device_train_batch_size=2 * PAIRS_PER_STEP,   # 2 prompts = (q0,q1)\n",
        "    num_generations=K,\n",
        "    generation_batch_size = (2 * PAIRS_PER_STEP) * K,     # MUST be divisible by K\n",
        "\n",
        "    max_steps=TOTAL_STEPS,           # required for IterableDataset (no __len__)\n",
        "    accelerator_config={\n",
        "        \"dispatch_batches\": False,   # <-- fixes your str concatenation crash\n",
        "        \"split_batches\": False,\n",
        "    },\n",
        "\n",
        "    # ✅ checkpointing\n",
        "    save_strategy=\"steps\",\n",
        "    save_steps=SAVE_EVERY,\n",
        "    save_total_limit=30,        # keep last 30 checkpoints (optional)\n",
        "    save_safetensors=True,     # optional\n",
        "\n",
        "    beta=KL_WEIGHT,\n",
        "    logging_steps=SAVE_EVERY, #LOG_EVERY,\n",
        "    max_completion_length=MAX_NEW_TOKENS,\n",
        "    learning_rate=1e-5,\n",
        "    dataloader_num_workers=0,\n",
        "    report_to=[]\n",
        ")\n",
        "\n",
        "\n",
        "trainer = OrigOnlyKLGRPOTrainer(\n",
        "    model=trainable_model,\n",
        "    args=grpo_args,\n",
        "    train_dataset=train_dataset,\n",
        "    processing_class=tokenizer,\n",
        "    reward_funcs=pair_majority_reward_fn,\n",
        ")\n",
        "\n",
        "eval_kwargs = dict(\n",
        "    tokenizer=tokenizer,\n",
        "    dataset=test_ds,          # or test split\n",
        "    limit=EVAL_LIMIT,              # whatever you use\n",
        "    batch_size=32,\n",
        "    max_new_tokens=MAX_NEW_TOKENS,\n",
        "    name=\"trained_grpo\",\n",
        ")\n",
        "\n",
        "\n",
        "trainer.add_callback(\n",
        "    PrintAndEvalOnSave(\n",
        "        trainer,\n",
        "        eval_fn=evaluate_gsm8k,\n",
        "        eval_kwargs=eval_kwargs,\n",
        "        reward_key=\"reward\",\n",
        "        kl_key=\"kl_orig_only\",\n",
        "    )\n",
        ")\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2ivyGTz_wngt",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2ivyGTz_wngt",
        "outputId": "883766d2-263c-4f5a-a69b-5a2da114b8e6"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/drive/MyDrive/coherence-GSM8K/answerlevel/microsoft/Phi-3-mini-4k-instruct/instruct-mode-jan16/weight-5e-2/checkpoint-960\n"
          ]
        }
      ],
      "source": [
        "\n",
        "# CKPT_LOAD_PATH = os.path.join(WORK_PATH, \"checkpoint-1000\")\n",
        "print(CKPT_LOAD_PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "JcH_iz27xxgb",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 787,
          "referenced_widgets": [
            "32f6799a1fc64819a03c877739231d87",
            "a1c96f4de2034d968f9b5753b0a40a29",
            "c6f77082b18a4e81a3a62afe975eb13f",
            "2bac1d5c35be4a7ebcbe73630cf793ba",
            "ccf52167ed7943e683dc8867b667d04b",
            "3381deb1bc12409d95b206282c52a4fa",
            "ac01321b67174b19a5729e55214bbfd3",
            "49a218f81ccd4c21b21b06e70fb7437a",
            "8f27b215d4644bb987cacdacfdff300d",
            "bd1867fce7ec40f5ab51987c6a15dc01",
            "76c9e6ab15ff4a42b491d83a86dd39f1",
            "05c8b40c0b824ff78c29a88acc28def0",
            "06ba9fd853cc41cb9a3da9f2a468bcbc",
            "ce21a30b7c05419fb9dfa60a03b860f0",
            "462cff19cfe245d5bdea79800e31fbb9",
            "43e6ca1b275644d38cf635e85c707e37",
            "6c3f843e18dd490e9fdd6f803c561c88",
            "75642ef6991447988a0fdb1a42525a54",
            "20fba02e7e904962ab20895229def0b9",
            "f8470899b99440a0bb6b16c02a83228a",
            "87b9c3fdb66f45e5995514f4dc2a8484",
            "e131acb0fcb34a1b8272f4d54735c7f9",
            "3e988fa7173f43deb3b983b18f3fa43a",
            "a8786cfd96e44ac8b30d746df82423dd",
            "cdbc5c0efd43440382ac0a344b8596e1",
            "851469eda17447f7a21d3d6f54ad955b",
            "3ba0bdcbbb7b4baba0f52e3b895dfc0a",
            "a832547670bd4e98a5fea376b41474fa",
            "d9675716381646f083c45dc38b205566",
            "09d8b677e54c4690882fc5135746b551",
            "cfe47daa7dd3457880191c9c884fd9d1",
            "6e3eb8f53b3648a4a34cf0c9304cd487",
            "59afc93f3c394424b0b40d96363293d1",
            "b578aef376294c2fbdd44ea77221dc49",
            "ac505e13948d4b889852f45b1e59abbf",
            "6dc843cb39f74cb68bf3404581d52302",
            "ef13d5aabbd74bc9bf7d1ef00573d318",
            "b4c526ed80ac473296e9466c3b1de900",
            "15c860f3f7b84a648b965d46f8089f02",
            "3583c46056cc4bb1aa80b54a150d6f49",
            "31c28f0dd70b40e9b9c58e38e7212702",
            "3c806eead3e54684aa9b4241282a5a4e",
            "da39f0035da542bd8411d9af0214897a",
            "02651e4f6390479f830f7c386c5f3760",
            "52823df7480e4fbe9145261be21d3078",
            "0af9121c170646b4bd8164825bf8c066",
            "7daa4689dcc444608b14b508a6cb2ed0",
            "57bffe25f3ea446ab840f0da486c45b6",
            "72325f3267264cba8567e2305d99bb02",
            "3c0c6563a42042efa63655d705966d7e",
            "0d00ec98f31a42b8b4e5eed967075949",
            "5312c593f45841febf846af3abf7dd0b",
            "cc056289a37e479ca61b8c71c26750dc",
            "ff2117536ec04841bb2275b63124ec68",
            "16700eb57b07426f9483715d2d15ab00",
            "e5492db54a1440fa82a8efcd1d8d9c31",
            "9efb508c93a34fa8925c8def2f62c91f",
            "508dbe13f9de4f2d94c53f8a731dcfe5",
            "d3ff6d9ec7ad4bd081c2be962a328891",
            "23891f2226224121912cdde1f515a403",
            "431188b0e5c647fa91ff0074412f16f8",
            "cf5a0838d81e489ab7b835374cba9d90",
            "07cdbbac3e624cfea4bd7b8c62f905a2",
            "281bda6512684a4fac0e44c6e2595a6c",
            "10b2adc8db8d4329b6e639ae4fd7cfd5",
            "f508cc9928f64596951cfa8db6fa7e6c",
            "e5178faae6c2443dade448218fe426a0",
            "49b2ac9455524a20a6f401d09dccbb16",
            "034749bdd790473b833e84d6736d372f",
            "af68c43396554c76a342b1e654470297",
            "d57f38d5576b465e9ba982c974c414e5",
            "3625f52ccc124938aa38fa846f2603ed",
            "450ed427db1f488ba2a4b3adbae3e37a",
            "2302fa98d451435dad4005f38a2243d1",
            "dbac9ac62a4f4e49a2e07d309c4762f8",
            "75df96751cda43cc989d2e6bd6ffec22",
            "3a28131ebfc844d2be2325cc1111a1c8"
          ]
        },
        "id": "JcH_iz27xxgb",
        "outputId": "047dc6d8-3262-4c1c-fd3f-363d913d21ff"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='3200' max='3200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [3200/3200 7:39:02, Epoch 1/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>-0.000400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1600</td>\n",
              "      <td>-0.004400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1920</td>\n",
              "      <td>0.005300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2240</td>\n",
              "      <td>-0.001700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2560</td>\n",
              "      <td>-0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2880</td>\n",
              "      <td>0.000800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3200</td>\n",
              "      <td>-0.002200</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[save step 1280] reward=0.770833  kl_orig_only=0.00183819  n1=2  n0=2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "32f6799a1fc64819a03c877739231d87"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 79.83 and ACC confidence [77.58, 81.91]  |  pred NA: 52/1319 (3.94%)  \n",
            "[save step 1600] reward=0.84375  kl_orig_only=0.00104847  n1=2  n0=2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "05c8b40c0b824ff78c29a88acc28def0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 81.35 and ACC confidence [79.16, 83.36]  |  pred NA: 30/1319 (2.27%)  \n",
            "[save step 1920] reward=0.854167  kl_orig_only=0.00164702  n1=2  n0=2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3e988fa7173f43deb3b983b18f3fa43a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 81.27 and ACC confidence [79.08, 83.29]  |  pred NA: 30/1319 (2.27%)  \n",
            "[save step 2240] reward=0.885417  kl_orig_only=0.00346876  n1=2  n0=2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "b578aef376294c2fbdd44ea77221dc49"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 81.58 and ACC confidence [79.39, 83.58]  |  pred NA: 29/1319 (2.20%)  \n",
            "[save step 2560] reward=0.854167  kl_orig_only=0.00121428  n1=2  n0=2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "52823df7480e4fbe9145261be21d3078"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 81.50 and ACC confidence [79.32, 83.50]  |  pred NA: 31/1319 (2.35%)  \n",
            "[save step 2880] reward=0.802083  kl_orig_only=0.00157572  n1=2  n0=2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e5492db54a1440fa82a8efcd1d8d9c31"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 81.73 and ACC confidence [79.55, 83.72]  |  pred NA: 32/1319 (2.43%)  \n",
            "[save step 3200] reward=0.760417  kl_orig_only=0.00357192  n1=2  n0=2\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e5178faae6c2443dade448218fe426a0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[trained_grpo] N=1319  ACC: 81.50 and ACC confidence [79.32, 83.50]  |  pred NA: 24/1319 (1.82%)  \n"
          ]
        }
      ],
      "source": [
        "if CKPT_LOAD_PATH:\n",
        "    trainer.train(resume_from_checkpoint=CKPT_LOAD_PATH)\n",
        "else:\n",
        "    trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "xUUbbN053K8i",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "d51f209b03524107b7ddf87d216b3e8c",
            "8fb456aa3d3c40b28811423b89741a57",
            "ee65a0bd33014bbf9bce43841123173e",
            "0125142e12494997bae874d02db91c04",
            "9038cdab7ca249979251730657071401",
            "e1b9e73487eb48baac1632bb4c64ffa0",
            "12aea8a6e24d41e09fa5abc7e96540c5",
            "e410aca5539d420592f6824b7ae91350",
            "8da70f99a4134e79a215998718df41ec",
            "627412ffcc784ebbb5f201e23ba9a47e",
            "48f1bb6613244a678c7c70957459aa71",
            "2c2690553c28419e8c205070265fec8b",
            "bfd3da20dbea45e891614b19f66fef4e",
            "452663c398e1482da9cca0301829c064",
            "bdb1cbec5e874db3bffff56442f6ec3c",
            "003329dff7a54a35afc3ebbe07bbb161",
            "d2ba3b9690d44e649e694c6f72d5bfa8",
            "602c2e372b6c40e7b8904f1c1525c2fa",
            "5889e1ddd2bb439fb860fc23e77093c5",
            "ecb11be7e0f445e99ea4be35fa486cdf",
            "b9d4e8a9db9d41e696f73535ac327beb",
            "170a030d0b954dd4a0597954e8bcb780",
            "64957ba8da1345288f1803d04b512014",
            "0f5c722767084418a3018b67bb63f295",
            "316b66e65c3e4ab18b283ee3a0be6c23",
            "c4d0d0db991d42468b1ddd6f1768be8b",
            "e2dd3220008c47d8ace31cfe8595a443",
            "1620cf0812d94dfdb188cabd69d7185f",
            "c55dcb589588469c93da86a5fb17c01f",
            "375df7873dcc4e76bb96ee39ddb859bf",
            "5f40d588569f4580a720f57923bf1969",
            "c253ba6221404c8b936244d468ab2774",
            "05b29e18fea8439ba26726bfd832a6e2",
            "0cc00ca50eb14244b8623fa83346b224",
            "dacd682ca95047248ad4c58556ef6752",
            "d7a0a9267a9b41df998216c52db22b9c",
            "3716b28289c9497c87c8655c358a54d6",
            "02fad28d4889434c89b37c102c227de8",
            "3b29cebc081f425f94cf7993561bfb0e",
            "a5d5a68661f84681ba2c0e5fed956c93",
            "d68b8bc8db334d7e9464f76cbbe41365",
            "35ce343819c44025a76d6683a59069a2",
            "1166eda0f8dd4b66a842cc6ed3796780",
            "f0342932dde4420c90682462b8a56396",
            "fb7ff2e985d4495ea059ae8deb1886ba",
            "40e14b1878eb4251bba8e6b57d4414f5",
            "f2f87e1839314a68aed4a4b5b78c42fa",
            "bbc5a0432af04709820be86c0d6e8713",
            "5b81518166804f87b51ca52ed12db7af",
            "fdfb607e571447e992ac015475b36167",
            "b15072f9000c4b3c91433f7dd3f21c67",
            "41db2fadc95849a3a5d8d408a07a4b24",
            "02d7422302084d3ca7f138055a525be0",
            "182ddfeba28d4fd59d5f58d1e5bff82c",
            "4e49a3b0ee974fe9bd5810803d598c78",
            "033e4e64f5bf48f1b95a65c2064ae67d",
            "4dd26f7c71f349dbb280f54ebd9235f7",
            "0b998539d29f45498def1bf2919fce18",
            "a0c2f6ab96f84135a094ab9f19c49138",
            "e031f3547b184bd3a1088b871194b67c",
            "78eb2ea591a348c5aa5d7b5b3b0eb810",
            "ce3bb42eb2d34294ad84019a80bb88d9",
            "a5495921684f4548aa6ac00cc26955aa",
            "2e019358119a4075828e44df864b591b",
            "2df7055b183549bc83e6c4a2a83d8a6a",
            "463b986f48ee4d43bc2529ef692fa42a",
            "ed5e3bf8692942e7ac794f8dbaceec99",
            "9728c65e75584e69a5f0079731ec1fd8",
            "b1d45f37aee24ebba7448decec0614ab",
            "044cf3f4c2314762a33d56c9294556f4",
            "d42a59e47b40495696da12aef8ceaffd",
            "1b1db59d3f574c84be39ed5d1f39d907",
            "f109a67864244d2aa2540ada8c5283d3",
            "f2bcb8a62be74c0dbbd06bd67b31ad54",
            "40a8fb2262984e0c91e6a84fbd01de97",
            "e944ddcecc534d1a839f877ed505a5b0",
            "32fcd0ac8ce6414da88d850132127704",
            "bef65b1b05e7490797390d19f20fa404",
            "9729511a432840199f41bf61d206f45b",
            "1fa46bb97549437d9671bc1107cb8890",
            "8cdab187cfc44daf9070d72467195684",
            "57bd29ef05214f3daa997ff96c2217aa",
            "9b386e24cec74153a17b0e462001ee97",
            "39e96afa265f4d2884e9fa4050ab0969",
            "4a8441bc5b2f406ab79d71559e593277",
            "79b071ae0d88466dab56dfad767b2039",
            "842c349e5ad44dcb8be771f5e686a254",
            "5ab35da2fb244dbe8b03cee0e29f476a",
            "e70ec3fb1e0d4b1194943d5bb67cb90c",
            "ce436a8911e54da09f24658596726b92",
            "9a98c8273f154c61af807aab551a51f9",
            "c95adefa67bb4cafa341eabd32b5a7be",
            "bdde60c90cb346a2b67354d91e83fe96",
            "f248b893c3ea4fa4870d38a953502385",
            "b49d6f19dfe245efb788c1fbfc6e25a6",
            "90d1a3f56d7c47b9bbf1be269514ebbc",
            "384e3bc506734b98a03abd4d17dde926",
            "038a3a47b941448b94b77807875c89d6",
            "db5d5452eb464df7b04d5269793f731e",
            "451226c8426946d7a8a992b86bf4120a",
            "7f9a014278204f52b7013a647da4184b",
            "dcbb7b89a6564a4f96ceff62312be8ba",
            "6ff6f6b967e74fef94b10898d57af4eb",
            "3b5734769485414281a2452bac8edac4",
            "d626846a29154eb0a817cfe6b7e99540",
            "f17d5c68b4f84ff88ac4b9c28802e949",
            "93887bc7f9aa41d6bcad100e1adc8349",
            "af84042f6ed14f48a3942e27970c7062",
            "0c382a8bb37b4e85acd76ee9ec735bab",
            "ad8d0aef86744743aa1ab5ecccb62931"
          ]
        },
        "id": "xUUbbN053K8i",
        "outputId": "e3f1b7db-34d7-4410-d25e-3c7ecaca8ff9"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2881' max='3200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2881/3200 9:07:38 < 1:00:40, 0.09 it/s, Epoch 0.90/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>-0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>-0.000700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>-0.000300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1600</td>\n",
              "      <td>-0.000200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1920</td>\n",
              "      <td>0.001000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2240</td>\n",
              "      <td>0.005000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2560</td>\n",
              "      <td>0.002400</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[save step 320] reward=0.458333  kl_orig_only=0.0054243  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d51f209b03524107b7ddf87d216b3e8c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 80.21 and ACC confidence [77.98, 82.27]  |  pred NA: 33/1319 (2.50%)  \n",
            "[save step 640] reward=0.5625  kl_orig_only=0.0548678  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2c2690553c28419e8c205070265fec8b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 82.26 and ACC confidence [80.10, 84.23]  |  pred NA: 25/1319 (1.90%)  \n",
            "[save step 960] reward=0.5  kl_orig_only=0.00559903  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "64957ba8da1345288f1803d04b512014",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 81.96 and ACC confidence [79.79, 83.94]  |  pred NA: 25/1319 (1.90%)  \n",
            "[save step 1280] reward=0.541667  kl_orig_only=0.00486956  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0cc00ca50eb14244b8623fa83346b224",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 83.24 and ACC confidence [81.13, 85.16]  |  pred NA: 18/1319 (1.36%)  \n",
            "[save step 1600] reward=0.729167  kl_orig_only=0.00308666  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "fb7ff2e985d4495ea059ae8deb1886ba",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 82.94 and ACC confidence [80.82, 84.88]  |  pred NA: 18/1319 (1.36%)  \n",
            "[save step 1920] reward=0.625  kl_orig_only=0.00271481  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "033e4e64f5bf48f1b95a65c2064ae67d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 82.34 and ACC confidence [80.18, 84.30]  |  pred NA: 21/1319 (1.59%)  \n",
            "[save step 2240] reward=0.666667  kl_orig_only=0.00617779  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ed5e3bf8692942e7ac794f8dbaceec99",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 82.79 and ACC confidence [80.66, 84.73]  |  pred NA: 17/1319 (1.29%)  \n",
            "[save step 2560] reward=0.6875  kl_orig_only=0.00259527  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bef65b1b05e7490797390d19f20fa404",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 82.64 and ACC confidence [80.50, 84.59]  |  pred NA: 17/1319 (1.29%)  \n",
            "[save step 2880] reward=0.625  kl_orig_only=0.00648064  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e70ec3fb1e0d4b1194943d5bb67cb90c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 83.02 and ACC confidence [80.90, 84.95]  |  pred NA: 18/1319 (1.36%)  \n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='3200' max='3200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [3200/3200 10:50:22, Epoch 1/9223372036854775807]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>320</td>\n",
              "      <td>-0.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>640</td>\n",
              "      <td>0.000900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>960</td>\n",
              "      <td>-0.000700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1280</td>\n",
              "      <td>-0.000300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1600</td>\n",
              "      <td>-0.000200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1920</td>\n",
              "      <td>0.001000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2240</td>\n",
              "      <td>0.005000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2560</td>\n",
              "      <td>0.002400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2880</td>\n",
              "      <td>-0.006900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3200</td>\n",
              "      <td>-0.001500</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[save step 3200] reward=0.520833  kl_orig_only=0.0174067  n1=2  n0=2\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "451226c8426946d7a8a992b86bf4120a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Evaluating trained_grpo:   0%|          | 0/42 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trained_grpo] N=1319  ACC: 82.87 and ACC confidence [80.74, 84.80]  |  pred NA: 17/1319 (1.29%)  \n"
          ]
        }
      ],
      "source": [
        "if CKPT_LOAD_PATH:\n",
        "    trainer.train(resume_from_checkpoint=CKPT_LOAD_PATH)\n",
        "else:\n",
        "    trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Ge7S4ERp7oma",
      "metadata": {
        "id": "Ge7S4ERp7oma"
      },
      "outputs": [],
      "source": [
        "import gc, torch\n",
        "\n",
        "# if you still have a reference:\n",
        "del trainable_model          # and del any other refs like ref_model, baseline_model, trainer, etc.\n",
        "gc.collect()\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.empty_cache()\n",
        "    torch.cuda.ipc_collect()   # optional, sometimes helps in notebooks\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "k5M1jnQdXDxC",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "k5M1jnQdXDxC",
        "outputId": "8f0fe89a-4491-40c2-a7da-d5b6344bf349"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using masked KL penalty.\n"
          ]
        }
      ],
      "source": [
        "# Option B: Train in chunks with TRL GRPOTrainer, run your eval() between chunks.\n",
        "# Assumes you already have:\n",
        "#   - model, tokenizer\n",
        "#   - classes_usable (paraphrase classes)\n",
        "#   - val_ds (gsm8k validation split you made)\n",
        "#   - your evaluate_gsm8k(...) and (optional) evaluate_gsm8k_paraphrase_agreement_batched(...)\n",
        "#   - your agreement_reward_fn(...) (reward over completions)\n",
        "#\n",
        "# Notes:\n",
        "# - GRPOTrainer supports resuming training; we use max_steps to control total steps.\n",
        "# - For IterableDataset, max_steps is the right way to bound training.\n",
        "import os\n",
        "os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
        "\n",
        "import math\n",
        "from trl import GRPOConfig\n",
        "try:\n",
        "    # If you use the masked KL subclass we discussed earlier:\n",
        "    TrainerCls = MaskedKLGRPOTrainer\n",
        "    print(\"Using masked KL penalty.\")\n",
        "except NameError:\n",
        "    from trl import GRPOTrainer\n",
        "    TrainerCls = GRPOTrainer\n",
        "\n",
        "# --- dataset for TRL training (paired stream) ---\n",
        "train_dataset = ParaphrasePairIterableFormatted(paraphrase_classes, tokenizer, seed=42)\n",
        "\n",
        "# --- GRPO config ---\n",
        "K = SAMPLES_PER_PAIR\n",
        "\n",
        "# ----------------------------\n",
        "# Chunked training + evaluation\n",
        "# ----------------------------\n",
        "TOTAL_STEPS = LOG_EVERY * 200\n",
        "CHUNK_STEPS = LOG_EVERY * 20\n",
        "\n",
        "EVAL_LIMIT = 1000\n",
        "EVAL_BS = 64\n",
        "\n",
        "DO_PARAPHRASE_EVAL = False\n",
        "PARA_EVAL_LIMIT = 200\n",
        "PARA_EVAL_CLASS_BS = 16   # classes per batch (prompts ~ 2*bs if each class has 2 items)\n",
        "\n",
        "\n",
        "grpo_args = GRPOConfig(\n",
        "    output_dir=WORK_PATH,\n",
        "    per_device_train_batch_size=2,       # (q0, q1)\n",
        "    gradient_accumulation_steps=1,\n",
        "    generation_batch_size=K * 2,   # <-- must be divisible by num_generations\n",
        "    num_generations=K,                   # K rollouts per prompt\n",
        "    max_completion_length=MAX_NEW_TOKENS,\n",
        "    learning_rate=1e-5,\n",
        "    beta=KL_WEIGHT,                        # adjust if your version uses beta instead\n",
        "    logging_steps=LOG_EVERY,\n",
        "    save_steps= CHUNK_STEPS,\n",
        "    max_steps= TOTAL_STEPS,                         # we'll set this per chunk below\n",
        ")\n",
        "\n",
        "trainer = TrainerCls(\n",
        "    model=trainable_model,\n",
        "    args=grpo_args,\n",
        "    train_dataset=train_dataset,\n",
        "    processing_class=tokenizer,\n",
        "    reward_funcs=agreement_reward_fn,\n",
        "    )\n",
        "\n",
        "# Set collator AFTER init (works with your GRPOTrainer signature)\n",
        "trainer.data_collator = lambda feats: grpo_tokenizing_collator(\n",
        "    feats, tokenizer, max_prompt_length=2048\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "m2ssljPjv5gT",
      "metadata": {
        "id": "m2ssljPjv5gT"
      },
      "source": [
        "### Debug"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "oBWStm0RxM1n",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oBWStm0RxM1n",
        "outputId": "c628f178-7254-48d3-b019-6a5ff5bc3540"
      },
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "Can only concatenate tensors but got <class 'str'>",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-761637024.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# After building trainer, grab one batch from its dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_train_dataloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    864\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stop_iteration\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    865\u001b[0m         \u001b[0mfirst_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 866\u001b[0;31m         \u001b[0mnext_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_batch_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    867\u001b[0m         \u001b[0mbatch_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    868\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstop_iteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m_fetch_batches\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    820\u001b[0m                             \u001b[0mbatches\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    821\u001b[0m                     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 822\u001b[0;31m                         \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    823\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mRuntimeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    824\u001b[0m                         raise RuntimeError(\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    615\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    619\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    620\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mTypeError\u001b[0m: Can only concatenate tensors but got <class 'str'>"
          ]
        }
      ],
      "source": [
        "# After building trainer, grab one batch from its dataloader\n",
        "dl = trainer.get_train_dataloader()\n",
        "batch = next(iter(dl))\n",
        "\n",
        "loss, out = trainer.compute_loss(trainer.model, batch, return_outputs=True)\n",
        "\n",
        "print(\"beta:\", float(out[\"beta\"]))\n",
        "print(\"kl_mean_all:\", float(out[\"kl_mean_all\"]))\n",
        "print(\"kl_orig_mean:\", float(out[\"kl_orig_mean\"]))\n",
        "print(\"kl_para_mean:\", float(out[\"kl_para_mean\"]))\n",
        "print(\"masked_kl:\", float(out[\"masked_kl\"]))\n",
        "print(\"loss:\", float(loss.detach().cpu()))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "q-p8g9fkv7gq",
      "metadata": {
        "id": "q-p8g9fkv7gq"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "def debug_kl_mask_once(trainer, train_dataset, tokenizer, K: int):\n",
        "    \"\"\"\n",
        "    Runs one forward pass and prints KL stats split by is_original.\n",
        "    \"\"\"\n",
        "    # Build a single batch: exactly 2 prompts (q0,q1)\n",
        "    it = iter(train_dataset)\n",
        "    batch = [next(it), next(it)]  # two records from ParaphrasePairIterable\n",
        "    # TRL will usually handle collation; easiest is to use trainer's dataloader.\n",
        "    # But we can mimic minimal inputs expected by compute_loss:\n",
        "    prompts = [b[\"prompt\"] for b in batch]\n",
        "    is_orig = torch.tensor([b[\"is_original\"] for b in batch])\n",
        "\n",
        "    # Let TRL tokenize the prompts the same way it does internally:\n",
        "    # Many TRL versions expect \"prompt\" in inputs and will process it.\n",
        "    inputs = {\"prompt\": prompts, \"is_original\": is_orig}\n",
        "\n",
        "    # Call compute_loss with outputs\n",
        "    loss, outputs = trainer.compute_loss(trainer.model, inputs, return_outputs=True)\n",
        "\n",
        "    # Find KL vector key\n",
        "    kl_key = None\n",
        "    for k in [\"kl\", \"kl_div\", \"per_sample_kl\", \"kl_per_sample\"]:\n",
        "        if isinstance(outputs, dict) and k in outputs:\n",
        "            kl_key = k\n",
        "            break\n",
        "    if kl_key is None:\n",
        "        raise KeyError(f\"Couldn't find per-sample KL in outputs keys: {list(outputs.keys())}\")\n",
        "\n",
        "    kl = outputs[kl_key].detach().cpu()         # expected shape [B*G] = [2*K]\n",
        "    mask = is_orig.float().repeat_interleave(K).cpu()\n",
        "\n",
        "    # Split KL by prompt type\n",
        "    kl_orig = kl[mask == 1]\n",
        "    kl_para = kl[mask == 0]\n",
        "\n",
        "    print(\"KL key:\", kl_key, \"shape:\", tuple(kl.shape))\n",
        "    print(\"KL mean (orig):\", float(kl_orig.mean()) if len(kl_orig) else None)\n",
        "    print(\"KL mean (para):\", float(kl_para.mean()) if len(kl_para) else None)\n",
        "\n",
        "    if \"masked_kl\" in outputs:\n",
        "        print(\"masked_kl (trainer):\", float(outputs[\"masked_kl\"].detach().cpu()))\n",
        "    else:\n",
        "        # recompute masked KL ourselves\n",
        "        masked_kl = (kl * mask).sum() / mask.sum().clamp_min(1.0)\n",
        "        print(\"masked_kl (recomputed):\", float(masked_kl))\n",
        "\n",
        "    print(\"loss:\", float(loss.detach().cpu()))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "fr6HpoQjv9QY",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fr6HpoQjv9QY",
        "outputId": "573da719-a9b0-48d7-cdc3-c38e51ca5446"
      },
      "outputs": [
        {
          "ename": "ValueError",
          "evalue": "The GRPOTrainer does not support returning outputs",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-1024889168.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdebug_kl_mask_once\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mSAMPLES_PER_PAIR\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/tmp/ipython-input-1970023020.py\u001b[0m in \u001b[0;36mdebug_kl_mask_once\u001b[0;34m(trainer, train_dataset, tokenizer, K)\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0;31m# Call compute_loss with outputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m     \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m     \u001b[0;31m# Find KL vector key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/tmp/ipython-input-276940471.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, model, inputs, return_outputs, **kwargs)\u001b[0m\n\u001b[1;32m     10\u001b[0m         \u001b[0;31m# Run the normal GRPO forward to get outputs that include per-sample stats.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0;31m# We rely on GRPOTrainer producing a dict-like outputs containing KL.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0;31m# inputs[\"is_original\"]: shape [batch_prompts] (here batch_prompts should be 2)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/trl/extras/profiling.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     96\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mprofiling_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/trl/trainer/grpo_trainer.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, model, inputs, return_outputs, num_items_in_batch)\u001b[0m\n\u001b[1;32m   2104\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2105\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mreturn_outputs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2106\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The GRPOTrainer does not support returning outputs\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2107\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_liger_kernel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2108\u001b[0m             \u001b[0;31m# Compute the loss using the liger grpo loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mValueError\u001b[0m: The GRPOTrainer does not support returning outputs"
          ]
        }
      ],
      "source": [
        "debug_kl_mask_once(trainer, train_dataset, tokenizer, K=SAMPLES_PER_PAIR)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sx5zM-iWwAJR",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "sx5zM-iWwAJR",
        "outputId": "34036967-7614-447c-9d61-5bc4162034de"
      },
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "Can only concatenate tensors but got <class 'str'>",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-3887887193.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0;31m# Train up to cur_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m     \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresume\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m     \u001b[0mresume\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   2323\u001b[0m                 \u001b[0mhf_hub_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_progress_bars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2324\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2325\u001b[0;31m             return inner_training_loop(\n\u001b[0m\u001b[1;32m   2326\u001b[0m                 \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2327\u001b[0m                 \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2616\u001b[0m                 \u001b[0mupdate_step\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2617\u001b[0m                 \u001b[0mnum_batches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgradient_accumulation_steps\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mupdate_step\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtotal_updates\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mremainder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2618\u001b[0;31m                 \u001b[0mbatch_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_items_in_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_batch_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_batches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2619\u001b[0m                 \u001b[0;31m# Store the number of batches for current gradient accumulation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2620\u001b[0m                 \u001b[0;31m# This is used to correctly scale the loss when the last accumulation step has fewer batches\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mget_batch_samples\u001b[0;34m(self, epoch_iterator, num_batches, device)\u001b[0m\n\u001b[1;32m   5652\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5653\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5654\u001b[0;31m                 \u001b[0mbatch_samples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   5655\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5656\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    864\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stop_iteration\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    865\u001b[0m         \u001b[0mfirst_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 866\u001b[0;31m         \u001b[0mnext_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_batch_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmain_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    867\u001b[0m         \u001b[0mbatch_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    868\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstop_iteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/data_loader.py\u001b[0m in \u001b[0;36m_fetch_batches\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    820\u001b[0m                             \u001b[0mbatches\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    821\u001b[0m                     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 822\u001b[0;31m                         \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    823\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mRuntimeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    824\u001b[0m                         raise RuntimeError(\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    613\u001b[0m     \"\"\"\n\u001b[1;32m    614\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    615\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhonor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    616\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    619\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(data, dim)\u001b[0m\n\u001b[1;32m    617\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    618\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Can only concatenate tensors but got {type(data[0])}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    620\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mTypeError\u001b[0m: Can only concatenate tensors but got <class 'str'>"
          ]
        }
      ],
      "source": [
        "\n",
        "resume = None\n",
        "cur_target = 0\n",
        "while cur_target < TOTAL_STEPS:\n",
        "    cur_target = min(cur_target + CHUNK_STEPS, TOTAL_STEPS)\n",
        "\n",
        "    # Update trainer's total step target.\n",
        "    # HF Trainer uses args.max_steps as the global target; calling train() again will continue.\n",
        "    trainer.args.max_steps = cur_target\n",
        "\n",
        "    # Train up to cur_target\n",
        "    trainer.train(resume_from_checkpoint = resume)\n",
        "    resume = True\n",
        "\n",
        "    # --- Evaluation (your functions) ---\n",
        "    # 1) GSM8K ACC\n",
        "    _ = evaluate_gsm8k(\n",
        "        model=trainer.model,\n",
        "        tokenizer=tokenizer,\n",
        "        dataset=test_ds,\n",
        "        limit=EVAL_LIMIT,\n",
        "        batch_size=EVAL_BS,\n",
        "        max_new_tokens=MAX_NEW_TOKENS,\n",
        "        name=f\"grpo_step{cur_target}\",\n",
        "        use_greedy=True,\n",
        "    )\n",
        "\n",
        "    # 2) Optional paraphrase agreement (fast batched version)\n",
        "    if DO_PARAPHRASE_EVAL:\n",
        "        _ = evaluate_gsm8k_paraphrase_agreement_batched(\n",
        "            model=trainer.model,\n",
        "            tokenizer=tokenizer,\n",
        "            dataset=paraphrase_classes,\n",
        "            limit=PARA_EVAL_LIMIT,\n",
        "            batch_size=PARA_EVAL_CLASS_BS,\n",
        "            max_new_tokens=MAX_NEW_TOKENS,\n",
        "            name=f\"agree_step{cur_target}\",\n",
        "            use_greedy=True,\n",
        "            show_samples=0,\n",
        "        )\n",
        "\n",
        "print(\"Done.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "SD17xtwHtI6d",
      "metadata": {
        "id": "SD17xtwHtI6d"
      },
      "source": [
        "## Load/Generated Answer-Level Partitions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "V5ZfIEeatNHc",
      "metadata": {
        "id": "V5ZfIEeatNHc"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from typing import List\n",
        "from collections import defaultdict\n",
        "\n",
        "# Assume something like:\n",
        "# from sentence_transformers import util as sbert_util\n",
        "# and 'embedder' is a global SentenceTransformer\n",
        "def cluster_embeddings_online(\n",
        "    embs: torch.Tensor,           # [K, D]\n",
        "    sim_threshold: float = 0.95,\n",
        ") -> List[int]:\n",
        "    \"\"\"\n",
        "    Simple online clustering:\n",
        "      - Start first cluster with embs[0]\n",
        "      - For each next embedding:\n",
        "            if cos_sim_to_any_center >= threshold -> join that cluster\n",
        "            else create a new cluster\n",
        "    Returns a list cluster_ids[k] in {0, 1, ...}\n",
        "    \"\"\"\n",
        "    K, D = embs.shape\n",
        "    centers = []          # list of [D] tensors\n",
        "    counts = []           # num points in each cluster\n",
        "    cluster_ids = []\n",
        "\n",
        "    for k in range(K):\n",
        "        e = embs[k : k + 1]          # [1, D] for util.cos_sim\n",
        "\n",
        "        best_cid = None\n",
        "        best_sim = -1.0\n",
        "\n",
        "        for cid, center in enumerate(centers):\n",
        "            sim = torch.nn.functional.cosine_similarity(\n",
        "                e, center.unsqueeze(0), dim=1\n",
        "            ).item()\n",
        "            if sim > best_sim:\n",
        "                best_sim = sim\n",
        "                best_cid = cid\n",
        "\n",
        "        if best_cid is not None and best_sim >= sim_threshold:\n",
        "            # join existing cluster\n",
        "            cluster_ids.append(best_cid)\n",
        "            # update running mean center\n",
        "            n = counts[best_cid]\n",
        "            centers[best_cid] = (centers[best_cid] * n + e.squeeze(0)) / (n + 1)\n",
        "            counts[best_cid] = n + 1\n",
        "        else:\n",
        "            # start new cluster\n",
        "            cid = len(centers)\n",
        "            centers.append(e.squeeze(0))\n",
        "            counts.append(1)\n",
        "            cluster_ids.append(cid)\n",
        "\n",
        "    return cluster_ids\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def cluster_answers_for_question(\n",
        "    answers: List[str],\n",
        "    *,\n",
        "    question: str,\n",
        "    embedder,\n",
        "    sim_threshold: float = 0.95,\n",
        "):\n",
        "    \"\"\"\n",
        "    Returns:\n",
        "      cluster_ids: [K] int\n",
        "      centers:     [C, D] tensor of cluster centers\n",
        "    \"\"\"\n",
        "    texts = [f\"Q: {question}\\nA: {a}\" for a in answers]\n",
        "    embs = embedder.encode(texts, convert_to_tensor=True, show_progress_bar=False)  # [K, D]\n",
        "\n",
        "    K, D = embs.shape\n",
        "    centers = []\n",
        "    counts = []\n",
        "    cluster_ids = []\n",
        "\n",
        "    for k in range(K):\n",
        "        e = embs[k]  # [D]\n",
        "        best_cid = None\n",
        "        best_sim = -1.0\n",
        "\n",
        "        for cid, c in enumerate(centers):\n",
        "            sim = torch.nn.functional.cosine_similarity(\n",
        "                e.unsqueeze(0), c.unsqueeze(0), dim=1\n",
        "            ).item()\n",
        "            if sim > best_sim:\n",
        "                best_sim = sim\n",
        "                best_cid = cid\n",
        "\n",
        "        if best_cid is not None and best_sim >= sim_threshold:\n",
        "            # join existing cluster\n",
        "            cluster_ids.append(best_cid)\n",
        "            n = counts[best_cid]\n",
        "            centers[best_cid] = (centers[best_cid] * n + e) / (n + 1)\n",
        "            counts[best_cid] = n + 1\n",
        "        else:\n",
        "            # create new cluster\n",
        "            cid = len(centers)\n",
        "            centers.append(e)\n",
        "            counts.append(1)\n",
        "            cluster_ids.append(cid)\n",
        "\n",
        "    centers_tensor = torch.stack(centers, dim=0)  # [C, D]\n",
        "    return cluster_ids, centers_tensor\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2izcDe6yyIxP",
      "metadata": {
        "id": "2izcDe6yyIxP"
      },
      "outputs": [],
      "source": [
        "import json, time, hashlib, random, pathlib\n",
        "from collections import Counter\n",
        "from typing import List, Dict, Any\n",
        "\n",
        "\n",
        "\n",
        "COH_CACHE_PATH = f\"cache/triviaqa_cluster.jsonl\"\n",
        "COH_CACHE_PATH = os.path.join(COHERENCE_DATA_PATH, COH_CACHE_PATH)\n",
        "COH_CACHE_PATH\n",
        "\n",
        "def _coh_meta_hash(meta: dict) -> str:\n",
        "    key = json.dumps(meta, sort_keys=True).encode(\"utf-8\")\n",
        "    return hashlib.sha1(key).hexdigest()[:10]\n",
        "\n",
        "\n",
        "def _coh_meta_compatible(current: dict, cached: dict) -> bool:\n",
        "    keys = [\n",
        "        \"train_size\",\n",
        "        \"n_questions\",\n",
        "        \"samples_per_x\",\n",
        "        \"baseline_model_id\",\n",
        "        \"clustering_version\",\n",
        "        \"seed\",\n",
        "    ]\n",
        "    return all(current.get(k) == cached.get(k) for k in keys)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def _sample_baseline_answers_for_x(\n",
        "    x: str,\n",
        "    *,\n",
        "    model,\n",
        "    tokenizer,\n",
        "    K: int,\n",
        "    max_new_tokens: int = 32,\n",
        "    temperature: float = 0.7,\n",
        "    top_p: float = 0.9,\n",
        ") -> List[str]:\n",
        "    ys = answer_batch_chat(\n",
        "        [x] * K,\n",
        "        model=model,\n",
        "        tokenizer=tokenizer,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        temperature=temperature,\n",
        "        top_p=top_p,\n",
        "    )\n",
        "    return ys\n",
        "\n",
        "\n",
        "def build_or_load_coherence_set_cluster(\n",
        "    train_ds,\n",
        "    *,\n",
        "    baseline_model,\n",
        "    tokenizer,\n",
        "    embedder,\n",
        "    paraphrase_classes_all = paraphrase_classes_all,              # <--- NEW: list per train_idx\n",
        "    cache_path: str = COH_CACHE_PATH,\n",
        "    n_questions: int = 600,\n",
        "    samples_per_x: int = 32,\n",
        "    sim_threshold: float = 0.95,\n",
        "    seed: int = 0,\n",
        ") -> List[Dict[str, Any]]:\n",
        "    \"\"\"\n",
        "    Coherence via paraphrases:\n",
        "      For each chosen train_idx:\n",
        "        - Get original question q0 and paraphrases q1..qM-1 from paraphrase_classes_all[train_idx]\n",
        "        - Generate 'samples_per_x' answers for EACH of these questions.\n",
        "          Total rollouts = samples_per_x * M.\n",
        "        - Cluster ALL answers together into semantic classes z.\n",
        "        - nu0(z | x): empirical distribution over clusters using only answers to q0.\n",
        "        - nu_coh(z | x): empirical distribution over clusters using all answers (q0 + paraphrases).\n",
        "        - ratio(z) = nu_coh(z) / nu0(z) for z with nu0(z) > 0.\n",
        "    \"\"\"\n",
        "    path = pathlib.Path(cache_path)\n",
        "    meta_current = {\n",
        "        \"train_size\": len(train_ds),\n",
        "        \"n_questions\": n_questions,\n",
        "        \"samples_per_x\": samples_per_x,\n",
        "        \"baseline_model_id\": getattr(baseline_model, \"name_or_path\", \"unknown\"),\n",
        "        \"clustering_version\": f\"cluster_sim_{sim_threshold}_with_paraphrases\",\n",
        "        \"seed\": seed,\n",
        "    }\n",
        "\n",
        "    # ---------- Try loading existing cache ----------\n",
        "    if path.exists():\n",
        "        records: List[Dict[str, Any]] = []\n",
        "        cached_meta = None\n",
        "        with path.open(\"r\", encoding=\"utf-8\") as f:\n",
        "            for line in f:\n",
        "                obj = json.loads(line)\n",
        "                if \"__meta__\" in obj:\n",
        "                    cached_meta = obj[\"__meta__\"]\n",
        "                else:\n",
        "                    records.append(obj)\n",
        "\n",
        "        if cached_meta is not None and _coh_meta_compatible(meta_current, cached_meta):\n",
        "            print(f\"[coh] Loaded clustered coherence set from {path}\")\n",
        "            print(f\"[coh] {len(records)} questions, \"\n",
        "                  f\"{cached_meta['samples_per_x']} samples each (per question variant)\")\n",
        "            return records\n",
        "        else:\n",
        "            print(\"[coh] Cache exists but metadata mismatch; rebuilding.\")\n",
        "\n",
        "    # ---------- Need to build a new coherence set ----------\n",
        "    print(\"[coh] Building new clustered coherence set (with paraphrases)...\")\n",
        "    path.parent.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "    rng = random.Random(seed)\n",
        "    all_idxs = list(range(len(train_ds)))\n",
        "    chosen_idxs = rng.sample(all_idxs, n_questions)\n",
        "\n",
        "    records: List[Dict[str, Any]] = []\n",
        "\n",
        "    for i, train_idx in enumerate(chosen_idxs):\n",
        "        ex = train_ds[train_idx]\n",
        "\n",
        "        # 1) Original question + paraphrases\n",
        "        q_variants = paraphrase_classes_all[train_idx]\n",
        "        if not q_variants:\n",
        "            # Fallback: if somehow no paraphrases exist, just use the original question\n",
        "            q0 = ex[\"question\"]\n",
        "            q_variants = [q0]\n",
        "        else:\n",
        "            # Ensure index 0 is the original as you described\n",
        "            q0 = q_variants[0]\n",
        "\n",
        "        # 2) Generate answers for each question variant\n",
        "        all_answers: List[str] = []\n",
        "        answer_src: List[int] = []  # which variant index produced each answer (0 = original)\n",
        "\n",
        "        for j, q_text in enumerate(q_variants):\n",
        "            ys = _sample_baseline_answers_for_x(\n",
        "                q_text,\n",
        "                model=baseline_model,\n",
        "                tokenizer=tokenizer,\n",
        "                K=samples_per_x,\n",
        "            )\n",
        "            all_answers.extend(ys)\n",
        "            answer_src.extend([j] * samples_per_x)\n",
        "\n",
        "        # 3) CLUSTERING-BASED SEMANTICS over ALL answers\n",
        "        #    (cluster IDs shared between original+paraphrases)\n",
        "        cluster_ids, centers = cluster_answers_for_question(\n",
        "            all_answers,\n",
        "            question=q0,             # cluster function can ignore this if it's answer-only\n",
        "            embedder=embedder,\n",
        "            sim_threshold=sim_threshold,\n",
        "        )\n",
        "\n",
        "        z_all = [f\"c{cid}\" for cid in cluster_ids]  # semantics E(y, x) for all rollouts\n",
        "\n",
        "        # 3a) nu_coh(z | x): empirical distribution over ALL rollouts (original + paraphrases)\n",
        "        counts_all = Counter(z_all)\n",
        "        total_all = float(len(z_all))\n",
        "        nu_coh = {z: c / total_all for z, c in counts_all.items()}\n",
        "\n",
        "        # 3b) nu0(z | x): distribution using ONLY answers to the original question (variant index 0)\n",
        "        z_orig = [z for z, src in zip(z_all, answer_src) if src == 0]\n",
        "        counts_orig = Counter(z_orig)\n",
        "        total_orig = float(len(z_orig))  # should be samples_per_x\n",
        "        nu0 = {z: c / total_orig for z, c in counts_orig.items()}\n",
        "\n",
        "        # 3c) ratio(z) = nu_coh(z) / nu0(z) where nu0(z) > 0\n",
        "        ratio = {}\n",
        "        for z, p0 in nu0.items():\n",
        "            if p0 > 0.0:\n",
        "                p_coh = nu_coh.get(z, 0.0)\n",
        "                ratio[z] = p_coh / p0\n",
        "\n",
        "        rec = {\n",
        "            \"train_idx\": train_idx,\n",
        "            \"question\": q0,                # original question\n",
        "            \"questions\": q_variants,       # original + paraphrases\n",
        "            \"answers\": all_answers,        # flattened, length = samples_per_x * len(q_variants)\n",
        "            \"answer_src\": answer_src,      # variant index for each answer\n",
        "            \"z_list\": z_all,               # cluster ID per answer\n",
        "            \"nu0\": nu0,                    # distribution from original-only rollouts\n",
        "            \"nu_coh\": nu_coh,              # distribution from all (original + paraphrases)\n",
        "            \"ratio\": ratio,                # nu_coh(z) / nu0(z) for z with nu0(z) > 0\n",
        "            \"centers\": centers.cpu().tolist(),\n",
        "        }\n",
        "        records.append(rec)\n",
        "\n",
        "        if (i + 1) % 50 == 0:\n",
        "            print(f\"[coh] processed {i+1}/{n_questions} base questions\")\n",
        "\n",
        "    # ---------- Save cache ----------\n",
        "    meta_to_save = dict(meta_current)\n",
        "    meta_to_save[\"created_utc\"] = int(time.time())\n",
        "    meta_to_save[\"hash\"] = _coh_meta_hash(meta_to_save)\n",
        "\n",
        "    # Uncomment this block when you're ready to persist to disk\n",
        "\n",
        "    tmp = path.with_suffix(path.suffix + \".tmp\")\n",
        "    with tmp.open(\"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(json.dumps({\"__meta__\": meta_to_save}) + \"\\n\")\n",
        "        for rec in records:\n",
        "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
        "    tmp.replace(path)\n",
        "\n",
        "    print(f\"[coh] Saved clustered coherence set to {path}\")\n",
        "\n",
        "\n",
        "    return records\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "DrkEk66cDZ4i",
      "metadata": {
        "id": "DrkEk66cDZ4i"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from typing import List, Union, Iterable, Optional\n",
        "\n",
        "class LMEmbedder:\n",
        "    \"\"\"\n",
        "    Generic wrapper to use any HF LM (Qwen, LLaMA, Phi-3, etc.)\n",
        "    as a sentence embedder with a SentenceTransformer-like interface.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        model,\n",
        "        tokenizer,\n",
        "        max_length: int = 128,\n",
        "        normalize: bool = True,\n",
        "        device: Optional[Union[str, torch.device]] = None,\n",
        "        batch_size: int = 32,\n",
        "    ):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model: HF model (e.g. AutoModelForCausalLM / LlamaForCausalLM / Phi3ForCausalLM, possibly PEFT-wrapped)\n",
        "            tokenizer: matching tokenizer\n",
        "            max_length: truncation length for inputs\n",
        "            normalize: whether to L2-normalize embeddings\n",
        "            device: override device; if None, use model.device\n",
        "            batch_size: batching size for .encode()\n",
        "        \"\"\"\n",
        "        self.model = model\n",
        "        self.tokenizer = tokenizer\n",
        "        self.max_length = max_length\n",
        "        self.normalize = normalize\n",
        "        self.batch_size = batch_size\n",
        "\n",
        "        # Ensure we can access hidden states\n",
        "        if hasattr(self.model.config, \"output_hidden_states\"):\n",
        "            self.model.config.output_hidden_states = True\n",
        "\n",
        "        self.model.eval()\n",
        "\n",
        "        # Figure out device\n",
        "        if device is None:\n",
        "            # works for most HF models, including PEFT-wrapped\n",
        "            try:\n",
        "                self.device = next(self.model.parameters()).device\n",
        "            except StopIteration:\n",
        "                self.device = torch.device(\"cpu\")\n",
        "        else:\n",
        "            self.device = torch.device(device)\n",
        "\n",
        "        # Make sure we have a pad_token_id\n",
        "        if self.tokenizer.pad_token_id is None:\n",
        "            # common choice: use EOS as pad\n",
        "            if self.tokenizer.eos_token_id is not None:\n",
        "                self.tokenizer.pad_token = self.tokenizer.eos_token\n",
        "                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def encode(\n",
        "        self,\n",
        "        texts: Union[str, List[str], Iterable[str]],\n",
        "        convert_to_tensor: bool = True,\n",
        "        show_progress_bar: bool = False,\n",
        "    ):\n",
        "        \"\"\"\n",
        "        SentenceTransformer-style API.\n",
        "\n",
        "        Args:\n",
        "            texts: single string or list/iterable of strings\n",
        "            convert_to_tensor: if True, returns a torch.Tensor [N, D]\n",
        "                               else returns a list of numpy arrays\n",
        "            show_progress_bar: ignored here, but kept for API compatibility\n",
        "\n",
        "        Returns:\n",
        "            embeddings: [N, D] tensor or list of np.array\n",
        "        \"\"\"\n",
        "        if isinstance(texts, str):\n",
        "            texts = [texts]\n",
        "        else:\n",
        "            texts = list(texts)\n",
        "\n",
        "        all_embs = []\n",
        "        bs = self.batch_size\n",
        "\n",
        "        for i in range(0, len(texts), bs):\n",
        "            batch_texts = texts[i : i + bs]\n",
        "\n",
        "            enc = self.tokenizer(\n",
        "                batch_texts,\n",
        "                padding=True,\n",
        "                truncation=True,\n",
        "                max_length=self.max_length,\n",
        "                return_tensors=\"pt\",\n",
        "            )\n",
        "            input_ids = enc[\"input_ids\"].to(self.device)\n",
        "            attention_mask = enc[\"attention_mask\"].to(self.device)\n",
        "\n",
        "            out = self.model(\n",
        "                input_ids=input_ids,\n",
        "                attention_mask=attention_mask,\n",
        "                output_hidden_states=True,\n",
        "            )\n",
        "\n",
        "            # last hidden state: [B, T, D]\n",
        "            # some models return .last_hidden_state, some in hidden_states[-1]\n",
        "            if hasattr(out, \"hidden_states\") and out.hidden_states is not None:\n",
        "                hidden = out.hidden_states[-1]\n",
        "            else:\n",
        "                # fallback to last_hidden_state if hidden_states is not present\n",
        "                hidden = out.last_hidden_state\n",
        "\n",
        "            # mean-pool over non-pad tokens\n",
        "            mask = attention_mask.unsqueeze(-1)  # [B, T, 1]\n",
        "            masked_hidden = hidden * mask\n",
        "            lengths = mask.sum(dim=1).clamp(min=1)  # [B, 1]\n",
        "            emb = masked_hidden.sum(dim=1) / lengths  # [B, D]\n",
        "\n",
        "            if self.normalize:\n",
        "                emb = F.normalize(emb, p=2, dim=1)\n",
        "\n",
        "            all_embs.append(emb.cpu())\n",
        "\n",
        "        embs = torch.cat(all_embs, dim=0)  # [N, D]\n",
        "\n",
        "        if convert_to_tensor:\n",
        "            return embs\n",
        "        else:\n",
        "            return [e.numpy() for e in embs]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Zp-xhGcSyXRq",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 404
        },
        "id": "Zp-xhGcSyXRq",
        "outputId": "503eaed4-b348-42d1-ce84-873811d4939b"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "28501ff67bde42a6a9410c55c8977bb9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "994a0db0c63545118bce1b33958c1614",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "cc3f6cd9667a49edafbab2449e02b4e5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0d5d447849144930962682f67ed4b1cb",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4288dfb5897345768b2378e8d159f539",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f1e6da22975748af9609388a2c5d5798",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ce24fe0613244bc6887926753798d897",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a593af02a7cf489f98067e067d0c5a38",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.txt: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e3507cdb84464004b6add70103ccfe09",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c05e536a07e44073895cc90f77f7655d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a6837fefcfe9491787b137373a41995e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[coh] Loaded clustered coherence set from /content/drive/MyDrive/coherence-answer/Qwen/Qwen2.5-3B-Instruct/instruct/cache/triviaqa_cluster.jsonl\n",
            "[coh] 6400 questions, 16 samples each (per question variant)\n"
          ]
        }
      ],
      "source": [
        "from sentence_transformers import SentenceTransformer\n",
        "\n",
        "embedder = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\")\n",
        "\n",
        "# embedder = LMEmbedder(\n",
        "#     model=baseline_model,\n",
        "#     tokenizer=tokenizer,\n",
        "#     max_length=128,\n",
        "#     normalize=True,\n",
        "#     batch_size=32,\n",
        "# )\n",
        "\n",
        "coh_records = build_or_load_coherence_set_cluster(\n",
        "    train_ds,\n",
        "    baseline_model=trainable_model,\n",
        "    tokenizer=tokenizer,\n",
        "    embedder = embedder,\n",
        "    n_questions=32 * 200,\n",
        "    samples_per_x=16,\n",
        "    cache_path=COH_CACHE_PATH,\n",
        "    seed = 42\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ikGAzn8nFtkS",
      "metadata": {
        "id": "ikGAzn8nFtkS"
      },
      "outputs": [],
      "source": [
        "\n",
        "import json\n",
        "\n",
        "for rec in coh_records[:1]:\n",
        "    temp_rec = {k: v for k, v in rec.items() if k != 'centers'}\n",
        "    print(json.dumps(temp_rec, indent=2))\n",
        "\n",
        "    # Fetch and print ground truth\n",
        "    question_idx = rec[\"train_idx\"]\n",
        "    ground_truth_obj = train_ds[question_idx][\"answer\"]\n",
        "    print(f\"  Ground Truth Answers (train_idx={question_idx}): {ground_truth_obj}\")\n",
        "    print(\"\\n\" + \"=\"*80 + \"\\n\") # Separator for readability\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "aJqwV-W114QE",
      "metadata": {
        "id": "aJqwV-W114QE"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def E_cluster_for_known_x(\n",
        "    y: str,\n",
        "    x: str,\n",
        "    rec: dict,        # one coherence record for this x\n",
        "    embedder,\n",
        "    min_sim: float = 0.0,  # or e.g. 0.7 if you want a floor\n",
        ") -> str:\n",
        "    \"\"\"\n",
        "    Map a new completion y for a *known* question x\n",
        "    onto one of the existing clusters in rec.\n",
        "    Returns a cluster label like 'c0', 'c1', ...\n",
        "    \"\"\"\n",
        "    # 1) embed the (x, y) pair the same way as during clustering\n",
        "    text = f\"Q: {x}\\nA: {y}\"\n",
        "    emb = embedder.encode([text], convert_to_tensor=True, show_progress_bar=False)[0]  # [D]\n",
        "\n",
        "    # 2) load centers\n",
        "    centers = torch.tensor(rec[\"centers\"], dtype=emb.dtype, device=emb.device)  # [C, D]\n",
        "\n",
        "    # 3) cosine similarity to each center\n",
        "    sims = torch.nn.functional.cosine_similarity(\n",
        "        emb.unsqueeze(0), centers, dim=1\n",
        "    )  # [C]\n",
        "\n",
        "    best_cid = int(torch.argmax(sims).item())\n",
        "    best_sim = float(sims[best_cid].item())\n",
        "\n",
        "    if best_sim < min_sim:\n",
        "        # Optional: treat as 'other' cluster; depends on your design.\n",
        "        # For many WSFT setups it’s simpler to just force assignment.\n",
        "        pass\n",
        "\n",
        "    return f\"c{best_cid}\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0jXMFGTV1607",
      "metadata": {
        "id": "0jXMFGTV1607"
      },
      "outputs": [],
      "source": [
        "# Build a lookup once after loading coh_records\n",
        "coh_by_idx = {rec[\"train_idx\"]: rec for rec in coh_records}\n",
        "\n",
        "def E_for_training_example(train_idx: int, x: str, y: str) -> str:\n",
        "    rec = coh_by_idx[train_idx]\n",
        "    return E_cluster_for_known_x(y, x, rec, embedder)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5_3wH5NkftWo",
      "metadata": {
        "id": "5_3wH5NkftWo"
      },
      "outputs": [],
      "source": [
        "from collections import Counter, defaultdict\n",
        "from typing import List, Dict, Any, Iterable, Tuple, Optional\n",
        "\n",
        "\n",
        "from dataclasses import dataclass\n",
        "\n",
        "@dataclass\n",
        "class WSFTExample:\n",
        "    x: str\n",
        "    y: str\n",
        "    w: float\n",
        "    train_idx: int\n",
        "    sample_idx: int   # 0..K-1 within that question\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def make_wsft_examples_from_coh_records(\n",
        "    coh_records: List[Dict[str, Any]],\n",
        "    alpha: float,\n",
        "    *,\n",
        "    # By default, \"coh\" means use x0 and x1 only (original + 1 paraphrase).\n",
        "    # If you want to include more paraphrases, pass e.g. use_variants=None to use all.\n",
        "    use_variants: Optional[Tuple[int, ...]] = (0, 1),\n",
        "    eps: float = 1e-8,\n",
        "    weight_clip: Optional[float] = None,\n",
        ") -> List[WSFTExample]:\n",
        "    \"\"\"\n",
        "    Build WSFT training tuples (x, y, w) from coherence cluster records.\n",
        "\n",
        "    E(y) is implemented by using the stored cluster id z in rec[\"z_list\"].\n",
        "\n",
        "    v0(z|x): empirical distribution of cluster labels among answers generated for *that* x (variant-specific).\n",
        "    v_coh(z|x): empirical distribution of cluster labels among answers generated for x0 and x1 (by default).\n",
        "\n",
        "    w = (1-alpha) + alpha * v_coh(z|x) / v0(z|x)\n",
        "    \"\"\"\n",
        "    assert 0.0 <= alpha <= 1.0, \"alpha must be in [0, 1].\"\n",
        "\n",
        "    wsft_examples: List[WSFTExample] = []\n",
        "\n",
        "    for rec in coh_records:\n",
        "        train_idx = rec[\"train_idx\"]\n",
        "        questions: List[str] = rec[\"questions\"]          # [x0, x1, ...]\n",
        "        answers: List[str] = rec[\"answers\"]              # flattened\n",
        "        answer_src: List[int] = rec[\"answer_src\"]        # which x_j produced each answer\n",
        "        z_list: List[str] = rec[\"z_list\"]                # E(y): cluster id per answer\n",
        "\n",
        "        if not questions or not answers:\n",
        "            continue\n",
        "\n",
        "        # Decide which variants define \"coherence\" set (default: only x0 and x1).\n",
        "        if use_variants is None:\n",
        "            coh_vars: Iterable[int] = range(len(questions))\n",
        "        else:\n",
        "            coh_vars = tuple(v for v in use_variants if v < len(questions))\n",
        "            if len(coh_vars) == 0:\n",
        "                coh_vars = (0,)  # fallback\n",
        "\n",
        "        coh_var_set = set(coh_vars)\n",
        "\n",
        "        # ---- v_coh over answers from selected variants (default: x0+x1) ----\n",
        "        coh_z = [z for z, src in zip(z_list, answer_src) if src in coh_var_set]\n",
        "        if len(coh_z) == 0:\n",
        "            continue\n",
        "        coh_counts = Counter(coh_z)\n",
        "        coh_total = float(len(coh_z))\n",
        "        v_coh = {z: c / coh_total for z, c in coh_counts.items()}\n",
        "\n",
        "        # ---- v0 per variant x_j (only for variants we include) ----\n",
        "        v0_by_var: Dict[int, Dict[str, float]] = {}\n",
        "        for j in coh_var_set:\n",
        "            z_j = [z for z, src in zip(z_list, answer_src) if src == j]\n",
        "            if len(z_j) == 0:\n",
        "                continue\n",
        "            cnt = Counter(z_j)\n",
        "            tot = float(len(z_j))\n",
        "            v0_by_var[j] = {z: c / tot for z, c in cnt.items()}\n",
        "\n",
        "        # ---- build WSFT examples ----\n",
        "        per_var_k = defaultdict(int)  # sample_idx within that x_j\n",
        "\n",
        "        for y, z, src in zip(answers, z_list, answer_src):\n",
        "            if src not in coh_var_set:\n",
        "                continue\n",
        "            if src not in v0_by_var:\n",
        "                continue  # no stats for this variant\n",
        "\n",
        "            v0_z = max(v0_by_var[src].get(z, 0.0), eps)\n",
        "            vcoh_z = v_coh.get(z, 0.0)\n",
        "\n",
        "            w = (1.0 - alpha) + alpha * (vcoh_z / v0_z)\n",
        "            if weight_clip is not None:\n",
        "                w = min(w, weight_clip)\n",
        "\n",
        "            wsft_examples.append(\n",
        "                WSFTExample(\n",
        "                    x=questions[src],           # x is either x0 or x1 (or more if you allow)\n",
        "                    y=y,\n",
        "                    w=float(w),\n",
        "                    train_idx=train_idx,\n",
        "                    sample_idx=int(per_var_k[src]),\n",
        "                )\n",
        "            )\n",
        "            per_var_k[src] += 1\n",
        "\n",
        "    print(f\"[wsft] Built {len(wsft_examples)} WSFT examples (alpha={alpha}, use_variants={use_variants})\")\n",
        "    return wsft_examples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "_hU9FbRClOjZ",
      "metadata": {
        "id": "_hU9FbRClOjZ"
      },
      "outputs": [],
      "source": [
        "from collections import Counter, defaultdict\n",
        "from typing import List, Dict, Any, Optional, Tuple, Iterable\n",
        "\n",
        "def make_wsft_examples_from_coh_records_majority(\n",
        "    coh_records: List[Dict[str, Any]],\n",
        "    alpha: float,\n",
        "    *,\n",
        "    use_variants: Optional[Tuple[int, ...]] = (0, 1),\n",
        "    eps: float = 1e-8,\n",
        "    weight_clip: Optional[float] = None,\n",
        ") -> List[WSFTExample]:\n",
        "    \"\"\"\n",
        "    Build WSFT training tuples (x, y, w) from coherence cluster records.\n",
        "\n",
        "    E(y) is implemented by using the stored cluster id z in rec[\"z_list\"].\n",
        "\n",
        "    v0(z|x): empirical distribution of cluster labels among answers generated for *that* x (variant-specific).\n",
        "    v_coh(z|x): MAJORITY VOTE over cluster labels among answers from the selected variants\n",
        "               (default: variants 0 and 1). Implemented as a degenerate distribution:\n",
        "                 v_coh(z|x) = 1 if z == z*, else 0,\n",
        "               where z* is the most frequent cluster among those answers.\n",
        "\n",
        "    w = (1-alpha) + alpha * v_coh(z|x) / v0(z|x)\n",
        "      = (1-alpha) + alpha * (1 / v0(z|x)) if z==z*, else (1-alpha)\n",
        "    \"\"\"\n",
        "    assert 0.0 <= alpha <= 1.0, \"alpha must be in [0, 1].\"\n",
        "\n",
        "    def _z_numeric(z: str) -> int:\n",
        "        # deterministic tie-break (prefer smaller cluster id)\n",
        "        try:\n",
        "            if isinstance(z, str) and z.startswith(\"c\"):\n",
        "                return int(z[1:])\n",
        "            return int(z)\n",
        "        except Exception:\n",
        "            return 0\n",
        "\n",
        "    wsft_examples: List[WSFTExample] = []\n",
        "\n",
        "    for rec in coh_records:\n",
        "        train_idx = rec[\"train_idx\"]\n",
        "        questions: List[str] = rec[\"questions\"]\n",
        "        answers: List[str] = rec[\"answers\"]\n",
        "        answer_src: List[int] = rec[\"answer_src\"]\n",
        "        z_list: List[str] = rec[\"z_list\"]\n",
        "\n",
        "        if not questions or not answers:\n",
        "            continue\n",
        "\n",
        "        # Decide which variants define the \"coherence\" majority vote set.\n",
        "        if use_variants is None:\n",
        "            coh_vars: Iterable[int] = range(len(questions))\n",
        "        else:\n",
        "            coh_vars = tuple(v for v in use_variants if v < len(questions))\n",
        "            if len(coh_vars) == 0:\n",
        "                coh_vars = (0,)\n",
        "\n",
        "        coh_var_set = set(coh_vars)\n",
        "\n",
        "        # ---- majority vote z* over answers from selected variants ----\n",
        "        coh_z = [z for z, src in zip(z_list, answer_src) if src in coh_var_set]\n",
        "        if len(coh_z) == 0:\n",
        "            continue\n",
        "\n",
        "        coh_counts = Counter(coh_z)\n",
        "        # pick max count; tie-break by smaller numeric cluster id for determinism\n",
        "        z_star = max(coh_counts.items(), key=lambda kv: (kv[1], -_z_numeric(kv[0])))[0]\n",
        "\n",
        "        # ---- v0 per variant x_j ----\n",
        "        v0_by_var: Dict[int, Dict[str, float]] = {}\n",
        "        for j in coh_var_set:\n",
        "            z_j = [z for z, src in zip(z_list, answer_src) if src == j]\n",
        "            if len(z_j) == 0:\n",
        "                continue\n",
        "            cnt = Counter(z_j)\n",
        "            tot = float(len(z_j))\n",
        "            v0_by_var[j] = {z: c / tot for z, c in cnt.items()}\n",
        "\n",
        "        # ---- build WSFT examples ----\n",
        "        per_var_k = defaultdict(int)\n",
        "\n",
        "        for y, z, src in zip(answers, z_list, answer_src):\n",
        "            if src not in coh_var_set:\n",
        "                continue\n",
        "            if src not in v0_by_var:\n",
        "                continue\n",
        "\n",
        "            v0_z = max(v0_by_var[src].get(z, 0.0), eps)\n",
        "            vcoh_z = 1.0 if z == z_star else 0.0  # majority vote distribution\n",
        "\n",
        "            w = (1.0 - alpha) + alpha * (vcoh_z / v0_z)\n",
        "            if weight_clip is not None:\n",
        "                w = min(w, weight_clip)\n",
        "\n",
        "            wsft_examples.append(\n",
        "                WSFTExample(\n",
        "                    x=questions[src],\n",
        "                    y=y,\n",
        "                    w=float(w),\n",
        "                    train_idx=train_idx,\n",
        "                    sample_idx=int(per_var_k[src]),\n",
        "                )\n",
        "            )\n",
        "            per_var_k[src] += 1\n",
        "\n",
        "    print(f\"[wsft] Built {len(wsft_examples)} WSFT examples (alpha={alpha}, use_variants={use_variants})\")\n",
        "    return wsft_examples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4P4XzkJNf6oG",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4P4XzkJNf6oG",
        "outputId": "59e001c4-6903-49ff-80dd-9c05dc3c958b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[wsft] Built 204800 WSFT examples (alpha=0.5, use_variants=(0, 1))\n"
          ]
        }
      ],
      "source": [
        "alpha = 0.5\n",
        "wsft_examples = make_wsft_examples_from_coh_records_majority(coh_records, alpha=alpha)  # uses x0+x1 only\n",
        "# or include all paraphrases in v_coh:\n",
        "# wsft_examples_all = make_wsft_examples_from_coh_records(coh_records, alpha=alpha, use_variants=None)\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Q3RRjK20mNjF",
        "u7T3ljRumNjG",
        "IviYIdL0mNjG",
        "Ul00h3tomNjH",
        "ChcYKgQMmNjI",
        "SD17xtwHtI6d"
      ],
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "003329dff7a54a35afc3ebbe07bbb161": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "0125142e12494997bae874d02db91c04": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_627412ffcc784ebbb5f201e23ba9a47e",
            "placeholder": "​",
            "style": "IPY_MODEL_48f1bb6613244a678c7c70957459aa71",
            "value": " 42/42 [34:14&lt;00:00, 40.36s/it, N=1319, ACC%=80.21, greedy=1]"
          }
        },
        "02d7422302084d3ca7f138055a525be0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "02fad28d4889434c89b37c102c227de8": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "033e4e64f5bf48f1b95a65c2064ae67d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_4dd26f7c71f349dbb280f54ebd9235f7",
              "IPY_MODEL_0b998539d29f45498def1bf2919fce18",
              "IPY_MODEL_a0c2f6ab96f84135a094ab9f19c49138"
            ],
            "layout": "IPY_MODEL_e031f3547b184bd3a1088b871194b67c"
          }
        },
        "038a3a47b941448b94b77807875c89d6": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "044cf3f4c2314762a33d56c9294556f4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e944ddcecc534d1a839f877ed505a5b0",
            "placeholder": "​",
            "style": "IPY_MODEL_32fcd0ac8ce6414da88d850132127704",
            "value": " 42/42 [34:49&lt;00:00, 48.42s/it, N=1319, ACC%=82.79, greedy=1]"
          }
        },
        "05b29e18fea8439ba26726bfd832a6e2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0b998539d29f45498def1bf2919fce18": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a5495921684f4548aa6ac00cc26955aa",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2e019358119a4075828e44df864b591b",
            "value": 42
          }
        },
        "0c382a8bb37b4e85acd76ee9ec735bab": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0cc00ca50eb14244b8623fa83346b224": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_dacd682ca95047248ad4c58556ef6752",
              "IPY_MODEL_d7a0a9267a9b41df998216c52db22b9c",
              "IPY_MODEL_3716b28289c9497c87c8655c358a54d6"
            ],
            "layout": "IPY_MODEL_02fad28d4889434c89b37c102c227de8"
          }
        },
        "0f5c722767084418a3018b67bb63f295": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1620cf0812d94dfdb188cabd69d7185f",
            "placeholder": "​",
            "style": "IPY_MODEL_c55dcb589588469c93da86a5fb17c01f",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "1166eda0f8dd4b66a842cc6ed3796780": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "116dc2b6474e4266a2a3aa71365627d5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "12aea8a6e24d41e09fa5abc7e96540c5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1620cf0812d94dfdb188cabd69d7185f": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "170a030d0b954dd4a0597954e8bcb780": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "182ddfeba28d4fd59d5f58d1e5bff82c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1b1db59d3f574c84be39ed5d1f39d907": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1fa46bb97549437d9671bc1107cb8890": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4a8441bc5b2f406ab79d71559e593277",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_79b071ae0d88466dab56dfad767b2039",
            "value": 42
          }
        },
        "2c2690553c28419e8c205070265fec8b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_bfd3da20dbea45e891614b19f66fef4e",
              "IPY_MODEL_452663c398e1482da9cca0301829c064",
              "IPY_MODEL_bdb1cbec5e874db3bffff56442f6ec3c"
            ],
            "layout": "IPY_MODEL_003329dff7a54a35afc3ebbe07bbb161"
          }
        },
        "2df7055b183549bc83e6c4a2a83d8a6a": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2e019358119a4075828e44df864b591b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "316b66e65c3e4ab18b283ee3a0be6c23": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_375df7873dcc4e76bb96ee39ddb859bf",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5f40d588569f4580a720f57923bf1969",
            "value": 42
          }
        },
        "32fcd0ac8ce6414da88d850132127704": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "3345258ab7fd4dc085c81ca3eacb35dd": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "346936dae8d0460b88f574168fea2209": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "35ce343819c44025a76d6683a59069a2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "3716b28289c9497c87c8655c358a54d6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1166eda0f8dd4b66a842cc6ed3796780",
            "placeholder": "​",
            "style": "IPY_MODEL_f0342932dde4420c90682462b8a56396",
            "value": " 42/42 [34:05&lt;00:00, 47.28s/it, N=1319, ACC%=83.24, greedy=1]"
          }
        },
        "375df7873dcc4e76bb96ee39ddb859bf": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "384e3bc506734b98a03abd4d17dde926": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "39b9ffa02eb84f85a76156cbc1c8a88b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_346936dae8d0460b88f574168fea2209",
            "placeholder": "​",
            "style": "IPY_MODEL_a7531cf6b513470da02333b1420e5c9c",
            "value": "Paraphrase agreement qwen3b: 100%"
          }
        },
        "39e96afa265f4d2884e9fa4050ab0969": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "3a3884b6b2c045c2ae33bba4c0c462b4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3345258ab7fd4dc085c81ca3eacb35dd",
            "placeholder": "​",
            "style": "IPY_MODEL_441473c91daf4aeea3a91dc460512b2a",
            "value": " 7473/7473 [2:12:43&lt;00:00,  1.12s/it, classes=7473, AGREE%=67.14, greedy=1]"
          }
        },
        "3b29cebc081f425f94cf7993561bfb0e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3b5734769485414281a2452bac8edac4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "40a8fb2262984e0c91e6a84fbd01de97": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "40e14b1878eb4251bba8e6b57d4414f5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fdfb607e571447e992ac015475b36167",
            "placeholder": "​",
            "style": "IPY_MODEL_b15072f9000c4b3c91433f7dd3f21c67",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "41db2fadc95849a3a5d8d408a07a4b24": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "441473c91daf4aeea3a91dc460512b2a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "451226c8426946d7a8a992b86bf4120a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7f9a014278204f52b7013a647da4184b",
              "IPY_MODEL_dcbb7b89a6564a4f96ceff62312be8ba",
              "IPY_MODEL_6ff6f6b967e74fef94b10898d57af4eb"
            ],
            "layout": "IPY_MODEL_3b5734769485414281a2452bac8edac4"
          }
        },
        "452663c398e1482da9cca0301829c064": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5889e1ddd2bb439fb860fc23e77093c5",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_ecb11be7e0f445e99ea4be35fa486cdf",
            "value": 42
          }
        },
        "463b986f48ee4d43bc2529ef692fa42a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "48f1bb6613244a678c7c70957459aa71": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4a8441bc5b2f406ab79d71559e593277": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4dd26f7c71f349dbb280f54ebd9235f7": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_78eb2ea591a348c5aa5d7b5b3b0eb810",
            "placeholder": "​",
            "style": "IPY_MODEL_ce3bb42eb2d34294ad84019a80bb88d9",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "4e49a3b0ee974fe9bd5810803d598c78": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "51096091a9594b3cbba9e1f06c439f65": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "57bd29ef05214f3daa997ff96c2217aa": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "5889e1ddd2bb439fb860fc23e77093c5": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "59f2130e9c564172b17f10efa43f72d6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5ab35da2fb244dbe8b03cee0e29f476a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "5b81518166804f87b51ca52ed12db7af": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "5f40d588569f4580a720f57923bf1969": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "602c2e372b6c40e7b8904f1c1525c2fa": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "627412ffcc784ebbb5f201e23ba9a47e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "630f910e53c848eb837d2b38607fa849": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9b19089b3ef840a2b0df283d9ae9abef",
              "IPY_MODEL_9e2cb734212a4abe9ecf419d3e673223",
              "IPY_MODEL_d0869f15627f48d89e1fff45de94aafd"
            ],
            "layout": "IPY_MODEL_b3fdeedd557a4743beb7038c796ecc0f"
          }
        },
        "64957ba8da1345288f1803d04b512014": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_0f5c722767084418a3018b67bb63f295",
              "IPY_MODEL_316b66e65c3e4ab18b283ee3a0be6c23",
              "IPY_MODEL_c4d0d0db991d42468b1ddd6f1768be8b"
            ],
            "layout": "IPY_MODEL_e2dd3220008c47d8ace31cfe8595a443"
          }
        },
        "6ff6f6b967e74fef94b10898d57af4eb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0c382a8bb37b4e85acd76ee9ec735bab",
            "placeholder": "​",
            "style": "IPY_MODEL_ad8d0aef86744743aa1ab5ecccb62931",
            "value": " 42/42 [34:47&lt;00:00, 46.20s/it, N=1319, ACC%=82.87, greedy=1]"
          }
        },
        "78eb2ea591a348c5aa5d7b5b3b0eb810": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "79b071ae0d88466dab56dfad767b2039": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7f8b51a3fe4343f39fa6ddc435410a3a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_da8588237dec4ae9877dc1ebb1236e48",
            "max": 7473,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_59f2130e9c564172b17f10efa43f72d6",
            "value": 7473
          }
        },
        "7f9a014278204f52b7013a647da4184b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d626846a29154eb0a817cfe6b7e99540",
            "placeholder": "​",
            "style": "IPY_MODEL_f17d5c68b4f84ff88ac4b9c28802e949",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "82ce5afcd1ff42318c5caf3aaa03f753": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "842c349e5ad44dcb8be771f5e686a254": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8cdab187cfc44daf9070d72467195684": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_842c349e5ad44dcb8be771f5e686a254",
            "placeholder": "​",
            "style": "IPY_MODEL_5ab35da2fb244dbe8b03cee0e29f476a",
            "value": " 42/42 [34:57&lt;00:00, 49.37s/it, N=1319, ACC%=82.64, greedy=1]"
          }
        },
        "8da70f99a4134e79a215998718df41ec": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "8fb456aa3d3c40b28811423b89741a57": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e1b9e73487eb48baac1632bb4c64ffa0",
            "placeholder": "​",
            "style": "IPY_MODEL_12aea8a6e24d41e09fa5abc7e96540c5",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "9038cdab7ca249979251730657071401": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "90d1a3f56d7c47b9bbf1be269514ebbc": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "93887bc7f9aa41d6bcad100e1adc8349": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9728c65e75584e69a5f0079731ec1fd8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1b1db59d3f574c84be39ed5d1f39d907",
            "placeholder": "​",
            "style": "IPY_MODEL_f109a67864244d2aa2540ada8c5283d3",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "9729511a432840199f41bf61d206f45b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9b386e24cec74153a17b0e462001ee97",
            "placeholder": "​",
            "style": "IPY_MODEL_39e96afa265f4d2884e9fa4050ab0969",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "9a98c8273f154c61af807aab551a51f9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_90d1a3f56d7c47b9bbf1be269514ebbc",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_384e3bc506734b98a03abd4d17dde926",
            "value": 42
          }
        },
        "9b19089b3ef840a2b0df283d9ae9abef": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e550a8f4bf6e4a258ce3cdd39981e076",
            "placeholder": "​",
            "style": "IPY_MODEL_82ce5afcd1ff42318c5caf3aaa03f753",
            "value": "Evaluating phi3 greedy: 100%"
          }
        },
        "9b386e24cec74153a17b0e462001ee97": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9e2cb734212a4abe9ecf419d3e673223": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fa91884fa4bb45658e3c905c83b5a592",
            "max": 21,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_ea6a755fa09d42fe9faaa9d756d3bf7f",
            "value": 21
          }
        },
        "a0c2f6ab96f84135a094ab9f19c49138": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2df7055b183549bc83e6c4a2a83d8a6a",
            "placeholder": "​",
            "style": "IPY_MODEL_463b986f48ee4d43bc2529ef692fa42a",
            "value": " 42/42 [35:11&lt;00:00, 46.65s/it, N=1319, ACC%=82.34, greedy=1]"
          }
        },
        "a5495921684f4548aa6ac00cc26955aa": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a5d5a68661f84681ba2c0e5fed956c93": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a5f645314be841a88da3af297c88f5ee": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a7531cf6b513470da02333b1420e5c9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ad8d0aef86744743aa1ab5ecccb62931": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "af84042f6ed14f48a3942e27970c7062": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "b15072f9000c4b3c91433f7dd3f21c67": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b1d45f37aee24ebba7448decec0614ab": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f2bcb8a62be74c0dbbd06bd67b31ad54",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_40a8fb2262984e0c91e6a84fbd01de97",
            "value": 42
          }
        },
        "b3fdeedd557a4743beb7038c796ecc0f": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "b49d6f19dfe245efb788c1fbfc6e25a6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b6c0617396d44837ac2bc00893b812d1": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_39b9ffa02eb84f85a76156cbc1c8a88b",
              "IPY_MODEL_7f8b51a3fe4343f39fa6ddc435410a3a",
              "IPY_MODEL_3a3884b6b2c045c2ae33bba4c0c462b4"
            ],
            "layout": "IPY_MODEL_51096091a9594b3cbba9e1f06c439f65"
          }
        },
        "b9d4e8a9db9d41e696f73535ac327beb": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bbc5a0432af04709820be86c0d6e8713": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_182ddfeba28d4fd59d5f58d1e5bff82c",
            "placeholder": "​",
            "style": "IPY_MODEL_4e49a3b0ee974fe9bd5810803d598c78",
            "value": " 42/42 [33:55&lt;00:00, 53.16s/it, N=1319, ACC%=82.94, greedy=1]"
          }
        },
        "bdb1cbec5e874db3bffff56442f6ec3c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b9d4e8a9db9d41e696f73535ac327beb",
            "placeholder": "​",
            "style": "IPY_MODEL_170a030d0b954dd4a0597954e8bcb780",
            "value": " 42/42 [33:41&lt;00:00, 40.76s/it, N=1319, ACC%=82.26, greedy=1]"
          }
        },
        "bdde60c90cb346a2b67354d91e83fe96": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "bef65b1b05e7490797390d19f20fa404": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9729511a432840199f41bf61d206f45b",
              "IPY_MODEL_1fa46bb97549437d9671bc1107cb8890",
              "IPY_MODEL_8cdab187cfc44daf9070d72467195684"
            ],
            "layout": "IPY_MODEL_57bd29ef05214f3daa997ff96c2217aa"
          }
        },
        "bfd3da20dbea45e891614b19f66fef4e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d2ba3b9690d44e649e694c6f72d5bfa8",
            "placeholder": "​",
            "style": "IPY_MODEL_602c2e372b6c40e7b8904f1c1525c2fa",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "c253ba6221404c8b936244d468ab2774": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c4d0d0db991d42468b1ddd6f1768be8b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c253ba6221404c8b936244d468ab2774",
            "placeholder": "​",
            "style": "IPY_MODEL_05b29e18fea8439ba26726bfd832a6e2",
            "value": " 42/42 [35:11&lt;00:00, 48.73s/it, N=1319, ACC%=81.96, greedy=1]"
          }
        },
        "c55dcb589588469c93da86a5fb17c01f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c95adefa67bb4cafa341eabd32b5a7be": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_038a3a47b941448b94b77807875c89d6",
            "placeholder": "​",
            "style": "IPY_MODEL_db5d5452eb464df7b04d5269793f731e",
            "value": " 42/42 [35:08&lt;00:00, 44.72s/it, N=1319, ACC%=83.02, greedy=1]"
          }
        },
        "ce3bb42eb2d34294ad84019a80bb88d9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ce436a8911e54da09f24658596726b92": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f248b893c3ea4fa4870d38a953502385",
            "placeholder": "​",
            "style": "IPY_MODEL_b49d6f19dfe245efb788c1fbfc6e25a6",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "d0869f15627f48d89e1fff45de94aafd": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a5f645314be841a88da3af297c88f5ee",
            "placeholder": "​",
            "style": "IPY_MODEL_116dc2b6474e4266a2a3aa71365627d5",
            "value": " 21/21 [23:14&lt;00:00, 55.75s/it, N=1319, ACC%=73.69, greedy=1]"
          }
        },
        "d2ba3b9690d44e649e694c6f72d5bfa8": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d42a59e47b40495696da12aef8ceaffd": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "d51f209b03524107b7ddf87d216b3e8c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8fb456aa3d3c40b28811423b89741a57",
              "IPY_MODEL_ee65a0bd33014bbf9bce43841123173e",
              "IPY_MODEL_0125142e12494997bae874d02db91c04"
            ],
            "layout": "IPY_MODEL_9038cdab7ca249979251730657071401"
          }
        },
        "d626846a29154eb0a817cfe6b7e99540": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d68b8bc8db334d7e9464f76cbbe41365": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d7a0a9267a9b41df998216c52db22b9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d68b8bc8db334d7e9464f76cbbe41365",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_35ce343819c44025a76d6683a59069a2",
            "value": 42
          }
        },
        "da8588237dec4ae9877dc1ebb1236e48": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "dacd682ca95047248ad4c58556ef6752": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3b29cebc081f425f94cf7993561bfb0e",
            "placeholder": "​",
            "style": "IPY_MODEL_a5d5a68661f84681ba2c0e5fed956c93",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "db5d5452eb464df7b04d5269793f731e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "dcbb7b89a6564a4f96ceff62312be8ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_93887bc7f9aa41d6bcad100e1adc8349",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_af84042f6ed14f48a3942e27970c7062",
            "value": 42
          }
        },
        "e031f3547b184bd3a1088b871194b67c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "e1b9e73487eb48baac1632bb4c64ffa0": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e2dd3220008c47d8ace31cfe8595a443": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "e410aca5539d420592f6824b7ae91350": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e550a8f4bf6e4a258ce3cdd39981e076": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e70ec3fb1e0d4b1194943d5bb67cb90c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_ce436a8911e54da09f24658596726b92",
              "IPY_MODEL_9a98c8273f154c61af807aab551a51f9",
              "IPY_MODEL_c95adefa67bb4cafa341eabd32b5a7be"
            ],
            "layout": "IPY_MODEL_bdde60c90cb346a2b67354d91e83fe96"
          }
        },
        "e944ddcecc534d1a839f877ed505a5b0": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ea6a755fa09d42fe9faaa9d756d3bf7f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ecb11be7e0f445e99ea4be35fa486cdf": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ed5e3bf8692942e7ac794f8dbaceec99": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9728c65e75584e69a5f0079731ec1fd8",
              "IPY_MODEL_b1d45f37aee24ebba7448decec0614ab",
              "IPY_MODEL_044cf3f4c2314762a33d56c9294556f4"
            ],
            "layout": "IPY_MODEL_d42a59e47b40495696da12aef8ceaffd"
          }
        },
        "ee65a0bd33014bbf9bce43841123173e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e410aca5539d420592f6824b7ae91350",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_8da70f99a4134e79a215998718df41ec",
            "value": 42
          }
        },
        "f0342932dde4420c90682462b8a56396": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f109a67864244d2aa2540ada8c5283d3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f17d5c68b4f84ff88ac4b9c28802e949": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f248b893c3ea4fa4870d38a953502385": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f2bcb8a62be74c0dbbd06bd67b31ad54": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f2f87e1839314a68aed4a4b5b78c42fa": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_41db2fadc95849a3a5d8d408a07a4b24",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_02d7422302084d3ca7f138055a525be0",
            "value": 42
          }
        },
        "fa91884fa4bb45658e3c905c83b5a592": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fb7ff2e985d4495ea059ae8deb1886ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_40e14b1878eb4251bba8e6b57d4414f5",
              "IPY_MODEL_f2f87e1839314a68aed4a4b5b78c42fa",
              "IPY_MODEL_bbc5a0432af04709820be86c0d6e8713"
            ],
            "layout": "IPY_MODEL_5b81518166804f87b51ca52ed12db7af"
          }
        },
        "fdfb607e571447e992ac015475b36167": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7a5e4a148a8046ce9f9eebb891e0e997": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_f39003ffa74d483d9915bdbb9a9687a8",
              "IPY_MODEL_ded174fdf0274aefa8d4f607d0cc5d10",
              "IPY_MODEL_5ced36961b7c410eaf907f18e36ec4e9"
            ],
            "layout": "IPY_MODEL_b4e1b888783d45c2888fa275d07a26f4"
          }
        },
        "f39003ffa74d483d9915bdbb9a9687a8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_797ad36b4a6f499592e36e142c7cf56a",
            "placeholder": "​",
            "style": "IPY_MODEL_ffe990e78ecb4baa9b89bcaafd1fcd45",
            "value": "tokenizer_config.json: "
          }
        },
        "ded174fdf0274aefa8d4f607d0cc5d10": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_618604fc21de45e796c41a76dc9b377c",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_3f0627222871458ca034b76ba320d062",
            "value": 1
          }
        },
        "5ced36961b7c410eaf907f18e36ec4e9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6a0da7f66ea44a77a3ec8050d3a2e830",
            "placeholder": "​",
            "style": "IPY_MODEL_45b4e334a9b94033ae78b96791bdf4e1",
            "value": " 3.44k/? [00:00&lt;00:00, 299kB/s]"
          }
        },
        "b4e1b888783d45c2888fa275d07a26f4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "797ad36b4a6f499592e36e142c7cf56a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ffe990e78ecb4baa9b89bcaafd1fcd45": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "618604fc21de45e796c41a76dc9b377c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "3f0627222871458ca034b76ba320d062": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "6a0da7f66ea44a77a3ec8050d3a2e830": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "45b4e334a9b94033ae78b96791bdf4e1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ec08eb790dbc4a00bfaf4c9174307d5f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b119d8e2d9c94031840b6d4977d9db07",
              "IPY_MODEL_b892c27f302946d09ecabd65f2bd2b52",
              "IPY_MODEL_ebd53e2c38e7455c859f1e30d8f1a085"
            ],
            "layout": "IPY_MODEL_799b5bebcd6b4fd89a5872b3f8af1a50"
          }
        },
        "b119d8e2d9c94031840b6d4977d9db07": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2276b8a3c8c7494f988822c71035ee3f",
            "placeholder": "​",
            "style": "IPY_MODEL_7a5d280f7f04401dbadc42de3d4b110b",
            "value": "tokenizer.model: 100%"
          }
        },
        "b892c27f302946d09ecabd65f2bd2b52": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1ddd36facd51491da8c1d34a4d340a79",
            "max": 499723,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_0d19e1d1729b4e88a7fe046a3fcf907c",
            "value": 499723
          }
        },
        "ebd53e2c38e7455c859f1e30d8f1a085": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_37f2190d83b04182a80c83c891b3b445",
            "placeholder": "​",
            "style": "IPY_MODEL_4efa2599bb644aba81fc2ed14216a602",
            "value": " 500k/500k [00:01&lt;00:00, 379kB/s]"
          }
        },
        "799b5bebcd6b4fd89a5872b3f8af1a50": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2276b8a3c8c7494f988822c71035ee3f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7a5d280f7f04401dbadc42de3d4b110b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1ddd36facd51491da8c1d34a4d340a79": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0d19e1d1729b4e88a7fe046a3fcf907c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "37f2190d83b04182a80c83c891b3b445": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4efa2599bb644aba81fc2ed14216a602": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0d9f778be8994c1f9c0d20bc2d5f1f11": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_e373c7e122834faab7ff309257d3f3f1",
              "IPY_MODEL_07fe66a591e548d0a4e72dd952d129b4",
              "IPY_MODEL_0e51bea0b0d84a65bead2df6b8f4195b"
            ],
            "layout": "IPY_MODEL_187704687eec405d953b997968c540ce"
          }
        },
        "e373c7e122834faab7ff309257d3f3f1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_41a4945ba4fe4a5290f6caac50f78ead",
            "placeholder": "​",
            "style": "IPY_MODEL_7b0acbeaa67c4622aec005e0720259d7",
            "value": "tokenizer.json: "
          }
        },
        "07fe66a591e548d0a4e72dd952d129b4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ea0b144cce3f46f491d3b3a202efaa84",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_6a8695dc8b3b45539633cdbae26b2f7a",
            "value": 1
          }
        },
        "0e51bea0b0d84a65bead2df6b8f4195b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1f5ec773417c4cbda9fac915ea5dc94a",
            "placeholder": "​",
            "style": "IPY_MODEL_d45bdcfc6f5d485d8239ddc0e548983b",
            "value": " 1.94M/? [00:00&lt;00:00, 85.8MB/s]"
          }
        },
        "187704687eec405d953b997968c540ce": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "41a4945ba4fe4a5290f6caac50f78ead": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7b0acbeaa67c4622aec005e0720259d7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ea0b144cce3f46f491d3b3a202efaa84": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "6a8695dc8b3b45539633cdbae26b2f7a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "1f5ec773417c4cbda9fac915ea5dc94a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d45bdcfc6f5d485d8239ddc0e548983b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7456f00a6fea4823a1f9513fd6d07cf2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9c696098fb87434aaf25c504970a2190",
              "IPY_MODEL_bbbdc35b1d4045e696daa8cd20700ce2",
              "IPY_MODEL_0707764eb0ab418493663a026597a4e0"
            ],
            "layout": "IPY_MODEL_67f9ad5d3221419d94a6be7197ada150"
          }
        },
        "9c696098fb87434aaf25c504970a2190": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2c453740ebd74bbebe04b2138b58edea",
            "placeholder": "​",
            "style": "IPY_MODEL_6c4217ad778b4ec698adcd8922aa7881",
            "value": "added_tokens.json: 100%"
          }
        },
        "bbbdc35b1d4045e696daa8cd20700ce2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5ce9ada55af34ab1a3a13b665821d81e",
            "max": 306,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_d81d9529b42c4f76865bddb827df56f0",
            "value": 306
          }
        },
        "0707764eb0ab418493663a026597a4e0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_60e76b9f6e0840ada2adf78ac1397d27",
            "placeholder": "​",
            "style": "IPY_MODEL_2a9521b75b4f4658a4abddfa74706a04",
            "value": " 306/306 [00:00&lt;00:00, 41.8kB/s]"
          }
        },
        "67f9ad5d3221419d94a6be7197ada150": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2c453740ebd74bbebe04b2138b58edea": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6c4217ad778b4ec698adcd8922aa7881": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "5ce9ada55af34ab1a3a13b665821d81e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d81d9529b42c4f76865bddb827df56f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "60e76b9f6e0840ada2adf78ac1397d27": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2a9521b75b4f4658a4abddfa74706a04": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "67aa64c60ad24a8889e1f34a44ac4fc9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8789ab53fa334c9c8da9941f1a5c3e51",
              "IPY_MODEL_e47ba5263c9f477db5a126f589ec1333",
              "IPY_MODEL_20f653836cbc4d839b62aac4b1aa904d"
            ],
            "layout": "IPY_MODEL_26ac80c7d933450e90a8cdb682b48660"
          }
        },
        "8789ab53fa334c9c8da9941f1a5c3e51": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_93d0bff923cb4aa2af17260db5fc5aa3",
            "placeholder": "​",
            "style": "IPY_MODEL_9a3da4e17ca24fc89752879eaf0a1f1e",
            "value": "special_tokens_map.json: 100%"
          }
        },
        "e47ba5263c9f477db5a126f589ec1333": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_aaeb6d43ea2c4d0a97137ccbd5b6b56e",
            "max": 599,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_7b9fe861916b473eadc34d7791dcd8a5",
            "value": 599
          }
        },
        "20f653836cbc4d839b62aac4b1aa904d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5dd4d8669b294a0ba9fc23fd4d56ddf6",
            "placeholder": "​",
            "style": "IPY_MODEL_7db97fac4f9e496c947c039f27f9911f",
            "value": " 599/599 [00:00&lt;00:00, 75.8kB/s]"
          }
        },
        "26ac80c7d933450e90a8cdb682b48660": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "93d0bff923cb4aa2af17260db5fc5aa3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9a3da4e17ca24fc89752879eaf0a1f1e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "aaeb6d43ea2c4d0a97137ccbd5b6b56e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7b9fe861916b473eadc34d7791dcd8a5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "5dd4d8669b294a0ba9fc23fd4d56ddf6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7db97fac4f9e496c947c039f27f9911f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "24913f7d4e384e9ba15136053fd6dce2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_2cd20eccac484d57ab1df90a2911d322",
              "IPY_MODEL_f5abc702151f43249d9c6205e6a4ab3b",
              "IPY_MODEL_69a9e28b095a40bdb220b3b5c80174f6"
            ],
            "layout": "IPY_MODEL_ea4e39e1326640fa807513eee76150a8"
          }
        },
        "2cd20eccac484d57ab1df90a2911d322": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_44e3abdc7ee949bbb352e5d032069563",
            "placeholder": "​",
            "style": "IPY_MODEL_d03c0cab1f4e4bc882520a07f591dce0",
            "value": "config.json: 100%"
          }
        },
        "f5abc702151f43249d9c6205e6a4ab3b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_862d908202b947639e1b266a1052b313",
            "max": 967,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_7565c74477314ec5aafcbdec86d2fa09",
            "value": 967
          }
        },
        "69a9e28b095a40bdb220b3b5c80174f6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e2a4c231c1c94455b59512a0f3b8ab01",
            "placeholder": "​",
            "style": "IPY_MODEL_f9f211046f8449749459c61e0257ac16",
            "value": " 967/967 [00:00&lt;00:00, 119kB/s]"
          }
        },
        "ea4e39e1326640fa807513eee76150a8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "44e3abdc7ee949bbb352e5d032069563": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d03c0cab1f4e4bc882520a07f591dce0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "862d908202b947639e1b266a1052b313": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7565c74477314ec5aafcbdec86d2fa09": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e2a4c231c1c94455b59512a0f3b8ab01": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f9f211046f8449749459c61e0257ac16": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "cdbad6d0d7d74f04bef0a0bef72bc049": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_0a3a3123bd2b4defb16b8c4bbe82fe8b",
              "IPY_MODEL_96de9938e80f436982115b248107f388",
              "IPY_MODEL_9a392cb3bb2f40a5b554ae29e366405c"
            ],
            "layout": "IPY_MODEL_afd3fec1be764e66ab6231552f9fcede"
          }
        },
        "0a3a3123bd2b4defb16b8c4bbe82fe8b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_19c68f9546764c66b7dc74ee4ffd8c1d",
            "placeholder": "​",
            "style": "IPY_MODEL_76ac6c1bef1f45a8ad24703744e8c0b2",
            "value": "model.safetensors.index.json: "
          }
        },
        "96de9938e80f436982115b248107f388": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8c0dad254951417d8fb28165ad139143",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_38291db55bc746bd8114e0bf5ed213a1",
            "value": 1
          }
        },
        "9a392cb3bb2f40a5b554ae29e366405c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_736cf06ab906427ca04087e3180f7930",
            "placeholder": "​",
            "style": "IPY_MODEL_109f3593f3864fea934c60edbe46b70d",
            "value": " 16.5k/? [00:00&lt;00:00, 1.72MB/s]"
          }
        },
        "afd3fec1be764e66ab6231552f9fcede": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "19c68f9546764c66b7dc74ee4ffd8c1d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "76ac6c1bef1f45a8ad24703744e8c0b2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8c0dad254951417d8fb28165ad139143": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "38291db55bc746bd8114e0bf5ed213a1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "736cf06ab906427ca04087e3180f7930": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "109f3593f3864fea934c60edbe46b70d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "01e5c0e01f384a29a3f98c7fefc0dded": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_23562b6755dc498ca3423bf68b77b65a",
              "IPY_MODEL_ba863d5535c047b99c7d26d066555f0b",
              "IPY_MODEL_8d5a36bdf6c049eab68fdd2567095d64"
            ],
            "layout": "IPY_MODEL_d4a968ea181f43688bbf975326484e2a"
          }
        },
        "23562b6755dc498ca3423bf68b77b65a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ffe0a6b2edd8486fa9b0b93dc916d0a1",
            "placeholder": "​",
            "style": "IPY_MODEL_a39590bba0d0492abcf414490763f559",
            "value": "Fetching 2 files: 100%"
          }
        },
        "ba863d5535c047b99c7d26d066555f0b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_807c6777747b4ada902120426de8a611",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f14f6c98b4884884bbfedb4aa4200123",
            "value": 2
          }
        },
        "8d5a36bdf6c049eab68fdd2567095d64": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4694232e149c428aae43a1ec9b79f2ee",
            "placeholder": "​",
            "style": "IPY_MODEL_36db6fbb53e34ca7bd71bdd22deb8fc3",
            "value": " 2/2 [00:19&lt;00:00, 19.01s/it]"
          }
        },
        "d4a968ea181f43688bbf975326484e2a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ffe0a6b2edd8486fa9b0b93dc916d0a1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a39590bba0d0492abcf414490763f559": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "807c6777747b4ada902120426de8a611": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f14f6c98b4884884bbfedb4aa4200123": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "4694232e149c428aae43a1ec9b79f2ee": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "36db6fbb53e34ca7bd71bdd22deb8fc3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "3f5a11b214724194a8fa39bb355e2bbe": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_2f43805da2ee43c38f7276896125e455",
              "IPY_MODEL_7f14f4d8df694f30aa71af895b1155af",
              "IPY_MODEL_510b3441845a4b9593229e10ff1ddc26"
            ],
            "layout": "IPY_MODEL_b8b08af66b2f448b95e3da143758c90b"
          }
        },
        "2f43805da2ee43c38f7276896125e455": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b00f0afb302a43119aa20912dbbd37b3",
            "placeholder": "​",
            "style": "IPY_MODEL_48dec881fbc14ed5ba3d1c16a7dd47c0",
            "value": "model-00001-of-00002.safetensors: 100%"
          }
        },
        "7f14f4d8df694f30aa71af895b1155af": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_666ece6cceed4a709fbf2f7fe226e877",
            "max": 4972489328,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_4600ae8cc09b4cce8ab03d6fc97eca8f",
            "value": 4972489328
          }
        },
        "510b3441845a4b9593229e10ff1ddc26": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d74d9a726d7249e3b0de95057d3f36cc",
            "placeholder": "​",
            "style": "IPY_MODEL_bb93c0459aa242f49f3cd7e1ae734ee0",
            "value": " 4.97G/4.97G [00:18&lt;00:00, 380MB/s]"
          }
        },
        "b8b08af66b2f448b95e3da143758c90b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b00f0afb302a43119aa20912dbbd37b3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "48dec881fbc14ed5ba3d1c16a7dd47c0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "666ece6cceed4a709fbf2f7fe226e877": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4600ae8cc09b4cce8ab03d6fc97eca8f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d74d9a726d7249e3b0de95057d3f36cc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bb93c0459aa242f49f3cd7e1ae734ee0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9a4a5417e4cc45ca8fa8f07e76fde0a8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_fbc22ca925314cf2b59c285a2db85ca2",
              "IPY_MODEL_e288c2fd0ee24608abffdba553045a7e",
              "IPY_MODEL_43d2252a5e4f4622ac362770258ed7ba"
            ],
            "layout": "IPY_MODEL_a2a6de3dd62b4a6fbfd291b2ca4e8730"
          }
        },
        "fbc22ca925314cf2b59c285a2db85ca2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_713c5923a6b3410b8672ecfe38e66f29",
            "placeholder": "​",
            "style": "IPY_MODEL_ea8d320b8213469cb81460d0fbeee3d8",
            "value": "model-00002-of-00002.safetensors: 100%"
          }
        },
        "e288c2fd0ee24608abffdba553045a7e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_73447fc92e084749a7581a67b2bbc336",
            "max": 2669692552,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_e6c4157e89c7479a9148ccb7de462366",
            "value": 2669692552
          }
        },
        "43d2252a5e4f4622ac362770258ed7ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7656329c18ea4c649be0a4daa666b4f0",
            "placeholder": "​",
            "style": "IPY_MODEL_4cccd262bcfa4d72950c433cf579e17f",
            "value": " 2.67G/2.67G [00:17&lt;00:00, 303MB/s]"
          }
        },
        "a2a6de3dd62b4a6fbfd291b2ca4e8730": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "713c5923a6b3410b8672ecfe38e66f29": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ea8d320b8213469cb81460d0fbeee3d8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "73447fc92e084749a7581a67b2bbc336": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e6c4157e89c7479a9148ccb7de462366": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7656329c18ea4c649be0a4daa666b4f0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4cccd262bcfa4d72950c433cf579e17f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ae3cc799469b4bb9b65702e3bb81311e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9ed53200c8e049bfa57c037d79ebaf2f",
              "IPY_MODEL_1e2bbf3a5d37489784b68fe9239ceed8",
              "IPY_MODEL_1c39bab6d8354d80a21d4b8e10a34258"
            ],
            "layout": "IPY_MODEL_230d4ede8bdf436e8b94e62c0ed1f1d3"
          }
        },
        "9ed53200c8e049bfa57c037d79ebaf2f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e6f2ad9ac925407eb4cab5d383e814ab",
            "placeholder": "​",
            "style": "IPY_MODEL_1dfce580b8ab401398bb1319d8e816ca",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "1e2bbf3a5d37489784b68fe9239ceed8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_15c68305dffb4d0c9a2deb5d1730c04a",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_3450aea3fc904dbe831388c60c1dffb8",
            "value": 2
          }
        },
        "1c39bab6d8354d80a21d4b8e10a34258": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9ef9e3aa3251406498565affe909b1c7",
            "placeholder": "​",
            "style": "IPY_MODEL_d1db823bd4544a6aa718f8fb413cdf85",
            "value": " 2/2 [00:08&lt;00:00,  4.03s/it]"
          }
        },
        "230d4ede8bdf436e8b94e62c0ed1f1d3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e6f2ad9ac925407eb4cab5d383e814ab": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1dfce580b8ab401398bb1319d8e816ca": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "15c68305dffb4d0c9a2deb5d1730c04a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3450aea3fc904dbe831388c60c1dffb8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "9ef9e3aa3251406498565affe909b1c7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d1db823bd4544a6aa718f8fb413cdf85": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "17714a5200ad439689748a4b8a9e73a6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_07c4803e23194b1abf9d8cccf8b68d0f",
              "IPY_MODEL_a230326aab3e4502afa4027a27a1bf42",
              "IPY_MODEL_0cfa1d8cde204151aad13b1af64925a7"
            ],
            "layout": "IPY_MODEL_6786981118b24cf88c6954116da51911"
          }
        },
        "07c4803e23194b1abf9d8cccf8b68d0f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6c98c97bfd0a4d97ac65839a3db66c3d",
            "placeholder": "​",
            "style": "IPY_MODEL_056fe12e14094451a2f9f10745031797",
            "value": "generation_config.json: 100%"
          }
        },
        "a230326aab3e4502afa4027a27a1bf42": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8bfa751712744c409598ee69280bdeb0",
            "max": 181,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_26f555af847b4aa89bb4b012d7fb0747",
            "value": 181
          }
        },
        "0cfa1d8cde204151aad13b1af64925a7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4c36dd3013414f35a66a52f133c5b5d6",
            "placeholder": "​",
            "style": "IPY_MODEL_9c58a1b03fef4bcd9b00bda9e1f0741b",
            "value": " 181/181 [00:00&lt;00:00, 25.3kB/s]"
          }
        },
        "6786981118b24cf88c6954116da51911": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6c98c97bfd0a4d97ac65839a3db66c3d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "056fe12e14094451a2f9f10745031797": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8bfa751712744c409598ee69280bdeb0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "26f555af847b4aa89bb4b012d7fb0747": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "4c36dd3013414f35a66a52f133c5b5d6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9c58a1b03fef4bcd9b00bda9e1f0741b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "daa54c05faeb42f8b108905ae30eb5bb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a3530137220b42f4a5b40b12624580ba",
              "IPY_MODEL_a7838f19c8714080b81480dfc3f2d243",
              "IPY_MODEL_f258d048c22944588894c61b61fecc9f"
            ],
            "layout": "IPY_MODEL_5c58d0fe6fb34b7d8f05f0c1ceadb35c"
          }
        },
        "a3530137220b42f4a5b40b12624580ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1014f966098e4c518c830f756c929d65",
            "placeholder": "​",
            "style": "IPY_MODEL_82f9f1c61ab3419ba812e25c6c0ce6f6",
            "value": "README.md: "
          }
        },
        "a7838f19c8714080b81480dfc3f2d243": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3dfaf9bfe0604969b7a3677f73b753b7",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b185c9f868fc419bb8f1dc1c0d44a465",
            "value": 1
          }
        },
        "f258d048c22944588894c61b61fecc9f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ce4b11394235491f9c6f073011a8ba94",
            "placeholder": "​",
            "style": "IPY_MODEL_227cf2d26b2f42d8a7acbb6e5547b728",
            "value": " 7.93k/? [00:00&lt;00:00, 810kB/s]"
          }
        },
        "5c58d0fe6fb34b7d8f05f0c1ceadb35c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1014f966098e4c518c830f756c929d65": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "82f9f1c61ab3419ba812e25c6c0ce6f6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "3dfaf9bfe0604969b7a3677f73b753b7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "b185c9f868fc419bb8f1dc1c0d44a465": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ce4b11394235491f9c6f073011a8ba94": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "227cf2d26b2f42d8a7acbb6e5547b728": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "072017f5299d4c85a4f54fbe34ae344f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_056cded851da46c686fcf32e109d8b94",
              "IPY_MODEL_4904aa32cf67436e8d54d37dea0f0762",
              "IPY_MODEL_dcc2d22e0aea4d18aac0b60c5d325fc2"
            ],
            "layout": "IPY_MODEL_55eed6957fec4dce8da7cf8adaf48597"
          }
        },
        "056cded851da46c686fcf32e109d8b94": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_092624ebb29f436894e4396f03953e86",
            "placeholder": "​",
            "style": "IPY_MODEL_b727e50e2d1b4d21a05164530f7053f1",
            "value": "main/train-00000-of-00001.parquet: 100%"
          }
        },
        "4904aa32cf67436e8d54d37dea0f0762": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b980c155b06c49adad63fe83fb36a373",
            "max": 2306545,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_94afaa6eecbd480cb6b149dff4e4f80c",
            "value": 2306545
          }
        },
        "dcc2d22e0aea4d18aac0b60c5d325fc2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8511a603a53543dab3f567a9c172b652",
            "placeholder": "​",
            "style": "IPY_MODEL_3485568498df43de8ee14437f61b3cd6",
            "value": " 2.31M/2.31M [00:00&lt;00:00, 4.06MB/s]"
          }
        },
        "55eed6957fec4dce8da7cf8adaf48597": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "092624ebb29f436894e4396f03953e86": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b727e50e2d1b4d21a05164530f7053f1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b980c155b06c49adad63fe83fb36a373": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "94afaa6eecbd480cb6b149dff4e4f80c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "8511a603a53543dab3f567a9c172b652": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3485568498df43de8ee14437f61b3cd6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8ab15221c29c4b26a72f791d2f16fa41": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_2e4112f36f7a4863b34423490f4d9139",
              "IPY_MODEL_3a6b4d6a40644ac086a090d7a046a9b1",
              "IPY_MODEL_712db2e1b43446e29131d92dbc7b1a81"
            ],
            "layout": "IPY_MODEL_b756871eb62d4f7bbd098ee10a922a39"
          }
        },
        "2e4112f36f7a4863b34423490f4d9139": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5c7f29b0d5e842b18fb868a6e25375b6",
            "placeholder": "​",
            "style": "IPY_MODEL_0fa62fd72be942f9976482d33c127229",
            "value": "main/test-00000-of-00001.parquet: 100%"
          }
        },
        "3a6b4d6a40644ac086a090d7a046a9b1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f4a02ee3abd8429fafd72fa5cd469e4d",
            "max": 419088,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_9c2e944601b641d5bbbd8d17818530bb",
            "value": 419088
          }
        },
        "712db2e1b43446e29131d92dbc7b1a81": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_21465763a3f84881be0d58c5b48db247",
            "placeholder": "​",
            "style": "IPY_MODEL_b085867817fd451dad0f099c0d99f2b2",
            "value": " 419k/419k [00:00&lt;00:00, 802kB/s]"
          }
        },
        "b756871eb62d4f7bbd098ee10a922a39": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5c7f29b0d5e842b18fb868a6e25375b6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0fa62fd72be942f9976482d33c127229": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f4a02ee3abd8429fafd72fa5cd469e4d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9c2e944601b641d5bbbd8d17818530bb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "21465763a3f84881be0d58c5b48db247": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b085867817fd451dad0f099c0d99f2b2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "cda513e07ffc47ffbe6d39b7db49a833": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_ae53cb72ec024ebf8d0358bc2cc77fdb",
              "IPY_MODEL_d6de9cd36dea474a9fd61076e797dc50",
              "IPY_MODEL_63de55d0846b4237969336a4f5355712"
            ],
            "layout": "IPY_MODEL_cc1ae692118b421ab11c7194a95b048f"
          }
        },
        "ae53cb72ec024ebf8d0358bc2cc77fdb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_76903d1a10a249d785a889f0be761858",
            "placeholder": "​",
            "style": "IPY_MODEL_3d6a5e6a18e14113aa8512420d38045a",
            "value": "Generating train split: 100%"
          }
        },
        "d6de9cd36dea474a9fd61076e797dc50": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f880abbd4444494c9c41120297951364",
            "max": 7473,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c4159664e6bd439e8c3a03e80a14f122",
            "value": 7473
          }
        },
        "63de55d0846b4237969336a4f5355712": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_de0b9490927c4524b0a544f664779eb8",
            "placeholder": "​",
            "style": "IPY_MODEL_38b494b59b8a4c459f726c96d0b36390",
            "value": " 7473/7473 [00:00&lt;00:00, 265724.24 examples/s]"
          }
        },
        "cc1ae692118b421ab11c7194a95b048f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "76903d1a10a249d785a889f0be761858": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3d6a5e6a18e14113aa8512420d38045a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f880abbd4444494c9c41120297951364": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c4159664e6bd439e8c3a03e80a14f122": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "de0b9490927c4524b0a544f664779eb8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "38b494b59b8a4c459f726c96d0b36390": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c2beba96514345648cf43cb638583d52": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_41501aac60d644a78c75863d3e2fa4c4",
              "IPY_MODEL_138698b26b38420fb43a18c99f2a7a64",
              "IPY_MODEL_05277fc8241f41ffbc57eab74ccc82c7"
            ],
            "layout": "IPY_MODEL_41adc3e56fc94969a27a87556c279c0f"
          }
        },
        "41501aac60d644a78c75863d3e2fa4c4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_cf494d2c2fdd46088506136d6bdf677c",
            "placeholder": "​",
            "style": "IPY_MODEL_dd07772ab98342d5bfb7d770c81a8ebf",
            "value": "Generating test split: 100%"
          }
        },
        "138698b26b38420fb43a18c99f2a7a64": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0e16c268f4ef47e482ad5fa9ebc9185b",
            "max": 1319,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2926d30ad85742e29a6f9d23df909c59",
            "value": 1319
          }
        },
        "05277fc8241f41ffbc57eab74ccc82c7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_123d4849dffa453b9f16447920072e1e",
            "placeholder": "​",
            "style": "IPY_MODEL_621fca71bdab4fe9897429a325760829",
            "value": " 1319/1319 [00:00&lt;00:00, 94895.06 examples/s]"
          }
        },
        "41adc3e56fc94969a27a87556c279c0f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cf494d2c2fdd46088506136d6bdf677c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "dd07772ab98342d5bfb7d770c81a8ebf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0e16c268f4ef47e482ad5fa9ebc9185b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2926d30ad85742e29a6f9d23df909c59": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "123d4849dffa453b9f16447920072e1e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "621fca71bdab4fe9897429a325760829": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "32f6799a1fc64819a03c877739231d87": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a1c96f4de2034d968f9b5753b0a40a29",
              "IPY_MODEL_c6f77082b18a4e81a3a62afe975eb13f",
              "IPY_MODEL_2bac1d5c35be4a7ebcbe73630cf793ba"
            ],
            "layout": "IPY_MODEL_ccf52167ed7943e683dc8867b667d04b"
          }
        },
        "a1c96f4de2034d968f9b5753b0a40a29": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3381deb1bc12409d95b206282c52a4fa",
            "placeholder": "​",
            "style": "IPY_MODEL_ac01321b67174b19a5729e55214bbfd3",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "c6f77082b18a4e81a3a62afe975eb13f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_49a218f81ccd4c21b21b06e70fb7437a",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_8f27b215d4644bb987cacdacfdff300d",
            "value": 42
          }
        },
        "2bac1d5c35be4a7ebcbe73630cf793ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bd1867fce7ec40f5ab51987c6a15dc01",
            "placeholder": "​",
            "style": "IPY_MODEL_76c9e6ab15ff4a42b491d83a86dd39f1",
            "value": " 42/42 [34:52&lt;00:00, 49.20s/it, N=1319, ACC%=79.83, greedy=1]"
          }
        },
        "ccf52167ed7943e683dc8867b667d04b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "3381deb1bc12409d95b206282c52a4fa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ac01321b67174b19a5729e55214bbfd3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "49a218f81ccd4c21b21b06e70fb7437a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8f27b215d4644bb987cacdacfdff300d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "bd1867fce7ec40f5ab51987c6a15dc01": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "76c9e6ab15ff4a42b491d83a86dd39f1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "05c8b40c0b824ff78c29a88acc28def0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_06ba9fd853cc41cb9a3da9f2a468bcbc",
              "IPY_MODEL_ce21a30b7c05419fb9dfa60a03b860f0",
              "IPY_MODEL_462cff19cfe245d5bdea79800e31fbb9"
            ],
            "layout": "IPY_MODEL_43e6ca1b275644d38cf635e85c707e37"
          }
        },
        "06ba9fd853cc41cb9a3da9f2a468bcbc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6c3f843e18dd490e9fdd6f803c561c88",
            "placeholder": "​",
            "style": "IPY_MODEL_75642ef6991447988a0fdb1a42525a54",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "ce21a30b7c05419fb9dfa60a03b860f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_20fba02e7e904962ab20895229def0b9",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f8470899b99440a0bb6b16c02a83228a",
            "value": 42
          }
        },
        "462cff19cfe245d5bdea79800e31fbb9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_87b9c3fdb66f45e5995514f4dc2a8484",
            "placeholder": "​",
            "style": "IPY_MODEL_e131acb0fcb34a1b8272f4d54735c7f9",
            "value": " 42/42 [33:42&lt;00:00, 44.02s/it, N=1319, ACC%=81.35, greedy=1]"
          }
        },
        "43e6ca1b275644d38cf635e85c707e37": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "6c3f843e18dd490e9fdd6f803c561c88": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "75642ef6991447988a0fdb1a42525a54": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "20fba02e7e904962ab20895229def0b9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f8470899b99440a0bb6b16c02a83228a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "87b9c3fdb66f45e5995514f4dc2a8484": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e131acb0fcb34a1b8272f4d54735c7f9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "3e988fa7173f43deb3b983b18f3fa43a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a8786cfd96e44ac8b30d746df82423dd",
              "IPY_MODEL_cdbc5c0efd43440382ac0a344b8596e1",
              "IPY_MODEL_851469eda17447f7a21d3d6f54ad955b"
            ],
            "layout": "IPY_MODEL_3ba0bdcbbb7b4baba0f52e3b895dfc0a"
          }
        },
        "a8786cfd96e44ac8b30d746df82423dd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a832547670bd4e98a5fea376b41474fa",
            "placeholder": "​",
            "style": "IPY_MODEL_d9675716381646f083c45dc38b205566",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "cdbc5c0efd43440382ac0a344b8596e1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_09d8b677e54c4690882fc5135746b551",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_cfe47daa7dd3457880191c9c884fd9d1",
            "value": 42
          }
        },
        "851469eda17447f7a21d3d6f54ad955b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6e3eb8f53b3648a4a34cf0c9304cd487",
            "placeholder": "​",
            "style": "IPY_MODEL_59afc93f3c394424b0b40d96363293d1",
            "value": " 42/42 [34:19&lt;00:00, 48.88s/it, N=1319, ACC%=81.27, greedy=1]"
          }
        },
        "3ba0bdcbbb7b4baba0f52e3b895dfc0a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "a832547670bd4e98a5fea376b41474fa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d9675716381646f083c45dc38b205566": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "09d8b677e54c4690882fc5135746b551": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cfe47daa7dd3457880191c9c884fd9d1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "6e3eb8f53b3648a4a34cf0c9304cd487": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "59afc93f3c394424b0b40d96363293d1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b578aef376294c2fbdd44ea77221dc49": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_ac505e13948d4b889852f45b1e59abbf",
              "IPY_MODEL_6dc843cb39f74cb68bf3404581d52302",
              "IPY_MODEL_ef13d5aabbd74bc9bf7d1ef00573d318"
            ],
            "layout": "IPY_MODEL_b4c526ed80ac473296e9466c3b1de900"
          }
        },
        "ac505e13948d4b889852f45b1e59abbf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_15c860f3f7b84a648b965d46f8089f02",
            "placeholder": "​",
            "style": "IPY_MODEL_3583c46056cc4bb1aa80b54a150d6f49",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "6dc843cb39f74cb68bf3404581d52302": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_31c28f0dd70b40e9b9c58e38e7212702",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_3c806eead3e54684aa9b4241282a5a4e",
            "value": 42
          }
        },
        "ef13d5aabbd74bc9bf7d1ef00573d318": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_da39f0035da542bd8411d9af0214897a",
            "placeholder": "​",
            "style": "IPY_MODEL_02651e4f6390479f830f7c386c5f3760",
            "value": " 42/42 [34:50&lt;00:00, 46.55s/it, N=1319, ACC%=81.58, greedy=1]"
          }
        },
        "b4c526ed80ac473296e9466c3b1de900": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "15c860f3f7b84a648b965d46f8089f02": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3583c46056cc4bb1aa80b54a150d6f49": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "31c28f0dd70b40e9b9c58e38e7212702": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3c806eead3e54684aa9b4241282a5a4e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "da39f0035da542bd8411d9af0214897a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "02651e4f6390479f830f7c386c5f3760": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "52823df7480e4fbe9145261be21d3078": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_0af9121c170646b4bd8164825bf8c066",
              "IPY_MODEL_7daa4689dcc444608b14b508a6cb2ed0",
              "IPY_MODEL_57bffe25f3ea446ab840f0da486c45b6"
            ],
            "layout": "IPY_MODEL_72325f3267264cba8567e2305d99bb02"
          }
        },
        "0af9121c170646b4bd8164825bf8c066": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3c0c6563a42042efa63655d705966d7e",
            "placeholder": "​",
            "style": "IPY_MODEL_0d00ec98f31a42b8b4e5eed967075949",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "7daa4689dcc444608b14b508a6cb2ed0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5312c593f45841febf846af3abf7dd0b",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_cc056289a37e479ca61b8c71c26750dc",
            "value": 42
          }
        },
        "57bffe25f3ea446ab840f0da486c45b6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ff2117536ec04841bb2275b63124ec68",
            "placeholder": "​",
            "style": "IPY_MODEL_16700eb57b07426f9483715d2d15ab00",
            "value": " 42/42 [35:40&lt;00:00, 47.22s/it, N=1319, ACC%=81.50, greedy=1]"
          }
        },
        "72325f3267264cba8567e2305d99bb02": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "3c0c6563a42042efa63655d705966d7e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0d00ec98f31a42b8b4e5eed967075949": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "5312c593f45841febf846af3abf7dd0b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cc056289a37e479ca61b8c71c26750dc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ff2117536ec04841bb2275b63124ec68": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "16700eb57b07426f9483715d2d15ab00": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e5492db54a1440fa82a8efcd1d8d9c31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9efb508c93a34fa8925c8def2f62c91f",
              "IPY_MODEL_508dbe13f9de4f2d94c53f8a731dcfe5",
              "IPY_MODEL_d3ff6d9ec7ad4bd081c2be962a328891"
            ],
            "layout": "IPY_MODEL_23891f2226224121912cdde1f515a403"
          }
        },
        "9efb508c93a34fa8925c8def2f62c91f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_431188b0e5c647fa91ff0074412f16f8",
            "placeholder": "​",
            "style": "IPY_MODEL_cf5a0838d81e489ab7b835374cba9d90",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "508dbe13f9de4f2d94c53f8a731dcfe5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_07cdbbac3e624cfea4bd7b8c62f905a2",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_281bda6512684a4fac0e44c6e2595a6c",
            "value": 42
          }
        },
        "d3ff6d9ec7ad4bd081c2be962a328891": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_10b2adc8db8d4329b6e639ae4fd7cfd5",
            "placeholder": "​",
            "style": "IPY_MODEL_f508cc9928f64596951cfa8db6fa7e6c",
            "value": " 42/42 [34:43&lt;00:00, 49.22s/it, N=1319, ACC%=81.73, greedy=1]"
          }
        },
        "23891f2226224121912cdde1f515a403": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "431188b0e5c647fa91ff0074412f16f8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cf5a0838d81e489ab7b835374cba9d90": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "07cdbbac3e624cfea4bd7b8c62f905a2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "281bda6512684a4fac0e44c6e2595a6c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "10b2adc8db8d4329b6e639ae4fd7cfd5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f508cc9928f64596951cfa8db6fa7e6c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e5178faae6c2443dade448218fe426a0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_49b2ac9455524a20a6f401d09dccbb16",
              "IPY_MODEL_034749bdd790473b833e84d6736d372f",
              "IPY_MODEL_af68c43396554c76a342b1e654470297"
            ],
            "layout": "IPY_MODEL_d57f38d5576b465e9ba982c974c414e5"
          }
        },
        "49b2ac9455524a20a6f401d09dccbb16": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3625f52ccc124938aa38fa846f2603ed",
            "placeholder": "​",
            "style": "IPY_MODEL_450ed427db1f488ba2a4b3adbae3e37a",
            "value": "Evaluating trained_grpo: 100%"
          }
        },
        "034749bdd790473b833e84d6736d372f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2302fa98d451435dad4005f38a2243d1",
            "max": 42,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_dbac9ac62a4f4e49a2e07d309c4762f8",
            "value": 42
          }
        },
        "af68c43396554c76a342b1e654470297": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_75df96751cda43cc989d2e6bd6ffec22",
            "placeholder": "​",
            "style": "IPY_MODEL_3a28131ebfc844d2be2325cc1111a1c8",
            "value": " 42/42 [35:12&lt;00:00, 44.83s/it, N=1319, ACC%=81.50, greedy=1]"
          }
        },
        "d57f38d5576b465e9ba982c974c414e5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "3625f52ccc124938aa38fa846f2603ed": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "450ed427db1f488ba2a4b3adbae3e37a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2302fa98d451435dad4005f38a2243d1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "dbac9ac62a4f4e49a2e07d309c4762f8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "75df96751cda43cc989d2e6bd6ffec22": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3a28131ebfc844d2be2325cc1111a1c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}