{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6r2zXuvOThnh",
        "outputId": "955889ba-68f1-49c3-b527-36fdf6c08d94"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (4.56.1)\n",
            "Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (1.10.1)\n",
            "Requirement already satisfied: rouge-score in /usr/local/lib/python3.12/dist-packages (0.1.2)\n",
            "Requirement already satisfied: nltk in /usr/local/lib/python3.12/dist-packages (3.9.1)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers) (3.19.1)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.34.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.34.4)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2.0.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.2)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2024.11.6)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers) (2.32.4)\n",
            "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.0)\n",
            "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.6.2)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate) (5.9.5)\n",
            "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from accelerate) (2.8.0+cu126)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.12/dist-packages (from rouge-score) (1.4.0)\n",
            "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.12/dist-packages (from rouge-score) (1.17.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from nltk) (8.2.1)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from nltk) (1.5.2)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (2025.3.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (1.1.9)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (75.2.0)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.13.3)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.5)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.1.6)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.6.80)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (9.10.2.21)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.6.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (10.3.7.77)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.5.4.2)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (0.7.1)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (2.27.3)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.6.77)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.6.85)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.11.1.6)\n",
            "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.4.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (3.4.3)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers) (2025.8.3)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.2)\n"
          ]
        }
      ],
      "source": [
        "!pip install transformers accelerate rouge-score nltk tqdm\n",
        "\n",
        "import os, torch, json, time, math, argparse, pathlib\n",
        "import torch.nn.functional as F\n",
        "from tqdm import tqdm\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from transformers import StoppingCriteria, StoppingCriteriaList\n",
        "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
        "from huggingface_hub import login\n",
        "\n",
        "torch.manual_seed(0)\n",
        "\n",
        "# === Config ===\n",
        "# NOTE: For Colab, pick a smaller model unless you have Colab Pro+ with A100.\n",
        "MODEL_ID = \"Qwen/Qwen2-7B\"\n",
        "# MODEL_ID = \"sshleifer/tiny-gpt2\"          # debug / demonstration model\n",
        "\n",
        "OUT_BASE  = \"/content/out\"\n",
        "os.makedirs(OUT_BASE, exist_ok=True)\n",
        "\n",
        "SHORT_LEN = 32\n",
        "NUCLEUS_P = 0.9\n",
        "NUM_SAMPLES = 5    # keep small for Colab demo\n",
        "MAX_NEW_TOKENS = 32\n",
        "DISTANCE_METRIC = \"jsd\"  # \"jsd\" or \"tvd\"\n",
        "num_examples = 100\n",
        "\n",
        "\n",
        "longctx_detection_thresh=0.12\n",
        "lambda_boost=4.0\n",
        "epsilon=0.05\n",
        "\n",
        "from google.colab import files\n",
        "\n",
        "# uploaded = files.upload() # Uncomment this one when uploading the file\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SJ4ezxq1asfv"
      },
      "outputs": [],
      "source": [
        "import os, json\n",
        "from glob import glob\n",
        "from collections import defaultdict\n",
        "import pandas as pd\n",
        "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
        "from rouge_score import rouge_scorer\n",
        "\n",
        "# --- metric helpers ---\n",
        "def compute_f1(pred, gold):\n",
        "    pred_tokens = pred.lower().split()\n",
        "    gold_tokens = gold.lower().split()\n",
        "    common = set(pred_tokens) & set(gold_tokens)\n",
        "    if not pred_tokens or not gold_tokens:\n",
        "        return 0.0\n",
        "    prec = len(common) / len(pred_tokens)\n",
        "    rec = len(common) / len(gold_tokens)\n",
        "    return 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0.0\n",
        "\n",
        "def compute_bleu(pred, gold):\n",
        "    ref = [gold.split()]\n",
        "    hyp = pred.split()\n",
        "    if not hyp or not ref[0]:\n",
        "        return 0.0\n",
        "    return sentence_bleu(ref, hyp, smoothing_function=SmoothingFunction().method1)\n",
        "\n",
        "def compute_rougeL(pred, gold):\n",
        "    scorer = rouge_scorer.RougeScorer([\"rougeL\"], use_stemmer=True)\n",
        "    return scorer.score(gold, pred)[\"rougeL\"].fmeasure\n",
        "\n",
        "def safe_metric(pred, gold):\n",
        "    if not pred or not gold:\n",
        "        return {}\n",
        "    return {\n",
        "        \"f1\": compute_f1(pred, gold),\n",
        "        \"bleu\": compute_bleu(pred, gold),\n",
        "        \"rougeL\": compute_rougeL(pred, gold),\n",
        "    }\n",
        "\n",
        "\n",
        "def evaluate_partial(buffer_json):\n",
        "    \"\"\"\n",
        "    Evaluate metrics on a buffer of JSONL-formatted experiment outputs.\n",
        "\n",
        "    Args:\n",
        "        buffer_json: list[str] where each item is a JSON line (like what you save to shard files).\n",
        "\n",
        "    Returns:\n",
        "        DataFrame with average (over all generations) and best-per-example scores per case.\n",
        "    \"\"\"\n",
        "    case_scores_all = defaultdict(list)   # all scores (for avg over all generations)\n",
        "    case_scores_best = defaultdict(list)  # best score per example\n",
        "\n",
        "    for line in buffer_json:\n",
        "        ex = json.loads(line)\n",
        "        gold = ex.get(\"answer\", \"\").strip()\n",
        "        if not gold:\n",
        "            continue\n",
        "\n",
        "        # --- collect per-example scores ---\n",
        "        per_case_scores = {c: {\"f1\": [], \"bleu\": [], \"rougeL\": []}\n",
        "                           for c in [\"case1\", \"case2\", \"case3\", \"case4\"]}\n",
        "\n",
        "        # Case 1\n",
        "        for ans in ex.get(\"generated_answers_case1\", []):\n",
        "            m = safe_metric(ans, gold)\n",
        "            for k, v in m.items():\n",
        "                case_scores_all[f\"case1_{k}\"].append(v)\n",
        "                per_case_scores[\"case1\"][k].append(v)\n",
        "\n",
        "        # Case 2\n",
        "        case2 = ex.get(\"generated_answer_case2\", []) or ex.get(\"generated_answers_case2\", [])\n",
        "        if isinstance(case2, str):\n",
        "            case2 = [case2]\n",
        "        for ans in case2:\n",
        "            m = safe_metric(ans, gold)\n",
        "            for k, v in m.items():\n",
        "                case_scores_all[f\"case2_{k}\"].append(v)\n",
        "                per_case_scores[\"case2\"][k].append(v)\n",
        "\n",
        "        # Case 3\n",
        "        for ans in ex.get(\"generated_answers_case3\", []):\n",
        "            m = safe_metric(ans, gold)\n",
        "            for k, v in m.items():\n",
        "                case_scores_all[f\"case3_{k}\"].append(v)\n",
        "                per_case_scores[\"case3\"][k].append(v)\n",
        "\n",
        "        # Case 4\n",
        "        for ans in ex.get(\"generated_answers_case4\", []):\n",
        "            m = safe_metric(ans, gold)\n",
        "            for k, v in m.items():\n",
        "                case_scores_all[f\"case4_{k}\"].append(v)\n",
        "                per_case_scores[\"case4\"][k].append(v)\n",
        "\n",
        "        # --- log the *best* score per case for this example ---\n",
        "        for case in per_case_scores:\n",
        "            for k, vals in per_case_scores[case].items():\n",
        "                if vals:  # non-empty\n",
        "                    case_scores_best[f\"{case}_{k}\"].append(max(vals))\n",
        "\n",
        "    # --- compute averages ---\n",
        "    results = []\n",
        "    for case in [\"case1\", \"case2\", \"case3\", \"case4\"]:\n",
        "        results.append({\n",
        "            \"case\": case,\n",
        "            \"F1_avg\": (sum(case_scores_all.get(f\"{case}_f1\", [])) /\n",
        "                       len(case_scores_all.get(f\"{case}_f1\", []))) if case_scores_all.get(f\"{case}_f1\") else 0.0,\n",
        "            \"F1_best\": (sum(case_scores_best.get(f\"{case}_f1\", [])) /\n",
        "                        len(case_scores_best.get(f\"{case}_f1\", []))) if case_scores_best.get(f\"{case}_f1\") else 0.0,\n",
        "            \"BLEU_avg\": (sum(case_scores_all.get(f\"{case}_bleu\", [])) /\n",
        "                         len(case_scores_all.get(f\"{case}_bleu\", []))) if case_scores_all.get(f\"{case}_bleu\") else 0.0,\n",
        "            \"BLEU_best\": (sum(case_scores_best.get(f\"{case}_bleu\", [])) /\n",
        "                          len(case_scores_best.get(f\"{case}_bleu\", []))) if case_scores_best.get(f\"{case}_bleu\") else 0.0,\n",
        "            \"ROUGE-L_avg\": (sum(case_scores_all.get(f\"{case}_rougeL\", [])) /\n",
        "                            len(case_scores_all.get(f\"{case}_rougeL\", []))) if case_scores_all.get(f\"{case}_rougeL\") else 0.0,\n",
        "            \"ROUGE-L_best\": (sum(case_scores_best.get(f\"{case}_rougeL\", [])) /\n",
        "                             len(case_scores_best.get(f\"{case}_rougeL\", []))) if case_scores_best.get(f\"{case}_rougeL\") else 0.0,\n",
        "        })\n",
        "\n",
        "    df = pd.DataFrame(results)\n",
        "    print(\"\\n=== Partial Evaluation (avg over all vs. best-per-example) ===\")\n",
        "    print(df.to_string(index=False, float_format=lambda x: f\"{x:.4f}\"))\n",
        "\n",
        "    return df\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "7c98acfb10cc4f0e913a9dd7d609cb20",
            "b1039b4a4fe44165a4197d4b9f9a9f03",
            "05400a2fbf074541a4496a22a6fa5018",
            "2162fc48cda6465aba6bd2f9abed1ae9",
            "9ac7ab1c499e46218731fc237fa45933",
            "d5ff5750629740b4bf0b2d7d7e2fc6f7",
            "2af4884b51a64f9bae66fcc72c308cb8",
            "09c2f611bb074dceb309a832fee6bb64",
            "b864b94bd86e42fcb9d7d654d324fd69",
            "31362b8b55a748ce9c7580b14df5a71e",
            "1e5cf8eeb8a84f4a92483fc4030436c6"
          ]
        },
        "id": "hvuclV_TP-wF",
        "outputId": "6e90c61f-ca5a-4d2a-eea2-27acfdc98130"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7c98acfb10cc4f0e913a9dd7d609cb20",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  0%|          | 0/100 [00:00<?, ?it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 9.8703 sec (91.56% of total)\n",
            "softmax_nuc    : 0.1066 sec (0.99% of total)\n",
            "boosting       : 0.1237 sec (1.15% of total)\n",
            "sampling_topk  : 0.1120 sec (1.04% of total)\n",
            "TOTAL          : 10.7802 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  1%|          | 1/100 [00:16<26:42, 16.19s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 1.8737 sec (91.59% of total)\n",
            "softmax_nuc    : 0.0152 sec (0.74% of total)\n",
            "boosting       : 0.0573 sec (2.80% of total)\n",
            "sampling_topk  : 0.0162 sec (0.79% of total)\n",
            "TOTAL          : 2.0458 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  2%|▏         | 2/100 [00:22<17:00, 10.41s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 10.7474 sec (91.59% of total)\n",
            "softmax_nuc    : 0.1166 sec (0.99% of total)\n",
            "boosting       : 0.1303 sec (1.11% of total)\n",
            "sampling_topk  : 0.1224 sec (1.04% of total)\n",
            "TOTAL          : 11.7339 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  3%|▎         | 3/100 [00:44<25:34, 15.82s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.6916 sec (91.38% of total)\n",
            "softmax_nuc    : 0.0249 sec (0.84% of total)\n",
            "boosting       : 0.0644 sec (2.19% of total)\n",
            "sampling_topk  : 0.0262 sec (0.89% of total)\n",
            "TOTAL          : 2.9457 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  4%|▍         | 4/100 [00:51<19:46, 12.36s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.9593 sec (91.47% of total)\n",
            "softmax_nuc    : 0.0738 sec (0.97% of total)\n",
            "boosting       : 0.1012 sec (1.33% of total)\n",
            "sampling_topk  : 0.0774 sec (1.02% of total)\n",
            "TOTAL          : 7.6081 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  5%|▌         | 5/100 [01:05<20:26, 12.91s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.7062 sec (91.52% of total)\n",
            "softmax_nuc    : 0.0364 sec (0.90% of total)\n",
            "boosting       : 0.0728 sec (1.80% of total)\n",
            "sampling_topk  : 0.0384 sec (0.95% of total)\n",
            "TOTAL          : 4.0495 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  6%|▌         | 6/100 [01:15<18:43, 11.95s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 9.0762 sec (91.63% of total)\n",
            "softmax_nuc    : 0.0979 sec (0.99% of total)\n",
            "boosting       : 0.1145 sec (1.16% of total)\n",
            "sampling_topk  : 0.1025 sec (1.03% of total)\n",
            "TOTAL          : 9.9057 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  7%|▋         | 7/100 [01:34<21:59, 14.19s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 5.7342 sec (91.77% of total)\n",
            "softmax_nuc    : 0.0591 sec (0.95% of total)\n",
            "boosting       : 0.0858 sec (1.37% of total)\n",
            "sampling_topk  : 0.0618 sec (0.99% of total)\n",
            "TOTAL          : 6.2481 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  8%|▊         | 8/100 [01:46<20:39, 13.47s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.1140 sec (91.55% of total)\n",
            "softmax_nuc    : 0.0640 sec (0.96% of total)\n",
            "boosting       : 0.0924 sec (1.38% of total)\n",
            "sampling_topk  : 0.0671 sec (1.00% of total)\n",
            "TOTAL          : 6.6783 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r  9%|▉         | 9/100 [02:00<20:33, 13.55s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 8.7726 sec (91.58% of total)\n",
            "softmax_nuc    : 0.0944 sec (0.99% of total)\n",
            "boosting       : 0.1150 sec (1.20% of total)\n",
            "sampling_topk  : 0.0990 sec (1.03% of total)\n",
            "TOTAL          : 9.5793 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 10%|█         | 10/100 [02:16<21:24, 14.27s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 1.8631 sec (91.46% of total)\n",
            "softmax_nuc    : 0.0152 sec (0.75% of total)\n",
            "boosting       : 0.0576 sec (2.83% of total)\n",
            "sampling_topk  : 0.0162 sec (0.80% of total)\n",
            "TOTAL          : 2.0371 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 11%|█         | 11/100 [02:21<17:08, 11.56s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 5.1789 sec (91.65% of total)\n",
            "softmax_nuc    : 0.0532 sec (0.94% of total)\n",
            "boosting       : 0.0837 sec (1.48% of total)\n",
            "sampling_topk  : 0.0557 sec (0.99% of total)\n",
            "TOTAL          : 5.6507 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 12%|█▏        | 12/100 [02:32<16:50, 11.48s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.9645 sec (91.64% of total)\n",
            "softmax_nuc    : 0.0390 sec (0.90% of total)\n",
            "boosting       : 0.0743 sec (1.72% of total)\n",
            "sampling_topk  : 0.0411 sec (0.95% of total)\n",
            "TOTAL          : 4.3264 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 13%|█▎        | 13/100 [02:41<15:28, 10.68s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.6236 sec (91.58% of total)\n",
            "softmax_nuc    : 0.0243 sec (0.85% of total)\n",
            "boosting       : 0.0633 sec (2.21% of total)\n",
            "sampling_topk  : 0.0255 sec (0.89% of total)\n",
            "TOTAL          : 2.8649 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 14%|█▍        | 14/100 [02:49<14:00,  9.77s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.2774 sec (91.47% of total)\n",
            "softmax_nuc    : 0.0202 sec (0.81% of total)\n",
            "boosting       : 0.0604 sec (2.43% of total)\n",
            "sampling_topk  : 0.0212 sec (0.85% of total)\n",
            "TOTAL          : 2.4899 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 15%|█▌        | 15/100 [02:56<12:51,  9.08s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.5903 sec (91.45% of total)\n",
            "softmax_nuc    : 0.0463 sec (0.92% of total)\n",
            "boosting       : 0.0807 sec (1.61% of total)\n",
            "sampling_topk  : 0.0490 sec (0.98% of total)\n",
            "TOTAL          : 5.0197 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 16%|█▌        | 16/100 [03:06<12:57,  9.25s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.2759 sec (91.63% of total)\n",
            "softmax_nuc    : 0.0199 sec (0.80% of total)\n",
            "boosting       : 0.0596 sec (2.40% of total)\n",
            "sampling_topk  : 0.0210 sec (0.85% of total)\n",
            "TOTAL          : 2.4838 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 17%|█▋        | 17/100 [03:13<11:38,  8.42s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 1.2639 sec (91.51% of total)\n",
            "softmax_nuc    : 0.0085 sec (0.62% of total)\n",
            "boosting       : 0.0526 sec (3.81% of total)\n",
            "sampling_topk  : 0.0090 sec (0.65% of total)\n",
            "TOTAL          : 1.3811 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 18%|█▊        | 18/100 [03:17<09:57,  7.28s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.1581 sec (91.43% of total)\n",
            "softmax_nuc    : 0.0188 sec (0.80% of total)\n",
            "boosting       : 0.0595 sec (2.52% of total)\n",
            "sampling_topk  : 0.0201 sec (0.85% of total)\n",
            "TOTAL          : 2.3603 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 19%|█▉        | 19/100 [03:24<09:34,  7.09s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 5.5965 sec (92.08% of total)\n",
            "softmax_nuc    : 0.0584 sec (0.96% of total)\n",
            "boosting       : 0.0786 sec (1.29% of total)\n",
            "sampling_topk  : 0.0614 sec (1.01% of total)\n",
            "TOTAL          : 6.0780 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 20%|██        | 20/100 [03:34<10:53,  8.17s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "💾 Saved 20 examples so far\n",
            "\n",
            "=== Partial Evaluation (avg over all vs. best-per-example) ===\n",
            " case  F1_avg  F1_best  BLEU_avg  BLEU_best  ROUGE-L_avg  ROUGE-L_best\n",
            "case1  0.3467   0.6350    0.0815     0.1938       0.4328        0.6864\n",
            "case2  0.4612   0.6948    0.0945     0.2256       0.5246        0.7494\n",
            "case3  0.0878   0.1726    0.0186     0.0462       0.0964        0.1749\n",
            "case4  0.5898   0.6466    0.1270     0.1899       0.6654        0.6883\n",
            "\n",
            "===== Timing Report =====\n",
            "forward        : 10.1501 sec (91.55% of total)\n",
            "softmax_nuc    : 0.1104 sec (1.00% of total)\n",
            "boosting       : 0.1259 sec (1.14% of total)\n",
            "sampling_topk  : 0.1158 sec (1.04% of total)\n",
            "TOTAL          : 11.0873 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 21%|██        | 21/100 [03:51<14:02, 10.66s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.1532 sec (91.45% of total)\n",
            "softmax_nuc    : 0.0186 sec (0.79% of total)\n",
            "boosting       : 0.0595 sec (2.53% of total)\n",
            "sampling_topk  : 0.0195 sec (0.83% of total)\n",
            "TOTAL          : 2.3544 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 22%|██▏       | 22/100 [03:58<12:33,  9.65s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.6799 sec (91.62% of total)\n",
            "softmax_nuc    : 0.0250 sec (0.85% of total)\n",
            "boosting       : 0.0637 sec (2.18% of total)\n",
            "sampling_topk  : 0.0259 sec (0.89% of total)\n",
            "TOTAL          : 2.9249 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 23%|██▎       | 23/100 [04:06<11:28,  8.94s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.9209 sec (91.65% of total)\n",
            "softmax_nuc    : 0.0848 sec (0.98% of total)\n",
            "boosting       : 0.1056 sec (1.22% of total)\n",
            "sampling_topk  : 0.0888 sec (1.03% of total)\n",
            "TOTAL          : 8.6421 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 24%|██▍       | 24/100 [04:20<13:16, 10.48s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.6578 sec (91.48% of total)\n",
            "softmax_nuc    : 0.0480 sec (0.94% of total)\n",
            "boosting       : 0.0798 sec (1.57% of total)\n",
            "sampling_topk  : 0.0499 sec (0.98% of total)\n",
            "TOTAL          : 5.0916 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 25%|██▌       | 25/100 [04:31<13:37, 10.90s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.3056 sec (91.76% of total)\n",
            "softmax_nuc    : 0.0204 sec (0.81% of total)\n",
            "boosting       : 0.0593 sec (2.36% of total)\n",
            "sampling_topk  : 0.0212 sec (0.85% of total)\n",
            "TOTAL          : 2.5125 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 26%|██▌       | 26/100 [04:39<12:06,  9.82s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.9778 sec (91.51% of total)\n",
            "softmax_nuc    : 0.0852 sec (0.98% of total)\n",
            "boosting       : 0.1092 sec (1.25% of total)\n",
            "sampling_topk  : 0.0896 sec (1.03% of total)\n",
            "TOTAL          : 8.7184 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 27%|██▋       | 27/100 [04:59<15:36, 12.83s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.0292 sec (91.41% of total)\n",
            "softmax_nuc    : 0.0175 sec (0.79% of total)\n",
            "boosting       : 0.0577 sec (2.60% of total)\n",
            "sampling_topk  : 0.0187 sec (0.84% of total)\n",
            "TOTAL          : 2.2198 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 28%|██▊       | 28/100 [05:04<12:52, 10.73s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 9.8706 sec (91.66% of total)\n",
            "softmax_nuc    : 0.1071 sec (0.99% of total)\n",
            "boosting       : 0.1191 sec (1.11% of total)\n",
            "sampling_topk  : 0.1120 sec (1.04% of total)\n",
            "TOTAL          : 10.7689 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 29%|██▉       | 29/100 [05:19<14:10, 11.98s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.6343 sec (91.67% of total)\n",
            "softmax_nuc    : 0.0246 sec (0.86% of total)\n",
            "boosting       : 0.0610 sec (2.12% of total)\n",
            "sampling_topk  : 0.0258 sec (0.90% of total)\n",
            "TOTAL          : 2.8737 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 30%|███       | 30/100 [05:26<12:07, 10.40s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 5.9710 sec (91.94% of total)\n",
            "softmax_nuc    : 0.0625 sec (0.96% of total)\n",
            "boosting       : 0.0846 sec (1.30% of total)\n",
            "sampling_topk  : 0.0658 sec (1.01% of total)\n",
            "TOTAL          : 6.4944 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 31%|███       | 31/100 [05:38<12:22, 10.77s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.8971 sec (91.60% of total)\n",
            "softmax_nuc    : 0.0383 sec (0.90% of total)\n",
            "boosting       : 0.0734 sec (1.73% of total)\n",
            "sampling_topk  : 0.0404 sec (0.95% of total)\n",
            "TOTAL          : 4.2544 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 32%|███▏      | 32/100 [05:47<11:45, 10.38s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 10.3287 sec (91.67% of total)\n",
            "softmax_nuc    : 0.1118 sec (0.99% of total)\n",
            "boosting       : 0.1225 sec (1.09% of total)\n",
            "sampling_topk  : 0.1182 sec (1.05% of total)\n",
            "TOTAL          : 11.2668 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 33%|███▎      | 33/100 [06:17<17:57, 16.08s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.8593 sec (91.68% of total)\n",
            "softmax_nuc    : 0.0497 sec (0.94% of total)\n",
            "boosting       : 0.0782 sec (1.48% of total)\n",
            "sampling_topk  : 0.0527 sec (0.99% of total)\n",
            "TOTAL          : 5.3003 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 34%|███▍      | 34/100 [06:28<16:07, 14.66s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 9.8319 sec (91.76% of total)\n",
            "softmax_nuc    : 0.1064 sec (0.99% of total)\n",
            "boosting       : 0.1177 sec (1.10% of total)\n",
            "sampling_topk  : 0.1118 sec (1.04% of total)\n",
            "TOTAL          : 10.7153 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 35%|███▌      | 35/100 [06:49<18:05, 16.70s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.9567 sec (91.74% of total)\n",
            "softmax_nuc    : 0.0854 sec (0.99% of total)\n",
            "boosting       : 0.1036 sec (1.19% of total)\n",
            "sampling_topk  : 0.0896 sec (1.03% of total)\n",
            "TOTAL          : 8.6730 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 36%|███▌      | 36/100 [07:06<17:42, 16.59s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.9544 sec (91.57% of total)\n",
            "softmax_nuc    : 0.0276 sec (0.85% of total)\n",
            "boosting       : 0.0662 sec (2.05% of total)\n",
            "sampling_topk  : 0.0290 sec (0.90% of total)\n",
            "TOTAL          : 3.2265 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 37%|███▋      | 37/100 [07:17<15:40, 14.93s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 5.4919 sec (91.53% of total)\n",
            "softmax_nuc    : 0.0572 sec (0.95% of total)\n",
            "boosting       : 0.0871 sec (1.45% of total)\n",
            "sampling_topk  : 0.0598 sec (1.00% of total)\n",
            "TOTAL          : 6.0003 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 38%|███▊      | 38/100 [07:31<15:17, 14.80s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.9653 sec (91.58% of total)\n",
            "softmax_nuc    : 0.0851 sec (0.98% of total)\n",
            "boosting       : 0.1075 sec (1.24% of total)\n",
            "sampling_topk  : 0.0895 sec (1.03% of total)\n",
            "TOTAL          : 8.6980 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 39%|███▉      | 39/100 [07:49<15:57, 15.69s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 1.6132 sec (91.63% of total)\n",
            "softmax_nuc    : 0.0126 sec (0.72% of total)\n",
            "boosting       : 0.0534 sec (3.03% of total)\n",
            "sampling_topk  : 0.0135 sec (0.77% of total)\n",
            "TOTAL          : 1.7606 sec\n",
            "\n",
            "💾 Saved 40 examples so far\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 40%|████      | 40/100 [07:55<12:39, 12.67s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "=== Partial Evaluation (avg over all vs. best-per-example) ===\n",
            " case  F1_avg  F1_best  BLEU_avg  BLEU_best  ROUGE-L_avg  ROUGE-L_best\n",
            "case1  0.2512   0.5062    0.0489     0.1206       0.3208        0.5668\n",
            "case2  0.3548   0.5814    0.0741     0.1759       0.4326        0.6419\n",
            "case3  0.0667   0.1282    0.0112     0.0277       0.0749        0.1356\n",
            "case4  0.4702   0.5357    0.0931     0.1483       0.5738        0.6308\n",
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.2522 sec (91.85% of total)\n",
            "softmax_nuc    : 0.0658 sec (0.97% of total)\n",
            "boosting       : 0.0869 sec (1.28% of total)\n",
            "sampling_topk  : 0.0693 sec (1.02% of total)\n",
            "TOTAL          : 6.8066 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 41%|████      | 41/100 [08:07<12:22, 12.59s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.7086 sec (91.59% of total)\n",
            "softmax_nuc    : 0.0362 sec (0.89% of total)\n",
            "boosting       : 0.0723 sec (1.78% of total)\n",
            "sampling_topk  : 0.0381 sec (0.94% of total)\n",
            "TOTAL          : 4.0492 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 42%|████▏     | 42/100 [08:17<11:28, 11.87s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.1329 sec (91.56% of total)\n",
            "softmax_nuc    : 0.0184 sec (0.79% of total)\n",
            "boosting       : 0.0588 sec (2.52% of total)\n",
            "sampling_topk  : 0.0195 sec (0.84% of total)\n",
            "TOTAL          : 2.3296 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 43%|████▎     | 43/100 [08:23<09:32, 10.04s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.1762 sec (91.53% of total)\n",
            "softmax_nuc    : 0.0418 sec (0.92% of total)\n",
            "boosting       : 0.0764 sec (1.67% of total)\n",
            "sampling_topk  : 0.0442 sec (0.97% of total)\n",
            "TOTAL          : 4.5628 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 44%|████▍     | 44/100 [08:35<09:48, 10.50s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.0677 sec (91.54% of total)\n",
            "softmax_nuc    : 0.0180 sec (0.80% of total)\n",
            "boosting       : 0.0578 sec (2.56% of total)\n",
            "sampling_topk  : 0.0188 sec (0.83% of total)\n",
            "TOTAL          : 2.2589 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 45%|████▌     | 45/100 [08:40<08:19,  9.08s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 8.6213 sec (91.63% of total)\n",
            "softmax_nuc    : 0.0924 sec (0.98% of total)\n",
            "boosting       : 0.1097 sec (1.17% of total)\n",
            "sampling_topk  : 0.0972 sec (1.03% of total)\n",
            "TOTAL          : 9.4085 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 46%|████▌     | 46/100 [09:00<11:00, 12.22s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.6872 sec (91.59% of total)\n",
            "softmax_nuc    : 0.0246 sec (0.84% of total)\n",
            "boosting       : 0.0627 sec (2.14% of total)\n",
            "sampling_topk  : 0.0263 sec (0.90% of total)\n",
            "TOTAL          : 2.9339 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 47%|████▋     | 47/100 [09:08<09:36, 10.87s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.3276 sec (91.58% of total)\n",
            "softmax_nuc    : 0.0776 sec (0.97% of total)\n",
            "boosting       : 0.1026 sec (1.28% of total)\n",
            "sampling_topk  : 0.0813 sec (1.02% of total)\n",
            "TOTAL          : 8.0009 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 48%|████▊     | 48/100 [09:25<11:12, 12.94s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.1739 sec (91.50% of total)\n",
            "softmax_nuc    : 0.0422 sec (0.93% of total)\n",
            "boosting       : 0.0751 sec (1.65% of total)\n",
            "sampling_topk  : 0.0452 sec (0.99% of total)\n",
            "TOTAL          : 4.5616 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 49%|████▉     | 49/100 [09:37<10:38, 12.52s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.1576 sec (91.78% of total)\n",
            "softmax_nuc    : 0.0642 sec (0.96% of total)\n",
            "boosting       : 0.0895 sec (1.33% of total)\n",
            "sampling_topk  : 0.0674 sec (1.00% of total)\n",
            "TOTAL          : 6.7093 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 50%|█████     | 50/100 [09:51<10:49, 12.98s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.4734 sec (91.71% of total)\n",
            "softmax_nuc    : 0.0684 sec (0.97% of total)\n",
            "boosting       : 0.0928 sec (1.32% of total)\n",
            "sampling_topk  : 0.0722 sec (1.02% of total)\n",
            "TOTAL          : 7.0586 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 51%|█████     | 51/100 [10:02<10:04, 12.33s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 5.7930 sec (91.56% of total)\n",
            "softmax_nuc    : 0.0606 sec (0.96% of total)\n",
            "boosting       : 0.0897 sec (1.42% of total)\n",
            "sampling_topk  : 0.0637 sec (1.01% of total)\n",
            "TOTAL          : 6.3270 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 52%|█████▏    | 52/100 [10:14<09:51, 12.33s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.4109 sec (91.50% of total)\n",
            "softmax_nuc    : 0.0214 sec (0.81% of total)\n",
            "boosting       : 0.0620 sec (2.35% of total)\n",
            "sampling_topk  : 0.0227 sec (0.86% of total)\n",
            "TOTAL          : 2.6350 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 53%|█████▎    | 53/100 [10:21<08:17, 10.58s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.5863 sec (91.54% of total)\n",
            "softmax_nuc    : 0.0466 sec (0.93% of total)\n",
            "boosting       : 0.0795 sec (1.59% of total)\n",
            "sampling_topk  : 0.0492 sec (0.98% of total)\n",
            "TOTAL          : 5.0099 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 54%|█████▍    | 54/100 [10:30<07:56, 10.36s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.9715 sec (91.62% of total)\n",
            "softmax_nuc    : 0.0285 sec (0.88% of total)\n",
            "boosting       : 0.0630 sec (1.94% of total)\n",
            "sampling_topk  : 0.0300 sec (0.92% of total)\n",
            "TOTAL          : 3.2433 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 55%|█████▌    | 55/100 [10:37<06:51,  9.15s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.7964 sec (91.70% of total)\n",
            "softmax_nuc    : 0.0384 sec (0.93% of total)\n",
            "boosting       : 0.0689 sec (1.66% of total)\n",
            "sampling_topk  : 0.0405 sec (0.98% of total)\n",
            "TOTAL          : 4.1399 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 56%|█████▌    | 56/100 [10:48<07:06,  9.69s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.3097 sec (91.67% of total)\n",
            "softmax_nuc    : 0.0438 sec (0.93% of total)\n",
            "boosting       : 0.0757 sec (1.61% of total)\n",
            "sampling_topk  : 0.0456 sec (0.97% of total)\n",
            "TOTAL          : 4.7016 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 57%|█████▋    | 57/100 [10:57<06:45,  9.42s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.7223 sec (91.65% of total)\n",
            "softmax_nuc    : 0.0717 sec (0.98% of total)\n",
            "boosting       : 0.0936 sec (1.28% of total)\n",
            "sampling_topk  : 0.0754 sec (1.03% of total)\n",
            "TOTAL          : 7.3347 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 58%|█████▊    | 58/100 [11:12<07:48, 11.16s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.8215 sec (91.51% of total)\n",
            "softmax_nuc    : 0.0726 sec (0.97% of total)\n",
            "boosting       : 0.0970 sec (1.30% of total)\n",
            "sampling_topk  : 0.0767 sec (1.03% of total)\n",
            "TOTAL          : 7.4542 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 59%|█████▉    | 59/100 [11:26<08:19, 12.18s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.5980 sec (91.42% of total)\n",
            "softmax_nuc    : 0.0711 sec (0.99% of total)\n",
            "boosting       : 0.0954 sec (1.32% of total)\n",
            "sampling_topk  : 0.0746 sec (1.03% of total)\n",
            "TOTAL          : 7.2175 sec\n",
            "\n",
            "💾 Saved 60 examples so far\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 60%|██████    | 60/100 [11:38<07:57, 11.94s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "=== Partial Evaluation (avg over all vs. best-per-example) ===\n",
            " case  F1_avg  F1_best  BLEU_avg  BLEU_best  ROUGE-L_avg  ROUGE-L_best\n",
            "case1  0.2335   0.5105    0.0465     0.1232       0.3289        0.6035\n",
            "case2  0.3091   0.5382    0.0669     0.1602       0.4116        0.6397\n",
            "case3  0.0505   0.1041    0.0086     0.0230       0.0559        0.1090\n",
            "case4  0.4174   0.5649    0.0828     0.1439       0.5623        0.6757\n",
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.1581 sec (91.68% of total)\n",
            "softmax_nuc    : 0.0300 sec (0.87% of total)\n",
            "boosting       : 0.0657 sec (1.91% of total)\n",
            "sampling_topk  : 0.0317 sec (0.92% of total)\n",
            "TOTAL          : 3.4446 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 61%|██████    | 61/100 [11:45<06:50, 10.52s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.4631 sec (91.66% of total)\n",
            "softmax_nuc    : 0.0215 sec (0.80% of total)\n",
            "boosting       : 0.0620 sec (2.31% of total)\n",
            "sampling_topk  : 0.0226 sec (0.84% of total)\n",
            "TOTAL          : 2.6871 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 62%|██████▏   | 62/100 [11:53<06:11,  9.77s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.2892 sec (91.76% of total)\n",
            "softmax_nuc    : 0.0426 sec (0.91% of total)\n",
            "boosting       : 0.0753 sec (1.61% of total)\n",
            "sampling_topk  : 0.0446 sec (0.95% of total)\n",
            "TOTAL          : 4.6746 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 63%|██████▎   | 63/100 [12:02<05:57,  9.66s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 8.3172 sec (91.74% of total)\n",
            "softmax_nuc    : 0.0889 sec (0.98% of total)\n",
            "boosting       : 0.1052 sec (1.16% of total)\n",
            "sampling_topk  : 0.0930 sec (1.03% of total)\n",
            "TOTAL          : 9.0660 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 64%|██████▍   | 64/100 [12:19<07:05, 11.83s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.7404 sec (92.07% of total)\n",
            "softmax_nuc    : 0.0824 sec (0.98% of total)\n",
            "boosting       : 0.0927 sec (1.10% of total)\n",
            "sampling_topk  : 0.0863 sec (1.03% of total)\n",
            "TOTAL          : 8.4069 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 65%|██████▌   | 65/100 [12:34<07:28, 12.81s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 9.9026 sec (91.57% of total)\n",
            "softmax_nuc    : 0.1069 sec (0.99% of total)\n",
            "boosting       : 0.1197 sec (1.11% of total)\n",
            "sampling_topk  : 0.1121 sec (1.04% of total)\n",
            "TOTAL          : 10.8147 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 66%|██████▌   | 66/100 [12:55<08:39, 15.27s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.6106 sec (91.56% of total)\n",
            "softmax_nuc    : 0.0232 sec (0.81% of total)\n",
            "boosting       : 0.0626 sec (2.20% of total)\n",
            "sampling_topk  : 0.0245 sec (0.86% of total)\n",
            "TOTAL          : 2.8511 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 67%|██████▋   | 67/100 [13:04<07:18, 13.28s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.0259 sec (91.56% of total)\n",
            "softmax_nuc    : 0.0396 sec (0.90% of total)\n",
            "boosting       : 0.0752 sec (1.71% of total)\n",
            "sampling_topk  : 0.0412 sec (0.94% of total)\n",
            "TOTAL          : 4.3968 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 68%|██████▊   | 68/100 [13:15<06:42, 12.57s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.6122 sec (92.05% of total)\n",
            "softmax_nuc    : 0.0346 sec (0.88% of total)\n",
            "boosting       : 0.0669 sec (1.70% of total)\n",
            "sampling_topk  : 0.0365 sec (0.93% of total)\n",
            "TOTAL          : 3.9241 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 69%|██████▉   | 69/100 [13:23<05:47, 11.22s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.7427 sec (91.75% of total)\n",
            "softmax_nuc    : 0.0248 sec (0.83% of total)\n",
            "boosting       : 0.0627 sec (2.10% of total)\n",
            "sampling_topk  : 0.0266 sec (0.89% of total)\n",
            "TOTAL          : 2.9892 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 70%|███████   | 70/100 [13:30<05:00, 10.02s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.4452 sec (91.36% of total)\n",
            "softmax_nuc    : 0.0338 sec (0.90% of total)\n",
            "boosting       : 0.0705 sec (1.87% of total)\n",
            "sampling_topk  : 0.0358 sec (0.95% of total)\n",
            "TOTAL          : 3.7711 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 71%|███████   | 71/100 [13:39<04:40,  9.66s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 10.3190 sec (91.65% of total)\n",
            "softmax_nuc    : 0.1123 sec (1.00% of total)\n",
            "boosting       : 0.1202 sec (1.07% of total)\n",
            "sampling_topk  : 0.1174 sec (1.04% of total)\n",
            "TOTAL          : 11.2594 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 72%|███████▏  | 72/100 [13:59<05:57, 12.78s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.2802 sec (91.67% of total)\n",
            "softmax_nuc    : 0.0310 sec (0.87% of total)\n",
            "boosting       : 0.0676 sec (1.89% of total)\n",
            "sampling_topk  : 0.0325 sec (0.91% of total)\n",
            "TOTAL          : 3.5784 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 73%|███████▎  | 73/100 [14:07<05:07, 11.40s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 9.3113 sec (91.52% of total)\n",
            "softmax_nuc    : 0.1001 sec (0.98% of total)\n",
            "boosting       : 0.1184 sec (1.16% of total)\n",
            "sampling_topk  : 0.1048 sec (1.03% of total)\n",
            "TOTAL          : 10.1735 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 74%|███████▍  | 74/100 [14:28<06:11, 14.30s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.5001 sec (91.36% of total)\n",
            "softmax_nuc    : 0.0339 sec (0.89% of total)\n",
            "boosting       : 0.0715 sec (1.87% of total)\n",
            "sampling_topk  : 0.0359 sec (0.94% of total)\n",
            "TOTAL          : 3.8311 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 75%|███████▌  | 75/100 [14:39<05:31, 13.25s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 8.7921 sec (91.67% of total)\n",
            "softmax_nuc    : 0.0944 sec (0.98% of total)\n",
            "boosting       : 0.1068 sec (1.11% of total)\n",
            "sampling_topk  : 0.0990 sec (1.03% of total)\n",
            "TOTAL          : 9.5912 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 76%|███████▌  | 76/100 [14:57<05:50, 14.61s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 1.6436 sec (91.52% of total)\n",
            "softmax_nuc    : 0.0122 sec (0.68% of total)\n",
            "boosting       : 0.0559 sec (3.11% of total)\n",
            "sampling_topk  : 0.0129 sec (0.72% of total)\n",
            "TOTAL          : 1.7959 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 77%|███████▋  | 77/100 [15:03<04:37, 12.06s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.2997 sec (91.76% of total)\n",
            "softmax_nuc    : 0.0774 sec (0.97% of total)\n",
            "boosting       : 0.0956 sec (1.20% of total)\n",
            "sampling_topk  : 0.0811 sec (1.02% of total)\n",
            "TOTAL          : 7.9555 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 78%|███████▊  | 78/100 [15:15<04:25, 12.06s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.8299 sec (91.89% of total)\n",
            "softmax_nuc    : 0.0376 sec (0.90% of total)\n",
            "boosting       : 0.0686 sec (1.65% of total)\n",
            "sampling_topk  : 0.0396 sec (0.95% of total)\n",
            "TOTAL          : 4.1679 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 79%|███████▉  | 79/100 [15:25<04:00, 11.45s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 1.9914 sec (91.91% of total)\n",
            "softmax_nuc    : 0.0163 sec (0.75% of total)\n",
            "boosting       : 0.0558 sec (2.58% of total)\n",
            "sampling_topk  : 0.0173 sec (0.80% of total)\n",
            "TOTAL          : 2.1665 sec\n",
            "\n",
            "💾 Saved 80 examples so far\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 80%|████████  | 80/100 [15:32<03:19,  9.96s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "=== Partial Evaluation (avg over all vs. best-per-example) ===\n",
            " case  F1_avg  F1_best  BLEU_avg  BLEU_best  ROUGE-L_avg  ROUGE-L_best\n",
            "case1  0.2437   0.5201    0.0457     0.1170       0.3243        0.5958\n",
            "case2  0.3163   0.5338    0.0636     0.1500       0.4002        0.6125\n",
            "case3  0.0506   0.1098    0.0080     0.0228       0.0540        0.1117\n",
            "case4  0.4137   0.5572    0.0765     0.1295       0.5377        0.6452\n",
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.4740 sec (91.67% of total)\n",
            "softmax_nuc    : 0.0221 sec (0.82% of total)\n",
            "boosting       : 0.0602 sec (2.23% of total)\n",
            "sampling_topk  : 0.0232 sec (0.86% of total)\n",
            "TOTAL          : 2.6987 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 81%|████████  | 81/100 [15:39<02:54,  9.19s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.3048 sec (91.67% of total)\n",
            "softmax_nuc    : 0.0430 sec (0.92% of total)\n",
            "boosting       : 0.0753 sec (1.60% of total)\n",
            "sampling_topk  : 0.0452 sec (0.96% of total)\n",
            "TOTAL          : 4.6961 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 82%|████████▏ | 82/100 [15:48<02:44,  9.12s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.6904 sec (91.96% of total)\n",
            "softmax_nuc    : 0.0363 sec (0.90% of total)\n",
            "boosting       : 0.0671 sec (1.67% of total)\n",
            "sampling_topk  : 0.0379 sec (0.94% of total)\n",
            "TOTAL          : 4.0130 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 83%|████████▎ | 83/100 [16:00<02:52, 10.15s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 10.0327 sec (91.54% of total)\n",
            "softmax_nuc    : 0.1083 sec (0.99% of total)\n",
            "boosting       : 0.1205 sec (1.10% of total)\n",
            "sampling_topk  : 0.1139 sec (1.04% of total)\n",
            "TOTAL          : 10.9599 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 84%|████████▍ | 84/100 [16:25<03:51, 14.48s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 8.2089 sec (91.62% of total)\n",
            "softmax_nuc    : 0.0873 sec (0.97% of total)\n",
            "boosting       : 0.1045 sec (1.17% of total)\n",
            "sampling_topk  : 0.0916 sec (1.02% of total)\n",
            "TOTAL          : 8.9598 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 85%|████████▌ | 85/100 [16:43<03:51, 15.42s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 11.0538 sec (91.75% of total)\n",
            "softmax_nuc    : 0.1205 sec (1.00% of total)\n",
            "boosting       : 0.1214 sec (1.01% of total)\n",
            "sampling_topk  : 0.1260 sec (1.05% of total)\n",
            "TOTAL          : 12.0480 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 86%|████████▌ | 86/100 [17:05<04:04, 17.49s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.0554 sec (92.00% of total)\n",
            "softmax_nuc    : 0.0167 sec (0.75% of total)\n",
            "boosting       : 0.0574 sec (2.57% of total)\n",
            "sampling_topk  : 0.0177 sec (0.79% of total)\n",
            "TOTAL          : 2.2340 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 87%|████████▋ | 87/100 [17:11<03:02, 14.03s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.1135 sec (91.85% of total)\n",
            "softmax_nuc    : 0.0631 sec (0.95% of total)\n",
            "boosting       : 0.0878 sec (1.32% of total)\n",
            "sampling_topk  : 0.0662 sec (0.99% of total)\n",
            "TOTAL          : 6.6563 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 88%|████████▊ | 88/100 [17:26<02:52, 14.38s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 10.3559 sec (91.89% of total)\n",
            "softmax_nuc    : 0.1125 sec (1.00% of total)\n",
            "boosting       : 0.1151 sec (1.02% of total)\n",
            "sampling_topk  : 0.1182 sec (1.05% of total)\n",
            "TOTAL          : 11.2697 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 89%|████████▉ | 89/100 [17:50<03:09, 17.21s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 6.4364 sec (92.00% of total)\n",
            "softmax_nuc    : 0.0671 sec (0.96% of total)\n",
            "boosting       : 0.0860 sec (1.23% of total)\n",
            "sampling_topk  : 0.0704 sec (1.01% of total)\n",
            "TOTAL          : 6.9961 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 90%|█████████ | 90/100 [18:03<02:40, 16.07s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 3.6219 sec (91.83% of total)\n",
            "softmax_nuc    : 0.0346 sec (0.88% of total)\n",
            "boosting       : 0.0692 sec (1.76% of total)\n",
            "sampling_topk  : 0.0364 sec (0.92% of total)\n",
            "TOTAL          : 3.9440 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 91%|█████████ | 91/100 [18:11<02:01, 13.50s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.7904 sec (91.46% of total)\n",
            "softmax_nuc    : 0.0255 sec (0.84% of total)\n",
            "boosting       : 0.0659 sec (2.16% of total)\n",
            "sampling_topk  : 0.0266 sec (0.87% of total)\n",
            "TOTAL          : 3.0510 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 92%|█████████▏| 92/100 [18:19<01:34, 11.80s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 8.2795 sec (91.60% of total)\n",
            "softmax_nuc    : 0.0880 sec (0.97% of total)\n",
            "boosting       : 0.1069 sec (1.18% of total)\n",
            "sampling_topk  : 0.0926 sec (1.02% of total)\n",
            "TOTAL          : 9.0391 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 93%|█████████▎| 93/100 [18:37<01:36, 13.74s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 5.2614 sec (91.62% of total)\n",
            "softmax_nuc    : 0.0537 sec (0.94% of total)\n",
            "boosting       : 0.0819 sec (1.43% of total)\n",
            "sampling_topk  : 0.0568 sec (0.99% of total)\n",
            "TOTAL          : 5.7426 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 94%|█████████▍| 94/100 [18:50<01:20, 13.38s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.8720 sec (91.50% of total)\n",
            "softmax_nuc    : 0.0266 sec (0.85% of total)\n",
            "boosting       : 0.0656 sec (2.09% of total)\n",
            "sampling_topk  : 0.0278 sec (0.89% of total)\n",
            "TOTAL          : 3.1390 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 95%|█████████▌| 95/100 [18:58<00:59, 11.83s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 4.8353 sec (91.59% of total)\n",
            "softmax_nuc    : 0.0489 sec (0.93% of total)\n",
            "boosting       : 0.0808 sec (1.53% of total)\n",
            "sampling_topk  : 0.0512 sec (0.97% of total)\n",
            "TOTAL          : 5.2790 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 96%|█████████▌| 96/100 [19:09<00:46, 11.63s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.1223 sec (91.64% of total)\n",
            "softmax_nuc    : 0.0176 sec (0.76% of total)\n",
            "boosting       : 0.0590 sec (2.55% of total)\n",
            "sampling_topk  : 0.0187 sec (0.81% of total)\n",
            "TOTAL          : 2.3159 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 97%|█████████▋| 97/100 [19:18<00:32, 10.85s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.9501 sec (91.61% of total)\n",
            "softmax_nuc    : 0.0271 sec (0.84% of total)\n",
            "boosting       : 0.0666 sec (2.07% of total)\n",
            "sampling_topk  : 0.0286 sec (0.89% of total)\n",
            "TOTAL          : 3.2205 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 98%|█████████▊| 98/100 [19:27<00:20, 10.34s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 2.7721 sec (91.51% of total)\n",
            "softmax_nuc    : 0.0253 sec (0.84% of total)\n",
            "boosting       : 0.0653 sec (2.16% of total)\n",
            "sampling_topk  : 0.0269 sec (0.89% of total)\n",
            "TOTAL          : 3.0295 sec\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\r 99%|█████████▉| 99/100 [19:36<00:10, 10.02s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "===== Timing Report =====\n",
            "forward        : 7.3091 sec (91.60% of total)\n",
            "softmax_nuc    : 0.0776 sec (0.97% of total)\n",
            "boosting       : 0.1001 sec (1.25% of total)\n",
            "sampling_topk  : 0.0811 sec (1.02% of total)\n",
            "TOTAL          : 7.9795 sec\n",
            "\n",
            "💾 Saved 100 examples so far\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 100/100 [19:53<00:00, 11.93s/it]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "=== Partial Evaluation (avg over all vs. best-per-example) ===\n",
            " case  F1_avg  F1_best  BLEU_avg  BLEU_best  ROUGE-L_avg  ROUGE-L_best\n",
            "case1  0.2345   0.5039    0.0435     0.1105       0.3110        0.5836\n",
            "case2  0.3241   0.5565    0.0679     0.1539       0.4021        0.6240\n",
            "case3  0.0482   0.1079    0.0073     0.0208       0.0538        0.1142\n",
            "case4  0.4174   0.5667    0.0821     0.1347       0.5344        0.6555\n",
            "✅ Finished, results at /content/out/results.jsonl and /content/out/results.txt\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "\n",
        "# ============== your helper functions (top_p_sampling, JSD, TVD, StopOnTokens, etc.)\n",
        "\n",
        "def top_p_sampling(probs: torch.Tensor, top_p: float = 0.9, eps: float = 1e-12) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Apply pure top-p (nucleus) filtering on a probability distribution.\n",
        "\n",
        "    Args:\n",
        "        probs: Tensor [..., vocab_size], already normalized (sum = 1)\n",
        "        top_p: cumulative probability cutoff\n",
        "        eps: numerical stability constant\n",
        "\n",
        "    Returns:\n",
        "        masked_probs: Tensor [..., vocab_size] with nucleus probs renormalized\n",
        "    \"\"\"\n",
        "\n",
        "    # If 1D, add batch dimension\n",
        "    if probs.dim() == 1:\n",
        "        probs = probs.unsqueeze(0)\n",
        "        squeeze_back = True\n",
        "    else:\n",
        "        squeeze_back = False\n",
        "\n",
        "    # Sort\n",
        "    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n",
        "    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)\n",
        "\n",
        "    # Build mask\n",
        "    sorted_indices_to_remove = cumulative_probs > top_p\n",
        "    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()\n",
        "    sorted_indices_to_remove[..., 0] = 0\n",
        "\n",
        "    # Scatter back\n",
        "    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)\n",
        "\n",
        "    # Mask\n",
        "    masked_probs = probs.masked_fill(indices_to_remove, 0.0)\n",
        "\n",
        "    # Renormalize\n",
        "    masked_probs = masked_probs / (masked_probs.sum(dim=-1, keepdim=True) + eps)\n",
        "\n",
        "    if squeeze_back:\n",
        "        masked_probs = masked_probs.squeeze(0)\n",
        "\n",
        "    return masked_probs\n",
        "\n",
        "\n",
        "def jensen_shannon_distance(p, q, eps=1e-12):\n",
        "    # do math in float32 for stability\n",
        "    p32 = p.float().clamp_min(eps)\n",
        "    q32 = q.float().clamp_min(eps)\n",
        "    m = 0.5 * (p32 + q32)\n",
        "\n",
        "    kl_pm = torch.sum(p32 * (torch.log(p32) - torch.log(m)), dim=-1)\n",
        "    kl_qm = torch.sum(q32 * (torch.log(q32) - torch.log(m)), dim=-1)\n",
        "    jsd = 0.5 * (kl_pm + kl_qm)\n",
        "    return torch.sqrt(jsd)   # stays in float32\n",
        "\n",
        "\n",
        "def total_variation_distance(p, q):\n",
        "    return 0.5 * torch.sum(torch.abs(p - q))\n",
        "\n",
        "\n",
        "def compute_distance(p_full, p_short, metric=\"jsd\"):\n",
        "    if metric == \"jsd\":\n",
        "        return jensen_shannon_distance(p_full, p_short)\n",
        "    elif metric == \"tvd\":\n",
        "        return total_variation_distance(p_full, p_short)\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown distance metric: {metric}\")\n",
        "\n",
        "\n",
        "# === Custom stopping ===\n",
        "# TODO: CHECK THIS ,\n",
        "class StopOnTokens(StoppingCriteria):\n",
        "    def __init__(self, tokenizer, stop_strs=[\"\\nQ:\", \"\\nA:\", \"\\n\"]):\n",
        "        self.tokenizer = tokenizer\n",
        "        self.stop_ids = [tokenizer.eos_token_id]\n",
        "        self.stop_strs = stop_strs\n",
        "\n",
        "    def __call__(self, input_ids, scores, **kwargs):\n",
        "        if input_ids[0, -1].item() in self.stop_ids:\n",
        "            return True\n",
        "        text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)\n",
        "        for s in self.stop_strs:\n",
        "            if text.endswith(s):\n",
        "                return True\n",
        "        return False\n",
        "\n",
        "def generate_with_longcontext_stepwise(model, tokenizer, context_ids,\n",
        "                                       max_new_tokens=64,\n",
        "                                       top_p=0.9,\n",
        "                                       num_return_sequences=5,\n",
        "                                       temperature=1.0,\n",
        "                                       log_topk=20):\n",
        "    \"\"\"\n",
        "    Long-context baseline: token-by-token generation with KV caching.\n",
        "    Uses same structure as boosting for fairness.\n",
        "    \"\"\"\n",
        "    all_sequences, all_logs = [], []\n",
        "    stop_checker = StopOnTokens(tokenizer)\n",
        "\n",
        "    for seq_i in range(num_return_sequences):\n",
        "        generated = []\n",
        "        step_logs = []\n",
        "\n",
        "        # === initial forward to build cache ===\n",
        "        with torch.no_grad():\n",
        "            out_full = model(context_ids, use_cache=True)\n",
        "            logits_full = out_full.logits\n",
        "            past_full = out_full.past_key_values\n",
        "\n",
        "        ctx = context_ids.clone()\n",
        "\n",
        "        for step in range(max_new_tokens):\n",
        "            # probs from last token\n",
        "            p_full = F.softmax(logits_full[0, -1] / temperature, dim=-1)\n",
        "\n",
        "            # 🔹 nucleus filtering with your function\n",
        "            probs_nuc = top_p_sampling(p_full, top_p=top_p)\n",
        "\n",
        "            # sample next token\n",
        "            next_token = torch.multinomial(probs_nuc, 1).view(1, 1).to(ctx.device)\n",
        "            next_prob = probs_nuc[next_token]\n",
        "\n",
        "            # log info\n",
        "            topk_probs, topk_ids = torch.topk(p_full, log_topk)\n",
        "            step_logs.append({\n",
        "                \"step\": step,\n",
        "                \"next_token\": int(next_token.item()),\n",
        "                \"decoded_next_token\": tokenizer.decode([next_token.item()]).strip(),\n",
        "                \"next_token_prob\": float(next_prob.item()),\n",
        "                \"topk\": {\n",
        "                    tokenizer.decode([tid.item()]): float(val.item())\n",
        "                    for tid, val in zip(topk_ids, topk_probs)\n",
        "                }\n",
        "            })\n",
        "\n",
        "            # === incremental forward with cache ===\n",
        "            with torch.no_grad():\n",
        "                out_full = model(next_token, past_key_values=past_full, use_cache=True)\n",
        "                logits_full = out_full.logits\n",
        "                past_full = out_full.past_key_values\n",
        "\n",
        "            # update context + stop check\n",
        "            ctx = torch.cat([ctx, next_token], dim=1)\n",
        "            generated.append(next_token.item())\n",
        "            if stop_checker(ctx, None):\n",
        "                break\n",
        "\n",
        "        all_sequences.append(tokenizer.decode(generated, skip_special_tokens=True).strip())\n",
        "        all_logs.append(step_logs)\n",
        "\n",
        "    return all_sequences, all_logs\n",
        "\n",
        "\n",
        "\n",
        "# === Case 3: short-context nucleus\n",
        "def generate_with_nucleus_short(model, tokenizer, context_ids,\n",
        "                                short_len=32,\n",
        "                                max_new_tokens=64,\n",
        "                                top_p=0.9,\n",
        "                                num_return_sequences=10,\n",
        "                                temperature=1.0,\n",
        "                                log_topk=20):\n",
        "    short_ids = context_ids[:, -short_len:]\n",
        "    stop_criteria = StoppingCriteriaList([StopOnTokens(tokenizer)])\n",
        "    outputs = model.generate(\n",
        "        short_ids,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        do_sample=True,\n",
        "        top_p=top_p,\n",
        "        num_return_sequences=num_return_sequences,\n",
        "        pad_token_id=tokenizer.eos_token_id,\n",
        "        eos_token_id=tokenizer.eos_token_id,\n",
        "        temperature=temperature,\n",
        "        stopping_criteria=stop_criteria,\n",
        "    )\n",
        "\n",
        "    sequences, logs = [], []\n",
        "    for out in outputs:\n",
        "        gen_ids = out[short_ids.shape[-1]:]\n",
        "        text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()\n",
        "        sequences.append(text)\n",
        "\n",
        "    return sequences, logs\n",
        "\n",
        "\n",
        "def generate_with_boosting(model, tokenizer, context_ids,\n",
        "                           short_len=32,\n",
        "                           metric=\"jsd\",\n",
        "                           longctx_detection_thresh=0.05,\n",
        "                           max_new_tokens=64,\n",
        "                           top_p=0.9,\n",
        "                           lambda_boost=2.0,\n",
        "                           log_topk=20,\n",
        "                           num_return_sequences=10,\n",
        "                           temperature=1.0,\n",
        "                           epsilon=0.05):\n",
        "    all_sequences, all_logs = [], []\n",
        "    stop_checker = StopOnTokens(tokenizer)\n",
        "\n",
        "    timers = {\n",
        "        \"forward\": 0.0,\n",
        "        \"softmax_nuc\": 0.0,\n",
        "        \"boosting\": 0.0,\n",
        "        \"sampling_topk\": 0.0,\n",
        "    }\n",
        "\n",
        "    t_global_start = time.time()\n",
        "    model.eval()\n",
        "\n",
        "    for seq_i in range(num_return_sequences):\n",
        "        generated = []\n",
        "        step_logs = []\n",
        "\n",
        "        # === initial forward (initialize caches) ===\n",
        "        t1 = time.time()\n",
        "        with torch.no_grad():\n",
        "            out_full = model(context_ids, use_cache=True)\n",
        "            logits_full = out_full.logits\n",
        "            past_full = out_full.past_key_values\n",
        "\n",
        "            ctx_short = context_ids[:, -short_len:]\n",
        "            out_short = model(ctx_short, use_cache=True)\n",
        "            logits_short = out_short.logits\n",
        "            past_short = out_short.past_key_values\n",
        "        timers[\"forward\"] += time.time() - t1\n",
        "\n",
        "\n",
        "        ctx = context_ids.clone()\n",
        "\n",
        "        for step in range(max_new_tokens):\n",
        "            # === probability computation ===\n",
        "            t2 = time.time()\n",
        "            p_full = F.softmax(logits_full[0, -1] / temperature, dim=-1)\n",
        "            p_short = F.softmax(logits_short[0, -1] / temperature, dim=-1)\n",
        "            p_full_nuc = top_p_sampling(p_full, top_p)\n",
        "            p_short_nuc = top_p_sampling(p_short, top_p)\n",
        "            timers[\"softmax_nuc\"] += time.time() - t2\n",
        "\n",
        "            # === divergence + boosting ===\n",
        "            t3 = time.time()\n",
        "            divergence = jensen_shannon_distance(p_full_nuc, p_short_nuc)\n",
        "            boosted_applied = divergence > longctx_detection_thresh\n",
        "\n",
        "            if boosted_applied:\n",
        "                diff = p_full_nuc - p_short_nuc\n",
        "                mask = diff > epsilon\n",
        "                if mask.any():\n",
        "                    probs = p_full.clone()\n",
        "                    probs[mask] *= lambda_boost\n",
        "                    probs /= probs.sum()\n",
        "                else:\n",
        "                    probs = p_full\n",
        "            else:\n",
        "                probs = p_full\n",
        "            timers[\"boosting\"] += time.time() - t3\n",
        "\n",
        "            # === sampling & top-k logging ===\n",
        "            t4 = time.time()\n",
        "            topk_probs_full, topk_indices_full = torch.topk(p_full_nuc, log_topk)\n",
        "            topk_probs_short, topk_indices_short = torch.topk(p_short_nuc, log_topk)\n",
        "\n",
        "            probs_nuc = top_p_sampling(probs, top_p)\n",
        "            topk_probs_boosted, topk_indices_boosted = torch.topk(probs_nuc, log_topk)\n",
        "            next_token = torch.multinomial(probs_nuc, 1).view(1, 1).to(ctx.device)\n",
        "            next_token_prob = probs_nuc[next_token].unsqueeze(0)\n",
        "            timers[\"sampling_topk\"] += time.time() - t4\n",
        "\n",
        "            # === step log with all details ===\n",
        "            step_logs.append({\n",
        "                \"step\": step,\n",
        "                \"divergence\": float(divergence.item()),\n",
        "                \"boosted_applied\": bool(boosted_applied),\n",
        "                \"next_token\": int(next_token.item()),\n",
        "                \"decoded_next_token\": tokenizer.decode([next_token.item()]).strip(),\n",
        "                \"next_token_prob\": float(next_token_prob.item()),\n",
        "                \"topk_full\": {\n",
        "                    tokenizer.decode([tid.item()]): float(val.item())\n",
        "                    for tid, val in zip(topk_indices_full, topk_probs_full)\n",
        "                },\n",
        "                \"topk_short\": {\n",
        "                    tokenizer.decode([tid.item()]): float(val.item())\n",
        "                    for tid, val in zip(topk_indices_short, topk_probs_short)\n",
        "                },\n",
        "                \"topk_boost\": (\n",
        "                    {tokenizer.decode([tid.item()]): float(val.item())\n",
        "                     for tid, val in zip(topk_indices_boosted, topk_probs_boosted)}\n",
        "                    if boosted_applied else None\n",
        "                ),\n",
        "                \"boosted_tokens\": (\n",
        "                    [tokenizer.decode([tid]).strip()\n",
        "                     for tid in torch.nonzero(diff > epsilon, as_tuple=True)[0][:log_topk].tolist()]\n",
        "                    if boosted_applied else []\n",
        "                )\n",
        "            })\n",
        "\n",
        "            # === update caches incrementally ===\n",
        "            t1 = time.time()\n",
        "            with torch.no_grad():\n",
        "                out_full = model(next_token, past_key_values=past_full, use_cache=True)\n",
        "                logits_full = out_full.logits\n",
        "                past_full = out_full.past_key_values\n",
        "\n",
        "                out_short = model(next_token, past_key_values=past_short, use_cache=True)\n",
        "                logits_short = out_short.logits\n",
        "                past_short = out_short.past_key_values\n",
        "            timers[\"forward\"] += time.time() - t1\n",
        "\n",
        "            # update context for stop_checker\n",
        "            ctx = torch.cat([ctx, next_token], dim=1)\n",
        "            generated.append(next_token.item())\n",
        "\n",
        "            if stop_checker(ctx, None):\n",
        "                break\n",
        "\n",
        "        all_sequences.append(tokenizer.decode(generated, skip_special_tokens=True).strip())\n",
        "        all_logs.append(step_logs)\n",
        "\n",
        "    total_time = time.time() - t_global_start\n",
        "    print(\"\\n===== Timing Report =====\")\n",
        "    for k, v in timers.items():\n",
        "        print(f\"{k:15s}: {v:.4f} sec ({100*v/total_time:.2f}% of total)\")\n",
        "    print(f\"{'TOTAL':15s}: {total_time:.4f} sec\\n\")\n",
        "\n",
        "    return all_sequences, all_logs\n",
        "\n",
        "\n",
        "\n",
        "def cad_combine_logits(logits_full, logits_short, alpha=0.5, temperature=1.0):\n",
        "    \"\"\"\n",
        "    CAD adjustment at the logit level.\n",
        "    Args:\n",
        "        logits_full: [vocab_size] logits from long context\n",
        "        logits_short: [vocab_size] logits from short context\n",
        "        alpha: contrastive strength (0 = no adjustment)\n",
        "        temperature: softmax temperature\n",
        "    \"\"\"\n",
        "    adjusted_logits = ((1 + alpha) * logits_full - alpha * logits_short) / temperature\n",
        "    probs_cad = F.softmax(adjusted_logits, dim=-1)\n",
        "    return probs_cad\n",
        "\n",
        "\n",
        "\n",
        "def generate_with_cad(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    context_ids,\n",
        "    question_ids=None,\n",
        "    max_new_tokens=64,\n",
        "    top_p=0.9,\n",
        "    num_return_sequences=5,\n",
        "    temperature=1.0,\n",
        "    log_topk=20,\n",
        "    alpha=0.5,\n",
        "):\n",
        "    \"\"\"\n",
        "    Context-Aware Decoding (CAD):\n",
        "    Contrast full-context vs. question-only context.\n",
        "    \"\"\"\n",
        "    all_sequences, all_logs = [], []\n",
        "    stop_checker = StopOnTokens(tokenizer)\n",
        "    model.eval()\n",
        "\n",
        "    for seq_i in range(num_return_sequences):\n",
        "        generated = []\n",
        "        step_logs = []\n",
        "\n",
        "        # === Initial forward passes ===\n",
        "        with torch.no_grad():\n",
        "            # full context forward\n",
        "            out_full = model(context_ids, use_cache=True)\n",
        "            logits_full = out_full.logits\n",
        "            past_full = out_full.past_key_values\n",
        "\n",
        "            # short context = just question\n",
        "            if question_ids is None:\n",
        "                raise ValueError(\"Must provide question_ids for CAD (question-only mode).\")\n",
        "            out_short = model(question_ids, use_cache=True)\n",
        "            logits_short = out_short.logits\n",
        "            past_short = out_short.past_key_values\n",
        "\n",
        "        ctx = context_ids.clone()\n",
        "\n",
        "        for step in range(max_new_tokens):\n",
        "            # === Current step logits ===\n",
        "            logits_full_step = logits_full[0, -1]\n",
        "            logits_short_step = logits_short[0, -1]\n",
        "\n",
        "            # === CAD adjustment ===\n",
        "            probs_cad = cad_combine_logits(\n",
        "                logits_full_step, logits_short_step,\n",
        "                alpha=alpha, temperature=temperature\n",
        "            )\n",
        "            probs_nuc = top_p_sampling(probs_cad, top_p=top_p)\n",
        "\n",
        "            # === Sample ===\n",
        "            next_token = torch.multinomial(probs_nuc, 1).view(1, 1).to(ctx.device)\n",
        "            next_prob = probs_nuc[next_token]\n",
        "\n",
        "            # === Logging ===\n",
        "            p_full = F.softmax(logits_full_step / temperature, dim=-1)\n",
        "            p_short = F.softmax(logits_short_step / temperature, dim=-1)\n",
        "\n",
        "            topk_probs_full, topk_ids_full = torch.topk(p_full, log_topk)\n",
        "            topk_probs_short, topk_ids_short = torch.topk(p_short, log_topk)\n",
        "            topk_probs_cad, topk_ids_cad = torch.topk(probs_cad, log_topk)\n",
        "\n",
        "            step_logs.append({\n",
        "                \"step\": step,\n",
        "                \"next_token\": int(next_token.item()),\n",
        "                \"decoded_next_token\": tokenizer.decode([next_token.item()]).strip(),\n",
        "                \"next_token_prob\": float(next_prob.item()),\n",
        "                \"topk_full\": {\n",
        "                    tokenizer.decode([tid.item()]): float(val.item())\n",
        "                    for tid, val in zip(topk_ids_full, topk_probs_full)\n",
        "                },\n",
        "                \"topk_short\": {\n",
        "                    tokenizer.decode([tid.item()]): float(val.item())\n",
        "                    for tid, val in zip(topk_ids_short, topk_probs_short)\n",
        "                },\n",
        "                \"topk_cad\": {\n",
        "                    tokenizer.decode([tid.item()]): float(val.item())\n",
        "                    for tid, val in zip(topk_ids_cad, topk_probs_cad)\n",
        "                },\n",
        "            })\n",
        "\n",
        "            # === Update caches ===\n",
        "            with torch.no_grad():\n",
        "                out_full = model(next_token, past_key_values=past_full, use_cache=True)\n",
        "                logits_full = out_full.logits\n",
        "                past_full = out_full.past_key_values\n",
        "\n",
        "                out_short = model(next_token, past_key_values=past_short, use_cache=True)\n",
        "                logits_short = out_short.logits\n",
        "                past_short = out_short.past_key_values\n",
        "\n",
        "            ctx = torch.cat([ctx, next_token], dim=1)\n",
        "            generated.append(next_token.item())\n",
        "\n",
        "            if stop_checker(ctx, None):\n",
        "                break\n",
        "\n",
        "        all_sequences.append(tokenizer.decode(generated, skip_special_tokens=True).strip())\n",
        "        all_logs.append(step_logs)\n",
        "\n",
        "    return all_sequences, all_logs\n",
        "\n",
        "\n",
        "\n",
        "def format_example_text(ex, tokenizer=None):\n",
        "    \"\"\"Format one example dict into a human-readable string block.\"\"\"\n",
        "    lines = []\n",
        "    lines.append(\"=\" * 80)\n",
        "    lines.append(f\"Index: {ex.get('index', ex.get('global_index', 'N/A'))}\")\n",
        "    lines.append(\"-\" * 80)\n",
        "    lines.append(\"Context:\")\n",
        "    lines.append(ex[\"context\"].strip())\n",
        "    lines.append(\"\")\n",
        "    lines.append(\"Question:\")\n",
        "    lines.append(ex[\"question\"].strip())\n",
        "    lines.append(\"\")\n",
        "    lines.append(\"Ground Truth Answer:\")\n",
        "    lines.append(ex[\"answer\"].strip())\n",
        "    lines.append(\"\")\n",
        "\n",
        "    # Case 1\n",
        "    lines.append(\"Case 1: Vanilla Nucleus Sampling\")\n",
        "    if isinstance(ex.get(\"generated_answers_case1\"), list):\n",
        "        for j, ans in enumerate(ex[\"generated_answers_case1\"], 1):\n",
        "            lines.append(f\"  {j}. {ans}\")\n",
        "    else:\n",
        "        lines.append(\"  \" + str(ex.get(\"generated_answers_case1\", \"\")))\n",
        "    lines.append(\"\")\n",
        "\n",
        "    # Case 2\n",
        "    lines.append(\"Case 2: Boosting (epsilon filter)\")\n",
        "    if isinstance(ex.get(\"generated_answer_case2\"), list):\n",
        "        for j, ans in enumerate(ex[\"generated_answer_case2\"], 1):\n",
        "            lines.append(f\"  {j}. {ans}\")\n",
        "    else:\n",
        "        lines.append(\"  \" + str(ex.get(\"generated_answer_case2\", \"\")))\n",
        "    lines.append(\"\")\n",
        "\n",
        "    # Case 3\n",
        "    lines.append(\"Case 3: Short-Context Nucleus Sampling\")\n",
        "    if isinstance(ex.get(\"generated_answers_case3\"), list):\n",
        "        for j, ans in enumerate(ex[\"generated_answers_case3\"], 1):\n",
        "            lines.append(f\"  {j}. {ans}\")\n",
        "    else:\n",
        "        lines.append(\"  \" + str(ex.get(\"generated_answers_case3\", \"\")))\n",
        "    lines.append(\"\")\n",
        "\n",
        "    # Case 4\n",
        "    lines.append(\"Case 4: Context-Aware Decoding (CAD)\")\n",
        "    if isinstance(ex.get(\"generated_answers_case4\"), list):\n",
        "        for j, ans in enumerate(ex[\"generated_answers_case4\"], 1):\n",
        "            lines.append(f\"  {j}. {ans}\")\n",
        "    else:\n",
        "        lines.append(\"  \" + str(ex.get(\"generated_answers_case4\", \"\")))\n",
        "    lines.append(\"\")\n",
        "\n",
        "    # Boosting logs (optional)\n",
        "    if \"boosting_logs\" in ex and ex[\"boosting_logs\"]:\n",
        "        lines.append(\"Boosting Logs (Case 2):\")\n",
        "        for seq_id, seq_logs in enumerate(ex[\"boosting_logs\"], 1):\n",
        "            lines.append(f\"  Sequence {seq_id}:\")\n",
        "            for step in seq_logs:\n",
        "                tok_id = step[\"next_token\"]\n",
        "                # If tokenizer is provided, decode token string\n",
        "                if tokenizer is not None:\n",
        "                    tok_str = tokenizer.convert_ids_to_tokens([tok_id])[0]\n",
        "                else:\n",
        "                    tok_str = str(tok_id)\n",
        "                lines.append(\n",
        "                    f\"    Step {step['step']}: \"\n",
        "                    f\"div={step['divergence']:.4f}, \"\n",
        "                    f\"boosted={step['boosted_applied']}, \"\n",
        "                    f\"token_id={tok_id}, token='{tok_str}'\"\n",
        "                )\n",
        "        lines.append(\"\")\n",
        "\n",
        "    # CAD logs (optional)\n",
        "    if \"cad_logs\" in ex and ex[\"cad_logs\"]:\n",
        "        lines.append(\"CAD Logs (Case 4):\")\n",
        "        for seq_id, seq_logs in enumerate(ex[\"cad_logs\"], 1):\n",
        "            lines.append(f\"  Sequence {seq_id}:\")\n",
        "            for step in seq_logs:\n",
        "                tok_id = step[\"next_token\"]\n",
        "                tok_str = tokenizer.convert_ids_to_tokens([tok_id])[0] if tokenizer else str(tok_id)\n",
        "                lines.append(\n",
        "                    f\"    Step {step['step']}: \"\n",
        "                    f\"token_id={tok_id}, token='{tok_str}', \"\n",
        "                    f\"prob={step['next_token_prob']:.4f}\"\n",
        "                )\n",
        "        lines.append(\"\")\n",
        "\n",
        "    return \"\\n\".join(lines) + \"\\n\\n\"\n",
        "\n",
        "# ============== Minimal run loop (no shards) ==============\n",
        "\n",
        "DATA_PATH = \"/content/filtered_narrativeqa_10k.jsonl\"\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
        "if tokenizer.pad_token is None:\n",
        "    tokenizer.pad_token = tokenizer.eos_token\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    MODEL_ID,\n",
        "    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\n",
        "    device_map=\"auto\"\n",
        ").eval()\n",
        "\n",
        "# --- load dataset (upload JSONL file in Colab) ---\n",
        "with open(DATA_PATH) as f:\n",
        "    data = [json.loads(line) for line in f]\n",
        "\n",
        "# just a small slice for demo\n",
        "data = data[:num_examples]\n",
        "\n",
        "\n",
        "OUT_PATH = os.path.join(OUT_BASE, \"results.jsonl\")\n",
        "OUT_TXT  = os.path.join(OUT_BASE, \"results.txt\")\n",
        "\n",
        "BATCH_SIZE = 20  # how many stories per flush/eval\n",
        "\n",
        "buffer_json, buffer_txt = [], []\n",
        "all_examples = []\n",
        "\n",
        "with open(OUT_PATH, \"w\") as fout_json, open(OUT_TXT, \"w\") as fout_txt:\n",
        "    for i, ex in enumerate(tqdm(data)):\n",
        "        context = ex[\"context\"].strip()\n",
        "        question = ex[\"question\"].strip()\n",
        "        answer = ex[\"answer\"].strip()\n",
        "        if not answer:\n",
        "            continue\n",
        "\n",
        "        prompt = context + \"\\n\\nQ: \" + question + \"\\nA:\"\n",
        "        input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.cuda()\n",
        "\n",
        "        # Case 1: long-context\n",
        "        generated_answers_case1, _ = generate_with_longcontext_stepwise(\n",
        "            model, tokenizer, input_ids.clone(),\n",
        "            max_new_tokens=MAX_NEW_TOKENS,\n",
        "            top_p=NUCLEUS_P,\n",
        "            num_return_sequences=NUM_SAMPLES,\n",
        "        )\n",
        "\n",
        "        # Case 2: boosting\n",
        "        generated_answer_case2, boosting_logs = generate_with_boosting(\n",
        "            model, tokenizer, input_ids.clone(),\n",
        "            short_len=SHORT_LEN,\n",
        "            metric=DISTANCE_METRIC,\n",
        "            longctx_detection_thresh=longctx_detection_thresh,\n",
        "            max_new_tokens=MAX_NEW_TOKENS,\n",
        "            top_p=NUCLEUS_P,\n",
        "            num_return_sequences=NUM_SAMPLES,\n",
        "            lambda_boost=lambda_boost,\n",
        "            epsilon=epsilon,\n",
        "        )\n",
        "\n",
        "        # Case 3: short-context\n",
        "        generated_answers_case3, _ = generate_with_nucleus_short(\n",
        "            model, tokenizer, input_ids.clone(),\n",
        "            short_len=SHORT_LEN,\n",
        "            max_new_tokens=MAX_NEW_TOKENS,\n",
        "            top_p=NUCLEUS_P,\n",
        "            num_return_sequences=NUM_SAMPLES,\n",
        "        )\n",
        "\n",
        "        # Case 4: CAD\n",
        "        prompt = context + \"\\n\\nQ: \" + question + \"\\nA:\"\n",
        "        input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.cuda()\n",
        "\n",
        "        question_only_prompt = question + \"\\nA:\"   # only Q for short context\n",
        "        question_ids = tokenizer(question_only_prompt, return_tensors=\"pt\").input_ids.cuda()\n",
        "\n",
        "        # Case 4: CAD (question-only)\n",
        "        generated_answers_case4, cad_logs = generate_with_cad(\n",
        "            model, tokenizer, input_ids.clone(),\n",
        "            question_ids=question_ids,\n",
        "            max_new_tokens=MAX_NEW_TOKENS,\n",
        "            top_p=NUCLEUS_P,\n",
        "            num_return_sequences=NUM_SAMPLES,\n",
        "            alpha=0.5,\n",
        "        )\n",
        "\n",
        "        # Save outputs\n",
        "        example_out = {\n",
        "            \"index\": i,\n",
        "            \"question\": question,\n",
        "            \"context\": context,\n",
        "            \"answer\": answer,\n",
        "            \"generated_answers_case1\": generated_answers_case1,\n",
        "            \"generated_answer_case2\": generated_answer_case2,\n",
        "            \"generated_answers_case3\": generated_answers_case3,\n",
        "            \"generated_answers_case4\": generated_answers_case4,\n",
        "            \"boosting_logs\": boosting_logs,\n",
        "            \"cad_logs\": cad_logs,\n",
        "        }\n",
        "\n",
        "        buffer_json.append(json.dumps(example_out, ensure_ascii=False))\n",
        "        buffer_txt.append(format_example_text(example_out))\n",
        "        all_examples.append(json.dumps(example_out, ensure_ascii=False))\n",
        "\n",
        "        # === Flush once every BATCH_SIZE examples ===\n",
        "        if (i + 1) % BATCH_SIZE == 0:\n",
        "            fout_json.write(\"\\n\".join(buffer_json) + \"\\n\")\n",
        "            fout_json.flush()\n",
        "            buffer_json = []\n",
        "\n",
        "            fout_txt.write(\"\".join(buffer_txt))\n",
        "            fout_txt.flush()\n",
        "            buffer_txt = []\n",
        "\n",
        "            print(f\"💾 Saved {i+1} examples so far\")\n",
        "\n",
        "            evaluate_partial(all_examples)\n",
        "\n",
        "    # final flush if leftovers\n",
        "    if buffer_json:\n",
        "        fout_json.write(\"\\n\".join(buffer_json) + \"\\n\")\n",
        "    if buffer_txt:\n",
        "        fout_txt.write(\"\".join(buffer_txt))\n",
        "\n",
        "print(f\"✅ Finished, results at {OUT_PATH} and {OUT_TXT}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zbe9HQqyQFab"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "05400a2fbf074541a4496a22a6fa5018": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_09c2f611bb074dceb309a832fee6bb64",
            "max": 4,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b864b94bd86e42fcb9d7d654d324fd69",
            "value": 4
          }
        },
        "09c2f611bb074dceb309a832fee6bb64": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1e5cf8eeb8a84f4a92483fc4030436c6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2162fc48cda6465aba6bd2f9abed1ae9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_31362b8b55a748ce9c7580b14df5a71e",
            "placeholder": "​",
            "style": "IPY_MODEL_1e5cf8eeb8a84f4a92483fc4030436c6",
            "value": " 4/4 [00:05&lt;00:00,  1.10s/it]"
          }
        },
        "2af4884b51a64f9bae66fcc72c308cb8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "31362b8b55a748ce9c7580b14df5a71e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7c98acfb10cc4f0e913a9dd7d609cb20": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b1039b4a4fe44165a4197d4b9f9a9f03",
              "IPY_MODEL_05400a2fbf074541a4496a22a6fa5018",
              "IPY_MODEL_2162fc48cda6465aba6bd2f9abed1ae9"
            ],
            "layout": "IPY_MODEL_9ac7ab1c499e46218731fc237fa45933"
          }
        },
        "9ac7ab1c499e46218731fc237fa45933": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b1039b4a4fe44165a4197d4b9f9a9f03": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d5ff5750629740b4bf0b2d7d7e2fc6f7",
            "placeholder": "​",
            "style": "IPY_MODEL_2af4884b51a64f9bae66fcc72c308cb8",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "b864b94bd86e42fcb9d7d654d324fd69": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d5ff5750629740b4bf0b2d7d7e2fc6f7": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
