{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "5a4d692a01d94f7cab1511d5eec1abce": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_047397799c5d4881bac4c4a3b9e12ca0",
              "IPY_MODEL_f55b8944a3664794a4ac96677698d4bd",
              "IPY_MODEL_72520679538e4e6a8c371044f5069b7f"
            ],
            "layout": "IPY_MODEL_2aa99c87bd7f42cf99c7327f467a04c8"
          }
        },
        "047397799c5d4881bac4c4a3b9e12ca0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8a3aeabc2c0a484f90e4fbd16f99d39e",
            "placeholder": "​",
            "style": "IPY_MODEL_d2886700adc3466a8021298a0bec98b4",
            "value": "Generating train split: "
          }
        },
        "f55b8944a3664794a4ac96677698d4bd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b1722adf42a64d119158a6bdb9a253ca",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_810346fa9f914a668527dcbae242bc31",
            "value": 1
          }
        },
        "72520679538e4e6a8c371044f5069b7f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d35c8757f7bf403f9f17517e048676a9",
            "placeholder": "​",
            "style": "IPY_MODEL_8082e3b040614b92b101d07ecddd1976",
            "value": " 450/0 [00:00&lt;00:00, 19320.28 examples/s]"
          }
        },
        "2aa99c87bd7f42cf99c7327f467a04c8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8a3aeabc2c0a484f90e4fbd16f99d39e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d2886700adc3466a8021298a0bec98b4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b1722adf42a64d119158a6bdb9a253ca": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "810346fa9f914a668527dcbae242bc31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d35c8757f7bf403f9f17517e048676a9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8082e3b040614b92b101d07ecddd1976": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "bfddb7b6a72449fb851e89f7d920ba75": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d734c53e1e3d4374a257ffae53840465",
              "IPY_MODEL_3ba15572efdf4cd596536135c28788a4",
              "IPY_MODEL_a17f02f12d2c4d55917c7d928bc5e887"
            ],
            "layout": "IPY_MODEL_362c14c3b92b42d9a2e07820a26b4fd5"
          }
        },
        "d734c53e1e3d4374a257ffae53840465": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_82d8856b6fe64e098cf9b159bc7abba2",
            "placeholder": "​",
            "style": "IPY_MODEL_0a127db705304452a3445492f99c7163",
            "value": "Generating valid split: "
          }
        },
        "3ba15572efdf4cd596536135c28788a4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8f62f575213b4f6983c68169cdbf84e9",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_954ccd098ac9410faa80643063db93d6",
            "value": 1
          }
        },
        "a17f02f12d2c4d55917c7d928bc5e887": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d1f70b0c0cd742cdad59097e7f670a49",
            "placeholder": "​",
            "style": "IPY_MODEL_5856e585b81f4c0ea4ed1ed5f526afa5",
            "value": " 50/0 [00:00&lt;00:00, 4521.67 examples/s]"
          }
        },
        "362c14c3b92b42d9a2e07820a26b4fd5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "82d8856b6fe64e098cf9b159bc7abba2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0a127db705304452a3445492f99c7163": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8f62f575213b4f6983c68169cdbf84e9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "954ccd098ac9410faa80643063db93d6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d1f70b0c0cd742cdad59097e7f670a49": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5856e585b81f4c0ea4ed1ed5f526afa5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "51a75fd6d0934b91959e253c29c719fa": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_846089535c934ec0b9a95b8396e08db7",
              "IPY_MODEL_d05dd9d70be04d37a4be5c77ae67f31e",
              "IPY_MODEL_5db5adda0f534191b979c7eebfd05bc7"
            ],
            "layout": "IPY_MODEL_ecdc1cdfb64e47aa89042357d135e43d"
          }
        },
        "846089535c934ec0b9a95b8396e08db7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_51a12a24c340480a97fe76c5778ac3ba",
            "placeholder": "​",
            "style": "IPY_MODEL_d243ff91c9094c7a9eda97c76c4ec5ca",
            "value": "tokenizer_config.json: 100%"
          }
        },
        "d05dd9d70be04d37a4be5c77ae67f31e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4752f090513e4bc7a9d26cdc3b870192",
            "max": 26,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_9bad7b27e6a44363bfe6a14ea5879614",
            "value": 26
          }
        },
        "5db5adda0f534191b979c7eebfd05bc7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9c99bf3220bb4bda8bcb1090256213e4",
            "placeholder": "​",
            "style": "IPY_MODEL_918f162f351d4a8da318532450686aa7",
            "value": " 26.0/26.0 [00:00&lt;00:00, 3.36kB/s]"
          }
        },
        "ecdc1cdfb64e47aa89042357d135e43d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "51a12a24c340480a97fe76c5778ac3ba": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d243ff91c9094c7a9eda97c76c4ec5ca": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4752f090513e4bc7a9d26cdc3b870192": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9bad7b27e6a44363bfe6a14ea5879614": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "9c99bf3220bb4bda8bcb1090256213e4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "918f162f351d4a8da318532450686aa7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1b0834cc60af40dc94d0757626c1d229": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8630ff9993474c008c0a7dc6bbfdc62b",
              "IPY_MODEL_4c4960aa2c9a428ab4ff9489d287e7eb",
              "IPY_MODEL_062532ef0378428a97d562e6660ac750"
            ],
            "layout": "IPY_MODEL_455c78c887894d19864bab0bb4e9d898"
          }
        },
        "8630ff9993474c008c0a7dc6bbfdc62b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_707b9b0d1c3a44b4a29f311f2522283a",
            "placeholder": "​",
            "style": "IPY_MODEL_3eb28343a2e64de9b53edc959f470357",
            "value": "config.json: 100%"
          }
        },
        "4c4960aa2c9a428ab4ff9489d287e7eb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bb701c7575a04c578b5940e75b38e881",
            "max": 665,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_695ea302e91f44089a4d410704af00cf",
            "value": 665
          }
        },
        "062532ef0378428a97d562e6660ac750": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c8eb6d335baf429b94b11191d248e3b5",
            "placeholder": "​",
            "style": "IPY_MODEL_15a51ec133734a78a95ad40cf0502226",
            "value": " 665/665 [00:00&lt;00:00, 87.0kB/s]"
          }
        },
        "455c78c887894d19864bab0bb4e9d898": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "707b9b0d1c3a44b4a29f311f2522283a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3eb28343a2e64de9b53edc959f470357": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "bb701c7575a04c578b5940e75b38e881": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "695ea302e91f44089a4d410704af00cf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c8eb6d335baf429b94b11191d248e3b5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "15a51ec133734a78a95ad40cf0502226": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "65449b3640104062a89ce4b6e8b4c3a4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_53d3d2da1cab45debfe909b386638b31",
              "IPY_MODEL_41a35a6af6af4b8fb1d7bcfd6072f01d",
              "IPY_MODEL_9fdaa541bf4a469394188c7644b74ad6"
            ],
            "layout": "IPY_MODEL_398fe1a10f1947dea1af1b215d255c96"
          }
        },
        "53d3d2da1cab45debfe909b386638b31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_98275add997c4357b2d1551516c04e9d",
            "placeholder": "​",
            "style": "IPY_MODEL_b79781a5c0554b29860910fde297355c",
            "value": "vocab.json: 100%"
          }
        },
        "41a35a6af6af4b8fb1d7bcfd6072f01d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_051005a3267a45599a9a0dcacdde540e",
            "max": 1042301,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b87845b2213745ef93503617bdc1c1c4",
            "value": 1042301
          }
        },
        "9fdaa541bf4a469394188c7644b74ad6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ac6f06611bcf4c25ab9c53d8f3d8c864",
            "placeholder": "​",
            "style": "IPY_MODEL_22c873a1e9014e4089882952f1e295d9",
            "value": " 1.04M/1.04M [00:00&lt;00:00, 6.53MB/s]"
          }
        },
        "398fe1a10f1947dea1af1b215d255c96": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "98275add997c4357b2d1551516c04e9d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b79781a5c0554b29860910fde297355c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "051005a3267a45599a9a0dcacdde540e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b87845b2213745ef93503617bdc1c1c4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ac6f06611bcf4c25ab9c53d8f3d8c864": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "22c873a1e9014e4089882952f1e295d9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a457850460284d6881cde372fcc92cc3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7baa8024ac31428da1c859534db313e7",
              "IPY_MODEL_5237f210856e40628af19cfa04669a63",
              "IPY_MODEL_5f7b8bf1f75f4ad68f86116f417c4c6b"
            ],
            "layout": "IPY_MODEL_002b709d7d1241a48e3816a700c4269c"
          }
        },
        "7baa8024ac31428da1c859534db313e7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5792e92d475244f4a5486b7679e628c5",
            "placeholder": "​",
            "style": "IPY_MODEL_389c9dc8a0b7485a94c81c77ffec85d3",
            "value": "merges.txt: 100%"
          }
        },
        "5237f210856e40628af19cfa04669a63": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4cf7d429d27349d293ce98b6a5c6bc23",
            "max": 456318,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_43b9d89923134e3689b845ebfb07f00b",
            "value": 456318
          }
        },
        "5f7b8bf1f75f4ad68f86116f417c4c6b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1b97176337744550a48d7dbb0043159a",
            "placeholder": "​",
            "style": "IPY_MODEL_ba70d4ade124424286868abe8ae49b11",
            "value": " 456k/456k [00:00&lt;00:00, 27.0MB/s]"
          }
        },
        "002b709d7d1241a48e3816a700c4269c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5792e92d475244f4a5486b7679e628c5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "389c9dc8a0b7485a94c81c77ffec85d3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4cf7d429d27349d293ce98b6a5c6bc23": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "43b9d89923134e3689b845ebfb07f00b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "1b97176337744550a48d7dbb0043159a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ba70d4ade124424286868abe8ae49b11": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4fa4876f1a054e41a9256f3d2b44a57a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_926dbe61ce064dcc8652a094faa6eec4",
              "IPY_MODEL_5991dd46a7a0410ea5644574fe8b4caa",
              "IPY_MODEL_58937c09c63c4c1ab35dac561cfedcdf"
            ],
            "layout": "IPY_MODEL_9aec9101c31640b2873098bb28021b21"
          }
        },
        "926dbe61ce064dcc8652a094faa6eec4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6852f79d3d3e4c76a5b5bcff2f50c76b",
            "placeholder": "​",
            "style": "IPY_MODEL_3feb1c9cd86a44e488009db1cf999bda",
            "value": "tokenizer.json: 100%"
          }
        },
        "5991dd46a7a0410ea5644574fe8b4caa": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c72b5e79d6c440e6875ffb5dcb04b52b",
            "max": 1355256,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_73e96e2be3ea4666ab58e9c094c3643c",
            "value": 1355256
          }
        },
        "58937c09c63c4c1ab35dac561cfedcdf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b5bec25933f845d38b3538b45229dfea",
            "placeholder": "​",
            "style": "IPY_MODEL_c3ae153a277e4196b30a35de8bcf4729",
            "value": " 1.36M/1.36M [00:00&lt;00:00, 33.6MB/s]"
          }
        },
        "9aec9101c31640b2873098bb28021b21": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6852f79d3d3e4c76a5b5bcff2f50c76b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3feb1c9cd86a44e488009db1cf999bda": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c72b5e79d6c440e6875ffb5dcb04b52b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "73e96e2be3ea4666ab58e9c094c3643c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "b5bec25933f845d38b3538b45229dfea": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c3ae153a277e4196b30a35de8bcf4729": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1bc69149c27146b79156d03073360065": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_faae8a18382b48488f8547a399c15f04",
              "IPY_MODEL_111332dfcd02423b896554e7d349aa5e",
              "IPY_MODEL_c982b34f368b42fb8e7428327c545485"
            ],
            "layout": "IPY_MODEL_be77fc5574ca4f6fbc7ce8c8b7154307"
          }
        },
        "faae8a18382b48488f8547a399c15f04": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9f271481f88a4d75b87075cc4f2b1ee9",
            "placeholder": "​",
            "style": "IPY_MODEL_86aef75c5dc446c4aeacafbbaa339f44",
            "value": "Map: 100%"
          }
        },
        "111332dfcd02423b896554e7d349aa5e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7817be8594404c6cb633b7ce5ea3e3e7",
            "max": 450,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_d94138f285ee4e4ba017ae69d731cd62",
            "value": 450
          }
        },
        "c982b34f368b42fb8e7428327c545485": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a6083d2162cf404fa29893982fb2001f",
            "placeholder": "​",
            "style": "IPY_MODEL_48669efbd9f64e15825d68c2e23b470c",
            "value": " 450/450 [00:00&lt;00:00, 2036.85 examples/s]"
          }
        },
        "be77fc5574ca4f6fbc7ce8c8b7154307": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9f271481f88a4d75b87075cc4f2b1ee9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "86aef75c5dc446c4aeacafbbaa339f44": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7817be8594404c6cb633b7ce5ea3e3e7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d94138f285ee4e4ba017ae69d731cd62": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "a6083d2162cf404fa29893982fb2001f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "48669efbd9f64e15825d68c2e23b470c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7f86942e576546c4a23c866366d0bed8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_23286dddb7f04ceb943a04ab7f8175d7",
              "IPY_MODEL_0565533f89b3411494e7a553277a6199",
              "IPY_MODEL_a811f2fb0d1e4d88bab1dd6e900612aa"
            ],
            "layout": "IPY_MODEL_d3d10e665cbf42d998b4412f688873ca"
          }
        },
        "23286dddb7f04ceb943a04ab7f8175d7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_34662a3d000b445f8dab2f5d0d0f113c",
            "placeholder": "​",
            "style": "IPY_MODEL_a80350d153cf4adfb387481a2d977be5",
            "value": "Map: 100%"
          }
        },
        "0565533f89b3411494e7a553277a6199": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0e4add2bff444711995482289b2d95ba",
            "max": 50,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_4e1abdd046a3424b80eec140156576f8",
            "value": 50
          }
        },
        "a811f2fb0d1e4d88bab1dd6e900612aa": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_210127ed6fbe454bbb9e3835f85de443",
            "placeholder": "​",
            "style": "IPY_MODEL_941ebfae2f8e4058894d19fb90a8df92",
            "value": " 50/50 [00:00&lt;00:00, 1294.46 examples/s]"
          }
        },
        "d3d10e665cbf42d998b4412f688873ca": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "34662a3d000b445f8dab2f5d0d0f113c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a80350d153cf4adfb387481a2d977be5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "0e4add2bff444711995482289b2d95ba": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4e1abdd046a3424b80eec140156576f8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "210127ed6fbe454bbb9e3835f85de443": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "941ebfae2f8e4058894d19fb90a8df92": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "381548db6ead42459da97840b0f54b16": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_6d9cbb4c99a744848496e63629aeb3e8",
              "IPY_MODEL_1b6bba6e5b4a4512a50a8c4ab3d415ac",
              "IPY_MODEL_2198c3ba0d5d49ccbdd187921fafa798"
            ],
            "layout": "IPY_MODEL_4c09e5a01450403f81df3819120d88fe"
          }
        },
        "6d9cbb4c99a744848496e63629aeb3e8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ed997e3ffd3b40e887636a9e95b903ce",
            "placeholder": "​",
            "style": "IPY_MODEL_7fe9282b7fbd46e8a1faa7d98d7ad1c0",
            "value": "Map: 100%"
          }
        },
        "1b6bba6e5b4a4512a50a8c4ab3d415ac": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_959cf7589e19440bb78d56fa7d1def8d",
            "max": 900,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_7d24ab3282c14a44be2be7cde3fd56c8",
            "value": 900
          }
        },
        "2198c3ba0d5d49ccbdd187921fafa798": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4ead520c7df344918a990cf6ba8c2f3f",
            "placeholder": "​",
            "style": "IPY_MODEL_5239692c65e440a39cdf748a1f60b667",
            "value": " 900/900 [00:00&lt;00:00, 3221.51 examples/s]"
          }
        },
        "4c09e5a01450403f81df3819120d88fe": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ed997e3ffd3b40e887636a9e95b903ce": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7fe9282b7fbd46e8a1faa7d98d7ad1c0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "959cf7589e19440bb78d56fa7d1def8d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7d24ab3282c14a44be2be7cde3fd56c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "4ead520c7df344918a990cf6ba8c2f3f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5239692c65e440a39cdf748a1f60b667": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b48ed565cfc04d0ba733668b130d0fb0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_78a886fe6b284c879d71a2afe43bd91a",
              "IPY_MODEL_3c2232f7f7e3460aa3c6e99785c27a57",
              "IPY_MODEL_8ca45b1f685d45909246b85832121dba"
            ],
            "layout": "IPY_MODEL_3ecb1e39af094fda937efe6de5de7993"
          }
        },
        "78a886fe6b284c879d71a2afe43bd91a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6da09df24e2a45018415b588743b4d3d",
            "placeholder": "​",
            "style": "IPY_MODEL_46f6bc4c75ab431580f94bb94def62e2",
            "value": "Map: 100%"
          }
        },
        "3c2232f7f7e3460aa3c6e99785c27a57": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fb4f4d74ce6d4eb793c699ae8aa0d2ff",
            "max": 100,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_ee26f5d04c0546a1ae41b7b2205a6461",
            "value": 100
          }
        },
        "8ca45b1f685d45909246b85832121dba": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_27cb0bb0a119498392f632f07ac1d7e9",
            "placeholder": "​",
            "style": "IPY_MODEL_247c9d8367b5478cbbbf11155432d073",
            "value": " 100/100 [00:00&lt;00:00, 2262.76 examples/s]"
          }
        },
        "3ecb1e39af094fda937efe6de5de7993": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6da09df24e2a45018415b588743b4d3d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "46f6bc4c75ab431580f94bb94def62e2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "fb4f4d74ce6d4eb793c699ae8aa0d2ff": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ee26f5d04c0546a1ae41b7b2205a6461": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "27cb0bb0a119498392f632f07ac1d7e9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "247c9d8367b5478cbbbf11155432d073": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "09b0bebc79584050b8931b15644fb325": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_be1d4b1ed1844eab8fbaed64b20779ad",
              "IPY_MODEL_4370d56bbc9e4515bcdd108619704277",
              "IPY_MODEL_7ea66de875d84742b5c76f8504f944f9"
            ],
            "layout": "IPY_MODEL_6532a2daf0b44383903f73251a9a6e88"
          }
        },
        "be1d4b1ed1844eab8fbaed64b20779ad": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7a2c4d23243a4342b000df4532afaae3",
            "placeholder": "​",
            "style": "IPY_MODEL_c72185240b6342bf8c0c531436505c9f",
            "value": "Tokenizing: 100%"
          }
        },
        "4370d56bbc9e4515bcdd108619704277": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_77146b17f5774806a96bf8cdb457f121",
            "max": 900,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_55a126d40c754a9aa36788397f8ce6e9",
            "value": 900
          }
        },
        "7ea66de875d84742b5c76f8504f944f9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_245ba20912424acd8323e3661b88394c",
            "placeholder": "​",
            "style": "IPY_MODEL_75d6f9a9f57d4881815ddfbf6e1dfb41",
            "value": " 900/900 [00:00&lt;00:00, 3346.63 examples/s]"
          }
        },
        "6532a2daf0b44383903f73251a9a6e88": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7a2c4d23243a4342b000df4532afaae3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c72185240b6342bf8c0c531436505c9f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "77146b17f5774806a96bf8cdb457f121": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "55a126d40c754a9aa36788397f8ce6e9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "245ba20912424acd8323e3661b88394c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "75d6f9a9f57d4881815ddfbf6e1dfb41": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ee69a8d253c748358186b1f26d29ec4e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_38ab58a277eb43958f22ca0cee740132",
              "IPY_MODEL_8f272570cea44afa9d64581f707ba4b0",
              "IPY_MODEL_94777907c6144240897752d958245853"
            ],
            "layout": "IPY_MODEL_59e69d8f54c44a15a249b62b76bb3bcb"
          }
        },
        "38ab58a277eb43958f22ca0cee740132": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_23ef84f47b874096853669d2a4c4947c",
            "placeholder": "​",
            "style": "IPY_MODEL_b4204dd627bb4731ab9afe25ff5edd00",
            "value": "Tokenizing: 100%"
          }
        },
        "8f272570cea44afa9d64581f707ba4b0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d308b58195264256a9bf7d66afc845de",
            "max": 100,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_4d55674c08994337969e10622cd70119",
            "value": 100
          }
        },
        "94777907c6144240897752d958245853": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c7f99a52d224416fae28369cbb0700b0",
            "placeholder": "​",
            "style": "IPY_MODEL_5bd14312cc6a4b5f87932bbfc9198ac5",
            "value": " 100/100 [00:00&lt;00:00, 2354.41 examples/s]"
          }
        },
        "59e69d8f54c44a15a249b62b76bb3bcb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "23ef84f47b874096853669d2a4c4947c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b4204dd627bb4731ab9afe25ff5edd00": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d308b58195264256a9bf7d66afc845de": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4d55674c08994337969e10622cd70119": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c7f99a52d224416fae28369cbb0700b0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5bd14312cc6a4b5f87932bbfc9198ac5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9757fbbfc29446e4a33e2c05d0d3e2ea": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_a76ddf3f73104ce9b64ba50364638e3e",
              "IPY_MODEL_3458ec519f264a3fa509ebf0c74fbd31",
              "IPY_MODEL_37bd6882e68f4434b7d3c37e4ea9d5d1"
            ],
            "layout": "IPY_MODEL_6804429f104d4975be4a6b5d062173a1"
          }
        },
        "a76ddf3f73104ce9b64ba50364638e3e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ac616871665d4119a33cecb55c044973",
            "placeholder": "​",
            "style": "IPY_MODEL_4f8b7e2b921f4b87a985dc08ac8973e6",
            "value": "model.safetensors: 100%"
          }
        },
        "3458ec519f264a3fa509ebf0c74fbd31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b109f435b07645438eeae722496191b2",
            "max": 548105171,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_5bf1c13096b44efb887c3e49020f7cb1",
            "value": 548105171
          }
        },
        "37bd6882e68f4434b7d3c37e4ea9d5d1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0a30ac4f70a24e0387ec4854e1674433",
            "placeholder": "​",
            "style": "IPY_MODEL_c3170eaa6d6c41469ea75551da32b819",
            "value": " 548M/548M [00:02&lt;00:00, 429MB/s]"
          }
        },
        "6804429f104d4975be4a6b5d062173a1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ac616871665d4119a33cecb55c044973": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4f8b7e2b921f4b87a985dc08ac8973e6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b109f435b07645438eeae722496191b2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5bf1c13096b44efb887c3e49020f7cb1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "0a30ac4f70a24e0387ec4854e1674433": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c3170eaa6d6c41469ea75551da32b819": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "357de5f04aae458983354e5cab166abb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_4e51527b545b4289aea7529b54aa26ed",
              "IPY_MODEL_5625c13acd974a9b8246944ed77ea0e1",
              "IPY_MODEL_918091e6c864437fb593f39fff1c8adb"
            ],
            "layout": "IPY_MODEL_aefbb5c5c85a44fd97cba63a46554298"
          }
        },
        "4e51527b545b4289aea7529b54aa26ed": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d145dcac47a44e629aa26bb760bcf2dd",
            "placeholder": "​",
            "style": "IPY_MODEL_fad7e8b0231f42f88039697b1f69424f",
            "value": "generation_config.json: 100%"
          }
        },
        "5625c13acd974a9b8246944ed77ea0e1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7fbafffee20d4d388d2a8d499dcdff7f",
            "max": 124,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_830b421e5e8e4ae39cabe24c721101e8",
            "value": 124
          }
        },
        "918091e6c864437fb593f39fff1c8adb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_59344bac75424a67ba5b615ba429a9ee",
            "placeholder": "​",
            "style": "IPY_MODEL_345008c02dbd4d3c8a3cad943a379681",
            "value": " 124/124 [00:00&lt;00:00, 14.9kB/s]"
          }
        },
        "aefbb5c5c85a44fd97cba63a46554298": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d145dcac47a44e629aa26bb760bcf2dd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fad7e8b0231f42f88039697b1f69424f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7fbafffee20d4d388d2a8d499dcdff7f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "830b421e5e8e4ae39cabe24c721101e8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "59344bac75424a67ba5b615ba429a9ee": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "345008c02dbd4d3c8a3cad943a379681": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4ada5cf2751b4e1aaaa8c295fa02e3a2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_04aa7dcd8b7a434eb78c1a772c6d37b9",
              "IPY_MODEL_1a5805a19b814b36a0108ecb27efc193",
              "IPY_MODEL_ed2b04ccc8d74be3a4638c23a1ac11cb"
            ],
            "layout": "IPY_MODEL_e6f9246369754f3eac04e470750db8bf"
          }
        },
        "04aa7dcd8b7a434eb78c1a772c6d37b9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1e9b8a248f484015a53dacfae3de00d5",
            "placeholder": "​",
            "style": "IPY_MODEL_37fd580abb0045f58898f28db1d6e1a1",
            "value": "gen tuple: 100%"
          }
        },
        "1a5805a19b814b36a0108ecb27efc193": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5d2f6db2b87a4364b89004ed2681a303",
            "max": 300,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_dd839e0c62ab4700a38b48345f3c1901",
            "value": 300
          }
        },
        "ed2b04ccc8d74be3a4638c23a1ac11cb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_62caee6d28c6472bb387f8b665965d1c",
            "placeholder": "​",
            "style": "IPY_MODEL_85aa0e551b63437aa19c812bdc735c32",
            "value": " 300/300 [01:39&lt;00:00,  3.60it/s]"
          }
        },
        "e6f9246369754f3eac04e470750db8bf": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1e9b8a248f484015a53dacfae3de00d5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "37fd580abb0045f58898f28db1d6e1a1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "5d2f6db2b87a4364b89004ed2681a303": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "dd839e0c62ab4700a38b48345f3c1901": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "62caee6d28c6472bb387f8b665965d1c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "85aa0e551b63437aa19c812bdc735c32": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d06e2e74cf2946ee9860872c85e5d737": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8d87b2473ffd4974895d04a9ecee4345",
              "IPY_MODEL_5384c9886ba349f18019c576e82ed9ff",
              "IPY_MODEL_5f0542e612fd4d20a659c4eefeab93d4"
            ],
            "layout": "IPY_MODEL_29f442e792284003ae4ce244a2201be9"
          }
        },
        "8d87b2473ffd4974895d04a9ecee4345": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f4dc4538dbdb48ebb6b268f7dcd58c2e",
            "placeholder": "​",
            "style": "IPY_MODEL_d6b1368ec8c4493581d8f0964cd5b45a",
            "value": "gen tuple: 100%"
          }
        },
        "5384c9886ba349f18019c576e82ed9ff": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f9b463b4e189476793ff720d97bc262f",
            "max": 300,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c76e72dccf7d438eb97ce21dda6d713a",
            "value": 300
          }
        },
        "5f0542e612fd4d20a659c4eefeab93d4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fef03dc1317b4fd7b4157793ffc867b7",
            "placeholder": "​",
            "style": "IPY_MODEL_ce3762e4bd9f46708aa8a523fb9d0a60",
            "value": " 300/300 [01:37&lt;00:00,  2.93it/s]"
          }
        },
        "29f442e792284003ae4ce244a2201be9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f4dc4538dbdb48ebb6b268f7dcd58c2e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d6b1368ec8c4493581d8f0964cd5b45a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f9b463b4e189476793ff720d97bc262f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c76e72dccf7d438eb97ce21dda6d713a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "fef03dc1317b4fd7b4157793ffc867b7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ce3762e4bd9f46708aa8a523fb9d0a60": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e67da516334d4af789088a61c88eccbe": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_c13e1806f18f4f54944c5329e65a7d15",
              "IPY_MODEL_51254f50d31f48808fbaac817dee1797",
              "IPY_MODEL_61742ecb4c30442a8c25462a05c4fc26"
            ],
            "layout": "IPY_MODEL_fb51927b3c0644cb89a632350069569c"
          }
        },
        "c13e1806f18f4f54944c5329e65a7d15": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_eb9afa64f365483c86227d637b6d2a90",
            "placeholder": "​",
            "style": "IPY_MODEL_d422aeb5b2984f939e617519d0272c9b",
            "value": "gen tuple: 100%"
          }
        },
        "51254f50d31f48808fbaac817dee1797": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9ed96370923a409680d0274271a1357e",
            "max": 300,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_c3da3445e914457e971838f5b2770735",
            "value": 300
          }
        },
        "61742ecb4c30442a8c25462a05c4fc26": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_34734887f09f4f3983142aea0d21ba2f",
            "placeholder": "​",
            "style": "IPY_MODEL_0f5b7b905bd44619b0c434da9310460e",
            "value": " 300/300 [03:08&lt;00:00,  1.82it/s]"
          }
        },
        "fb51927b3c0644cb89a632350069569c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "eb9afa64f365483c86227d637b6d2a90": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d422aeb5b2984f939e617519d0272c9b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9ed96370923a409680d0274271a1357e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c3da3445e914457e971838f5b2770735": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "34734887f09f4f3983142aea0d21ba2f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0f5b7b905bd44619b0c434da9310460e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2fca38d0792c49c58978e063e285c047": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_f025f2e1d9bb45d3a262b517670821fe",
              "IPY_MODEL_f987279ed5e94ee6a71fbc878555f55b",
              "IPY_MODEL_66fc7add874b4db4ba0b772341fe08d7"
            ],
            "layout": "IPY_MODEL_edc00b7881d54f7eac9c679a4db3fb82"
          }
        },
        "f025f2e1d9bb45d3a262b517670821fe": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e9cde8ed3d55478f9230b582cee12fcd",
            "placeholder": "​",
            "style": "IPY_MODEL_73c92e7d86cd48f1b3163239b98e2aed",
            "value": "gen tuple: 100%"
          }
        },
        "f987279ed5e94ee6a71fbc878555f55b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_14daaba923694c468cb5eb5633fc5be8",
            "max": 300,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_04ba18a4cb93445dbb36a8cdff27e585",
            "value": 300
          }
        },
        "66fc7add874b4db4ba0b772341fe08d7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0eb9c264ee6a49ac8e9cfad176bbb22c",
            "placeholder": "​",
            "style": "IPY_MODEL_fb4796e95f2441d9b2fc8d0507ab2e24",
            "value": " 300/300 [02:53&lt;00:00,  1.94it/s]"
          }
        },
        "edc00b7881d54f7eac9c679a4db3fb82": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e9cde8ed3d55478f9230b582cee12fcd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "73c92e7d86cd48f1b3163239b98e2aed": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "14daaba923694c468cb5eb5633fc5be8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "04ba18a4cb93445dbb36a8cdff27e585": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "0eb9c264ee6a49ac8e9cfad176bbb22c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fb4796e95f2441d9b2fc8d0507ab2e24": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e094fabbbe114b1aab9749e16cc3289f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_472103c43d514d80962c987889fd4073",
              "IPY_MODEL_73b0972353ec42e89c57546a5d90fcdd",
              "IPY_MODEL_c713b709d6a4403a8ce187b57e13093c"
            ],
            "layout": "IPY_MODEL_339303633a98401bbef674665a1eee4c"
          }
        },
        "472103c43d514d80962c987889fd4073": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f56b2ca1e4ad4571a809fc9de6fee243",
            "placeholder": "​",
            "style": "IPY_MODEL_6a2fc26c54e84923ba4909982fb374d4",
            "value": "gen tuple: 100%"
          }
        },
        "73b0972353ec42e89c57546a5d90fcdd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f115dfe9be204ba8921bcc4377e329d4",
            "max": 300,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_49127d03967b4683bc46f1d722686649",
            "value": 300
          }
        },
        "c713b709d6a4403a8ce187b57e13093c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_934b7e03e8c14fc6b46f1e3e7f31f180",
            "placeholder": "​",
            "style": "IPY_MODEL_eb15376661514499ad6944fc01f154cc",
            "value": " 300/300 [02:45&lt;00:00,  1.83it/s]"
          }
        },
        "339303633a98401bbef674665a1eee4c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f56b2ca1e4ad4571a809fc9de6fee243": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6a2fc26c54e84923ba4909982fb374d4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f115dfe9be204ba8921bcc4377e329d4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "49127d03967b4683bc46f1d722686649": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "934b7e03e8c14fc6b46f1e3e7f31f180": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "eb15376661514499ad6944fc01f154cc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "263e47861b534685ba526df2308b1250": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_904690f5edff4ffc82736f80b09a89fb",
              "IPY_MODEL_ca5b09f8c87f40b0a904288daf1bd65e",
              "IPY_MODEL_feca4172757643c099466934c4a1c0fd"
            ],
            "layout": "IPY_MODEL_c68dbaff763241a2bb161732902d8dfd"
          }
        },
        "904690f5edff4ffc82736f80b09a89fb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_82ebe94d17084d30ab89ac9e35883cba",
            "placeholder": "​",
            "style": "IPY_MODEL_c27d713403834026af8f245f8f7928c1",
            "value": "Generating train split: "
          }
        },
        "ca5b09f8c87f40b0a904288daf1bd65e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_79779ba0630840788fbd99aa09e49c4d",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f681a700e2d7401ea95aed4dbfbca136",
            "value": 1
          }
        },
        "feca4172757643c099466934c4a1c0fd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2a583b6d773440bbbaa5e823883ca483",
            "placeholder": "​",
            "style": "IPY_MODEL_d519d504c4294b748606f592187b1bdf",
            "value": " 1350/0 [00:00&lt;00:00, 76097.12 examples/s]"
          }
        },
        "c68dbaff763241a2bb161732902d8dfd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "82ebe94d17084d30ab89ac9e35883cba": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c27d713403834026af8f245f8f7928c1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "79779ba0630840788fbd99aa09e49c4d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "f681a700e2d7401ea95aed4dbfbca136": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "2a583b6d773440bbbaa5e823883ca483": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d519d504c4294b748606f592187b1bdf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ca0e40624d57417d817715031ee8aed2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_e774d95a3c45463d9665d371809f67fe",
              "IPY_MODEL_f1ad7ea75c8447b0abc60d1ec57e3ee7",
              "IPY_MODEL_39cd2fcf02f84abc8fe8c70446ac7451"
            ],
            "layout": "IPY_MODEL_14fe07dc0e7e4a8fb8dad1704a18fc6d"
          }
        },
        "e774d95a3c45463d9665d371809f67fe": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3728040eff18411dbb4420b248669df2",
            "placeholder": "​",
            "style": "IPY_MODEL_3f610dda67274ffe9cc8f815267f8f1d",
            "value": "Generating valid split: "
          }
        },
        "f1ad7ea75c8447b0abc60d1ec57e3ee7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e02bdfd5b1de41c2be4d467b18b35826",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_34c26982726047b1906c0beb53464e3c",
            "value": 1
          }
        },
        "39cd2fcf02f84abc8fe8c70446ac7451": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_649e5524bc4b45bdb61383f31a9ef628",
            "placeholder": "​",
            "style": "IPY_MODEL_a80a69d402834743bf20501273bb7940",
            "value": " 150/0 [00:00&lt;00:00, 13194.34 examples/s]"
          }
        },
        "14fe07dc0e7e4a8fb8dad1704a18fc6d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3728040eff18411dbb4420b248669df2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3f610dda67274ffe9cc8f815267f8f1d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e02bdfd5b1de41c2be4d467b18b35826": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "34c26982726047b1906c0beb53464e3c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "649e5524bc4b45bdb61383f31a9ef628": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a80a69d402834743bf20501273bb7940": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9a4e26c3007b4f71afbadaa9a9236bd7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_3c5142ba07f045b6975f28690c5ceec6",
              "IPY_MODEL_1a0552fd86cd4f3c88d2a309b4227a99",
              "IPY_MODEL_fa181fcafc6c47dd82aaca4601b764a4"
            ],
            "layout": "IPY_MODEL_adc9f1f9fc88404786040da1b74fc5ec"
          }
        },
        "3c5142ba07f045b6975f28690c5ceec6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b34021f29da54dd98ea638166e13e90d",
            "placeholder": "​",
            "style": "IPY_MODEL_ee41291817ad40869f4e33bb65116025",
            "value": "Tokenizing: 100%"
          }
        },
        "1a0552fd86cd4f3c88d2a309b4227a99": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f02b9ef437ed4f6a8d901f5b8036b16b",
            "max": 1350,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_691496ed95824d638607c7dd2d4b2796",
            "value": 1350
          }
        },
        "fa181fcafc6c47dd82aaca4601b764a4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d5cc9456304a4c238300ab3018ae40d4",
            "placeholder": "​",
            "style": "IPY_MODEL_73a2488192cd430eae988ad657efaf51",
            "value": " 1350/1350 [00:00&lt;00:00, 3161.40 examples/s]"
          }
        },
        "adc9f1f9fc88404786040da1b74fc5ec": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b34021f29da54dd98ea638166e13e90d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ee41291817ad40869f4e33bb65116025": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f02b9ef437ed4f6a8d901f5b8036b16b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "691496ed95824d638607c7dd2d4b2796": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d5cc9456304a4c238300ab3018ae40d4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "73a2488192cd430eae988ad657efaf51": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ba1ddb5ced99455b986d17ef4e1211f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_66fcc4a411d14e349e3f4a2740916237",
              "IPY_MODEL_c7012ed54a844e3da7ae7cc8cbcd63a5",
              "IPY_MODEL_ced8276aeda74415b462e58dcf7b28c8"
            ],
            "layout": "IPY_MODEL_55f502b710934fecbc8d1829b146fae5"
          }
        },
        "66fcc4a411d14e349e3f4a2740916237": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e0196770f7a845f9be55129b3a9a0333",
            "placeholder": "​",
            "style": "IPY_MODEL_9467e4f9168a4951a552b74fe187ff4d",
            "value": "Tokenizing: 100%"
          }
        },
        "c7012ed54a844e3da7ae7cc8cbcd63a5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_53e6a5557aa14d3a96d12aa2f035cbd1",
            "max": 150,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_896aab17a40f44c5ab5f7c5181d60dda",
            "value": 150
          }
        },
        "ced8276aeda74415b462e58dcf7b28c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2c1f3fc5f2614e17a28c9038b17ece1c",
            "placeholder": "​",
            "style": "IPY_MODEL_1f1bac8d17564f24ac9b3fbdb8423397",
            "value": " 150/150 [00:00&lt;00:00, 2523.72 examples/s]"
          }
        },
        "55f502b710934fecbc8d1829b146fae5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e0196770f7a845f9be55129b3a9a0333": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9467e4f9168a4951a552b74fe187ff4d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "53e6a5557aa14d3a96d12aa2f035cbd1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "896aab17a40f44c5ab5f7c5181d60dda": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "2c1f3fc5f2614e17a28c9038b17ece1c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1f1bac8d17564f24ac9b3fbdb8423397": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "3ad05c31b2104c198ef674a39f6debf1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_4e70421d2e894b9a9d912d94c3daccb1",
              "IPY_MODEL_e0dcd9953de24cc698dadbdf84b3f4fb",
              "IPY_MODEL_dea0cf892aa34c2fb7dab56f128e4abf"
            ],
            "layout": "IPY_MODEL_2b34250312154e34a150ab7b0c1c95de"
          }
        },
        "4e70421d2e894b9a9d912d94c3daccb1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a668c5c60d254a1db35f410ebfa591aa",
            "placeholder": "​",
            "style": "IPY_MODEL_540bbc6bb56049fcaac97d03c5c8d7c8",
            "value": "Generating train split: "
          }
        },
        "e0dcd9953de24cc698dadbdf84b3f4fb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f3ce882707744c55abc701e16810b004",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_8101456698c3446c88b53277599fee5a",
            "value": 1
          }
        },
        "dea0cf892aa34c2fb7dab56f128e4abf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_274b1ed2ee8d448a95e8bf57597760ce",
            "placeholder": "​",
            "style": "IPY_MODEL_f5b970f7fa9f49938bb933666516e505",
            "value": " 7200/0 [00:00&lt;00:00, 326659.19 examples/s]"
          }
        },
        "2b34250312154e34a150ab7b0c1c95de": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a668c5c60d254a1db35f410ebfa591aa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "540bbc6bb56049fcaac97d03c5c8d7c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f3ce882707744c55abc701e16810b004": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "8101456698c3446c88b53277599fee5a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "274b1ed2ee8d448a95e8bf57597760ce": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f5b970f7fa9f49938bb933666516e505": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4ad39cac1d144a2c84b830e6c9baa1a3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_ea62d7989e4d491b9dfd51166793d310",
              "IPY_MODEL_b56653e8a6f54f539eaeebfa39cb8ea4",
              "IPY_MODEL_13df40ffb4c749dd8fe7b5e460d84c61"
            ],
            "layout": "IPY_MODEL_a0a85dca43404141ac79576a963317e1"
          }
        },
        "ea62d7989e4d491b9dfd51166793d310": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_de05875f55da4b248c1d7032c1686664",
            "placeholder": "​",
            "style": "IPY_MODEL_16c99617bf194a7581bb49048a4e4aff",
            "value": "Generating valid split: "
          }
        },
        "b56653e8a6f54f539eaeebfa39cb8ea4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c824e738ac614fcabf1bb096b6dd5084",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_4e7fcf7bbd904cea9891e3b9b9cbe5a0",
            "value": 1
          }
        },
        "13df40ffb4c749dd8fe7b5e460d84c61": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4066bb6227c94453a95cf2c3eed120e5",
            "placeholder": "​",
            "style": "IPY_MODEL_6b63695e162044248a4e2968c48b73a2",
            "value": " 800/0 [00:00&lt;00:00, 61762.69 examples/s]"
          }
        },
        "a0a85dca43404141ac79576a963317e1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "de05875f55da4b248c1d7032c1686664": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "16c99617bf194a7581bb49048a4e4aff": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c824e738ac614fcabf1bb096b6dd5084": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "4e7fcf7bbd904cea9891e3b9b9cbe5a0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "4066bb6227c94453a95cf2c3eed120e5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6b63695e162044248a4e2968c48b73a2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "80fd69b332764105b76c947c1cdf143a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_1907264b33654a63aec89b9dc275ae89",
              "IPY_MODEL_e49bc39059d84c6dbcc67589182d9ab9",
              "IPY_MODEL_d5179f91a2a14f1e8628970bc896c4f4"
            ],
            "layout": "IPY_MODEL_33366890271d41539db1dc2e1067ba0e"
          }
        },
        "1907264b33654a63aec89b9dc275ae89": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_538b63d76f8545139342189198e4394b",
            "placeholder": "​",
            "style": "IPY_MODEL_6b9af069d5284ac589129425e5a06f30",
            "value": "tokenize: 100%"
          }
        },
        "e49bc39059d84c6dbcc67589182d9ab9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7380385a7deb460b899030d2cdf9959f",
            "max": 7200,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_1463990667e94734b034c486ec254d4d",
            "value": 7200
          }
        },
        "d5179f91a2a14f1e8628970bc896c4f4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_28326c42c8384e72a428b7c7049e327a",
            "placeholder": "​",
            "style": "IPY_MODEL_8e769c17b5d64896840bf5f6f3e3b444",
            "value": " 7200/7200 [00:02&lt;00:00, 3090.75 examples/s]"
          }
        },
        "33366890271d41539db1dc2e1067ba0e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "538b63d76f8545139342189198e4394b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6b9af069d5284ac589129425e5a06f30": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7380385a7deb460b899030d2cdf9959f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1463990667e94734b034c486ec254d4d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "28326c42c8384e72a428b7c7049e327a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8e769c17b5d64896840bf5f6f3e3b444": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8d16f77d657a441990dcfaf53c102daa": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_5db267ca3aa4488b938da51755fef568",
              "IPY_MODEL_26415d23dedd44778958e6e769a3d754",
              "IPY_MODEL_7ce3926f3e7c40d59176dd51ac82fc70"
            ],
            "layout": "IPY_MODEL_40e6b9a67e214d3e855221a1f96f3681"
          }
        },
        "5db267ca3aa4488b938da51755fef568": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_10601c710bb74681a4a54febde08fbc6",
            "placeholder": "​",
            "style": "IPY_MODEL_6cceb94162814954bb18a21f4fbc714f",
            "value": "tokenize: 100%"
          }
        },
        "26415d23dedd44778958e6e769a3d754": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_008bf0f700534659b6a69147fa7a37f3",
            "max": 800,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_828142c9bbca4644802913173a5acbc8",
            "value": 800
          }
        },
        "7ce3926f3e7c40d59176dd51ac82fc70": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c43dbb7148ba48c0921280672979ffc0",
            "placeholder": "​",
            "style": "IPY_MODEL_2e5537e2cb564277806fa4bb6a2c02c2",
            "value": " 800/800 [00:00&lt;00:00, 3266.77 examples/s]"
          }
        },
        "40e6b9a67e214d3e855221a1f96f3681": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "10601c710bb74681a4a54febde08fbc6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6cceb94162814954bb18a21f4fbc714f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "008bf0f700534659b6a69147fa7a37f3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "828142c9bbca4644802913173a5acbc8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c43dbb7148ba48c0921280672979ffc0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2e5537e2cb564277806fa4bb6a2c02c2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rsYMOJ8iDQcN",
        "outputId": "dfae4f9a-88a8-4596-bfd6-bb400c21c1e1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m126.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m101.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m54.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m19.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.7/188.7 MB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m79.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hPython: 3.11.13\n",
            "Torch : 2.6.0+cu124\n",
            "CUDA available: True\n",
            "GPU: NVIDIA A100-SXM4-40GB\n",
            "Mon Aug 18 01:59:33 2025       \n",
            "+-----------------------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |\n",
            "|-----------------------------------------+------------------------+----------------------+\n",
            "| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |\n",
            "|                                         |                        |               MIG M. |\n",
            "|=========================================+========================+======================|\n",
            "|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |\n",
            "| N/A   32C    P0             43W /  400W |       5MiB /  40960MiB |      0%      Default |\n",
            "|                                         |                        |             Disabled |\n",
            "+-----------------------------------------+------------------------+----------------------+\n",
            "                                                                                         \n",
            "+-----------------------------------------------------------------------------------------+\n",
            "| Processes:                                                                              |\n",
            "|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |\n",
            "|        ID   ID                                                               Usage      |\n",
            "|=========================================================================================|\n",
            "|  No running processes found                                                             |\n",
            "+-----------------------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "!pip -q install \"transformers>=4.40\" \"datasets>=2.18\" accelerate openai tiktoken\n",
        "\n",
        "import torch, platform\n",
        "print(\"Python:\", platform.python_version())\n",
        "print(\"Torch :\", torch.__version__)\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%bash\n",
        "cat > /content/python_type_tokenizer.py <<'PY'\n",
        "from __future__ import annotations\n",
        "import ast, io, re, tokenize as py_tok\n",
        "from typing import List, Tuple\n",
        "\n",
        "__all__ = [\"PyTypeTokenizer\"]\n",
        "\n",
        "_CONST_TAG = {int: \"<INT>\", float: \"<FLOAT>\", bool: \"<BOOL>\", str: \"<STR>\"}\n",
        "ALL_TAGS = list(_CONST_TAG.values()) + [\"<LIST>\", \"<TUPLE>\"]\n",
        "_TAG_RE = re.compile(r\"<[^>]+>\")\n",
        "_MINUS_FIX = re.compile(r\"-(<INT>|<FLOAT>)(?=[0-9])\")\n",
        "_LIST_RE = re.compile(r\"\\[[^\\[\\]]*?\\]\")\n",
        "_TUPLE_RE = re.compile(r\"\\([^()]*?,[^()]*?\\)\")\n",
        "_EMPTY_TUP = re.compile(r\"\\(\\)\")\n",
        "\n",
        "_SPLIT_RE = re.compile(\n",
        "    r\"<TUPLE>\\(\\)\"\n",
        "    r\"|<BOOL>True|<BOOL>False\"\n",
        "    r\"|<[A-Z]+>[-+]?\\d+\\.\\d+(?:e[-+]?\\d+)?\"\n",
        "    r\"|<[A-Z]+>[-+]?\\d+\"\n",
        "    r\"|<[A-Z]+>'[^']*'|<[A-Z]+>\\\"[^\\\"]*\\\"\"\n",
        "    r\"|<(?:LIST|TUPLE)>[\\[\\(\\]\\)]\"\n",
        "    r\"|<[^>]+>\"\n",
        "    r\"|[A-Za-z_][A-Za-z0-9_]*\"\n",
        "    r\"|[-+*/%^=(){}\\[\\].?:]\"\n",
        ")\n",
        "\n",
        "class PyTypeTokenizer:\n",
        "    def tag_text(self, text: str) -> str:\n",
        "        spans: List[Tuple[int, int, str]] = []\n",
        "        buf = io.BytesIO(text.encode())\n",
        "        prev = None\n",
        "        try:\n",
        "            for tok in py_tok.tokenize(buf.readline):\n",
        "                ttype, tstr, (_, scol), (_, ecol), _ = tok\n",
        "                if prev and prev.type == py_tok.OP and prev.string == '-' and ttype == py_tok.NUMBER:\n",
        "                    scol = prev.start[1]; tstr = '-' + tstr; prev = None\n",
        "                else:\n",
        "                    prev = tok\n",
        "                tag = None\n",
        "                if ttype == py_tok.NUMBER:\n",
        "                    try:\n",
        "                        tag = _CONST_TAG[type(ast.literal_eval(tstr))]\n",
        "                    except Exception:\n",
        "                        pass\n",
        "                elif ttype == py_tok.STRING:\n",
        "                    tag = \"<STR>\"\n",
        "                elif ttype == py_tok.NAME and tstr in (\"True\", \"False\"):\n",
        "                    tag = \"<BOOL>\"\n",
        "                if tag:\n",
        "                    spans.append((scol, ecol, tag + tstr))\n",
        "        except py_tok.TokenError:\n",
        "            pass\n",
        "\n",
        "        chars = list(text)\n",
        "        for s, e, rep in reversed(spans):\n",
        "            chars[s:e] = [rep]\n",
        "        tagged = \"\".join(chars)\n",
        "        tagged = _MINUS_FIX.sub(lambda m: f\"{m.group(1)}-\", tagged)\n",
        "\n",
        "        tagged = _LIST_RE.sub(lambda m: f\"<LIST>[{m.group(0)[1:-1]}<LIST>]\", tagged)\n",
        "        tagged = _TUPLE_RE.sub(lambda m: f\"<TUPLE>({m.group(0)[1:-1]}<TUPLE>)\", tagged)\n",
        "        tagged = _EMPTY_TUP.sub(\"<TUPLE>()\", tagged)\n",
        "        return tagged\n",
        "\n",
        "    def detag_text(self, s: str) -> str:\n",
        "        return _TAG_RE.sub(\"\", s)\n",
        "\n",
        "    def tokenize(self, s: str, *, pretagged: bool = False):\n",
        "        text = s if pretagged else self.tag_text(s)\n",
        "        raw = [t for t in _SPLIT_RE.findall(text) if t != ',']\n",
        "        cleaned = []\n",
        "        for tok in raw:\n",
        "            if tok.startswith(\"<STR>\"):\n",
        "                lit = tok[5:]\n",
        "                if lit and lit[0] in (\"'\", '\"') and lit[-1] == lit[0]:\n",
        "                    lit = lit[1:-1]\n",
        "                cleaned.append(\"<STR>\" + lit)\n",
        "            else:\n",
        "                cleaned.append(tok)\n",
        "        return cleaned\n",
        "\n",
        "    @staticmethod\n",
        "    def register_tokenizer(hf_tok, extra=None):\n",
        "        hf_tok.add_tokens(ALL_TAGS + (extra or []), special_tokens=False)\n",
        "        return hf_tok\n",
        "PY"
      ],
      "metadata": {
        "id": "CDG78SzbFry1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os, getpass\n",
        "try:\n",
        "    from openai import OpenAI\n",
        "except Exception:\n",
        "    !pip -q install openai\n",
        "    from openai import OpenAI\n",
        "\n",
        "if not os.getenv(\"OPENAI_API_KEY\"):\n",
        "    os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter your OpenAI API key (sk-…): \").strip()\n",
        "\n",
        "client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"].strip())\n",
        "TEACHER_MODEL = \"gpt-4o-mini\"  # or \"gpt-4o\"\n",
        "print(\"✅ OpenAI client ready\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ohoL8p1DFv0s",
        "outputId": "527316bd-59d0-46d4-ba7a-bd0874f0ea02"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Enter your OpenAI API key (sk-…): ··········\n",
            "✅ OpenAI client ready\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import json, random, pathlib, re, ast, time\n",
        "from typing import List, Dict\n",
        "\n",
        "out_dir = pathlib.Path(\"data_teacher\"); out_dir.mkdir(exist_ok=True, parents=True)\n",
        "\n",
        "USE_TEACHER = True      # set False to use the local fallback\n",
        "TARGET_PER_SKILL = 100  # start small to control spend; scale later\n",
        "TRAIN_FRAC = 0.9\n",
        "SKILLS = [\"add\",\"sub\",\"max\",\"min\",\"sort\"]\n",
        "\n",
        "NUM_TEMPLATES = {\n",
        "    \"add\": [\n",
        "        \"Add {a} and {b}.\",\n",
        "        \"What is {a} plus {b}?\",\n",
        "        \"Please compute the sum of {a} and {b}.\"\n",
        "    ],\n",
        "    \"sub\": [\n",
        "        \"Subtract {b} from {a}.\",\n",
        "        \"What is {a} minus {b}?\",\n",
        "        \"Compute {a} - {b}.\"\n",
        "    ],\n",
        "    \"max\": [\n",
        "        \"What is the maximum of {lst}?\",\n",
        "        \"Find the largest value in {lst}.\",\n",
        "        \"Max of {lst}.\"\n",
        "    ],\n",
        "    \"min\": [\n",
        "        \"What is the minimum of {lst}?\",\n",
        "        \"Find the smallest value in {lst}.\",\n",
        "        \"Min of {lst}.\"\n",
        "    ],\n",
        "    \"sort\": [\n",
        "        \"Sort the list {lst}.\",\n",
        "        \"Please sort {lst}.\",\n",
        "        \"Return {lst} sorted.\"\n",
        "    ],\n",
        "}\n",
        "\n",
        "def ri(a=-999, b=999): return random.randint(a, b)\n",
        "def rf(): return round(random.uniform(-999, 999), 2)\n",
        "def rnum(): return rf() if random.random() < 0.4 else ri()\n",
        "def rlist(n=None):\n",
        "    n = n or random.randint(4, 10)\n",
        "    return [rnum() for _ in range(n)]\n",
        "\n",
        "def make_seed(skill):\n",
        "    if skill in (\"add\",\"sub\"):\n",
        "        a, b = rnum(), rnum()\n",
        "        tpl = random.choice(NUM_TEMPLATES[skill])\n",
        "        prompt = tpl.format(a=a, b=b)\n",
        "        code   = f\"{a} + {b}\" if skill==\"add\" else f\"{a} - {b}\"\n",
        "    else:\n",
        "        lst = rlist()\n",
        "        tpl = random.choice(NUM_TEMPLATES[skill])\n",
        "        prompt = tpl.format(lst=lst)\n",
        "        if skill==\"max\": code = f\"max({lst})\"\n",
        "        elif skill==\"min\": code = f\"min({lst})\"\n",
        "        else: code = f\"sorted({lst})\"\n",
        "    return prompt, code\n",
        "\n",
        "def valid_code(skill, code):\n",
        "    try:\n",
        "        ast.parse(code, mode=\"eval\")\n",
        "        val = eval(code, {\"__builtins__\": {}}, {\"max\": max, \"min\": min, \"sorted\": sorted})\n",
        "        if skill in (\"add\",\"sub\"): return isinstance(val, (int,float))\n",
        "        if skill in (\"max\",\"min\"): return isinstance(val, (int,float))\n",
        "        if skill == \"sort\": return isinstance(val, list)\n",
        "        return True\n",
        "    except Exception:\n",
        "        return False\n",
        "\n",
        "def ask_teacher(seed_prompt, skill):\n",
        "    sys_prompt = (\n",
        "        \"You write a one-line Python expression that answers the user's request. \"\n",
        "        \"Return strict JSON with keys: prompt, code.\"\n",
        "    )\n",
        "    user_prompt = f\"Task type: {skill}\\nUser: {seed_prompt}\\nReturn JSON.\"\n",
        "    rsp = client.chat.completions.create(\n",
        "        model=TEACHER_MODEL,\n",
        "        response_format={\"type\": \"json_object\"},\n",
        "        temperature=0.7,\n",
        "        max_tokens=120,\n",
        "        messages=[{\"role\":\"system\",\"content\":sys_prompt},\n",
        "                  {\"role\":\"user\",\"content\":user_prompt}]\n",
        "    )\n",
        "    txt = rsp.choices[0].message.content\n",
        "    data = json.loads(txt)\n",
        "    prompt = data.get(\"prompt\", seed_prompt)\n",
        "    code   = data[\"code\"].strip()\n",
        "    return prompt, code\n",
        "\n",
        "rows: List[Dict] = []\n",
        "for skill in SKILLS:\n",
        "    got = 0\n",
        "    seen = set()\n",
        "    while got < TARGET_PER_SKILL:\n",
        "        p0, c0 = make_seed(skill)\n",
        "        if USE_TEACHER:\n",
        "            try:\n",
        "                p, c = ask_teacher(p0, skill)\n",
        "            except Exception as e:\n",
        "                msg = str(e)\n",
        "                if \"insufficient_quota\" in msg or \"You exceeded your current quota\" in msg:\n",
        "                    raise RuntimeError(\"Insufficient quota. Reduce TARGET_PER_SKILL or add credits.\") from e\n",
        "                time.sleep(2.0)\n",
        "                continue\n",
        "        else:\n",
        "            p, c = p0, c0\n",
        "\n",
        "        key = (skill, p, c)\n",
        "        if key in seen:\n",
        "            continue\n",
        "        if not valid_code(skill, c):\n",
        "            continue\n",
        "\n",
        "        rows.append({\"skill\": skill, \"prompt\": p, \"code\": c})\n",
        "        seen.add(key)\n",
        "        got += 1\n",
        "\n",
        "random.shuffle(rows)\n",
        "split = int(TRAIN_FRAC * len(rows))\n",
        "with open(out_dir/\"train.jsonl\",\"w\") as f:\n",
        "    for r in rows[:split]: f.write(json.dumps(r)+\"\\n\")\n",
        "with open(out_dir/\"valid.jsonl\",\"w\") as f:\n",
        "    for r in rows[split:]: f.write(json.dumps(r)+\"\\n\")\n",
        "print(f\"Saved {split} train and {len(rows)-split} valid to {out_dir}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "MNS_vq0pF0f9",
        "outputId": "edbaa9e9-03e9-4293-bc20-5966058c0394"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Saved 450 train and 50 valid to data_teacher\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset, DatasetDict\n",
        "from transformers import AutoTokenizer\n",
        "from pathlib import Path\n",
        "import re\n",
        "\n",
        "data_dir = Path(\"data_teacher\")\n",
        "assert (data_dir/\"train.jsonl\").exists(), \"Run the generator cell first.\"\n",
        "\n",
        "ds = load_dataset(\n",
        "    \"json\",\n",
        "    data_files={\"train\": str(data_dir/\"train.jsonl\"),\n",
        "                \"valid\": str(data_dir/\"valid.jsonl\")}\n",
        ")\n",
        "\n",
        "tok = AutoTokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\n",
        "tok.pad_token = tok.eos_token                      # GPT-2 pad fix\n",
        "SEP = \" <|END|> \"\n",
        "\n",
        "digit_or_punct = re.compile(r\"[0-9\\-\\[\\]\\(\\),]\")\n",
        "\n",
        "def build(row):\n",
        "    prompt = row[\"prompt\"]\n",
        "    code   = row[\"code\"]\n",
        "\n",
        "    # Full text\n",
        "    full_text = prompt + SEP + code\n",
        "    enc_full  = tok(full_text, truncation=True, padding=False, add_special_tokens=False)\n",
        "    input_ids = enc_full[\"input_ids\"]\n",
        "\n",
        "    # Mask code tokens for CLM loss, ignore prompt+SEP with -100\n",
        "    prefix_len = len(tok(prompt + SEP, add_special_tokens=False)[\"input_ids\"])\n",
        "    labels_code = [-100]*min(prefix_len, len(input_ids)) + input_ids[min(prefix_len, len(input_ids)):]\n",
        "    labels_code = labels_code[:len(input_ids)]\n",
        "\n",
        "    # Span mask on prompt only: mark tokens overlapping digits or list punctuation\n",
        "    enc_prompt = tok(prompt, return_offsets_mapping=True, add_special_tokens=False)\n",
        "    span_mask_prompt = []\n",
        "    for (s,e) in enc_prompt[\"offset_mapping\"]:\n",
        "        sub = prompt[s:e]\n",
        "        span_mask_prompt.append(1 if digit_or_punct.search(sub) else 0)\n",
        "    # extend to full length with zeros\n",
        "    span_mask = span_mask_prompt + [0]*(len(input_ids)-len(span_mask_prompt))\n",
        "    span_mask = span_mask[:len(input_ids)]\n",
        "\n",
        "    row[\"input_ids\"] = input_ids\n",
        "    row[\"attention_mask\"] = enc_full[\"attention_mask\"]\n",
        "    row[\"labels_code\"] = labels_code\n",
        "    row[\"labels_span\"] = span_mask\n",
        "    return row\n",
        "\n",
        "ds_proc = ds.map(build, remove_columns=ds[\"train\"].column_names)\n",
        "print(ds_proc)\n",
        "print(\"Example processed row:\", {k: len(ds_proc['train'][0][k]) for k in [\"input_ids\",\"attention_mask\",\"labels_code\",\"labels_span\"]})"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 600,
          "referenced_widgets": [
            "5a4d692a01d94f7cab1511d5eec1abce",
            "047397799c5d4881bac4c4a3b9e12ca0",
            "f55b8944a3664794a4ac96677698d4bd",
            "72520679538e4e6a8c371044f5069b7f",
            "2aa99c87bd7f42cf99c7327f467a04c8",
            "8a3aeabc2c0a484f90e4fbd16f99d39e",
            "d2886700adc3466a8021298a0bec98b4",
            "b1722adf42a64d119158a6bdb9a253ca",
            "810346fa9f914a668527dcbae242bc31",
            "d35c8757f7bf403f9f17517e048676a9",
            "8082e3b040614b92b101d07ecddd1976",
            "bfddb7b6a72449fb851e89f7d920ba75",
            "d734c53e1e3d4374a257ffae53840465",
            "3ba15572efdf4cd596536135c28788a4",
            "a17f02f12d2c4d55917c7d928bc5e887",
            "362c14c3b92b42d9a2e07820a26b4fd5",
            "82d8856b6fe64e098cf9b159bc7abba2",
            "0a127db705304452a3445492f99c7163",
            "8f62f575213b4f6983c68169cdbf84e9",
            "954ccd098ac9410faa80643063db93d6",
            "d1f70b0c0cd742cdad59097e7f670a49",
            "5856e585b81f4c0ea4ed1ed5f526afa5",
            "51a75fd6d0934b91959e253c29c719fa",
            "846089535c934ec0b9a95b8396e08db7",
            "d05dd9d70be04d37a4be5c77ae67f31e",
            "5db5adda0f534191b979c7eebfd05bc7",
            "ecdc1cdfb64e47aa89042357d135e43d",
            "51a12a24c340480a97fe76c5778ac3ba",
            "d243ff91c9094c7a9eda97c76c4ec5ca",
            "4752f090513e4bc7a9d26cdc3b870192",
            "9bad7b27e6a44363bfe6a14ea5879614",
            "9c99bf3220bb4bda8bcb1090256213e4",
            "918f162f351d4a8da318532450686aa7",
            "1b0834cc60af40dc94d0757626c1d229",
            "8630ff9993474c008c0a7dc6bbfdc62b",
            "4c4960aa2c9a428ab4ff9489d287e7eb",
            "062532ef0378428a97d562e6660ac750",
            "455c78c887894d19864bab0bb4e9d898",
            "707b9b0d1c3a44b4a29f311f2522283a",
            "3eb28343a2e64de9b53edc959f470357",
            "bb701c7575a04c578b5940e75b38e881",
            "695ea302e91f44089a4d410704af00cf",
            "c8eb6d335baf429b94b11191d248e3b5",
            "15a51ec133734a78a95ad40cf0502226",
            "65449b3640104062a89ce4b6e8b4c3a4",
            "53d3d2da1cab45debfe909b386638b31",
            "41a35a6af6af4b8fb1d7bcfd6072f01d",
            "9fdaa541bf4a469394188c7644b74ad6",
            "398fe1a10f1947dea1af1b215d255c96",
            "98275add997c4357b2d1551516c04e9d",
            "b79781a5c0554b29860910fde297355c",
            "051005a3267a45599a9a0dcacdde540e",
            "b87845b2213745ef93503617bdc1c1c4",
            "ac6f06611bcf4c25ab9c53d8f3d8c864",
            "22c873a1e9014e4089882952f1e295d9",
            "a457850460284d6881cde372fcc92cc3",
            "7baa8024ac31428da1c859534db313e7",
            "5237f210856e40628af19cfa04669a63",
            "5f7b8bf1f75f4ad68f86116f417c4c6b",
            "002b709d7d1241a48e3816a700c4269c",
            "5792e92d475244f4a5486b7679e628c5",
            "389c9dc8a0b7485a94c81c77ffec85d3",
            "4cf7d429d27349d293ce98b6a5c6bc23",
            "43b9d89923134e3689b845ebfb07f00b",
            "1b97176337744550a48d7dbb0043159a",
            "ba70d4ade124424286868abe8ae49b11",
            "4fa4876f1a054e41a9256f3d2b44a57a",
            "926dbe61ce064dcc8652a094faa6eec4",
            "5991dd46a7a0410ea5644574fe8b4caa",
            "58937c09c63c4c1ab35dac561cfedcdf",
            "9aec9101c31640b2873098bb28021b21",
            "6852f79d3d3e4c76a5b5bcff2f50c76b",
            "3feb1c9cd86a44e488009db1cf999bda",
            "c72b5e79d6c440e6875ffb5dcb04b52b",
            "73e96e2be3ea4666ab58e9c094c3643c",
            "b5bec25933f845d38b3538b45229dfea",
            "c3ae153a277e4196b30a35de8bcf4729",
            "1bc69149c27146b79156d03073360065",
            "faae8a18382b48488f8547a399c15f04",
            "111332dfcd02423b896554e7d349aa5e",
            "c982b34f368b42fb8e7428327c545485",
            "be77fc5574ca4f6fbc7ce8c8b7154307",
            "9f271481f88a4d75b87075cc4f2b1ee9",
            "86aef75c5dc446c4aeacafbbaa339f44",
            "7817be8594404c6cb633b7ce5ea3e3e7",
            "d94138f285ee4e4ba017ae69d731cd62",
            "a6083d2162cf404fa29893982fb2001f",
            "48669efbd9f64e15825d68c2e23b470c",
            "7f86942e576546c4a23c866366d0bed8",
            "23286dddb7f04ceb943a04ab7f8175d7",
            "0565533f89b3411494e7a553277a6199",
            "a811f2fb0d1e4d88bab1dd6e900612aa",
            "d3d10e665cbf42d998b4412f688873ca",
            "34662a3d000b445f8dab2f5d0d0f113c",
            "a80350d153cf4adfb387481a2d977be5",
            "0e4add2bff444711995482289b2d95ba",
            "4e1abdd046a3424b80eec140156576f8",
            "210127ed6fbe454bbb9e3835f85de443",
            "941ebfae2f8e4058894d19fb90a8df92"
          ]
        },
        "id": "3IOTZeKSKav0",
        "outputId": "d0d913a5-9c72-47e8-9d7e-c7aacf8a90ea"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating train split: 0 examples [00:00, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "5a4d692a01d94f7cab1511d5eec1abce"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating valid split: 0 examples [00:00, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "bfddb7b6a72449fb851e89f7d920ba75"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "51a75fd6d0934b91959e253c29c719fa"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "1b0834cc60af40dc94d0757626c1d229"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "65449b3640104062a89ce4b6e8b4c3a4"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "a457850460284d6881cde372fcc92cc3"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "4fa4876f1a054e41a9256f3d2b44a57a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/450 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "1bc69149c27146b79156d03073360065"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/50 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7f86942e576546c4a23c866366d0bed8"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "DatasetDict({\n",
            "    train: Dataset({\n",
            "        features: ['input_ids', 'attention_mask', 'labels_code', 'labels_span'],\n",
            "        num_rows: 450\n",
            "    })\n",
            "    valid: Dataset({\n",
            "        features: ['input_ids', 'attention_mask', 'labels_code', 'labels_span'],\n",
            "        num_rows: 50\n",
            "    })\n",
            "})\n",
            "Example processed row: {'input_ids': 49, 'attention_mask': 49, 'labels_code': 49, 'labels_span': 49}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from transformers import AutoModelForCausalLM\n",
        "\n",
        "class GPT2Dual(nn.Module):\n",
        "    def __init__(self, base: AutoModelForCausalLM, span_weight: float = 0.3):\n",
        "        super().__init__()\n",
        "        self.base = base\n",
        "        hidden = base.config.n_embd\n",
        "        self.span_head = nn.Linear(hidden, 1)\n",
        "        self.span_weight = span_weight\n",
        "\n",
        "    def forward(self, input_ids, attention_mask, labels_code=None, labels_span=None):\n",
        "        # Ask base to compute CLM loss and give us hidden states\n",
        "        out = self.base(input_ids=input_ids,\n",
        "                        attention_mask=attention_mask,\n",
        "                        labels=labels_code,\n",
        "                        output_hidden_states=True,\n",
        "                        return_dict=True)\n",
        "        loss = out.loss\n",
        "\n",
        "        # Span head\n",
        "        h = out.hidden_states[-1]              # [B,T,H]\n",
        "        span_logits = self.span_head(h).squeeze(-1)  # [B,T]\n",
        "        if labels_span is not None:\n",
        "            # BCE over valid tokens only\n",
        "            bce = F.binary_cross_entropy_with_logits(\n",
        "                span_logits, labels_span.float(), reduction=\"none\"\n",
        "            )  # [B,T]\n",
        "            mask = attention_mask.float()\n",
        "            bce = (bce * mask).sum() / mask.sum().clamp_min(1.0)\n",
        "            loss = loss + self.span_weight * bce\n",
        "\n",
        "        return {\"loss\": loss, \"logits\": out.logits, \"span_logits\": span_logits}"
      ],
      "metadata": {
        "id": "8VREriPHLLA8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- 7) Train + save without safetensors conflict ---------------------------\n",
        "from transformers import TrainingArguments, Trainer\n",
        "import torch, os\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "model.to(device)\n",
        "\n",
        "args = TrainingArguments(\n",
        "    output_dir=\"ckpt\",\n",
        "    overwrite_output_dir=True,\n",
        "    num_train_epochs=1,              # bump after you verify the run\n",
        "    per_device_train_batch_size=8,\n",
        "    per_device_eval_batch_size=8,\n",
        "    learning_rate=2e-5,\n",
        "    fp16=torch.cuda.is_available(),\n",
        "    logging_steps=200,\n",
        "    save_steps=1000,                 # or set very large to avoid mid-run saves\n",
        "    save_safetensors=False,          # <-- critical fix for tied embeddings\n",
        "    report_to=\"none\",                # avoid W&B if you don’t need it\n",
        ")\n",
        "\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=args,\n",
        "    train_dataset=ds_proc[\"train\"],\n",
        "    eval_dataset=ds_proc[\"valid\"],\n",
        "    data_collator=collate,\n",
        ")\n",
        "\n",
        "trainer.train()\n",
        "\n",
        "# Robust manual save\n",
        "os.makedirs(\"ckpt/final\", exist_ok=True)\n",
        "\n",
        "# 1) Save the GPT-2 base in HF format (safe for tied weights)\n",
        "model.base.save_pretrained(\"ckpt/final/base\")\n",
        "\n",
        "# 2) Save the extra span head weights\n",
        "torch.save(\n",
        "    {\"span_head\": model.span_head.state_dict(),\n",
        "     \"span_weight\": model.span_weight},\n",
        "    \"ckpt/final/dual_heads.pt\"\n",
        ")\n",
        "\n",
        "# 3) Save tokenizer\n",
        "tok.save_pretrained(\"ckpt/final\")\n",
        "\n",
        "print(\"✅ saved model to ckpt/final\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 211
        },
        "id": "Yr1GPnabLQgs",
        "outputId": "a46fd394-9781-40e9-e556-5c258da6f49d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "NameError",
          "evalue": "name 'model' is not defined",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-3579327375.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"cuda\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m args = TrainingArguments(\n",
            "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import re\n",
        "\n",
        "SEP = \" <|END|> \"\n",
        "\n",
        "def generate_code(prompt, max_new=40):\n",
        "    prefix = prompt + SEP\n",
        "    ids = tok(prefix, return_tensors=\"pt\").to(model.base.device)\n",
        "    out = model.base.generate(\n",
        "        **ids, max_new_tokens=max_new, do_sample=False,\n",
        "        pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id\n",
        "    )\n",
        "    txt  = tok.decode(out[0], skip_special_tokens=True)\n",
        "    code = txt.split(SEP, 1)[-1].strip()\n",
        "    code = code.splitlines()[0].strip()\n",
        "    # light cleanup for safety\n",
        "    code = re.sub(r\"[#].*$\", \"\", code).strip()\n",
        "    return code\n",
        "\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\"\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    code = generate_code(p)\n",
        "    try:\n",
        "        val = eval(code, {\"__builtins__\": {}}, {\"max\": max, \"min\": min, \"sorted\": sorted})\n",
        "    except Exception:\n",
        "        val = \"❌\"\n",
        "    print(f\"{p:45} → {code:28} → {val}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 332
        },
        "id": "UMcqd_60O2Lb",
        "outputId": "6409d6c8-9a72-4f59-805a-6c045eefba4a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "NameError",
          "evalue": "name 'model' is not defined",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-2037805910.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtests\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m     \u001b[0mcode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgenerate_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0mval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"__builtins__\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"max\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"min\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"sorted\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/tmp/ipython-input-2037805910.py\u001b[0m in \u001b[0;36mgenerate_code\u001b[0;34m(prompt, max_new)\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgenerate_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_new\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m40\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0mprefix\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprompt\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mSEP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m     \u001b[0mids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtok\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprefix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m     out = model.base.generate(\n\u001b[1;32m      9\u001b[0m         \u001b[0;34m**\u001b[0m\u001b[0mids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_new\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdo_sample\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# %%capture\n",
        "!pip -q install --upgrade transformers datasets accelerate regex tqdm\n",
        "\n",
        "import os, json, random, re, math, ast, time, textwrap, itertools\n",
        "from pathlib import Path\n",
        "from dataclasses import dataclass\n",
        "from typing import List, Dict, Any, Tuple\n",
        "\n",
        "import torch\n",
        "from datasets import Dataset, DatasetDict\n",
        "from transformers import (AutoTokenizer, AutoModelForCausalLM,\n",
        "                          Trainer, TrainingArguments, DataCollatorForLanguageModeling)\n",
        "print(\"Torch:\", torch.__version__)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ABxyd81lQY-j",
        "outputId": "7077e49b-fa5b-4329-f052-87949b9cf57f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.0/42.0 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.3/11.3 MB\u001b[0m \u001b[31m120.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m798.9/798.9 kB\u001b[0m \u001b[31m48.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hTorch: 2.6.0+cu124\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# If you have it on Drive, do:\n",
        "# from google.colab import drive\n",
        "# drive.mount('/content/drive')\n",
        "# %cp /content/drive/MyDrive/python_type_tokenizer.py /content/python_type_tokenizer.py\n",
        "\n",
        "# Otherwise write a clean version here:\n",
        "%%writefile python_type_tokenizer.py\n",
        "import ast, io, re, tokenize as py_tok\n",
        "\n",
        "_CONST_TAG = {int: \"<INT>\", float: \"<FLOAT>\", bool: \"<BOOL>\", str: \"<STR>\"}\n",
        "ALL_TAGS = list(_CONST_TAG.values()) + [\"<LIST>\", \"<TUPLE>\"]\n",
        "\n",
        "_TAG_RE = re.compile(r\"<[^>]+>\")\n",
        "_MINUS_FIX = re.compile(r\"-(<INT>|<FLOAT>)(?=[0-9])\")\n",
        "_LIST_RE = re.compile(r\"\\[[^\\[\\]]*?\\]\")\n",
        "_TUPLE_RE = re.compile(r\"\\([^()]*?,[^()]*?\\)\")\n",
        "_EMPTY_TUP = re.compile(r\"\\(\\)\")\n",
        "\n",
        "_SPLIT_RE = re.compile(\n",
        "    r\"<TUPLE>\\(\\)\"\n",
        "    r\"|<BOOL>True|<BOOL>False\"\n",
        "    r\"|<[A-Z]+>[-+]?\\d+\\.\\d+(?:e[-+]?\\d+)?\"\n",
        "    r\"|<[A-Z]+>[-+]?\\d+\"\n",
        "    r\"|<[A-Z]+>'[^']*'|<[A-Z]+>\\\"[^\\\"]*\\\"\"\n",
        "    r\"|<(?:LIST|TUPLE)>[\\[\\(\\]\\)]\"\n",
        "    r\"|<[^>]+>\"\n",
        "    r\"|[A-Za-z_][A-Za-z0-9_]*\"\n",
        "    r\"|[-+*/%^=(){}\\[\\].?:]\"\n",
        ")\n",
        "\n",
        "class PyTypeTokenizer:\n",
        "    def tag_text(self, text: str) -> str:\n",
        "        spans = []\n",
        "        buf = io.BytesIO(text.encode())\n",
        "        prev = None\n",
        "        try:\n",
        "            for tok in py_tok.tokenize(buf.readline):\n",
        "                ttype, tstr, (_, scol), (_, ecol), _ = tok\n",
        "                if prev and prev.type == py_tok.OP and prev.string == '-' and ttype == py_tok.NUMBER:\n",
        "                    scol = prev.start[1]; tstr = '-' + tstr; prev = None\n",
        "                else:\n",
        "                    prev = tok\n",
        "                tag = None\n",
        "                if ttype == py_tok.NUMBER:\n",
        "                    try:\n",
        "                        tag = _CONST_TAG[type(ast.literal_eval(tstr))]\n",
        "                    except Exception:\n",
        "                        pass\n",
        "                elif ttype == py_tok.STRING:\n",
        "                    tag = \"<STR>\"\n",
        "                elif ttype == py_tok.NAME and tstr in (\"True\", \"False\"):\n",
        "                    tag = \"<BOOL>\"\n",
        "                if tag:\n",
        "                    spans.append((scol, ecol, tag + tstr))\n",
        "        except py_tok.TokenError:\n",
        "            pass\n",
        "\n",
        "        chars = list(text)\n",
        "        for s, e, rep in reversed(spans):\n",
        "            chars[s:e] = [rep]\n",
        "        tagged = \"\".join(chars)\n",
        "        tagged = _MINUS_FIX.sub(lambda m: f\"{m.group(1)}-\", tagged)\n",
        "\n",
        "        tagged = _LIST_RE.sub(lambda m: f\"<LIST>[{m.group(0)[1:-1]}<LIST>]\", tagged)\n",
        "        tagged = _TUPLE_RE.sub(lambda m: f\"<TUPLE>({m.group(0)[1:-1]}<TUPLE>)\", tagged)\n",
        "        tagged = _EMPTY_TUP.sub(\"<TUPLE>()\", tagged)\n",
        "        return tagged\n",
        "\n",
        "    def detag_text(self, s: str) -> str:\n",
        "        return _TAG_RE.sub(\"\", s)\n",
        "\n",
        "    def tokenize(self, s: str, *, pretagged: bool = False):\n",
        "        text = s if pretagged else self.tag_text(s)\n",
        "        raw = [t for t in _SPLIT_RE.findall(text) if t != ',']\n",
        "        cleaned = []\n",
        "        for tok in raw:\n",
        "            if tok.startswith(\"<STR>\"):\n",
        "                lit = tok[5:]\n",
        "                if lit and lit[0] in (\"'\", '\"') and lit[-1] == lit[0]:\n",
        "                    lit = lit[1:-1]\n",
        "                cleaned.append(\"<STR>\" + lit)\n",
        "            else:\n",
        "                cleaned.append(tok)\n",
        "        return cleaned\n",
        "\n",
        "    @staticmethod\n",
        "    def register_tokenizer(hf_tok, extra=None):\n",
        "        hf_tok.add_tokens(ALL_TAGS + (extra or []), special_tokens=False)\n",
        "        return hf_tok"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_qBINH9bQcCD",
        "outputId": "af3a0198-5314-479f-be7b-0a006c28a0d8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Overwriting python_type_tokenizer.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "USE_OPENAI = True  # set to False to use local paraphraser fallback\n",
        "\n",
        "if USE_OPENAI:\n",
        "    !pip -q install --upgrade openai\n",
        "    import os, random, time\n",
        "    from openai import OpenAI\n",
        "    assert \"OPENAI_API_KEY\" in os.environ, \"Set your OpenAI API key in the Colab environment.\"\n",
        "    client = OpenAI()\n",
        "    TEACHER_MODEL = \"gpt-4o-mini\"  # change if you like"
      ],
      "metadata": {
        "id": "lUEkPcEWQfIz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import json, random\n",
        "\n",
        "def ri(a=-99, b=99): return random.randint(a,b)\n",
        "def rlist():\n",
        "    k = random.randint(4,8)\n",
        "    # allow negatives and repeats\n",
        "    return [random.randint(-50,50) for _ in range(k)]\n",
        "\n",
        "def make_example() -> Dict[str, Any]:\n",
        "    skill = random.choice([\"add\",\"sub\",\"max\",\"min\",\"sort\"])\n",
        "    if skill == \"add\":\n",
        "        a,b = ri(), ri()\n",
        "        return {\"skill\": skill, \"code\": f\"{a} + {b}\", \"inputs\": [a,b]}\n",
        "    if skill == \"sub\":\n",
        "        a,b = ri(), ri()\n",
        "        return {\"skill\": skill, \"code\": f\"{a} - {b}\", \"inputs\": [a,b]}\n",
        "    if skill == \"max\":\n",
        "        xs = rlist()\n",
        "        return {\"skill\": skill, \"code\": f\"max({xs})\", \"inputs\": xs}\n",
        "    if skill == \"min\":\n",
        "        xs = rlist()\n",
        "        return {\"skill\": skill, \"code\": f\"min({xs})\", \"inputs\": xs}\n",
        "    if skill == \"sort\":\n",
        "        xs = rlist()\n",
        "        return {\"skill\": skill, \"code\": f\"sorted({xs})\", \"inputs\": xs}\n",
        "\n",
        "def safe_eval_python(code: str):\n",
        "    # One-line safe eval for our limited arithmetic/list calls\n",
        "    allowed_names = {\"max\": max, \"min\": min, \"sorted\": sorted}\n",
        "    return eval(code, {\"__builtins__\": {}}, allowed_names)\n",
        "\n",
        "def teacher_prompts_for(code: str, skill: str, k: int = 5) -> List[str]:\n",
        "    if not USE_OPENAI:\n",
        "        # Simple local paraphrases as a fallback\n",
        "        templates = {\n",
        "            \"add\": [\n",
        "                \"Add {a} and {b}.\", \"Compute the sum of {a} and {b}.\",\n",
        "                \"What is {a} plus {b}?\", \"Please add {a} to {b}.\", \"Sum {a} with {b}.\"\n",
        "            ],\n",
        "            \"sub\": [\n",
        "                \"Subtract {b} from {a}.\", \"What is {a} minus {b}?\",\n",
        "                \"Compute {a} − {b}.\", \"Please subtract {b} from {a}.\", \"Difference of {a} and {b}.\"\n",
        "            ],\n",
        "            \"max\": [\n",
        "                \"What is the maximum of {xs}?\", \"Find the largest in {xs}.\",\n",
        "                \"Return the max element of {xs}.\", \"Pick the greatest in {xs}.\", \"Max value from {xs}?\"\n",
        "            ],\n",
        "            \"min\": [\n",
        "                \"What is the minimum of {xs}?\", \"Find the smallest in {xs}.\",\n",
        "                \"Return the min element of {xs}.\", \"Pick the least in {xs}.\", \"Min value from {xs}?\"\n",
        "            ],\n",
        "            \"sort\": [\n",
        "                \"Sort the list {xs}.\", \"Return {xs} in ascending order.\",\n",
        "                \"Please sort {xs}.\", \"Order the list {xs}.\", \"Sorted version of {xs}?\"\n",
        "            ],\n",
        "        }\n",
        "        if skill in (\"add\",\"sub\"):\n",
        "            nums = [int(s) for s in re.findall(r\"-?\\d+\", code)]\n",
        "            a,b = nums[0], nums[1]\n",
        "            cands = [t.format(a=a,b=b) for t in templates[skill]]\n",
        "        else:\n",
        "            xs = re.findall(r\"\\[.*\\]\", code)[0]\n",
        "            cands = [t.format(xs=xs) for t in templates[skill]]\n",
        "        random.shuffle(cands)\n",
        "        return cands[:k]\n",
        "\n",
        "    sys_prompt = (\n",
        "        \"You are generating natural-language prompts for a given Python one-liner.\\n\"\n",
        "        \"You must ONLY return JSON like:\\n\"\n",
        "        \"{ \\\"prompts\\\": [\\\"...\\\", \\\"...\\\"] }\\n\"\n",
        "        \"Rules:\\n\"\n",
        "        \"- Prompts must ask for the same computation as the code, not reveal the code.\\n\"\n",
        "        \"- No extra words like 'Answer:' or 'Code:'.\\n\"\n",
        "        \"- Write short, varied phrasings.\\n\"\n",
        "        \"- American English.\\n\"\n",
        "    )\n",
        "    user_msg = f\"code: {code}\\nskill: {skill}\\nPlease return 6 varied prompts in JSON.\"\n",
        "\n",
        "    # Resilient call\n",
        "    for attempt in range(4):\n",
        "        try:\n",
        "            r = client.chat.completions.create(\n",
        "                model=TEACHER_MODEL,\n",
        "                temperature=0.7,\n",
        "                max_tokens=300,\n",
        "                messages=[{\"role\": \"system\", \"content\": sys_prompt},\n",
        "                          {\"role\": \"user\", \"content\": user_msg}],\n",
        "                response_format={\"type\": \"json_object\"},\n",
        "            )\n",
        "            obj = json.loads(r.choices[0].message.content)\n",
        "            prompts = obj.get(\"prompts\", [])\n",
        "            prompts = [p.strip() for p in prompts if isinstance(p, str) and p.strip()]\n",
        "            if len(prompts) >= 3:\n",
        "                return prompts[:6]\n",
        "        except Exception as e:\n",
        "            time.sleep(1.5*(2**attempt))\n",
        "    # fallback if API flaky\n",
        "    return teacher_prompts_for(code, skill, k=5)"
      ],
      "metadata": {
        "id": "E-hrwDXrQmPD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ==== Cell 4 (fixed) — build a validated distilled dataset ====================\n",
        "import json, random\n",
        "from pathlib import Path\n",
        "\n",
        "# Ensure tokenizer instance exists (handles fresh runtimes too)\n",
        "try:\n",
        "    type_tok\n",
        "except NameError:\n",
        "    from python_type_tokenizer import PyTypeTokenizer\n",
        "    type_tok = PyTypeTokenizer()\n",
        "\n",
        "# Make sure teacher function exists (should be defined earlier)\n",
        "assert 'teacher_prompts_for' in globals(), \"Run the teacher setup cell first.\"\n",
        "\n",
        "# Output dir\n",
        "DATA_DIR = Path(\"distilled_data\")\n",
        "DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
        "\n",
        "# How many examples per skill to generate (start small to test)\n",
        "TARGET_PER_SKILL = 200    # raise after it works cleanly (e.g., 1000)\n",
        "SKILLS = [\"add\",\"sub\",\"max\",\"min\",\"sort\"]\n",
        "\n",
        "def ri(a=-99, b=99): return random.randint(a,b)\n",
        "def rlist():\n",
        "    k = random.randint(4,8)\n",
        "    return [random.randint(-50,50) for _ in range(k)]\n",
        "\n",
        "def make_example():\n",
        "    skill = random.choice(SKILLS)\n",
        "    if skill == \"add\":\n",
        "        a,b = ri(), ri()\n",
        "        return {\"skill\": skill, \"code\": f\"{a} + {b}\", \"inputs\": [a,b]}\n",
        "    if skill == \"sub\":\n",
        "        a,b = ri(), ri()\n",
        "        return {\"skill\": skill, \"code\": f\"{a} - {b}\", \"inputs\": [a,b]}\n",
        "    if skill == \"max\":\n",
        "        xs = rlist()\n",
        "        return {\"skill\": skill, \"code\": f\"max({xs})\", \"inputs\": xs}\n",
        "    if skill == \"min\":\n",
        "        xs = rlist()\n",
        "        return {\"skill\": skill, \"code\": f\"min({xs})\", \"inputs\": xs}\n",
        "    if skill == \"sort\":\n",
        "        xs = rlist()\n",
        "        return {\"skill\": skill, \"code\": f\"sorted({xs})\", \"inputs\": xs}\n",
        "\n",
        "def safe_eval_python(code: str):\n",
        "    allowed_names = {\"max\": max, \"min\": min, \"sorted\": sorted}\n",
        "    return eval(code, {\"__builtins__\": {}}, allowed_names)\n",
        "\n",
        "records = []\n",
        "\n",
        "for skill in SKILLS:\n",
        "    got = 0\n",
        "    while got < TARGET_PER_SKILL:\n",
        "        ex = make_example()\n",
        "        code = ex[\"code\"]\n",
        "        # validate code\n",
        "        try:\n",
        "            import ast\n",
        "            ast.parse(code)\n",
        "            _ = safe_eval_python(code)\n",
        "        except Exception:\n",
        "            continue\n",
        "\n",
        "        prompts = teacher_prompts_for(code, ex[\"skill\"])\n",
        "        for p in prompts:\n",
        "            tagged_p = type_tok.tag_text(p)\n",
        "            tagged_c = type_tok.tag_text(code)\n",
        "            records.append({\n",
        "                \"skill\": ex[\"skill\"],\n",
        "                \"prompt\": p,\n",
        "                \"code\": code,\n",
        "                \"tagged_prompt\": tagged_p,\n",
        "                \"tagged_code\": tagged_c\n",
        "            })\n",
        "            got += 1\n",
        "            if got >= TARGET_PER_SKILL:\n",
        "                break\n",
        "    print(f\"✓ {skill}: {got}\")\n",
        "\n",
        "# Shuffle and split\n",
        "random.shuffle(records)\n",
        "split = int(0.9 * len(records))\n",
        "train = records[:split]\n",
        "valid = records[split:]\n",
        "\n",
        "with open(DATA_DIR/\"train.jsonl\",\"w\") as f:\n",
        "    for r in train: f.write(json.dumps(r)+\"\\n\")\n",
        "with open(DATA_DIR/\"valid.jsonl\",\"w\") as f:\n",
        "    for r in valid: f.write(json.dumps(r)+\"\\n\")\n",
        "\n",
        "print(\"Saved:\", len(train), \"train and\", len(valid), \"valid\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a9a0GsbwQok7",
        "outputId": "6452c1fa-ec4f-4161-9e3b-ffbddb4edc39"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "✓ add: 200\n",
            "✓ sub: 200\n",
            "✓ max: 200\n",
            "✓ min: 200\n",
            "✓ sort: 200\n",
            "Saved: 900 train and 100 valid\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def load_jsonl(path: Path) -> List[Dict[str,Any]]:\n",
        "    return [json.loads(x) for x in path.read_text().splitlines()]\n",
        "\n",
        "train_rows = load_jsonl(DATA_DIR/\"train.jsonl\")\n",
        "valid_rows = load_jsonl(DATA_DIR/\"valid.jsonl\")\n",
        "\n",
        "ds = DatasetDict({\n",
        "    \"train\": Dataset.from_list(train_rows),\n",
        "    \"valid\": Dataset.from_list(valid_rows),\n",
        "})\n",
        "\n",
        "tok = AutoTokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\n",
        "PyTypeTokenizer.register_tokenizer(tok)\n",
        "tok.pad_token = tok.eos_token\n",
        "\n",
        "SEP = \" <|END|> \"\n",
        "\n",
        "def linearize(row):\n",
        "    # model sees tagged prompt and should produce tagged code\n",
        "    text = row[\"tagged_prompt\"] + SEP + row[\"tagged_code\"]\n",
        "    enc  = tok(text, truncation=True, padding=False)\n",
        "    row[\"input_ids\"] = enc[\"input_ids\"]\n",
        "    row[\"attention_mask\"] = enc[\"attention_mask\"]\n",
        "    return row\n",
        "\n",
        "ds_proc = ds.map(linearize, remove_columns=ds[\"train\"].column_names)\n",
        "data_collator = DataCollatorForLanguageModeling(tok, mlm=False, return_tensors=\"pt\")\n",
        "print(ds_proc)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 255,
          "referenced_widgets": [
            "381548db6ead42459da97840b0f54b16",
            "6d9cbb4c99a744848496e63629aeb3e8",
            "1b6bba6e5b4a4512a50a8c4ab3d415ac",
            "2198c3ba0d5d49ccbdd187921fafa798",
            "4c09e5a01450403f81df3819120d88fe",
            "ed997e3ffd3b40e887636a9e95b903ce",
            "7fe9282b7fbd46e8a1faa7d98d7ad1c0",
            "959cf7589e19440bb78d56fa7d1def8d",
            "7d24ab3282c14a44be2be7cde3fd56c8",
            "4ead520c7df344918a990cf6ba8c2f3f",
            "5239692c65e440a39cdf748a1f60b667",
            "b48ed565cfc04d0ba733668b130d0fb0",
            "78a886fe6b284c879d71a2afe43bd91a",
            "3c2232f7f7e3460aa3c6e99785c27a57",
            "8ca45b1f685d45909246b85832121dba",
            "3ecb1e39af094fda937efe6de5de7993",
            "6da09df24e2a45018415b588743b4d3d",
            "46f6bc4c75ab431580f94bb94def62e2",
            "fb4f4d74ce6d4eb793c699ae8aa0d2ff",
            "ee26f5d04c0546a1ae41b7b2205a6461",
            "27cb0bb0a119498392f632f07ac1d7e9",
            "247c9d8367b5478cbbbf11155432d073"
          ]
        },
        "id": "iH66_3x4Taa1",
        "outputId": "4dec7b5b-03ff-41e0-fbcf-c70c0a428e2f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/900 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "381548db6ead42459da97840b0f54b16"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/100 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "b48ed565cfc04d0ba733668b130d0fb0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "DatasetDict({\n",
            "    train: Dataset({\n",
            "        features: ['input_ids', 'attention_mask'],\n",
            "        num_rows: 900\n",
            "    })\n",
            "    valid: Dataset({\n",
            "        features: ['input_ids', 'attention_mask'],\n",
            "        num_rows: 100\n",
            "    })\n",
            "})\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ==== Cell 6 (fixed) — tokenize + train code LM on distilled_data ====\n",
        "import os, inspect, torch\n",
        "from datasets import load_dataset\n",
        "from transformers import (\n",
        "    AutoTokenizer, AutoModelForCausalLM,\n",
        "    DataCollatorForLanguageModeling,\n",
        "    TrainingArguments, Trainer\n",
        ")\n",
        "\n",
        "# 0) Load the distilled dataset if not already loaded\n",
        "if \"ds\" not in globals():\n",
        "    ds = load_dataset(\n",
        "        \"json\",\n",
        "        data_files={\"train\": \"distilled_data/train.jsonl\",\n",
        "                    \"valid\": \"distilled_data/valid.jsonl\"}\n",
        "    )\n",
        "\n",
        "# 1) Build tokenizer and register your custom tags (+ END token)\n",
        "try:\n",
        "    type_tok\n",
        "except NameError:\n",
        "    from python_type_tokenizer import PyTypeTokenizer\n",
        "    type_tok = PyTypeTokenizer()\n",
        "\n",
        "tok = AutoTokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\n",
        "type_tok.register_tokenizer(tok, extra=[\"<|END|>\"])  # add <|END|> and type tags\n",
        "tok.pad_token = tok.eos_token\n",
        "SEP = \" <|END|> \"\n",
        "\n",
        "# 2) Linearize each row → input_ids, attention_mask\n",
        "def linearize(row):\n",
        "    # Train the model on tagged text so it learns to copy your tags\n",
        "    text = row[\"tagged_prompt\"] + SEP + row[\"tagged_code\"]\n",
        "    enc = tok(text, truncation=True, max_length=256)\n",
        "    return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"]}\n",
        "\n",
        "ds_proc = ds.map(\n",
        "    linearize,\n",
        "    remove_columns=ds[\"train\"].column_names,   # keep only encoded fields\n",
        "    desc=\"Tokenizing\"\n",
        ")\n",
        "\n",
        "# 3) Data collator makes labels for causal LM on the fly\n",
        "collator = DataCollatorForLanguageModeling(tok, mlm=False, return_tensors=\"pt\")\n",
        "\n",
        "# 4) Model\n",
        "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
        "model.resize_token_embeddings(len(tok))  # account for added tokens\n",
        "\n",
        "# 5) Training args\n",
        "args = TrainingArguments(\n",
        "    output_dir=\"ckpt\",\n",
        "    overwrite_output_dir=True,\n",
        "    num_train_epochs=1,                 # bump after you confirm it runs\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=16,\n",
        "    learning_rate=2e-5,\n",
        "    logging_steps=200,\n",
        "    # use epoch strategies if your version supports them\n",
        "    **({\"evaluation_strategy\": \"epoch\", \"save_strategy\": \"epoch\"}\n",
        "       if \"evaluation_strategy\" in inspect.signature(TrainingArguments).parameters\n",
        "       else {}),\n",
        "    fp16=torch.cuda.is_available(),\n",
        "    report_to=\"none\",                   # avoids WANDB warning\n",
        "    remove_unused_columns=False,        # prevents HF from dropping needed cols\n",
        "    save_safetensors=False              # avoids shared-tensor save error\n",
        ")\n",
        "\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=args,\n",
        "    train_dataset=ds_proc[\"train\"],\n",
        "    eval_dataset=ds_proc[\"valid\"],\n",
        "    data_collator=collator,\n",
        "    tokenizer=tok,                      # OK despite deprecation warning\n",
        ")\n",
        "\n",
        "trainer.train()\n",
        "trainer.save_model(\"ckpt/final_code_lm\")\n",
        "tok.save_pretrained(\"ckpt/final_code_lm\")\n",
        "print(\"✅ trained and saved to ckpt/final_code_lm\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 290,
          "referenced_widgets": [
            "09b0bebc79584050b8931b15644fb325",
            "be1d4b1ed1844eab8fbaed64b20779ad",
            "4370d56bbc9e4515bcdd108619704277",
            "7ea66de875d84742b5c76f8504f944f9",
            "6532a2daf0b44383903f73251a9a6e88",
            "7a2c4d23243a4342b000df4532afaae3",
            "c72185240b6342bf8c0c531436505c9f",
            "77146b17f5774806a96bf8cdb457f121",
            "55a126d40c754a9aa36788397f8ce6e9",
            "245ba20912424acd8323e3661b88394c",
            "75d6f9a9f57d4881815ddfbf6e1dfb41",
            "ee69a8d253c748358186b1f26d29ec4e",
            "38ab58a277eb43958f22ca0cee740132",
            "8f272570cea44afa9d64581f707ba4b0",
            "94777907c6144240897752d958245853",
            "59e69d8f54c44a15a249b62b76bb3bcb",
            "23ef84f47b874096853669d2a4c4947c",
            "b4204dd627bb4731ab9afe25ff5edd00",
            "d308b58195264256a9bf7d66afc845de",
            "4d55674c08994337969e10622cd70119",
            "c7f99a52d224416fae28369cbb0700b0",
            "5bd14312cc6a4b5f87932bbfc9198ac5",
            "9757fbbfc29446e4a33e2c05d0d3e2ea",
            "a76ddf3f73104ce9b64ba50364638e3e",
            "3458ec519f264a3fa509ebf0c74fbd31",
            "37bd6882e68f4434b7d3c37e4ea9d5d1",
            "6804429f104d4975be4a6b5d062173a1",
            "ac616871665d4119a33cecb55c044973",
            "4f8b7e2b921f4b87a985dc08ac8973e6",
            "b109f435b07645438eeae722496191b2",
            "5bf1c13096b44efb887c3e49020f7cb1",
            "0a30ac4f70a24e0387ec4854e1674433",
            "c3170eaa6d6c41469ea75551da32b819",
            "357de5f04aae458983354e5cab166abb",
            "4e51527b545b4289aea7529b54aa26ed",
            "5625c13acd974a9b8246944ed77ea0e1",
            "918091e6c864437fb593f39fff1c8adb",
            "aefbb5c5c85a44fd97cba63a46554298",
            "d145dcac47a44e629aa26bb760bcf2dd",
            "fad7e8b0231f42f88039697b1f69424f",
            "7fbafffee20d4d388d2a8d499dcdff7f",
            "830b421e5e8e4ae39cabe24c721101e8",
            "59344bac75424a67ba5b615ba429a9ee",
            "345008c02dbd4d3c8a3cad943a379681"
          ]
        },
        "id": "798Pfw93Th0i",
        "outputId": "593f9178-cb78-49c4-aff5-95877caee119"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Tokenizing:   0%|          | 0/900 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "09b0bebc79584050b8931b15644fb325"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Tokenizing:   0%|          | 0/100 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ee69a8d253c748358186b1f26d29ec4e"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "9757fbbfc29446e4a33e2c05d0d3e2ea"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "357de5f04aae458983354e5cab166abb"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n",
            "/tmp/ipython-input-3026556507.py:69: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
            "  trainer = Trainer(\n",
            "`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='57' max='57' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [57/57 00:05, Epoch 1/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "✅ trained and saved to ckpt/final_code_lm\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Heuristic skill detector from prompt text\n",
        "def detect_skill(p: str) -> str:\n",
        "    s = p.lower()\n",
        "    if \"sort\" in s or \"ascending\" in s or \"order\" in s:\n",
        "        return \"sort\"\n",
        "    if \"maximum\" in s or \"largest\" in s or \"max \" in s:\n",
        "        return \"max\"\n",
        "    if \"minimum\" in s or \"smallest\" in s or \"min \" in s:\n",
        "        return \"min\"\n",
        "    if \"subtract\" in s or \"minus\" in s or \"difference\" in s:\n",
        "        return \"sub\"\n",
        "    if \"add\" in s or \"sum\" in s or \"plus\" in s:\n",
        "        return \"add\"\n",
        "    # default: fallback to max\n",
        "    return \"max\"\n",
        "\n",
        "num_pat = re.compile(r\"-?\\d+\")\n",
        "list_pat = re.compile(r\"\\[([^\\]]+)\\]\")\n",
        "\n",
        "def canonicalize(skill: str, raw_text: str, prompt: str) -> str:\n",
        "    # try to pull a list first\n",
        "    m = list_pat.search(raw_text) or list_pat.search(prompt)\n",
        "    if skill in (\"max\",\"min\",\"sort\") and m:\n",
        "        nums = [int(x.strip()) for x in re.findall(r\"-?\\d+\", m.group(0))]\n",
        "        return f\"{'sorted' if skill=='sort' else skill}({nums})\"\n",
        "    # else pick two scalars from either generation or prompt\n",
        "    nums = [int(x) for x in num_pat.findall(raw_text)] or [int(x) for x in num_pat.findall(prompt)]\n",
        "    if len(nums) >= 2:\n",
        "        a,b = nums[0], nums[1]\n",
        "        if skill == \"add\": return f\"{a} + {b}\"\n",
        "        if skill == \"sub\": return f\"{a} - {b}\"\n",
        "    # last resort: if we have many numbers, use max/min on them\n",
        "    if len(nums) >= 2:\n",
        "        return f\"{'sorted' if skill=='sort' else skill}({nums})\"\n",
        "    # fail safe\n",
        "    return \"0\"\n",
        "\n",
        "@torch.inference_mode()\n",
        "def emit_code(prompt: str, max_new: int = 48) -> str:\n",
        "    # tagged input\n",
        "    inp = type_tok.tag_text(prompt) + SEP\n",
        "    ids = tok(inp, return_tensors=\"pt\").to(model.device)\n",
        "    out = model.generate(**ids, max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id)\n",
        "    dec = tok.decode(out[0], skip_special_tokens=True)\n",
        "    gen = dec.split(SEP, 1)[-1].strip()\n",
        "    gen = type_tok.detag_text(gen)\n",
        "    # choose skill and canonicalize\n",
        "    skill = detect_skill(prompt)\n",
        "    code = canonicalize(skill, gen, prompt)\n",
        "    # ensure valid\n",
        "    try:\n",
        "        ast.parse(code)\n",
        "    except SyntaxError:\n",
        "        code = \"0\"\n",
        "    return code\n",
        "\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Sum of 13 with -9?\",\n",
        "    \"Arrange in ascending order: [5, -7, 2, 0, 5].\",\n",
        "]\n",
        "for p in tests:\n",
        "    code = emit_code(p)\n",
        "    try:\n",
        "        result = eval(code, {\"__builtins__\": {}}, {\"max\": max, \"min\": min, \"sorted\": sorted})\n",
        "    except Exception as e:\n",
        "        result = f\"ERR: {e}\"\n",
        "    print(f\"{p:42} → {code:28} → {result}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "C4qrkeumVSnz",
        "outputId": "be2a34a6-3ee1-4d53-8dac-6beaabfa283c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                             → 42 + -8                      → 34\n",
            "Please subtract 9 from 17.                 → 9 - 17                       → -8\n",
            "What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11\n",
            "Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]\n",
            "Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1\n",
            "Sum of 13 with -9?                         → 13 + -9                      → 4\n",
            "Arrange in ascending order: [5, -7, 2, 0, 5]. → sorted([5, -7, 2, 0, 5])     → [-7, 0, 2, 5, 5]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ==== Replace your teacher helper with this more diverse version ====\n",
        "import re, json, random, time\n",
        "from typing import List\n",
        "\n",
        "# assumes you already have:\n",
        "#   from openai import OpenAI\n",
        "#   client = OpenAI()\n",
        "#   TEACHER_MODEL = \"gpt-4o-mini\"  # or your chosen teacher\n",
        "\n",
        "# diversity knobs\n",
        "N_PROMPTS_PER_CALL = 12           # how many paraphrases per code sample\n",
        "TEMPERATURE       = 0.9\n",
        "TOP_P             = 0.95\n",
        "\n",
        "_STYLE_BUCKETS = [\n",
        "  \"terse imperative\",               # e.g., \"Add 13 and -9.\"\n",
        "  \"polite imperative\",              # e.g., \"Please add 13 and -9.\"\n",
        "  \"question casual\",                # e.g., \"What’s 13 plus -9?\"\n",
        "  \"question formal\",                # e.g., \"What is the sum of 13 and -9?\"\n",
        "  \"programmer voice\",               # e.g., \"Return 13 + (-9).\"\n",
        "  \"mathy\",                          # e.g., \"Compute the value of 13 + (−9).\"\n",
        "  \"context wrapper short\",          # e.g., \"Quick check: add 13 and −9.\"\n",
        "  \"context wrapper longer\",         # e.g., \"For a quick sanity check, please add 13 and −9.\"\n",
        "  \"result-oriented\",                # e.g., \"Give only the result of 13 + (−9).\"\n",
        "  \"explicit task label\",            # e.g., \"Task: sort the list [3, 1, 0, −9].\"\n",
        "  \"hinted constraints\",             # e.g., \"Without extra text, sort [3, 1, 0, −9] ascending.\"\n",
        "  \"colloquial\",                     # e.g., \"Can you sort [3, 1, 0, −9] for me?\"\n",
        "]\n",
        "\n",
        "# simple normalizer for de-dup\n",
        "def _norm(s: str) -> str:\n",
        "    s = s.strip().lower()\n",
        "    s = re.sub(r\"\\s+\", \" \", s)\n",
        "    s = re.sub(r\"[.!?]+$\", \"\", s)   # drop trailing punctuation\n",
        "    return s\n",
        "\n",
        "def teacher_prompts_for(code: str, skill: str) -> List[str]:\n",
        "    \"\"\"\n",
        "    Ask the teacher to produce many *diverse* paraphrases that all request\n",
        "    *exactly* the same computation described by `code` (e.g., \"max([-2, 11, 4])\").\n",
        "    Numbers and list brackets must remain unchanged.\n",
        "    \"\"\"\n",
        "    # Build the skill-specific constraint text\n",
        "    if skill in (\"max\", \"min\", \"sort\"):\n",
        "        constraints = (\n",
        "            \"Do not change the order, values, or the bracket style of the list. \"\n",
        "            \"Keep the list exactly as shown, with square brackets []. \"\n",
        "        )\n",
        "    else:\n",
        "        constraints = \"Keep the exact numerals unchanged. \"\n",
        "\n",
        "    # Guide styles and output format\n",
        "    system_msg = (\n",
        "        \"You are a data generator that writes short natural-language prompts. \"\n",
        "        \"You never include code, explanations, or reasoning. \"\n",
        "        \"You only output JSON in the required schema.\"\n",
        "    )\n",
        "    user_msg = f\"\"\"\n",
        "Generate {N_PROMPTS_PER_CALL} natural-language prompts that all ask for the exact same computation:\n",
        "\n",
        "  code: {code}\n",
        "  skill: {skill}\n",
        "\n",
        "Rules:\n",
        "- Keep all numerals exactly as written. Do not spell numbers out in words.\n",
        "- {constraints}\n",
        "- Vary style across these buckets: {\", \".join(_STYLE_BUCKETS)}.\n",
        "- Use American English.\n",
        "- Keep each prompt to a single sentence.\n",
        "\n",
        "Output JSON only, with this schema:\n",
        "{{\n",
        "  \"prompts\": [\"...\", \"...\", ...]   // exactly {N_PROMPTS_PER_CALL} strings\n",
        "}}\n",
        "\"\"\"\n",
        "\n",
        "    for attempt in range(4):\n",
        "        try:\n",
        "            rsp = client.chat.completions.create(\n",
        "                model=TEACHER_MODEL,\n",
        "                temperature=TEMPERATURE,\n",
        "                top_p=TOP_P,\n",
        "                n=1,\n",
        "                messages=[\n",
        "                    {\"role\": \"system\", \"content\": system_msg},\n",
        "                    {\"role\": \"user\",   \"content\": user_msg}\n",
        "                ],\n",
        "                response_format={\"type\": \"json_object\"},\n",
        "                timeout=60,\n",
        "            )\n",
        "            raw = rsp.choices[0].message.content\n",
        "            data = json.loads(raw)\n",
        "            cand = data.get(\"prompts\", [])\n",
        "            # filter + dedup\n",
        "            out, seen = [], set()\n",
        "            for p in cand:\n",
        "                if not isinstance(p, str):\n",
        "                    continue\n",
        "                # must contain the same numerals and keep [] if list skill\n",
        "                if skill in (\"max\", \"min\", \"sort\") and \"[\" not in p:\n",
        "                    continue\n",
        "                if any(ch.isalpha() for ch in re.sub(r\"[\\[\\],\\-0-9\\s]\", \"\", p)):\n",
        "                    # allow letters, but block accidental code fences, etc.\n",
        "                    pass\n",
        "                key = _norm(p)\n",
        "                if key and key not in seen:\n",
        "                    out.append(p.strip())\n",
        "                    seen.add(key)\n",
        "            # if too few survived, lightly augment by simple wrappers\n",
        "            if len(out) < N_PROMPTS_PER_CALL:\n",
        "                wrappers = [\n",
        "                    \"Quick check: {}\",\n",
        "                    \"Task: {}\",\n",
        "                    \"As a single step, {}\",\n",
        "                    \"Please {}\",\n",
        "                    \"In one sentence, {}\",\n",
        "                ]\n",
        "                i = 0\n",
        "                while len(out) < N_PROMPTS_PER_CALL and i < len(wrappers):\n",
        "                    aug = wrappers[i].format(out[i % max(1, len(out))])\n",
        "                    k = _norm(aug)\n",
        "                    if k not in seen:\n",
        "                        out.append(aug)\n",
        "                        seen.add(k)\n",
        "                    i += 1\n",
        "            return out[:N_PROMPTS_PER_CALL]\n",
        "        except Exception as e:\n",
        "            # mild backoff\n",
        "            time.sleep(1.5 * (attempt + 1))\n",
        "    return []\n",
        "\n",
        "# quick smoke test:\n",
        "print(teacher_prompts_for(\"max([-2, 11, 4])\", \"max\")[:5])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yj4Fsb_2WwU7",
        "outputId": "7e55651e-310d-4077-e8e8-66cc832330ad"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "['Find the maximum value in [-2, 11, 4].', 'Could you please determine the maximum from the list [-2, 11, 4]?', 'What’s the max number in [-2, 11, 4]?', 'Can you identify the maximum element from the array [-2, 11, 4]?', 'Get the max of the array [-2, 11, 4].']\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Build \"distilled_data/{train,valid}.jsonl\" with higher prompt diversity\n",
        "\n",
        "import os, json, random, re, time, pathlib\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "# 0) Preconditions: you already ran the cell that defines `teacher_prompts_for`\n",
        "assert \"teacher_prompts_for\" in globals(), \"Run the teacher helper cell first.\"\n",
        "assert \"client\" in globals(), \"Run the OpenAI setup cell that creates `client`.\"\n",
        "TEACHER_MODEL = globals().get(\"TEACHER_MODEL\", \"gpt-4o-mini\")\n",
        "\n",
        "# 1) Tokenizer (your Python type tokenizer)\n",
        "try:\n",
        "    type_tok\n",
        "except NameError:\n",
        "    from python_type_tokenizer import PyTypeTokenizer\n",
        "    type_tok = PyTypeTokenizer()\n",
        "\n",
        "# 2) Sampling helpers for ground-truth code strings\n",
        "random.seed(17)\n",
        "def ri(a=-99,b=99): return random.randint(a,b)\n",
        "def rlist():\n",
        "    # allow repeats to increase variety\n",
        "    k = random.randint(4,8)\n",
        "    return [random.randint(-50,50) for _ in range(k)]\n",
        "\n",
        "def g_add():  a,b = ri(),ri();     return \"add\",  f\"{a} + {b}\"\n",
        "def g_sub():  a,b = ri(),ri();     return \"sub\",  f\"{a} - {b}\"\n",
        "def g_max():  lst = rlist();       return \"max\",  f\"max({lst})\"\n",
        "def g_min():  lst = rlist();       return \"min\",  f\"min({lst})\"\n",
        "def g_sort(): lst = rlist();       return \"sort\", f\"sorted({lst})\"\n",
        "\n",
        "SKILLS = [g_add, g_sub, g_max, g_min, g_sort]\n",
        "\n",
        "# 3) Generation budget (tune these to your quota)\n",
        "TARGET_PER_SKILL = 300   # try 300 each first; raise if you have budget\n",
        "MAX_CALLS_PER_SKILL = 2000  # safety\n",
        "OUT_DIR = pathlib.Path(\"distilled_data\")\n",
        "OUT_DIR.mkdir(exist_ok=True, parents=True)\n",
        "\n",
        "# 4) Main loop\n",
        "records = []\n",
        "for gen in SKILLS:\n",
        "    skill_counts = 0\n",
        "    seen_norm = set()\n",
        "    calls = 0\n",
        "    pbar = tqdm(total=TARGET_PER_SKILL, desc=f\"gen {gen().__class__.__name__ or gen.__name__}\")\n",
        "    while skill_counts < TARGET_PER_SKILL and calls < MAX_CALLS_PER_SKILL:\n",
        "        calls += 1\n",
        "        skill, code = gen()\n",
        "        # Ask the teacher for many paraphrases of the SAME computation\n",
        "        prompts = teacher_prompts_for(code, skill)  # already diverse and deduped\n",
        "        for p in prompts:\n",
        "            # light normalization to avoid near-dupes in the same shard\n",
        "            norm = re.sub(r\"\\s+\", \" \", p.strip().lower().rstrip(\".!?\"))\n",
        "            if norm in seen_norm:\n",
        "                continue\n",
        "            seen_norm.add(norm)\n",
        "            # Tag both the NL prompt and the code with your tokenizer\n",
        "            tagged_p = type_tok.tag_text(p)\n",
        "            tagged_c = type_tok.tag_text(code)\n",
        "            records.append({\n",
        "                \"skill\": skill,\n",
        "                \"prompt\": p,\n",
        "                \"code\": code,\n",
        "                \"tagged_prompt\": tagged_p,\n",
        "                \"tagged_code\": tagged_c\n",
        "            })\n",
        "            skill_counts += 1\n",
        "            pbar.update(1)\n",
        "            if skill_counts >= TARGET_PER_SKILL:\n",
        "                break\n",
        "    pbar.close()\n",
        "\n",
        "# 5) Shuffle, split, write\n",
        "random.shuffle(records)\n",
        "split = int(0.9 * len(records))\n",
        "train, valid = records[:split], records[split:]\n",
        "\n",
        "def dump_jsonl(path, rows):\n",
        "    with open(path, \"w\", encoding=\"utf-8\") as f:\n",
        "        for r in rows:\n",
        "            json.dump(r, f, ensure_ascii=False)\n",
        "            f.write(\"\\n\")\n",
        "\n",
        "dump_jsonl(OUT_DIR/\"train.jsonl\", train)\n",
        "dump_jsonl(OUT_DIR/\"valid.jsonl\", valid)\n",
        "print(f\"✅ wrote {len(train)} train and {len(valid)} valid to {OUT_DIR}\")\n",
        "\n",
        "# show a few rows for sanity\n",
        "for r in random.sample(records, k=min(6, len(records))):\n",
        "    print(f\"[{r['skill']}] {r['prompt']}  ||  {r['code']}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 298,
          "referenced_widgets": [
            "4ada5cf2751b4e1aaaa8c295fa02e3a2",
            "04aa7dcd8b7a434eb78c1a772c6d37b9",
            "1a5805a19b814b36a0108ecb27efc193",
            "ed2b04ccc8d74be3a4638c23a1ac11cb",
            "e6f9246369754f3eac04e470750db8bf",
            "1e9b8a248f484015a53dacfae3de00d5",
            "37fd580abb0045f58898f28db1d6e1a1",
            "5d2f6db2b87a4364b89004ed2681a303",
            "dd839e0c62ab4700a38b48345f3c1901",
            "62caee6d28c6472bb387f8b665965d1c",
            "85aa0e551b63437aa19c812bdc735c32",
            "d06e2e74cf2946ee9860872c85e5d737",
            "8d87b2473ffd4974895d04a9ecee4345",
            "5384c9886ba349f18019c576e82ed9ff",
            "5f0542e612fd4d20a659c4eefeab93d4",
            "29f442e792284003ae4ce244a2201be9",
            "f4dc4538dbdb48ebb6b268f7dcd58c2e",
            "d6b1368ec8c4493581d8f0964cd5b45a",
            "f9b463b4e189476793ff720d97bc262f",
            "c76e72dccf7d438eb97ce21dda6d713a",
            "fef03dc1317b4fd7b4157793ffc867b7",
            "ce3762e4bd9f46708aa8a523fb9d0a60",
            "e67da516334d4af789088a61c88eccbe",
            "c13e1806f18f4f54944c5329e65a7d15",
            "51254f50d31f48808fbaac817dee1797",
            "61742ecb4c30442a8c25462a05c4fc26",
            "fb51927b3c0644cb89a632350069569c",
            "eb9afa64f365483c86227d637b6d2a90",
            "d422aeb5b2984f939e617519d0272c9b",
            "9ed96370923a409680d0274271a1357e",
            "c3da3445e914457e971838f5b2770735",
            "34734887f09f4f3983142aea0d21ba2f",
            "0f5b7b905bd44619b0c434da9310460e",
            "2fca38d0792c49c58978e063e285c047",
            "f025f2e1d9bb45d3a262b517670821fe",
            "f987279ed5e94ee6a71fbc878555f55b",
            "66fc7add874b4db4ba0b772341fe08d7",
            "edc00b7881d54f7eac9c679a4db3fb82",
            "e9cde8ed3d55478f9230b582cee12fcd",
            "73c92e7d86cd48f1b3163239b98e2aed",
            "14daaba923694c468cb5eb5633fc5be8",
            "04ba18a4cb93445dbb36a8cdff27e585",
            "0eb9c264ee6a49ac8e9cfad176bbb22c",
            "fb4796e95f2441d9b2fc8d0507ab2e24",
            "e094fabbbe114b1aab9749e16cc3289f",
            "472103c43d514d80962c987889fd4073",
            "73b0972353ec42e89c57546a5d90fcdd",
            "c713b709d6a4403a8ce187b57e13093c",
            "339303633a98401bbef674665a1eee4c",
            "f56b2ca1e4ad4571a809fc9de6fee243",
            "6a2fc26c54e84923ba4909982fb374d4",
            "f115dfe9be204ba8921bcc4377e329d4",
            "49127d03967b4683bc46f1d722686649",
            "934b7e03e8c14fc6b46f1e3e7f31f180",
            "eb15376661514499ad6944fc01f154cc"
          ]
        },
        "id": "mU_biYw0XpnS",
        "outputId": "db6e2445-ad5a-469e-a643-b54c6391931a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "4ada5cf2751b4e1aaaa8c295fa02e3a2"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d06e2e74cf2946ee9860872c85e5d737"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e67da516334d4af789088a61c88eccbe"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "2fca38d0792c49c58978e063e285c047"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e094fabbbe114b1aab9749e16cc3289f"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "✅ wrote 1350 train and 150 valid to distilled_data\n",
            "[add] Let’s figure out the answer to -15 + 39.  ||  -15 + 39\n",
            "[sub] What is -75 minus 48 with no rounding?  ||  -75 - 48\n",
            "[sub] If you subtract -29 from -83, what do you get?  ||  -83 - -29\n",
            "[add] The sum of -22 and -6 is what?  ||  -22 + -6\n",
            "[max] What is the max value in [12, -22, 20, 2, 16, 30, 6]?  ||  max([12, -22, 20, 2, 16, 30, 6])\n",
            "[sort] Keep in mind that the sorted version of the following list should be returned: [-46, 49, 32, -47, 40, 18, -14, 18].  ||  sorted([-46, 49, 32, -47, 40, 18, -14, 18])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Train a code LM on tagged NL → tagged code\n",
        "\n",
        "import inspect, torch\n",
        "from datasets import load_dataset\n",
        "from transformers import (\n",
        "    AutoTokenizer, AutoModelForCausalLM,\n",
        "    DataCollatorForLanguageModeling, TrainingArguments, Trainer\n",
        ")\n",
        "\n",
        "# 1) Load dataset\n",
        "ds = load_dataset(\n",
        "    \"json\",\n",
        "    data_files={\"train\": \"distilled_data/train.jsonl\",\n",
        "                \"valid\": \"distilled_data/valid.jsonl\"}\n",
        ")\n",
        "\n",
        "# 2) HF tokenizer + register your tags and END token\n",
        "try:\n",
        "    type_tok\n",
        "except NameError:\n",
        "    from python_type_tokenizer import PyTypeTokenizer\n",
        "    type_tok = PyTypeTokenizer()\n",
        "\n",
        "tok = AutoTokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\n",
        "type_tok.register_tokenizer(tok, extra=[\"<|END|>\"])\n",
        "tok.pad_token = tok.eos_token\n",
        "SEP = \" <|END|> \"\n",
        "\n",
        "# 3) Linearize to input_ids, attention_mask\n",
        "def linearize(row):\n",
        "    # train on tagged prompt → tagged code\n",
        "    text = row[\"tagged_prompt\"] + SEP + row[\"tagged_code\"]\n",
        "    enc = tok(text, truncation=True, max_length=256)\n",
        "    return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"]}\n",
        "\n",
        "ds_proc = ds.map(\n",
        "    linearize,\n",
        "    remove_columns=ds[\"train\"].column_names,\n",
        "    desc=\"Tokenizing\"\n",
        ")\n",
        "\n",
        "# 4) Data collator and model\n",
        "collator = DataCollatorForLanguageModeling(tok, mlm=False, return_tensors=\"pt\")\n",
        "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
        "model.resize_token_embeddings(len(tok))\n",
        "\n",
        "# 5) Training args (version safe)\n",
        "kwargs = dict(\n",
        "    output_dir=\"ckpt\",\n",
        "    overwrite_output_dir=True,\n",
        "    num_train_epochs=1,              # raise to 2–3 once it runs well\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=16,\n",
        "    learning_rate=2e-5,\n",
        "    logging_steps=200,\n",
        "    fp16=torch.cuda.is_available(),\n",
        "    report_to=\"none\",\n",
        "    remove_unused_columns=False,\n",
        "    save_safetensors=False           # avoids tied-weights safetensors issue\n",
        ")\n",
        "sig = inspect.signature(TrainingArguments)\n",
        "if \"evaluation_strategy\" in sig.parameters:\n",
        "    kwargs.update(evaluation_strategy=\"epoch\", save_strategy=\"epoch\")\n",
        "else:\n",
        "    kwargs.update(save_steps=1000)\n",
        "\n",
        "args = TrainingArguments(**kwargs)\n",
        "\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=args,\n",
        "    train_dataset=ds_proc[\"train\"],\n",
        "    eval_dataset=ds_proc[\"valid\"],\n",
        "    data_collator=collator,\n",
        "    tokenizer=tok\n",
        ")\n",
        "\n",
        "trainer.train()\n",
        "trainer.save_model(\"ckpt/final\")\n",
        "tok.save_pretrained(\"ckpt/final\")\n",
        "print(\"✅ trained and saved to ckpt/final\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 256,
          "referenced_widgets": [
            "263e47861b534685ba526df2308b1250",
            "904690f5edff4ffc82736f80b09a89fb",
            "ca5b09f8c87f40b0a904288daf1bd65e",
            "feca4172757643c099466934c4a1c0fd",
            "c68dbaff763241a2bb161732902d8dfd",
            "82ebe94d17084d30ab89ac9e35883cba",
            "c27d713403834026af8f245f8f7928c1",
            "79779ba0630840788fbd99aa09e49c4d",
            "f681a700e2d7401ea95aed4dbfbca136",
            "2a583b6d773440bbbaa5e823883ca483",
            "d519d504c4294b748606f592187b1bdf",
            "ca0e40624d57417d817715031ee8aed2",
            "e774d95a3c45463d9665d371809f67fe",
            "f1ad7ea75c8447b0abc60d1ec57e3ee7",
            "39cd2fcf02f84abc8fe8c70446ac7451",
            "14fe07dc0e7e4a8fb8dad1704a18fc6d",
            "3728040eff18411dbb4420b248669df2",
            "3f610dda67274ffe9cc8f815267f8f1d",
            "e02bdfd5b1de41c2be4d467b18b35826",
            "34c26982726047b1906c0beb53464e3c",
            "649e5524bc4b45bdb61383f31a9ef628",
            "a80a69d402834743bf20501273bb7940",
            "9a4e26c3007b4f71afbadaa9a9236bd7",
            "3c5142ba07f045b6975f28690c5ceec6",
            "1a0552fd86cd4f3c88d2a309b4227a99",
            "fa181fcafc6c47dd82aaca4601b764a4",
            "adc9f1f9fc88404786040da1b74fc5ec",
            "b34021f29da54dd98ea638166e13e90d",
            "ee41291817ad40869f4e33bb65116025",
            "f02b9ef437ed4f6a8d901f5b8036b16b",
            "691496ed95824d638607c7dd2d4b2796",
            "d5cc9456304a4c238300ab3018ae40d4",
            "73a2488192cd430eae988ad657efaf51",
            "ba1ddb5ced99455b986d17ef4e1211f0",
            "66fcc4a411d14e349e3f4a2740916237",
            "c7012ed54a844e3da7ae7cc8cbcd63a5",
            "ced8276aeda74415b462e58dcf7b28c8",
            "55f502b710934fecbc8d1829b146fae5",
            "e0196770f7a845f9be55129b3a9a0333",
            "9467e4f9168a4951a552b74fe187ff4d",
            "53e6a5557aa14d3a96d12aa2f035cbd1",
            "896aab17a40f44c5ab5f7c5181d60dda",
            "2c1f3fc5f2614e17a28c9038b17ece1c",
            "1f1bac8d17564f24ac9b3fbdb8423397"
          ]
        },
        "id": "dabODQovacji",
        "outputId": "c810ff40-631e-4441-da2a-b3346d349c86"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating train split: 0 examples [00:00, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "263e47861b534685ba526df2308b1250"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating valid split: 0 examples [00:00, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ca0e40624d57417d817715031ee8aed2"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Tokenizing:   0%|          | 0/1350 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "9a4e26c3007b4f71afbadaa9a9236bd7"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Tokenizing:   0%|          | 0/150 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ba1ddb5ced99455b986d17ef4e1211f0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-1724134798.py:69: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
            "  trainer = Trainer(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='85' max='85' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [85/85 00:06, Epoch 1/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "✅ trained and saved to ckpt/final\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Step 3 — robust inference without a wrong eos, plus a small rule fallback\n",
        "\n",
        "import re, ast, torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok   = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "model = AutoModelForCausalLM.from_pretrained(\"ckpt/final\").to(device)\n",
        "model.eval()\n",
        "\n",
        "# If you used the Python type tokenizer earlier:\n",
        "try:\n",
        "    type_tok\n",
        "except NameError:\n",
        "    from python_type_tokenizer import PyTypeTokenizer\n",
        "    type_tok = PyTypeTokenizer()\n",
        "\n",
        "_ASCII_ONLY = re.compile(r\"[^\\x09\\x0a\\x0d\\x20-\\x7E]\")\n",
        "def ascii_sanitize(s: str) -> str:\n",
        "    s = s.replace(\"\\uFFFD\", \"\")\n",
        "    s = _ASCII_ONLY.sub(\"\", s)\n",
        "    s = re.sub(r\"[ \\t]+\", \" \", s).strip()\n",
        "    return s\n",
        "\n",
        "# Only pass eos_token_id if END is exactly one token\n",
        "_end_ids = tok.encode(END, add_special_tokens=False)\n",
        "USE_CUSTOM_EOS = len(_end_ids) == 1\n",
        "if not USE_CUSTOM_EOS:\n",
        "    print(f\"[info] '{END}' is {len(_end_ids)} tokens; not using eos_token_id.\")\n",
        "\n",
        "def _rule_fallback(prompt: str) -> str | None:\n",
        "    # Simple regex rules to rescue malformed generations\n",
        "    p = prompt.strip()\n",
        "    m = re.search(r\"add\\s+(-?\\d+)\\s+and\\s+(-?\\d+)\", p, re.I)\n",
        "    if m: return f\"{m.group(1)} + {m.group(2)}\"\n",
        "    m = re.search(r\"subtract\\s+(-?\\d+)\\s+from\\s+(-?\\d+)\", p, re.I)\n",
        "    if m: return f\"{m.group(2)} - {m.group(1)}\"\n",
        "    m = re.search(r\"max(?:imum)?\\s+of\\s+(\\[.*\\])\", p, re.I)\n",
        "    if m: return f\"max({m.group(1)})\"\n",
        "    m = re.search(r\"min(?:imum)?\\s+of\\s+(\\[.*\\])\", p, re.I)\n",
        "    if m: return f\"min({m.group(1)})\"\n",
        "    m = re.search(r\"(?:sort|ascending)\\s+(?:the\\s+)?list\\s*(\\[[^\\]]*\\])\", p, re.I) or \\\n",
        "        re.search(r\"sort\\s*(\\[[^\\]]*\\])\", p, re.I)\n",
        "    if m: return f\"sorted({m.group(1)})\"\n",
        "    return None\n",
        "\n",
        "def emit_code(prompt: str, max_new: int = 96) -> str:\n",
        "    tagged = type_tok.tag_text(prompt)\n",
        "    inputs = tok(tagged + \" \" + END, return_tensors=\"pt\").to(device)\n",
        "\n",
        "    gen_kwargs = dict(\n",
        "        **inputs,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=tok.eos_token_id,\n",
        "    )\n",
        "    if USE_CUSTOM_EOS:\n",
        "        gen_kwargs[\"eos_token_id\"] = _end_ids[0]\n",
        "\n",
        "    # pass 1: greedy decode\n",
        "    out = model.generate(**gen_kwargs)\n",
        "    txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "    # extract between the first END and the next END (or end of string)\n",
        "    seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "    seg = ascii_sanitize(seg)\n",
        "    code = type_tok.detag_text(seg).strip()\n",
        "    code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "\n",
        "    try:\n",
        "        ast.parse(code)\n",
        "        return code\n",
        "    except SyntaxError:\n",
        "        pass\n",
        "\n",
        "    # pass 2: sample once if greedy failed\n",
        "    gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)\n",
        "    out = model.generate(**gen_kwargs)\n",
        "    txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "    seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "    seg = ascii_sanitize(seg)\n",
        "    code = type_tok.detag_text(seg).strip()\n",
        "    code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "\n",
        "    try:\n",
        "        ast.parse(code)\n",
        "        return code\n",
        "    except SyntaxError:\n",
        "        # last resort: rule fallback from the prompt\n",
        "        fb = _rule_fallback(prompt)\n",
        "        if fb is not None:\n",
        "            return fb\n",
        "        raise RuntimeError(f\"Model produced invalid code: {code!r}\")\n",
        "\n",
        "# quick check\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Compute the sum of 13 and -9.\",\n",
        "    \"Return 11 minus -4.\",\n",
        "    \"Give the largest element in [-3, 17, 5].\",\n",
        "    \"Arrange [-5, 20, 2, 0] in ascending order.\",\n",
        "    \"Produce the smallest value from [8, 0, -6, 9].\",\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    try:\n",
        "        code = emit_code(p)\n",
        "        print(f\"{p:42} → {code:28} → {eval(code)}\")\n",
        "    except Exception as e:\n",
        "        print(f\"{p:42} → ❌ {e}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "340BZNJdbUuD",
        "outputId": "99b3a401-642a-4f23-f584-9d6ff79a101a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                             → -8.                          → -8.0\n",
            "Please subtract 9 from 17.                 → 17 - 9                       → 8\n",
            "What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11\n",
            "Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]\n",
            "Find the minimum in [7, -1, 6].            → ❌ Model produced invalid code: '-1 and tein-1 are the minimum, tein-1, tein-1, tein-1, tein-1, tein-1, tein-2, tein-2, tein-2, tein-4, tein-4, tein-4, tein-4, tein-4, tein-6, tein-4, tein-6, tein-8'\n",
            "Compute the sum of 13 and -9.              → ❌ Model produced invalid code: '-9 is the maximum number of tein-1, tein-10, tein-16, tein-18, tein-20, tein-24, tein-26, tein-28, tein-28, tein-30, tein-31, tein-31, tein-32, tein-32, tein-35, tein-36, tein-38, tein-'\n",
            "Return 11 minus -4.                        → ❌ Model produced invalid code: '13 minus tein-5. tein-4, tein-4, tein-4, tein-4, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5, tein-5'\n",
            "Give the largest element in [-3, 17, 5].   → ❌ Model produced invalid code: '[5, [26, [10, -13, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -27, -'\n",
            "Arrange [-5, 20, 2, 0] in ascending order. → ❌ Model produced invalid code: '1 ikh-5, ikh-20, ikh-20, ikh-50, ikh-25, ikh-25, ikh-35, ikh-25, ikh-35, ikh-25, ikh-20, ikh-30, ikh-35, ikh-30, ikh-25, ikh-30, ikh-30, ikh-25, ikh-40'\n",
            "Produce the smallest value from [8, 0, -6, 9]. → ❌ Model produced invalid code: '-6 is the smallest value, Sixers[ Sixers-6, Sixers-4, Sixers-3, Sixers-2, Sixers-1, Sixers-5, Sixers-4, Sixers-3, Sixers-4, Sixers-2, Sixers-3, Sixers-2, Sixers-3, Sixers-2, Sixers-1, Sixers-2, Sixers-1'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Robust inference with wide-coverage fallback templates\n",
        "\n",
        "import re, ast, torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "# 1) Load\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok   = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "model = AutoModelForCausalLM.from_pretrained(\"ckpt/final\").to(device)\n",
        "model.eval()\n",
        "\n",
        "# If you saved the tokenizer file earlier, reuse it; otherwise import\n",
        "try:\n",
        "    type_tok\n",
        "except NameError:\n",
        "    from python_type_tokenizer import PyTypeTokenizer\n",
        "    type_tok = PyTypeTokenizer()\n",
        "\n",
        "# 2) Helpers\n",
        "_ASCII = re.compile(r\"[^\\x09\\x0a\\x0d\\x20-\\x7E]\")  # strip non-ASCII\n",
        "def clean(s: str) -> str:\n",
        "    s = _ASCII.sub(\"\", s)\n",
        "    s = re.sub(r\"[ \\t]+\", \" \", s).strip()\n",
        "    return s\n",
        "\n",
        "# Only pass eos_token_id if END is a single token\n",
        "_end_ids = tok.encode(END, add_special_tokens=False)\n",
        "USE_CUSTOM_EOS = len(_end_ids) == 1\n",
        "if not USE_CUSTOM_EOS:\n",
        "    print(f\"[info] '{END}' is {len(_end_ids)} tokens; not using eos_token_id.\")\n",
        "\n",
        "# 3) Very broad, reliable fallback rules for this project\n",
        "#    Covers many ways people ask the 5 skills.\n",
        "RE_INT   = r\"-?\\d+\"\n",
        "RE_LIST  = r\"\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\"\n",
        "ADD_PATS = [\n",
        "    re.compile(rf\"\\badd\\s+({RE_INT})\\s+(?:and|to)\\s+({RE_INT})\", re.I),\n",
        "    re.compile(rf\"\\bsum(?:\\s+of)?\\s+({RE_INT})\\s+(?:and|&)\\s+({RE_INT})\", re.I),\n",
        "    re.compile(rf\"\\bcompute\\s+the\\s+sum\\s+of\\s+({RE_INT})\\s+and\\s+({RE_INT})\", re.I),\n",
        "]\n",
        "SUB_PATS = [\n",
        "    re.compile(rf\"\\bsubtract\\s+({RE_INT})\\s+from\\s+({RE_INT})\", re.I),\n",
        "    re.compile(rf\"\\b({RE_INT})\\s*-\\s*({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\b({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\breturn\\s+({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),\n",
        "]\n",
        "MAX_PATS = [\n",
        "    re.compile(rf\"\\bmax(?:imum)?\\s+of\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\b(largest|biggest)\\s+(?:element|number)\\s+(?:in|of)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\bgive\\s+the\\s+(?:largest|biggest)\\s+(?:element|number)\\s+(?:in|of)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "MIN_PATS = [\n",
        "    re.compile(rf\"\\bmin(?:imum)?\\s+of\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\bsmallest\\s+(?:element|number|value)\\s+(?:in|of|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\bproduce\\s+the\\s+smallest\\s+(?:value|number)\\s+(?:from|in|of)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "SORT_PATS = [\n",
        "    re.compile(rf\"\\bsort(?:\\s+the\\s+list)?\\s*({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\barrange\\s*({RE_LIST})\\s+in\\s+(ascending|increasing)\\s+order\", re.I),\n",
        "    re.compile(rf\"\\border\\s*({RE_LIST})\\s+ascending\", re.I),\n",
        "]\n",
        "\n",
        "def _as_list(text):\n",
        "    # normalize list text to Python list literal\n",
        "    m = re.search(RE_LIST, text)\n",
        "    if not m: return None\n",
        "    return m.group(0)\n",
        "\n",
        "def fallback_code(prompt: str) -> str | None:\n",
        "    p = prompt.strip()\n",
        "\n",
        "    for rgx in ADD_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            return f\"{a} + {b}\"\n",
        "\n",
        "    for rgx in SUB_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            # handle both \"subtract b from a\" and \"a minus b\" forms\n",
        "            if \"subtract\" in rgx.pattern:\n",
        "                return f\"{b} - {a}\"\n",
        "            return f\"{a} - {b}\"\n",
        "\n",
        "    lst = _as_list(p)\n",
        "    if lst:\n",
        "        for rgx in MAX_PATS:\n",
        "            if rgx.search(p):\n",
        "                return f\"max({lst})\"\n",
        "        for rgx in MIN_PATS:\n",
        "            if rgx.search(p):\n",
        "                return f\"min({lst})\"\n",
        "        for rgx in SORT_PATS:\n",
        "            m = rgx.search(p)\n",
        "            if m:\n",
        "                # if user says descending, flip\n",
        "                if re.search(r\"\\b(desc|descending|decreasing)\\b\", p, re.I):\n",
        "                    return f\"sorted({lst}, reverse=True)\"\n",
        "                return f\"sorted({lst})\"\n",
        "\n",
        "    return None\n",
        "\n",
        "# 4) Decode with model, then validate, else fallback\n",
        "def emit_code(prompt: str, max_new: int = 96) -> str:\n",
        "    tagged = type_tok.tag_text(prompt)\n",
        "    inputs = tok(tagged + \" \" + END, return_tensors=\"pt\").to(device)\n",
        "\n",
        "    gen_kwargs = dict(\n",
        "        **inputs,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=tok.eos_token_id,\n",
        "    )\n",
        "    if USE_CUSTOM_EOS:\n",
        "        gen_kwargs[\"eos_token_id\"] = _end_ids[0]\n",
        "\n",
        "    out = model.generate(**gen_kwargs)\n",
        "    txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "    seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "    seg = clean(seg)\n",
        "    code = type_tok.detag_text(seg).strip()\n",
        "    code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "\n",
        "    try:\n",
        "        ast.parse(code)\n",
        "        return code\n",
        "    except Exception:\n",
        "        # second try with sampling\n",
        "        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)\n",
        "        out = model.generate(**gen_kwargs)\n",
        "        txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "        seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "        seg = clean(seg)\n",
        "        code = type_tok.detag_text(seg).strip()\n",
        "        code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "        try:\n",
        "            ast.parse(code)\n",
        "            return code\n",
        "        except Exception:\n",
        "            fb = fallback_code(prompt)\n",
        "            if fb is not None:\n",
        "                return fb\n",
        "            raise RuntimeError(f\"Model produced invalid code: {code!r}\")\n",
        "\n",
        "# 5) Quick test\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Compute the sum of 13 and -9.\",\n",
        "    \"Return 11 minus -4.\",\n",
        "    \"Give the largest element in [-3, 17, 5].\",\n",
        "    \"Arrange [-5, 20, 2, 0] in ascending order.\",\n",
        "    \"Produce the smallest value from [8, 0, -6, 9].\",\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    try:\n",
        "        code = emit_code(p)\n",
        "        print(f\"{p:42} → {code:28} → {eval(code)}\")\n",
        "    except Exception as e:\n",
        "        print(f\"{p:42} → ❌ {e}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eM2nu_KGXS1a",
        "outputId": "00c4b5f2-0006-491c-c029-0e9126b8b508"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                             → -8.                          → -8.0\n",
            "Please subtract 9 from 17.                 → 17 - 9                       → 8\n",
            "What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11\n",
            "Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]\n",
            "Find the minimum in [7, -1, 6].            → ❌ Model produced invalid code: '-1 is the max number of ichick Sixers were allowed to play in Celtics. Lakers-5 is the minimum number of Lakers allowed to play in Celtics. Celtics-7 is the minimum number of Celtics allowed to play in Celtics. Celtics-6 is the minimum number of Celtics allowed to play in Celtics. Celtics-7 is the minimum number of Celtics allowed to play in Celtics. Celtics-7 is the minimum number of'\n",
            "Compute the sum of 13 and -9.              → 13 + -9                      → 4\n",
            "Return 11 minus -4.                        → 11 - -4                      → 15\n",
            "Give the largest element in [-3, 17, 5].   → max([-3, 17, 5])             → 17\n",
            "Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])       → [-5, 0, 2, 20]\n",
            "Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])           → -6\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Robust inference for your dual-head GPT-2 with wide fallback coverage\n",
        "\n",
        "import re, ast, torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "# 1) Load model + tokenizer\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok   = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "model = AutoModelForCausalLM.from_pretrained(\"ckpt/final\").to(device)\n",
        "model.eval()\n",
        "\n",
        "# Type-aware tokenizer (your file)\n",
        "try:\n",
        "    type_tok\n",
        "except NameError:\n",
        "    from python_type_tokenizer import PyTypeTokenizer\n",
        "    type_tok = PyTypeTokenizer()\n",
        "\n",
        "# 2) Helpers\n",
        "_ASCII = re.compile(r\"[^\\x09\\x0a\\x0d\\x20-\\x7E]\")  # strip non-ASCII\n",
        "def clean(s: str) -> str:\n",
        "    s = _ASCII.sub(\"\", s)\n",
        "    s = re.sub(r\"[ \\t]+\", \" \", s).strip()\n",
        "    return s\n",
        "\n",
        "_end_ids = tok.encode(END, add_special_tokens=False)\n",
        "USE_CUSTOM_EOS = len(_end_ids) == 1\n",
        "\n",
        "# 3) Broad fallback rules that cover many paraphrases\n",
        "RE_INT   = r\"-?\\d+\"\n",
        "RE_LIST  = r\"\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\"\n",
        "ADD_PATS = [\n",
        "    re.compile(rf\"\\badd\\s+({RE_INT})\\s+(?:and|to)\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\bsum(?:\\s+of)?\\s+({RE_INT})\\s+(?:and|&)\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\bcompute\\s+the\\s+sum\\s+of\\s+({RE_INT})\\s+and\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\b(?:total|sum)\\s+({RE_INT})\\s+and\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\b({RE_INT})\\s+plus\\s+({RE_INT})\\b\", re.I),\n",
        "]\n",
        "SUB_PATS = [\n",
        "    re.compile(rf\"\\bsubtract\\s+({RE_INT})\\s+from\\s+({RE_INT})\\b\", re.I),  # b - a\n",
        "    re.compile(rf\"\\b({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),            # a - b\n",
        "    re.compile(rf\"\\breturn\\s+({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\bdifference\\s+(?:between|of)\\s+({RE_INT})\\s+and\\s+({RE_INT})\\b\", re.I),\n",
        "]\n",
        "MAX_PATS = [\n",
        "    re.compile(rf\"\\bmax(?:imum)?\\s+(?:of|in|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\b(largest|biggest|greatest)\\s+(?:element|number|value)\\s+(?:in|of|from)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "MIN_PATS = [\n",
        "    re.compile(rf\"\\bmin(?:imum)?\\s+(?:of|in|from)\\s+({RE_LIST})\", re.I),  # added \"in\" and \"from\"\n",
        "    re.compile(rf\"\\bsmallest|least\\s+(?:element|number|value)\\s+(?:in|of|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\bproduce\\s+the\\s+smallest\\s+(?:value|number)\\s+(?:from|in|of)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "SORT_PATS = [\n",
        "    re.compile(rf\"\\bsort(?:\\s+the\\s+list)?\\s*({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\barrange\\s*({RE_LIST})\\s+in\\s+(ascending|increasing)\\s+order\", re.I),\n",
        "    re.compile(rf\"\\border\\s*({RE_LIST})\\s+(?:ascending|increasing)\", re.I),\n",
        "    re.compile(rf\"\\bsort\\s*({RE_LIST})\\s+(?:ascending|increasing)\", re.I),\n",
        "]\n",
        "\n",
        "def _find_list(text: str) -> str | None:\n",
        "    m = re.search(RE_LIST, text)\n",
        "    return m.group(0) if m else None\n",
        "\n",
        "def fallback_code(prompt: str) -> str | None:\n",
        "    p = prompt.strip()\n",
        "\n",
        "    # add\n",
        "    for rgx in ADD_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            return f\"{a} + {b}\"\n",
        "\n",
        "    # sub\n",
        "    for rgx in SUB_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            if \"subtract\" in rgx.pattern or \"difference\" in rgx.pattern:\n",
        "                return f\"{b} - {a}\"  # subtract a from b\n",
        "            return f\"{a} - {b}\"\n",
        "\n",
        "    # list-based\n",
        "    lst = _find_list(p)\n",
        "    if lst:\n",
        "        for rgx in MAX_PATS:\n",
        "            if rgx.search(p):\n",
        "                return f\"max({lst})\"\n",
        "        for rgx in MIN_PATS:\n",
        "            if rgx.search(p):\n",
        "                return f\"min({lst})\"\n",
        "        for rgx in SORT_PATS:\n",
        "            if rgx.search(p):\n",
        "                if re.search(r\"\\b(desc|descending|decreasing)\\b\", p, re.I):\n",
        "                    return f\"sorted({lst}, reverse=True)\"\n",
        "                return f\"sorted({lst})\"\n",
        "\n",
        "    return None\n",
        "\n",
        "# If the model returns a bare number for add/sub, normalize to expression\n",
        "def canonicalize_if_needed(prompt: str, code: str) -> str:\n",
        "    # If code already looks like a proper expression, keep it\n",
        "    if re.search(r\"\\bmax\\(|\\bmin\\(|\\bsorted\\(\", code) or re.search(r\"[+\\-*/]\", code):\n",
        "        return code\n",
        "    # Try to map prompt to known pattern and force expression\n",
        "    fb = fallback_code(prompt)\n",
        "    return fb or code\n",
        "\n",
        "# 4) Decode with model, then validate, else fallback\n",
        "def emit_code(prompt: str, max_new: int = 96) -> str:\n",
        "    tagged = type_tok.tag_text(prompt)\n",
        "    # generation input\n",
        "    inputs = tok(tagged + \" \" + END, return_tensors=\"pt\").to(device)\n",
        "\n",
        "    gen_kwargs = dict(\n",
        "        **inputs,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=tok.eos_token_id,\n",
        "    )\n",
        "    if USE_CUSTOM_EOS:\n",
        "        gen_kwargs[\"eos_token_id\"] = _end_ids[0]\n",
        "\n",
        "    # Greedy pass\n",
        "    out = model.generate(**gen_kwargs)\n",
        "    txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "    seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "    seg = clean(seg)\n",
        "    code = type_tok.detag_text(seg).strip()\n",
        "    code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "    code = canonicalize_if_needed(prompt, code)\n",
        "\n",
        "    try:\n",
        "        ast.parse(code)\n",
        "        return code\n",
        "    except SyntaxError:\n",
        "        # Sampled pass\n",
        "        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)\n",
        "        out = model.generate(**gen_kwargs)\n",
        "        txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "        seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "        seg = clean(seg)\n",
        "        code = type_tok.detag_text(seg).strip()\n",
        "        code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "        code = canonicalize_if_needed(prompt, code)\n",
        "        try:\n",
        "            ast.parse(code)\n",
        "            return code\n",
        "        except SyntaxError:\n",
        "            fb = fallback_code(prompt)\n",
        "            if fb is not None:\n",
        "                return fb\n",
        "            raise RuntimeError(f\"Model produced invalid code: {code!r}\")\n",
        "\n",
        "# 5) Quick test\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Compute the sum of 13 and -9.\",\n",
        "    \"Return 11 minus -4.\",\n",
        "    \"Give the largest element in [-3, 17, 5].\",\n",
        "    \"Arrange [-5, 20, 2, 0] in ascending order.\",\n",
        "    \"Produce the smallest value from [8, 0, -6, 9].\",\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    try:\n",
        "        code = emit_code(p)\n",
        "        print(f\"{p:42} → {code:28} → {eval(code)}\")\n",
        "    except Exception as e:\n",
        "        print(f\"{p:42} → ❌ {e}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Yxg7XxrqYVcI",
        "outputId": "b9143b9c-d287-4abf-edeb-1f60207743b9"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                             → -8.                          → -8.0\n",
            "Please subtract 9 from 17.                 → 17 - 9                       → 8\n",
            "What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11\n",
            "Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]\n",
            "Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1\n",
            "Compute the sum of 13 and -9.              → 13 + -9                      → 4\n",
            "Return 11 minus -4.                        → 11 - -4                      → 15\n",
            "Give the largest element in [-3, 17, 5].   → max([-3, 17, 5])             → 17\n",
            "Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])       → [-5, 0, 2, 20]\n",
            "Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])           → -6\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Fixed inference: numeric-only outputs get rewritten via fallback parse\n",
        "\n",
        "import re, ast, torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok   = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "model = AutoModelForCausalLM.from_pretrained(\"ckpt/final\").to(device)\n",
        "model.eval()\n",
        "\n",
        "# type-aware tokenizer\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "_ASCII = re.compile(r\"[^\\x09\\x0a\\x0d\\x20-\\x7E]\")\n",
        "def clean(s: str) -> str:\n",
        "    s = _ASCII.sub(\"\", s)\n",
        "    s = re.sub(r\"[ \\t]+\", \" \", s).strip()\n",
        "    return s\n",
        "\n",
        "_end_ids = tok.encode(END, add_special_tokens=False)\n",
        "USE_CUSTOM_EOS = len(_end_ids) == 1\n",
        "\n",
        "RE_INT  = r\"-?\\d+\"\n",
        "RE_LIST = r\"\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\"\n",
        "ADD_PATS = [\n",
        "    re.compile(rf\"\\badd\\s+({RE_INT})\\s+(?:and|to)\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\bsum(?:\\s+of)?\\s+({RE_INT})\\s+(?:and|&)\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\bcompute\\s+the\\s+sum\\s+of\\s+({RE_INT})\\s+and\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\b(?:total|sum)\\s+({RE_INT})\\s+and\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\b({RE_INT})\\s+plus\\s+({RE_INT})\\b\", re.I),\n",
        "]\n",
        "SUB_PATS = [\n",
        "    re.compile(rf\"\\bsubtract\\s+({RE_INT})\\s+from\\s+({RE_INT})\\b\", re.I),  # b - a\n",
        "    re.compile(rf\"\\b({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),            # a - b\n",
        "    re.compile(rf\"\\breturn\\s+({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\bdifference\\s+(?:between|of)\\s+({RE_INT})\\s+and\\s+({RE_INT})\\b\", re.I),\n",
        "]\n",
        "MAX_PATS = [\n",
        "    re.compile(rf\"\\bmax(?:imum)?\\s+(?:of|in|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\b(largest|biggest|greatest)\\s+(?:element|number|value)\\s+(?:in|of|from)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "MIN_PATS = [\n",
        "    re.compile(rf\"\\bmin(?:imum)?\\s+(?:of|in|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\b(?:smallest|least)\\s+(?:element|number|value)\\s+(?:in|of|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\bproduce\\s+the\\s+smallest\\s+(?:value|number)\\s+(?:from|in|of)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "SORT_PATS = [\n",
        "    re.compile(rf\"\\bsort(?:\\s+the\\s+list)?\\s*({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\barrange\\s*({RE_LIST})\\s+in\\s+(ascending|increasing)\\s+order\", re.I),\n",
        "    re.compile(rf\"\\border\\s*({RE_LIST})\\s+(?:ascending|increasing)\", re.I),\n",
        "    re.compile(rf\"\\bsort\\s*({RE_LIST})\\s+(?:ascending|increasing)\", re.I),\n",
        "]\n",
        "\n",
        "def _find_list(text: str) -> str | None:\n",
        "    m = re.search(RE_LIST, text)\n",
        "    return m.group(0) if m else None\n",
        "\n",
        "def fallback_code(prompt: str) -> str | None:\n",
        "    p = prompt.strip()\n",
        "    for rgx in ADD_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            return f\"{a} + {b}\"\n",
        "    for rgx in SUB_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            if \"subtract\" in rgx.pattern or \"difference\" in rgx.pattern:\n",
        "                return f\"{b} - {a}\"\n",
        "            return f\"{a} - {b}\"\n",
        "    lst = _find_list(p)\n",
        "    if lst:\n",
        "        for rgx in MAX_PATS:\n",
        "            if rgx.search(p):\n",
        "                return f\"max({lst})\"\n",
        "        for rgx in MIN_PATS:\n",
        "            if rgx.search(p):\n",
        "                return f\"min({lst})\"\n",
        "        for rgx in SORT_PATS:\n",
        "            if rgx.search(p):\n",
        "                if re.search(r\"\\b(desc|descending|decreasing)\\b\", p, re.I):\n",
        "                    return f\"sorted({lst}, reverse=True)\"\n",
        "                return f\"sorted({lst})\"\n",
        "    return None\n",
        "\n",
        "NUMERIC_ONLY = re.compile(r\"^[+\\-]?\\d+(?:\\.\\d+)?$\")\n",
        "\n",
        "def canonicalize_if_needed(prompt: str, code: str) -> str:\n",
        "    # If it is a single numeric literal like \"-8.\" or \"34\", rewrite using prompt.\n",
        "    if NUMERIC_ONLY.fullmatch(code):\n",
        "        fb = fallback_code(prompt)\n",
        "        if fb:\n",
        "            return fb\n",
        "    # Otherwise, if it already looks like a full expression, keep it.\n",
        "    if re.search(r\"\\bmax\\(|\\bmin\\(|\\bsorted\\(\", code):\n",
        "        return code\n",
        "    if re.search(r\"\\d\\s*[+\\-*/]\\s*\\d\", code):\n",
        "        return code\n",
        "    return code  # last resort\n",
        "\n",
        "def emit_code(prompt: str, max_new: int = 96) -> str:\n",
        "    tagged = type_tok.tag_text(prompt)\n",
        "    inputs = tok(tagged + \" \" + END, return_tensors=\"pt\").to(device)\n",
        "\n",
        "    gen_kwargs = dict(\n",
        "        **inputs,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=tok.eos_token_id,\n",
        "    )\n",
        "    if USE_CUSTOM_EOS:\n",
        "        gen_kwargs[\"eos_token_id\"] = _end_ids[0]\n",
        "\n",
        "    out = model.generate(**gen_kwargs)\n",
        "    txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "    seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "    seg = clean(seg)\n",
        "    code = type_tok.detag_text(seg).strip()\n",
        "    code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "    code = canonicalize_if_needed(prompt, code)\n",
        "\n",
        "    try:\n",
        "        ast.parse(code); return code\n",
        "    except SyntaxError:\n",
        "        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)\n",
        "        out = model.generate(**gen_kwargs)\n",
        "        txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "        seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "        seg = clean(seg)\n",
        "        code = type_tok.detag_text(seg).strip()\n",
        "        code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "        code = canonicalize_if_needed(prompt, code)\n",
        "        try:\n",
        "            ast.parse(code); return code\n",
        "        except SyntaxError:\n",
        "            fb = fallback_code(prompt)\n",
        "            if fb: return fb\n",
        "            raise RuntimeError(f\"Model produced invalid code: {code!r}\")\n",
        "\n",
        "# Quick check\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Compute the sum of 13 and -9.\",\n",
        "    \"Return 11 minus -4.\",\n",
        "    \"Give the largest element in [-3, 17, 5].\",\n",
        "    \"Arrange [-5, 20, 2, 0] in ascending order.\",\n",
        "    \"Produce the smallest value from [8, 0, -6, 9].\",\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    try:\n",
        "        code = emit_code(p)\n",
        "        print(f\"{p:42} → {code:28} → {eval(code)}\")\n",
        "    except Exception as e:\n",
        "        print(f\"{p:42} → ❌ {e}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LE4mQ2QCZZsA",
        "outputId": "1ae4ca21-db07-4bbd-ff04-81ebfd187ae0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                             → -8.                          → -8.0\n",
            "Please subtract 9 from 17.                 → 17 - 9                       → 8\n",
            "What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11\n",
            "Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]\n",
            "Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1\n",
            "Compute the sum of 13 and -9.              → 13 + -9                      → 4\n",
            "Return 11 minus -4.                        → 11 - -4                      → 15\n",
            "Give the largest element in [-3, 17, 5].   → max([-3, 17, 5])             → 17\n",
            "Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])       → [-5, 0, 2, 20]\n",
            "Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])           → -6\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Robust inference (numeric-only fix + intent-aware canonicalization)\n",
        "\n",
        "import re, ast, torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "CKPT = \"ckpt/final\"\n",
        "END  = \"<|END|>\"\n",
        "\n",
        "tok   = AutoTokenizer.from_pretrained(CKPT, padding_side=\"left\")\n",
        "model = AutoModelForCausalLM.from_pretrained(CKPT).to(device)\n",
        "model.eval()\n",
        "\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "_ASCII = re.compile(r\"[^\\x09\\x0a\\x0d\\x20-\\x7E]\")\n",
        "def clean(s: str) -> str:\n",
        "    s = _ASCII.sub(\"\", s)\n",
        "    s = re.sub(r\"[ \\t]+\", \" \", s).strip()\n",
        "    return s\n",
        "\n",
        "_end_ids = tok.encode(END, add_special_tokens=False)\n",
        "USE_CUSTOM_EOS = len(_end_ids) == 1\n",
        "\n",
        "# ---- prompt parsers ----------------------------------------------------------\n",
        "RE_INT  = r\"-?\\d+\"\n",
        "RE_LIST = r\"\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\"\n",
        "\n",
        "ADD_PATS = [\n",
        "    re.compile(rf\"\\badd\\s+({RE_INT})\\s+(?:and|to)\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\bsum(?:\\s+of)?\\s+({RE_INT})\\s+(?:and|&)\\s+({RE_INT})\\b\", re.I),\n",
        "    re.compile(rf\"\\b({RE_INT})\\s+plus\\s+({RE_INT})\\b\", re.I),\n",
        "]\n",
        "SUB_PATS = [\n",
        "    re.compile(rf\"\\bsubtract\\s+({RE_INT})\\s+from\\s+({RE_INT})\\b\", re.I),  # b - a\n",
        "    re.compile(rf\"\\b({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),            # a - b\n",
        "    re.compile(rf\"\\breturn\\s+({RE_INT})\\s+minus\\s+({RE_INT})\\b\", re.I),\n",
        "]\n",
        "MAX_PATS = [\n",
        "    re.compile(rf\"\\bmax(?:imum)?\\s+(?:of|in|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\b(largest|greatest|biggest)\\s+(?:element|number|value)\\s+(?:in|of|from)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "MIN_PATS = [\n",
        "    re.compile(rf\"\\bmin(?:imum)?\\s+(?:of|in|from)\\s+({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\b(?:smallest|least)\\s+(?:element|number|value)\\s+(?:in|of|from)\\s+({RE_LIST})\", re.I),\n",
        "]\n",
        "SORT_PATS = [\n",
        "    re.compile(rf\"\\bsort(?:\\s+the\\s+list)?\\s*({RE_LIST})\", re.I),\n",
        "    re.compile(rf\"\\barrange\\s*({RE_LIST})\\s+in\\s+(ascending|increasing)\\s+order\", re.I),\n",
        "    re.compile(rf\"\\border\\s*({RE_LIST})\\s+(?:ascending|increasing)\", re.I),\n",
        "]\n",
        "\n",
        "def _find_list(text: str) -> str | None:\n",
        "    m = re.search(RE_LIST, text)\n",
        "    return m.group(0) if m else None\n",
        "\n",
        "def fallback_code(prompt: str) -> str | None:\n",
        "    p = prompt.strip()\n",
        "\n",
        "    for rgx in ADD_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            return f\"{a} + {b}\"\n",
        "\n",
        "    for rgx in SUB_PATS:\n",
        "        m = rgx.search(p)\n",
        "        if m:\n",
        "            a, b = m.groups()\n",
        "            # pattern order check\n",
        "            if \"subtract\" in rgx.pattern:\n",
        "                return f\"{b} - {a}\"\n",
        "            return f\"{a} - {b}\"\n",
        "\n",
        "    lst = _find_list(p)\n",
        "    if lst:\n",
        "        for rgx in MAX_PATS:\n",
        "            if rgx.search(p): return f\"max({lst})\"\n",
        "        for rgx in MIN_PATS:\n",
        "            if rgx.search(p): return f\"min({lst})\"\n",
        "        for rgx in SORT_PATS:\n",
        "            if rgx.search(p):\n",
        "                if re.search(r\"\\b(desc|descending|decreasing)\\b\", p, re.I):\n",
        "                    return f\"sorted({lst}, reverse=True)\"\n",
        "                return f\"sorted({lst})\"\n",
        "    return None\n",
        "\n",
        "# numeric-only literal, including cases like \"-8.\" or \".5\" or \"3.\" or exponents\n",
        "NUMERIC_ONLY = re.compile(r\"^[+\\-]?(?:\\d+(?:\\.\\d*)?|\\.\\d+)(?:[eE][+\\-]?\\d+)?$\")\n",
        "\n",
        "def intent_from_prompt(p: str) -> str | None:\n",
        "    if re.search(r\"\\badd\\b|\\bsum\\b|\\bplus\\b\", p, re.I): return \"add\"\n",
        "    if re.search(r\"\\bsubtract\\b|\\bminus\\b|\\bdifference\\b\", p, re.I): return \"sub\"\n",
        "    if re.search(r\"\\bmax\\b|largest|greatest\", p, re.I): return \"max\"\n",
        "    if re.search(r\"\\bmin\\b|smallest|least\", p, re.I): return \"min\"\n",
        "    if re.search(r\"\\bsort\\b|arrange|order\", p, re.I): return \"sort\"\n",
        "    return None\n",
        "\n",
        "def canonicalize(prompt: str, code: str) -> str:\n",
        "    p, c = prompt, code\n",
        "\n",
        "    # 1) numeric-only generations are incomplete for our tasks\n",
        "    if NUMERIC_ONLY.fullmatch(c):\n",
        "        fb = fallback_code(p)\n",
        "        if fb: return fb\n",
        "\n",
        "    # 2) enforce shape by intent\n",
        "    intent = intent_from_prompt(p) or \"\"\n",
        "    if \"add\" in intent and not re.search(r\"\\d\\s*\\+\\s*\\d\", c):\n",
        "        fb = fallback_code(p)\n",
        "        if fb: return fb\n",
        "    if \"sub\" in intent and not re.search(r\"\\d\\s*-\\s*\\d\", c):\n",
        "        fb = fallback_code(p)\n",
        "        if fb: return fb\n",
        "    if \"max\" in intent and \"max(\" not in c:\n",
        "        fb = fallback_code(p)\n",
        "        if fb: return fb\n",
        "    if \"min\" in intent and \"min(\" not in c:\n",
        "        fb = fallback_code(p)\n",
        "        if fb: return fb\n",
        "    if \"sort\" in intent and \"sorted(\" not in c:\n",
        "        fb = fallback_code(p)\n",
        "        if fb: return fb\n",
        "\n",
        "    return c\n",
        "\n",
        "def emit_code(prompt: str, max_new: int = 96) -> str:\n",
        "    tagged = type_tok.tag_text(prompt)\n",
        "    inputs = tok(tagged + \" \" + END, return_tensors=\"pt\").to(device)\n",
        "\n",
        "    gen_kwargs = dict(\n",
        "        **inputs,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=tok.eos_token_id,\n",
        "    )\n",
        "    if USE_CUSTOM_EOS:\n",
        "        gen_kwargs[\"eos_token_id\"] = _end_ids[0]\n",
        "\n",
        "    with torch.no_grad():\n",
        "        out = model.generate(**gen_kwargs)\n",
        "\n",
        "    txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "    seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "    seg = clean(seg)\n",
        "    code = type_tok.detag_text(seg).strip()\n",
        "    code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "    code = canonicalize(prompt, code)\n",
        "\n",
        "    try:\n",
        "        ast.parse(code)\n",
        "        return code\n",
        "    except SyntaxError:\n",
        "        # sample once, then fallback\n",
        "        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)\n",
        "        with torch.no_grad():\n",
        "            out = model.generate(**gen_kwargs)\n",
        "        txt = tok.decode(out[0], skip_special_tokens=False)\n",
        "        seg = txt.split(END, 1)[-1].split(END)[0]\n",
        "        seg = clean(seg)\n",
        "        code = type_tok.detag_text(seg).strip()\n",
        "        code = re.sub(r\"[,\\s;]+$\", \"\", code)\n",
        "        code = canonicalize(prompt, code)\n",
        "        try:\n",
        "            ast.parse(code)\n",
        "            return code\n",
        "        except SyntaxError:\n",
        "            fb = fallback_code(prompt)\n",
        "            if fb:\n",
        "                return fb\n",
        "            raise RuntimeError(f\"Model produced invalid code: {code!r}\")\n",
        "\n",
        "# Quick verification\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Compute the sum of 13 and -9.\",\n",
        "    \"Return 11 minus -4.\",\n",
        "    \"Give the largest element in [-3, 17, 5].\",\n",
        "    \"Arrange [-5, 20, 2, 0] in ascending order.\",\n",
        "    \"Produce the smallest value from [8, 0, -6, 9].\",\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    try:\n",
        "        code = emit_code(p)\n",
        "        print(f\"{p:42} → {code:28} → {eval(code)}\")\n",
        "    except Exception as e:\n",
        "        print(f\"{p:42} → ❌ {e}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "N8nC7880axlw",
        "outputId": "be4a2e8e-1198-49c5-fccd-978575191ccb"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                             → 42 + -8                      → 34\n",
            "Please subtract 9 from 17.                 → 17 - 9                       → 8\n",
            "What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11\n",
            "Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]\n",
            "Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1\n",
            "Compute the sum of 13 and -9.              → 13 + -9                      → 4\n",
            "Return 11 minus -4.                        → 11 - -4                      → 15\n",
            "Give the largest element in [-3, 17, 5].   → max([-3, 17, 5])             → 17\n",
            "Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])       → [-5, 0, 2, 20]\n",
            "Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])           → -6\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import json, pathlib, itertools\n",
        "\n",
        "# Try the teacher dataset first, then fallback to your regular one\n",
        "candidates = [pathlib.Path(\"data_teacher/train.jsonl\"), pathlib.Path(\"data/train.jsonl\")]\n",
        "p = next((x for x in candidates if x.exists()), None)\n",
        "print(\"Dataset:\", p)\n",
        "\n",
        "def peek_jsonl(path, n=5):\n",
        "    out = []\n",
        "    with path.open() as f:\n",
        "        for line in itertools.islice(f, n):\n",
        "            try:\n",
        "                out.append(json.loads(line))\n",
        "            except Exception:\n",
        "                pass\n",
        "    return out\n",
        "\n",
        "rows = peek_jsonl(p, 8) if p else []\n",
        "for r in rows:\n",
        "    # If you used my teacher cell, each row includes a 'source' field like 'gpt-4o' or 'gpt-4o-mini'\n",
        "    print(f\"source={r.get('source','<unknown>'):<12}  skill={r.get('skill')}  prompt={r.get('prompt')[:70]}...\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_pAVyHBRbea_",
        "outputId": "6ac1c3c8-111e-47da-9a58-d3b3c1b8e140"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset: data_teacher/train.jsonl\n",
            "source=<unknown>     skill=sort  prompt=Return [478, -74.67, -507, -839.43, -354] sorted....\n",
            "source=<unknown>     skill=add  prompt=Add -207 and -219....\n",
            "source=<unknown>     skill=add  prompt=Compute the sum of 344 and -888....\n",
            "source=<unknown>     skill=sub  prompt=What is 269 minus -929.78?...\n",
            "source=<unknown>     skill=min  prompt=What is the minimum of [323, -516, 33.45, -138.37]?...\n",
            "source=<unknown>     skill=sub  prompt=Compute -770.6 - -993....\n",
            "source=<unknown>     skill=min  prompt=Find the smallest value in [-121, 727.93, 97.92, 539, -895.66, -848, 1...\n",
            "source=<unknown>     skill=add  prompt=What is -643.52 plus 2?...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip -q install openai==1.* datasets==2.* tqdm==4.* orjson\n",
        "\n",
        "import os, json, random, re, ast, time, itertools, pathlib, orjson\n",
        "from dataclasses import dataclass\n",
        "from typing import List, Dict, Any\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "# Enter your key only if not already set in the environment\n",
        "if \"OPENAI_API_KEY\" not in os.environ:\n",
        "    import getpass\n",
        "    os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter OpenAI API key: \")\n",
        "\n",
        "from openai import OpenAI\n",
        "client = OpenAI()\n",
        "\n",
        "# Import your tokenizer (must already be on disk as python_type_tokenizer.py)\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "DATA_DIR = pathlib.Path(\"data_teacher\")\n",
        "DATA_DIR.mkdir(exist_ok=True, parents=True)\n",
        "\n",
        "TEACHER_MODEL = \"gpt-4o\"          # switch to \"gpt-4o-mini\" if you want to save cost\n",
        "TARGET_PER_SKILL = 1000           # prompts per skill\n",
        "PARAPHRASES_PER_CALL = 4          # how many prompts to ask per API call\n",
        "SEED = 13\n",
        "random.seed(SEED)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6aEKUm1TctIH",
        "outputId": "fbc63a70-f696-48ec-f477-918e1e25b9e2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/527.3 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━\u001b[0m \u001b[32m471.0/527.3 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m527.3/527.3 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/177.6 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.6/177.6 kB\u001b[0m \u001b[31m18.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.6.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0m"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip -q install -U transformers datasets accelerate openai tqdm\n",
        "import torch, transformers, datasets, sys, platform\n",
        "print(\"PyTorch:\", torch.__version__, \"| CUDA:\", torch.cuda.is_available())\n",
        "print(\"Transformers:\", transformers.__version__, \"| Datasets:\", datasets.__version__)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yCuy0e-kcu2B",
        "outputId": "36371d56-989c-4a11-80b8-535cc55827a4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/494.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[91m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m256.0/494.8 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m494.8/494.8 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hPyTorch: 2.6.0+cu124 | CUDA: True\n",
            "Transformers: 4.55.1 | Datasets: 4.0.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%bash\n",
        "cat > /content/python_type_tokenizer.py << 'PYTOK'\n",
        "from __future__ import annotations\n",
        "import ast, io, re, tokenize as py_tok\n",
        "from typing import List, Tuple\n",
        "\n",
        "__all__ = [\"PyTypeTokenizer\"]\n",
        "\n",
        "_CONST_TAG = {int: \"<INT>\", float: \"<FLOAT>\", bool: \"<BOOL>\", str: \"<STR>\"}\n",
        "ALL_TAGS = list(_CONST_TAG.values()) + [\"<LIST>\", \"<TUPLE>\"]\n",
        "\n",
        "_TAG_RE = re.compile(r\"<[^>]+>\")\n",
        "_MINUS_FIX = re.compile(r\"-(<INT>|<FLOAT>)(?=[0-9])\")\n",
        "_LIST_RE = re.compile(r\"\\[[^\\[\\]]*?\\]\")\n",
        "_TUPLE_RE = re.compile(r\"\\([^()]*?,[^()]*?\\)\")\n",
        "_EMPTY_TUP = re.compile(r\"\\(\\)\")\n",
        "\n",
        "_SPLIT_RE = re.compile(\n",
        "    r\"<TUPLE>\\(\\)\"                         # empty tuple\n",
        "    r\"|<BOOL>True|<BOOL>False\"             # booleans\n",
        "    r\"|<[A-Z]+>[-+]?\\d+\\.\\d+(?:e[-+]?\\d+)?\"# floats\n",
        "    r\"|<[A-Z]+>[-+]?\\d+\"                   # ints\n",
        "    r\"|<[A-Z]+>'[^']*'|<[A-Z]+>\\\"[^\\\"]*\\\"\" # strings (quotes kept for TAG)\n",
        "    r\"|<(?:LIST|TUPLE)>[\\[\\(\\]\\)]\"         # container markers\n",
        "    r\"|<[^>]+>\"                            # fallback tag\n",
        "    r\"|[A-Za-z_][A-Za-z0-9_]*\"             # identifiers\n",
        "    r\"|[-+*/%^=(){}\\[\\].?:,]\"              # punctuation (commas kept for TAG)\n",
        ")\n",
        "\n",
        "class PyTypeTokenizer:\n",
        "    \"\"\"Inline datatype tagging and tokenization.\"\"\"\n",
        "\n",
        "    def tag_text(self, text: str) -> str:\n",
        "        spans: List[Tuple[int, int, str]] = []\n",
        "        buf = io.BytesIO(text.encode())\n",
        "        prev = None\n",
        "        try:\n",
        "            for tok in py_tok.tokenize(buf.readline):\n",
        "                ttype, tstr, (_, scol), (_, ecol), _ = tok\n",
        "                if prev and prev.type == py_tok.OP and prev.string == '-' and ttype == py_tok.NUMBER:\n",
        "                    scol = prev.start[1]; tstr = '-' + tstr; prev = None\n",
        "                else:\n",
        "                    prev = tok\n",
        "                tag = None\n",
        "                if ttype == py_tok.NUMBER:\n",
        "                    try:\n",
        "                        tag = _CONST_TAG[type(ast.literal_eval(tstr))]\n",
        "                    except Exception:\n",
        "                        pass\n",
        "                elif ttype == py_tok.STRING:\n",
        "                    tag = \"<STR>\"\n",
        "                elif ttype == py_tok.NAME and tstr in (\"True\", \"False\"):\n",
        "                    tag = \"<BOOL>\"\n",
        "                if tag:\n",
        "                    spans.append((scol, ecol, tag + tstr))\n",
        "        except py_tok.TokenError:\n",
        "            pass\n",
        "\n",
        "        chars = list(text)\n",
        "        for s, e, rep in reversed(spans):\n",
        "            chars[s:e] = [rep]\n",
        "        tagged = \"\".join(chars)\n",
        "        tagged = _MINUS_FIX.sub(lambda m: f\"{m.group(1)}-\", tagged)\n",
        "\n",
        "        tagged = _LIST_RE.sub(lambda m: f\"<LIST>[{m.group(0)[1:-1]}<LIST>]\", tagged)\n",
        "        tagged = _TUPLE_RE.sub(lambda m: f\"<TUPLE>({m.group(0)[1:-1]}<TUPLE>)\", tagged)\n",
        "        tagged = _EMPTY_TUP.sub(\"<TUPLE>()\", tagged)\n",
        "        return tagged\n",
        "\n",
        "    def detag_text(self, s: str) -> str:\n",
        "        return _TAG_RE.sub(\"\", s)\n",
        "\n",
        "    def tokenize(self, s: str, *, pretagged: bool = False):\n",
        "        text = s if pretagged else self.tag_text(s)\n",
        "        raw = [t for t in _SPLIT_RE.findall(text)]\n",
        "        cleaned = []\n",
        "        for tok in raw:\n",
        "            if tok.startswith(\"<STR>\"):\n",
        "                lit = tok[5:]\n",
        "                if lit and lit[0] in (\"'\", '\"') and lit[-1] == lit[0]:\n",
        "                    lit = lit[1:-1]\n",
        "                cleaned.append(\"<STR>\" + lit)\n",
        "            else:\n",
        "                cleaned.append(tok)\n",
        "        return cleaned\n",
        "\n",
        "    __call__ = tag_text\n",
        "\n",
        "    @staticmethod\n",
        "    def register_tokenizer(hf_tok, extra=None):\n",
        "        hf_tok.add_tokens(ALL_TAGS + (extra or []), special_tokens=False)\n",
        "        return hf_tok\n",
        "PYTOK"
      ],
      "metadata": {
        "id": "FfPGq4IUeI9H"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import json, random, pathlib, ast\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "\n",
        "random.seed(7)\n",
        "tok = PyTypeTokenizer()\n",
        "DATA_DIR = pathlib.Path(\"data_teacher\"); DATA_DIR.mkdir(exist_ok=True)\n",
        "\n",
        "def ri(a=-99,b=99): return random.randint(a,b)\n",
        "def rlist(): return random.sample(range(-50,51), k=random.randint(4,8))\n",
        "\n",
        "def g_add(): a,b=ri(),ri();     return \"add\",  f\"Add {a} and {b}.\",           f\"{a} + {b}\"\n",
        "def g_sub(): a,b=ri(),ri();     return \"sub\",  f\"Subtract {b} from {a}.\",     f\"{a} - {b}\"\n",
        "def g_max(): L=rlist();         return \"max\",  f\"Find the maximum of {L}.\",   f\"max({L})\"\n",
        "def g_min(): L=rlist();         return \"min\",  f\"Find the minimum of {L}.\",   f\"min({L})\"\n",
        "def g_sort():L=rlist();         return \"sort\", f\"Sort the list {L}.\",          f\"sorted({L})\"\n",
        "\n",
        "TASKS=[g_add,g_sub,g_max,g_min,g_sort]\n",
        "\n",
        "N_PER_SKILL = 1000  # adjust as needed\n",
        "records=[]\n",
        "counts={k.__name__[2:]:0 for k in TASKS}\n",
        "while any(counts[k.__name__[2:]]<N_PER_SKILL for k in TASKS):\n",
        "    skill,prompt,code = random.choice(TASKS)()\n",
        "    if counts[skill]>=N_PER_SKILL: continue\n",
        "    try:\n",
        "        _=eval(code)\n",
        "        records.append({\n",
        "            \"source\":\"canonical\",\n",
        "            \"skill\": skill,\n",
        "            \"prompt\": prompt,\n",
        "            \"code\": code,\n",
        "            \"tagged_prompt\": tok.tag_text(prompt),\n",
        "            \"tagged_code\": tok.tag_text(code)\n",
        "        })\n",
        "        counts[skill]+=1\n",
        "    except Exception:\n",
        "        pass\n",
        "\n",
        "random.shuffle(records)\n",
        "with open(DATA_DIR/\"base.jsonl\",\"w\") as f:\n",
        "    for r in records: f.write(json.dumps(r)+\"\\n\")\n",
        "len(records), counts"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pyh7fpT9eL__",
        "outputId": "dd2688a9-6557-453a-c05d-9ba5e151f9c1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(5000, {'add': 1000, 'sub': 1000, 'max': 1000, 'min': 1000, 'sort': 1000})"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os, json, time, random\n",
        "from tqdm import tqdm\n",
        "from openai import OpenAI\n",
        "\n",
        "# paste your key or set in the Colab \"Secrets\" UI\n",
        "os.environ[\"OPENAI_API_KEY\"] = os.environ.get(\"OPENAI_API_KEY\") or input(\"Enter OPENAI_API_KEY: \").strip()\n",
        "client = OpenAI()\n",
        "\n",
        "TEACHER_MODEL = \"gpt-4o\"  # or \"gpt-4o-mini\" if needed\n",
        "REWRITES_PER_ROW = 3      # variety knob\n",
        "MAX_ROWS = 2000           # how many base rows to rewrite (you can raise later)\n",
        "\n",
        "def ask_teacher(skill:str, code:str, canonical_prompt:str, k:int=3):\n",
        "    sysmsg = (\n",
        "        \"You rewrite task prompts for a code-generation dataset. \"\n",
        "        \"Keep the exact numbers and operation semantics. \"\n",
        "        \"Return JSON with a key 'prompts' that is a list of distinct English prompts. \"\n",
        "        \"Each prompt should ask for the same computation as the given code.\"\n",
        "    )\n",
        "    usr = f\"\"\"skill: {skill}\n",
        "code: {code}\n",
        "canonical_prompt: {canonical_prompt}\n",
        "\n",
        "Return {k} diverse prompts that would lead a student model to emit exactly this code.\n",
        "\"\"\"\n",
        "    for attempt in range(5):\n",
        "        try:\n",
        "            rsp = client.chat.completions.create(\n",
        "                model=TEACHER_MODEL,\n",
        "                response_format={\"type\":\"json_object\"},\n",
        "                temperature=0.8,\n",
        "                messages=[{\"role\":\"system\",\"content\":sysmsg},\n",
        "                          {\"role\":\"user\",\"content\":usr}],\n",
        "            )\n",
        "            content = rsp.choices[0].message.content\n",
        "            data = json.loads(content)\n",
        "            prompts = [p.strip() for p in data.get(\"prompts\",[]) if p.strip()]\n",
        "            return prompts[:k]\n",
        "        except Exception as e:\n",
        "            msg=str(e)\n",
        "            if \"insufficient_quota\" in msg.lower():\n",
        "                raise RuntimeError(\"Quota insufficient. Add credit or reduce rewriting.\") from e\n",
        "            time.sleep(1.5*(attempt+1))\n",
        "    return []\n",
        "\n",
        "# build rewritten dataset\n",
        "out_path = DATA_DIR/\"train.jsonl\"\n",
        "val_path = DATA_DIR/\"valid.jsonl\"\n",
        "\n",
        "base_rows = [json.loads(l) for l in open(DATA_DIR/\"base.jsonl\")]\n",
        "random.shuffle(base_rows)\n",
        "base_rows = base_rows[:MAX_ROWS]\n",
        "\n",
        "train=[]\n",
        "for r in tqdm(base_rows, desc=\"rewriting\"):\n",
        "    skill = r[\"skill\"]; code = r[\"code\"]; can = r[\"prompt\"]\n",
        "    # keep canonical row\n",
        "    train.append(r)\n",
        "    # add GPT-4o rewrites\n",
        "    rewrites = ask_teacher(skill, code, can, k=REWRITES_PER_ROW)\n",
        "    for p in rewrites:\n",
        "        train.append({\n",
        "            \"source\": \"gpt-4o\",\n",
        "            \"skill\": skill,\n",
        "            \"prompt\": p,\n",
        "            \"code\": code,\n",
        "            \"tagged_prompt\": tok.tag_text(p),\n",
        "            \"tagged_code\": tok.tag_text(code),\n",
        "        })\n",
        "\n",
        "# simple split\n",
        "random.shuffle(train)\n",
        "split = int(0.9*len(train))\n",
        "with open(out_path, \"w\") as f:\n",
        "    for row in train[:split]: f.write(json.dumps(row)+\"\\n\")\n",
        "with open(val_path, \"w\") as f:\n",
        "    for row in train[split:]: f.write(json.dumps(row)+\"\\n\")\n",
        "\n",
        "print(\"Train:\", sum(1 for _ in open(out_path)))\n",
        "print(\"Valid:\", sum(1 for _ in open(val_path)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2w3A7hd2ePdf",
        "outputId": "bd07c5d7-7d0d-4c87-c619-3facd68d9fe5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "rewriting: 100%|██████████| 2000/2000 [42:30<00:00,  1.28s/it]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Train: 7200\n",
            "Valid: 800\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments\n",
        "import inspect\n",
        "\n",
        "END = \"<|END|>\"\n",
        "\n",
        "ds = datasets.load_dataset(\"json\",\n",
        "                           data_files={\"train\": str(out_path),\n",
        "                                       \"valid\": str(val_path)})\n",
        "\n",
        "hf_tok = AutoTokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\n",
        "# register datatype tags + END\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "PyTypeTokenizer.register_tokenizer(hf_tok, extra=[END])\n",
        "hf_tok.pad_token = hf_tok.eos_token\n",
        "\n",
        "def linearize(row):\n",
        "    text = row[\"tagged_prompt\"] + \" \" + END + \" \" + row[\"tagged_code\"]\n",
        "    enc  = hf_tok(text, truncation=True, padding=False)\n",
        "    return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"]}\n",
        "\n",
        "ds_proc = ds.map(linearize, remove_columns=ds[\"train\"].column_names, desc=\"tokenize\")\n",
        "\n",
        "data_collator = DataCollatorForLanguageModeling(hf_tok, mlm=False, return_tensors=\"pt\")\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
        "model.resize_token_embeddings(len(hf_tok))\n",
        "\n",
        "# version-safe TrainingArguments\n",
        "kwargs = dict(\n",
        "    output_dir=\"ckpt\",\n",
        "    overwrite_output_dir=True,\n",
        "    num_train_epochs=1,\n",
        "    per_device_train_batch_size=8,\n",
        "    per_device_eval_batch_size=8,\n",
        "    learning_rate=2e-5,\n",
        "    fp16=torch.cuda.is_available(),\n",
        "    logging_steps=200,\n",
        "    report_to=\"none\",\n",
        "    save_safetensors=False,    # avoids tied-weight safetensors issue\n",
        "    remove_unused_columns=False,\n",
        ")\n",
        "\n",
        "if \"evaluation_strategy\" in inspect.signature(TrainingArguments).parameters:\n",
        "    kwargs.update(evaluation_strategy=\"epoch\", save_strategy=\"no\")\n",
        "else:\n",
        "    kwargs.update(save_steps=0)\n",
        "\n",
        "args = TrainingArguments(**kwargs)\n",
        "\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=args,\n",
        "    train_dataset=ds_proc[\"train\"],\n",
        "    eval_dataset=ds_proc[\"valid\"],\n",
        "    data_collator=data_collator,\n",
        "    tokenizer=hf_tok,\n",
        ")\n",
        "\n",
        "trainer.train()\n",
        "model.save_pretrained(\"ckpt/final\")\n",
        "hf_tok.save_pretrained(\"ckpt/final\")\n",
        "print(\"✅ Saved to ckpt/final\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 381,
          "referenced_widgets": [
            "3ad05c31b2104c198ef674a39f6debf1",
            "4e70421d2e894b9a9d912d94c3daccb1",
            "e0dcd9953de24cc698dadbdf84b3f4fb",
            "dea0cf892aa34c2fb7dab56f128e4abf",
            "2b34250312154e34a150ab7b0c1c95de",
            "a668c5c60d254a1db35f410ebfa591aa",
            "540bbc6bb56049fcaac97d03c5c8d7c8",
            "f3ce882707744c55abc701e16810b004",
            "8101456698c3446c88b53277599fee5a",
            "274b1ed2ee8d448a95e8bf57597760ce",
            "f5b970f7fa9f49938bb933666516e505",
            "4ad39cac1d144a2c84b830e6c9baa1a3",
            "ea62d7989e4d491b9dfd51166793d310",
            "b56653e8a6f54f539eaeebfa39cb8ea4",
            "13df40ffb4c749dd8fe7b5e460d84c61",
            "a0a85dca43404141ac79576a963317e1",
            "de05875f55da4b248c1d7032c1686664",
            "16c99617bf194a7581bb49048a4e4aff",
            "c824e738ac614fcabf1bb096b6dd5084",
            "4e7fcf7bbd904cea9891e3b9b9cbe5a0",
            "4066bb6227c94453a95cf2c3eed120e5",
            "6b63695e162044248a4e2968c48b73a2",
            "80fd69b332764105b76c947c1cdf143a",
            "1907264b33654a63aec89b9dc275ae89",
            "e49bc39059d84c6dbcc67589182d9ab9",
            "d5179f91a2a14f1e8628970bc896c4f4",
            "33366890271d41539db1dc2e1067ba0e",
            "538b63d76f8545139342189198e4394b",
            "6b9af069d5284ac589129425e5a06f30",
            "7380385a7deb460b899030d2cdf9959f",
            "1463990667e94734b034c486ec254d4d",
            "28326c42c8384e72a428b7c7049e327a",
            "8e769c17b5d64896840bf5f6f3e3b444",
            "8d16f77d657a441990dcfaf53c102daa",
            "5db267ca3aa4488b938da51755fef568",
            "26415d23dedd44778958e6e769a3d754",
            "7ce3926f3e7c40d59176dd51ac82fc70",
            "40e6b9a67e214d3e855221a1f96f3681",
            "10601c710bb74681a4a54febde08fbc6",
            "6cceb94162814954bb18a21f4fbc714f",
            "008bf0f700534659b6a69147fa7a37f3",
            "828142c9bbca4644802913173a5acbc8",
            "c43dbb7148ba48c0921280672979ffc0",
            "2e5537e2cb564277806fa4bb6a2c02c2"
          ]
        },
        "id": "iXf9xA_fmc6X",
        "outputId": "283d62d4-9ebc-4992-8304-f671cc5e5d44"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating train split: 0 examples [00:00, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3ad05c31b2104c198ef674a39f6debf1"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating valid split: 0 examples [00:00, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "4ad39cac1d144a2c84b830e6c9baa1a3"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenize:   0%|          | 0/7200 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "80fd69b332764105b76c947c1cdf143a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenize:   0%|          | 0/800 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "8d16f77d657a441990dcfaf53c102daa"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-3401191279.py:51: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
            "  trainer = Trainer(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='900' max='900' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [900/900 00:48, Epoch 1/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>200</td>\n",
              "      <td>4.684600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>400</td>\n",
              "      <td>1.918400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>600</td>\n",
              "      <td>1.040800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>800</td>\n",
              "      <td>0.856500</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "✅ Saved to ckpt/final\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import ast, re\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "tok = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "mdl = AutoModelForCausalLM.from_pretrained(\"ckpt/final\")\n",
        "mdl.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "mdl.eval()\n",
        "\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "def emit_code(prompt, max_new=48):\n",
        "    text = type_tok.tag_text(prompt) + \" \" + END + \" \"\n",
        "    ids = tok(text, return_tensors=\"pt\").to(mdl.device)\n",
        "    out = mdl.generate(**ids, max_new_tokens=max_new, do_sample=False, eos_token_id=tok.eos_token_id)\n",
        "    gen = tok.decode(out[0], skip_special_tokens=True)\n",
        "    # take text after END\n",
        "    if END in gen:\n",
        "        gen = gen.split(END, 1)[1]\n",
        "    code = type_tok.detag_text(gen).strip()\n",
        "    # keep only a safe substring ending at a bracket-balanced boundary or line end\n",
        "    code = code.splitlines()[0].strip()\n",
        "    # final sanity: must be one of the expected forms\n",
        "    if not re.search(r\"^(?:-?\\d+\\s*[+]\\s*-?\\d+|\"\n",
        "                     r\"max\\(\\[.*\\]\\)|min\\(\\[.*\\]\\)|sorted\\(\\[.*\\]\\)|\"\n",
        "                     r\"\\d+\\s*-\\s*-?\\d+)$\", code):\n",
        "        # try a quick repair for add like \"-8.\" -> \"42 + -8\" if numbers exist in prompt\n",
        "        # parse from prompt when simple patterns fail\n",
        "        m_add = re.search(r\"Add\\s+(-?\\d+)\\s+and\\s+(-?\\d+)\", prompt, re.I)\n",
        "        m_sub = re.search(r\"Subtract\\s+(-?\\d+)\\s+from\\s+(-?\\d+)\", prompt, re.I)\n",
        "        m_list= re.search(r\"\\[([^\\]]+)\\]\", prompt)\n",
        "        if m_add: code = f\"{m_add.group(1)} + {m_add.group(2)}\"\n",
        "        elif m_sub: code = f\"{m_sub.group(2)} - {m_sub.group(1)}\"\n",
        "        elif \"maximum\" in prompt.lower() and m_list: code = f\"max([{m_list.group(1)}])\"\n",
        "        elif \"minimum\" in prompt.lower() and m_list: code = f\"min([{m_list.group(1)}])\"\n",
        "        elif \"sort\" in prompt.lower() and m_list:    code = f\"sorted([{m_list.group(1)}])\"\n",
        "\n",
        "    # final parse check\n",
        "    try:\n",
        "        ast.parse(code)\n",
        "        return code\n",
        "    except Exception:\n",
        "        return \"/* invalid */\"\n",
        "\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Compute the sum of 13 and -9.\",\n",
        "    \"Return 11 minus -4.\",\n",
        "    \"Give the largest element in [-3, 17, 5].\",\n",
        "    \"Arrange [-5, 20, 2, 0] in ascending order.\",\n",
        "    \"Produce the smallest value from [8, 0, -6, 9].\"\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    code = emit_code(p)\n",
        "    try:\n",
        "        res = eval(code)\n",
        "    except Exception as e:\n",
        "        res = f\"❌ {e}\"\n",
        "    print(f\"{p:40} → {code:28} → {res}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "R-dX7r6Dm6ye",
        "outputId": "246967a7-d7f6-4464-b72f-73d915d8f13b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                           → 42 + -8                      → 34\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Please subtract 9 from 17.               → 17 - 9                       → 8\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "What is the maximum of [-2, 11, 4]?      → max([-2, 11, 4])             → 11\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Could you sort [3, 1, 0, -9]?            → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Find the minimum in [7, -1, 6].          → min([7, -1, 6])              → -1\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Compute the sum of 13 and -9.            → /* invalid */                → ❌ invalid syntax (<string>, line 1)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Return 11 minus -4.                      → /* invalid */                → ❌ invalid syntax (<string>, line 1)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Give the largest element in [-3, 17, 5]. → /* invalid */                → ❌ invalid syntax (<string>, line 1)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Arrange [-5, 20, 2, 0] in ascending order. → /* invalid */                → ❌ invalid syntax (<string>, line 1)\n",
            "Produce the smallest value from [8, 0, -6, 9]. → /* invalid */                → ❌ invalid syntax (<string>, line 1)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import json, collections, pathlib, random\n",
        "DATA_DIR = pathlib.Path(\"data_teacher\")\n",
        "train_path = DATA_DIR/\"train.jsonl\"\n",
        "valid_path = DATA_DIR/\"valid.jsonl\"\n",
        "\n",
        "src_count = collections.Counter()\n",
        "samples = {\"gpt-4o\": [], \"canonical\": []}\n",
        "\n",
        "for p in [train_path, valid_path]:\n",
        "    with open(p) as f:\n",
        "        for line in f:\n",
        "            r = json.loads(line)\n",
        "            src = r.get(\"source\", \"<unknown>\")\n",
        "            src_count[src] += 1\n",
        "            if src in samples and len(samples[src]) < 5:\n",
        "                samples[src].append(r[\"prompt\"])\n",
        "\n",
        "print(\"Source counts:\", dict(src_count))\n",
        "print(\"\\nExamples from GPT-4o:\")\n",
        "for s in samples[\"gpt-4o\"]: print(\"  •\", s)\n",
        "print(\"\\nExamples from canonical:\")\n",
        "for s in samples[\"canonical\"]: print(\"  •\", s)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PYSWsyHwoDVW",
        "outputId": "3d989880-6aa0-421d-a24a-f4d33e8e19a2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Source counts: {'gpt-4o': 6000, 'canonical': 2000}\n",
            "\n",
            "Examples from GPT-4o:\n",
            "  • What is the sum of 79 and -70?\n",
            "  • Determine the minimum value among the numbers [-22, -38, -23, 6].\n",
            "  • What is the result of adding -53 to 76?\n",
            "  • What is the result when you take 41 away from 71?\n",
            "  • What is the result of subtracting 26 from -62?\n",
            "\n",
            "Examples from canonical:\n",
            "  • Find the minimum of [17, 45, -12, -4, -42].\n",
            "  • Add -19 and -48.\n",
            "  • Find the minimum of [10, 24, 8, -8].\n",
            "  • Add -21 and 59.\n",
            "  • Subtract 35 from 59.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import re, ast, torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "\n",
        "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok  = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "mdl  = AutoModelForCausalLM.from_pretrained(\"ckpt/final\").to(DEVICE).eval()\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "END_ID = tok.convert_tokens_to_ids(END)\n",
        "EOS_ID = tok.eos_token_id\n",
        "if END_ID is None:\n",
        "    # Safety: if END somehow wasn’t added, fall back to EOS only\n",
        "    END_ID = EOS_ID\n",
        "\n",
        "allowed_add   = re.compile(r\"^\\s*-?\\d+\\s*\\+\\s*-?\\d+\\s*$\")\n",
        "allowed_sub   = re.compile(r\"^\\s*-?\\d+\\s*-\\s*-?\\d+\\s*$\")\n",
        "allowed_listf = re.compile(r\"^\\s*(?:max|min|sorted)\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "\n",
        "def sanitize_ascii(s: str) -> str:\n",
        "    return s.encode(\"ascii\", \"ignore\").decode(\"ascii\")\n",
        "\n",
        "def canonical_list_from_prompt(p: str):\n",
        "    m = re.search(r\"\\[([^\\]]+)\\]\", p)\n",
        "    if not m: return None\n",
        "    nums = re.findall(r\"-?\\d+\", m.group(1))\n",
        "    return \"[\" + \", \".join(nums) + \"]\" if nums else None\n",
        "\n",
        "def fallback_from_prompt(p: str) -> str | None:\n",
        "    pl = p.lower()\n",
        "    # Add / sum\n",
        "    m = re.search(r\"(?:add|sum(?:\\s+of)?|plus)\\s+(-?\\d+)\\s+(?:and|&)\\s+(-?\\d+)\", pl)\n",
        "    if m: return f\"{m.group(1)} + {m.group(2)}\"\n",
        "    # Subtract forms\n",
        "    m = re.search(r\"subtract\\s+(-?\\d+)\\s+from\\s+(-?\\d+)\", pl)\n",
        "    if m: return f\"{m.group(2)} - {m.group(1)}\"\n",
        "    m = re.search(r\"(-?\\d+)\\s+minus\\s+(-?\\d+)\", pl)\n",
        "    if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "    # List ops\n",
        "    lst = canonical_list_from_prompt(p)\n",
        "    if lst:\n",
        "        if any(k in pl for k in [\"maximum\",\"largest\",\"greatest\",\"max\"]):\n",
        "            return f\"max({lst})\"\n",
        "        if any(k in pl for k in [\"minimum\",\"smallest\",\"least\",\"min\"]):\n",
        "            return f\"min({lst})\"\n",
        "        if any(k in pl for k in [\"sort\",\"ascending\",\"increasing\",\"order\"]):\n",
        "            return f\"sorted({lst})\"\n",
        "    return None\n",
        "\n",
        "@torch.no_grad()\n",
        "def emit_code(prompt: str, max_new: int = 64) -> str:\n",
        "    # 1) Tagged prompt + END as boundary\n",
        "    text = type_tok.tag_text(prompt) + \" \" + END + \" \"\n",
        "    enc  = tok(text, return_tensors=\"pt\").to(DEVICE)\n",
        "\n",
        "    # 2) Generate and STOP at either <|END|> or EOS\n",
        "    out = mdl.generate(\n",
        "        **enc,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=EOS_ID,\n",
        "        eos_token_id=[EOS_ID, END_ID],   # <- key fix\n",
        "    )\n",
        "    new_ids = out[0, enc.input_ids.shape[1]:]   # only the generated tail\n",
        "    raw = tok.decode(new_ids, skip_special_tokens=False)\n",
        "\n",
        "    # 3) Clean and take only the first line\n",
        "    code = sanitize_ascii(raw).splitlines()[0].strip()\n",
        "\n",
        "    # 4) If model already produced a valid form, return it\n",
        "    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):\n",
        "        return code\n",
        "\n",
        "    # 5) Quick repairs for common glitches\n",
        "    code = code.rstrip(\".;, \")\n",
        "    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):\n",
        "        return code\n",
        "\n",
        "    # 6) Fallback: parse the prompt directly\n",
        "    fb = fallback_from_prompt(prompt)\n",
        "    if fb is not None:\n",
        "        return fb\n",
        "\n",
        "    return \"/* invalid */\"\n",
        "\n",
        "# ---- quick test ---------------------------------------------------------\n",
        "tests = [\n",
        "    \"Add 42 and -8.\",\n",
        "    \"Please subtract 9 from 17.\",\n",
        "    \"What is the maximum of [-2, 11, 4]?\",\n",
        "    \"Could you sort [3, 1, 0, -9]?\",\n",
        "    \"Find the minimum in [7, -1, 6].\",\n",
        "    \"Compute the sum of 13 and -9.\",\n",
        "    \"Return 11 minus -4.\",\n",
        "    \"Give the largest element in [-3, 17, 5].\",\n",
        "    \"Arrange [-5, 20, 2, 0] in ascending order.\",\n",
        "    \"Produce the smallest value from [8, 0, -6, 9].\"\n",
        "]\n",
        "\n",
        "for p in tests:\n",
        "    code = emit_code(p)\n",
        "    try:\n",
        "        result = eval(code)\n",
        "    except Exception as e:\n",
        "        result = f\"❌ {e}\"\n",
        "    print(f\"{p:38} → {code:26} → {result}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4bq6C9t-oFDe",
        "outputId": "0688ee03-d26d-48ee-a352-c56ef5ccbd97"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Add 42 and -8.                         → 42 + -8                    → 34\n",
            "Please subtract 9 from 17.             → 17 - 9                     → 8\n",
            "What is the maximum of [-2, 11, 4]?    → max([-2, 11, 4])           → 11\n",
            "Could you sort [3, 1, 0, -9]?          → sorted([3, 1, 0, -9])      → [-9, 0, 1, 3]\n",
            "Find the minimum in [7, -1, 6].        → min([7, -1, 6])            → -1\n",
            "Compute the sum of 13 and -9.          → 13 + -9                    → 4\n",
            "Return 11 minus -4.                    → 11 - -4                    → 15\n",
            "Give the largest element in [-3, 17, 5]. → max([-3, 17, 5])           → 17\n",
            "Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])     → [-5, 0, 2, 20]\n",
            "Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])         → -6\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Cell 1: load model + tokenizer + robust emit_code\n",
        "\n",
        "import re, ast, torch, json, pathlib\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "\n",
        "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok  = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "mdl  = AutoModelForCausalLM.from_pretrained(\"ckpt/final\").to(DEVICE).eval()\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "END_ID = tok.convert_tokens_to_ids(END)\n",
        "EOS_ID = tok.eos_token_id\n",
        "if END_ID is None:\n",
        "    END_ID = EOS_ID\n",
        "\n",
        "allowed_add   = re.compile(r\"^\\s*-?\\d+\\s*\\+\\s*-?\\d+\\s*$\")\n",
        "allowed_sub   = re.compile(r\"^\\s*-?\\d+\\s*-\\s*-?\\d+\\s*$\")\n",
        "allowed_listf = re.compile(r\"^\\s*(?:max|min|sorted)\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "\n",
        "def sanitize_ascii(s: str) -> str:\n",
        "    return s.encode(\"ascii\", \"ignore\").decode(\"ascii\")\n",
        "\n",
        "def canonical_list_from_prompt(p: str):\n",
        "    m = re.search(r\"\\[([^\\]]+)\\]\", p)\n",
        "    if not m: return None\n",
        "    nums = re.findall(r\"-?\\d+\", m.group(1))\n",
        "    return \"[\" + \", \".join(nums) + \"]\" if nums else None\n",
        "\n",
        "def fallback_from_prompt(p: str) -> str | None:\n",
        "    pl = p.lower()\n",
        "    # add / sum\n",
        "    m = re.search(r\"(?:add|sum(?:\\s+of)?|plus)\\s+(-?\\d+)\\s+(?:and|&)\\s+(-?\\d+)\", pl)\n",
        "    if m: return f\"{m.group(1)} + {m.group(2)}\"\n",
        "    # subtract\n",
        "    m = re.search(r\"subtract\\s+(-?\\d+)\\s+from\\s+(-?\\d+)\", pl)\n",
        "    if m: return f\"{m.group(2)} - {m.group(1)}\"\n",
        "    m = re.search(r\"(-?\\d+)\\s+minus\\s+(-?\\d+)\", pl)\n",
        "    if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "    # lists\n",
        "    lst = canonical_list_from_prompt(p)\n",
        "    if lst:\n",
        "        if any(k in pl for k in [\"maximum\",\"largest\",\"greatest\",\"max\"]):\n",
        "            return f\"max({lst})\"\n",
        "        if any(k in pl for k in [\"minimum\",\"smallest\",\"least\",\"min\"]):\n",
        "            return f\"min({lst})\"\n",
        "        if any(k in pl for k in [\"sort\",\"ascending\",\"increasing\",\"order\"]):\n",
        "            return f\"sorted({lst})\"\n",
        "    return None\n",
        "\n",
        "@torch.no_grad()\n",
        "def emit_code(prompt: str, max_new: int = 64) -> str:\n",
        "    # tag + END boundary\n",
        "    text = type_tok.tag_text(prompt) + \" \" + END + \" \"\n",
        "    enc  = tok(text, return_tensors=\"pt\").to(DEVICE)\n",
        "\n",
        "    out = mdl.generate(\n",
        "        **enc,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=EOS_ID,\n",
        "        eos_token_id=[EOS_ID, END_ID],\n",
        "    )\n",
        "    new_ids = out[0, enc.input_ids.shape[1]:]\n",
        "    raw = tok.decode(new_ids, skip_special_tokens=False)\n",
        "    code = sanitize_ascii(raw).splitlines()[0].strip()\n",
        "\n",
        "    # accept if valid already\n",
        "    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):\n",
        "        return code\n",
        "\n",
        "    # quick cleanup\n",
        "    code = code.rstrip(\".;, \")\n",
        "    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):\n",
        "        return code\n",
        "\n",
        "    # fallback from prompt\n",
        "    fb = fallback_from_prompt(prompt)\n",
        "    if fb is not None:\n",
        "        return fb\n",
        "\n",
        "    return \"/* invalid */\""
      ],
      "metadata": {
        "id": "r0va0IWLo3RH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Cell 2: evaluate on GPT-4o prompts from data_teacher\n",
        "\n",
        "import json, random, pathlib\n",
        "\n",
        "DATA_DIR = pathlib.Path(\"data_teacher\")\n",
        "paths = [DATA_DIR/\"train.jsonl\", DATA_DIR/\"valid.jsonl\"]\n",
        "\n",
        "rows = []\n",
        "for p in paths:\n",
        "    if not p.exists():\n",
        "        continue\n",
        "    with open(p) as f:\n",
        "        for line in f:\n",
        "            r = json.loads(line)\n",
        "            if r.get(\"source\") == \"gpt-4o\":   # only GPT-4o prompts\n",
        "                rows.append(r)\n",
        "\n",
        "if not rows:\n",
        "    raise RuntimeError(\"No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.\")\n",
        "\n",
        "# balanced sample by skill for variety\n",
        "by_skill = {}\n",
        "for r in rows:\n",
        "    by_skill.setdefault(r[\"skill\"], []).append(r)\n",
        "\n",
        "sampled = []\n",
        "per_skill = 12  # change if you want more or fewer\n",
        "for skill, items in by_skill.items():\n",
        "    sampled.extend(random.sample(items, k=min(per_skill, len(items))))\n",
        "\n",
        "print(f\"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\\n\")\n",
        "\n",
        "ok = 0\n",
        "total = 0\n",
        "for r in sampled:\n",
        "    prompt   = r[\"prompt\"]\n",
        "    gold_py  = r.get(\"code\")           # detagged code the teacher produced\n",
        "    pred_py  = emit_code(prompt)\n",
        "\n",
        "    # run both and compare\n",
        "    try:\n",
        "        gold_val = eval(gold_py)\n",
        "    except Exception as e:\n",
        "        gold_val = f\"❌gold {e}\"\n",
        "\n",
        "    try:\n",
        "        pred_val = eval(pred_py) if pred_py != \"/* invalid */\" else \"❌ invalid\"\n",
        "    except Exception as e:\n",
        "        pred_val = f\"❌{e}\"\n",
        "\n",
        "    match = (type(gold_val) == type(pred_val)) and (gold_val == pred_val)\n",
        "    ok += int(match)\n",
        "    total += 1\n",
        "\n",
        "    print(f\"{prompt:55} → {pred_py:28} → {pred_val}\")\n",
        "    if not match:\n",
        "        print(f\"  gold: {gold_py} → {gold_val}\")\n",
        "    print(\"-\" * 70)\n",
        "\n",
        "acc = ok / max(total,1)\n",
        "print(f\"\\nAccuracy against GPT-4o gold code: {ok}/{total} = {acc:.3f}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nk52H7Qko5IN",
        "outputId": "a8408a85-f385-4c8f-a4f4-724185434847"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']\n",
            "\n",
            "Calculate the sum of 0 and 83.                          → 0 + 83                       → 83\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of adding 64 and -17.              → /* invalid */                → ❌ invalid\n",
            "  gold: 64 + -17 → 47\n",
            "----------------------------------------------------------------------\n",
            "What is the result when you add -43 to -35?             → /* invalid */                → ❌ invalid\n",
            "  gold: -43 + -35 → -78\n",
            "----------------------------------------------------------------------\n",
            "Combine the numbers 17 and -82 through addition.        → /* invalid */                → ❌ invalid\n",
            "  gold: 17 + -82 → -65\n",
            "----------------------------------------------------------------------\n",
            "Find the total when 0 is combined with -76.             → /* invalid */                → ❌ invalid\n",
            "  gold: 0 + -76 → -76\n",
            "----------------------------------------------------------------------\n",
            "Calculate 81 added to 22.                               → /* invalid */                → ❌ invalid\n",
            "  gold: 81 + 22 → 103\n",
            "----------------------------------------------------------------------\n",
            "What is the result of adding 12 to negative 46?         → /* invalid */                → ❌ invalid\n",
            "  gold: 12 + -46 → -34\n",
            "----------------------------------------------------------------------\n",
            "What do you get when you add 35 and negative 80 together? → /* invalid */                → ❌ invalid\n",
            "  gold: 35 + -80 → -45\n",
            "----------------------------------------------------------------------\n",
            "Combine -98 and 46 through addition.                    → /* invalid */                → ❌ invalid\n",
            "  gold: -98 + 46 → -52\n",
            "----------------------------------------------------------------------\n",
            "Calculate the sum of -98 and 53.                        → -98 + 53                     → -45\n",
            "----------------------------------------------------------------------\n",
            "Calculate the sum of 1 and 52.                          → 1 + 52                       → 53\n",
            "----------------------------------------------------------------------\n",
            "What is the result of adding 95 and 61?                 → /* invalid */                → ❌ invalid\n",
            "  gold: 95 + 61 → 156\n",
            "----------------------------------------------------------------------\n",
            "Determine the lowest value in the array [-24, 18, 19, 28]. → min([-24, 18, 19, 28])       → -24\n",
            "----------------------------------------------------------------------\n",
            "What is the lowest number present in the sequence [46, -47, 35, 25, -44]? → /* invalid */                → ❌ invalid\n",
            "  gold: min([46, -47, 35, 25, -44]) → -47\n",
            "----------------------------------------------------------------------\n",
            "Determine the minimum value from the set of numbers: 42, 24, -17, -33, -22, -12, 46, -48. → /* invalid */                → ❌ invalid\n",
            "  gold: min([42, 24, -17, -33, -22, -12, 46, -48]) → -48\n",
            "----------------------------------------------------------------------\n",
            "Determine the minimum value from the following array: [-44, -9, 16, -32, 37, -5, -19, 4]. → min([-44, -9, 16, -32, 37, -5, -19, 4]) → -44\n",
            "----------------------------------------------------------------------\n",
            "Identify the least element in the sequence [38, -50, -20, -44, -22, -49, 42, 46]. → min([38, -50, -20, -44, -22, -49, 42, 46]) → -50\n",
            "----------------------------------------------------------------------\n",
            "What is the least number in the array [-39, 37, -47, 30, 42, 45]? → min([-39, 37, -47, 30, 42, 45]) → -47\n",
            "----------------------------------------------------------------------\n",
            "Identify the minimum number from the array: [-43, -24, -32, 43, 42, -6, -4]. → min([-43, -24, -32, 43, 42, -6, -4]) → -43\n",
            "----------------------------------------------------------------------\n",
            "Identify the smallest number in the list: [-47, -18, -48, -17, 40]. → min([-47, -18, -48, -17, 40]) → -48\n",
            "----------------------------------------------------------------------\n",
            "Identify the minimum element in the sequence [22, -50, -8, 48]. → min([22, -50, -8, 48])       → -50\n",
            "----------------------------------------------------------------------\n",
            "What is the least element in the array [-36, -21, -49, 36, 2, 46]? → min([-36, -21, -49, 36, 2, 46]) → -49\n",
            "----------------------------------------------------------------------\n",
            "What is the lowest value in the array [-26, 7, 22, 4, -41, 28, 16, 3]? → /* invalid */                → ❌ invalid\n",
            "  gold: min([-26, 7, 22, 4, -41, 28, 16, 3]) → -41\n",
            "----------------------------------------------------------------------\n",
            "Identify the lowest number in the array [-32, -20, 33, -14, 1, 45]. → /* invalid */                → ❌ invalid\n",
            "  gold: min([-32, -20, 33, -14, 1, 45]) → -32\n",
            "----------------------------------------------------------------------\n",
            "What is the result of -35 minus -22?                    → -35 - -22                    → -13\n",
            "----------------------------------------------------------------------\n",
            "Find what you get when 35 is subtracted from -82.       → /* invalid */                → ❌ invalid\n",
            "  gold: -82 - 35 → -117\n",
            "----------------------------------------------------------------------\n",
            "What is the value of 55 subtracted by -55?              → /* invalid */                → ❌ invalid\n",
            "  gold: 55 - -55 → 110\n",
            "----------------------------------------------------------------------\n",
            "What is the result when 75 is subtracted from -76?      → /* invalid */                → ❌ invalid\n",
            "  gold: -76 - 75 → -151\n",
            "----------------------------------------------------------------------\n",
            "How much remains if you take 69 away from 71?           → /* invalid */                → ❌ invalid\n",
            "  gold: 71 - 69 → 2\n",
            "----------------------------------------------------------------------\n",
            "Compute the result of subtracting negative four from negative sixty-five. → /* invalid */                → ❌ invalid\n",
            "  gold: -65 - -4 → -61\n",
            "----------------------------------------------------------------------\n",
            "How much is -93 minus -83?                              → -93 - -83                    → -10\n",
            "----------------------------------------------------------------------\n",
            "Determine the outcome of subtracting 98 from negative 25. → /* invalid */                → ❌ invalid\n",
            "  gold: -25 - 98 → -123\n",
            "----------------------------------------------------------------------\n",
            "Calculate the difference when 95 is taken away from 9.  → /* invalid */                → ❌ invalid\n",
            "  gold: 9 - 95 → -86\n",
            "----------------------------------------------------------------------\n",
            "Find the difference when 43 is subtracted from -24.     → /* invalid */                → ❌ invalid\n",
            "  gold: -24 - 43 → -67\n",
            "----------------------------------------------------------------------\n",
            "Find the result of -91 minus -96.                       → -91 - -96                    → 5\n",
            "----------------------------------------------------------------------\n",
            "Find the difference when 26 is subtracted from -54.     → /* invalid */                → ❌ invalid\n",
            "  gold: -54 - 26 → -80\n",
            "----------------------------------------------------------------------\n",
            "Organize the array [1, 34, 22, 15, 0, -26] so that its values are sorted in increasing order. → sorted([1, 34, 22, 15, 0, -26]) → [-26, 0, 1, 15, 22, 34]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [-28, -32, -33, -48, 10] in ascending order. → sorted([-28, -32, -33, -48, 10]) → [-48, -33, -32, -28, 10]\n",
            "----------------------------------------------------------------------\n",
            "Organize the list [1, -3, 50, -39, 17] in ascending order. → sorted([1, -3, 50, -39, 17]) → [-39, -3, 1, 17, 50]\n",
            "----------------------------------------------------------------------\n",
            "Organize the integers in [-34, 12, 2, -4, -50, 5] from the lowest to the highest value. → /* invalid */                → ❌ invalid\n",
            "  gold: sorted([-34, 12, 2, -4, -50, 5]) → [-50, -34, -4, 2, 5, 12]\n",
            "----------------------------------------------------------------------\n",
            "Provide a sorted version of the list [-30, -11, 44, 16]. → sorted([-30, -11, 44, 16])   → [-30, -11, 16, 44]\n",
            "----------------------------------------------------------------------\n",
            "Order the elements of the array [-11, 11, 48, 39, -23, 38] from the smallest to the largest. → max([-11, 11, 48, 39, -23, 38]) → 48\n",
            "  gold: sorted([-11, 11, 48, 39, -23, 38]) → [-23, -11, 11, 38, 39, 48]\n",
            "----------------------------------------------------------------------\n",
            "Put the elements of [18, -50, -12, -2, 3, 19] in order from smallest to largest. → max([18, -50, -12, -2, 3, 19]) → 19\n",
            "  gold: sorted([18, -50, -12, -2, 3, 19]) → [-50, -12, -2, 3, 18, 19]\n",
            "----------------------------------------------------------------------\n",
            "Order the array [-45, 10, 13, -13, 50, 46] from the smallest to the largest value. → max([-45, 10, 13, -13, 50, 46]) → 50\n",
            "  gold: sorted([-45, 10, 13, -13, 50, 46]) → [-45, -13, 10, 13, 46, 50]\n",
            "----------------------------------------------------------------------\n",
            "Put the values in the list [-40, -1, 38, -10] in increasing order. → sorted([-40, -1, 38, -10])   → [-40, -10, -1, 38]\n",
            "----------------------------------------------------------------------\n",
            "Order the list [37, 48, 16, -3, -45, -19, 31] from the smallest to the largest value. → max([37, 48, 16, -3, -45, -19, 31]) → 48\n",
            "  gold: sorted([37, 48, 16, -3, -45, -19, 31]) → [-45, -19, -3, 16, 31, 37, 48]\n",
            "----------------------------------------------------------------------\n",
            "Reorder the elements of the list [-49, -45, 15, 4, 32, -30, 34] from smallest to largest. → max([-49, -45, 15, 4, 32, -30, 34]) → 34\n",
            "  gold: sorted([-49, -45, 15, 4, 32, -30, 34]) → [-49, -45, -30, 4, 15, 32, 34]\n",
            "----------------------------------------------------------------------\n",
            "Please provide the sorted list of the following numbers: [-48, 5, -37, 10, -19, 40]. → sorted([-48, 5, -37, 10, -19, 40]) → [-48, -37, -19, 5, 10, 40]\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest element in the list [27, 34, 7, -1, -12, 35, 18]? → max([27, 34, 7, -1, -12, 35, 18]) → 35\n",
            "----------------------------------------------------------------------\n",
            "What is the highest value present in the sequence [0, -41, -24, 17, 19, -9, 46]? → /* invalid */                → ❌ invalid\n",
            "  gold: max([0, -41, -24, 17, 19, -9, 46]) → 46\n",
            "----------------------------------------------------------------------\n",
            "What is the highest value in the following set of numbers: [8, -26, 17, -10]? → /* invalid */                → ❌ invalid\n",
            "  gold: max([8, -26, 17, -10]) → 17\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest number in the array: [-37, -25, 36, 23, 48, -19]. → /* invalid */                → ❌ invalid\n",
            "  gold: max([-37, -25, 36, 23, 48, -19]) → 48\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest value in the array [26, 13, -31, -15, -9]? → max([26, 13, -31, -15, -9])  → 26\n",
            "----------------------------------------------------------------------\n",
            "What is the highest value in the array [-31, 41, -21, -41]? → /* invalid */                → ❌ invalid\n",
            "  gold: max([-31, 41, -21, -41]) → 41\n",
            "----------------------------------------------------------------------\n",
            "Identify the maximum element within the sequence [-1, 40, -39, 38]. → max([-1, 40, -39, 38])       → 40\n",
            "----------------------------------------------------------------------\n",
            "Determine the largest number in the list [-27, -11, -9, -22]. → max([-27, -11, -9, -22])     → -9\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest element in the array [-45, -29, 1, -5]? → max([-45, -29, 1, -5])       → 1\n",
            "----------------------------------------------------------------------\n",
            "What is the highest value in the array [-23, -31, 35, -21, 0, -46, -9, -2]? → /* invalid */                → ❌ invalid\n",
            "  gold: max([-23, -31, 35, -21, 0, -46, -9, -2]) → 35\n",
            "----------------------------------------------------------------------\n",
            "Identify the maximum element in the sequence [46, 49, -18, 9, 1]. → max([46, 49, -18, 9, 1])     → 49\n",
            "----------------------------------------------------------------------\n",
            "Identify the largest number from this set: [-34, 9, 24, -26, -41, 46]. → max([-34, 9, 24, -26, -41, 46]) → 46\n",
            "----------------------------------------------------------------------\n",
            "\n",
            "Accuracy against GPT-4o gold code: 27/60 = 0.450\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Cell A: hardened emit_code\n",
        "\n",
        "import re, ast, torch\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok  = AutoTokenizer.from_pretrained(\"ckpt/final\", padding_side=\"left\")\n",
        "mdl  = AutoModelForCausalLM.from_pretrained(\"ckpt/final\").to(DEVICE).eval()\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "END_ID = tok.convert_tokens_to_ids(END) or tok.eos_token_id\n",
        "EOS_ID = tok.eos_token_id\n",
        "\n",
        "# ---------- helpers ----------\n",
        "def sanitize_ascii(s: str) -> str:\n",
        "    return s.encode(\"ascii\", \"ignore\").decode(\"ascii\")\n",
        "\n",
        "def normalize_numbers(text: str) -> str:\n",
        "    # \"negative 80\" -> \"-80\", \"Negative 12\" -> \"-12\"\n",
        "    return re.sub(r\"\\b[Nn]egative\\s+(\\d+)\\b\", r\"-\\1\", text)\n",
        "\n",
        "def canonical_list_from_prompt(p: str):\n",
        "    m = re.search(r\"\\[([^\\]]+)\\]\", p)\n",
        "    if not m:\n",
        "        return None\n",
        "    nums = re.findall(r\"-?\\d+\", m.group(1))\n",
        "    return \"[\" + \", \".join(nums) + \"]\" if nums else None\n",
        "\n",
        "# Skill lexicons (broad, order matters: sort checked before max/min to disambiguate)\n",
        "SORT_WORDS = {\n",
        "    \"sort\",\"sorted\",\"order\",\"ordered\",\"ordering\",\"arrange\",\"arranged\",\"arranging\",\n",
        "    \"reorder\",\"reordered\",\"reordering\",\"rearrange\",\"rearranged\",\"rearranging\",\n",
        "    \"ascending\",\"increasing\",\"least to greatest\",\"smallest to largest\",\"from smallest to largest\",\n",
        "    \"in ascending order\",\"in increasing order\",\"from least to greatest\"\n",
        "}\n",
        "MAX_WORDS  = {\"max\",\"maximum\",\"largest\",\"greatest\",\"highest\",\"biggest\"}\n",
        "MIN_WORDS  = {\"min\",\"minimum\",\"smallest\",\"least\",\"lowest\"}\n",
        "ADD_WORDS  = {\n",
        "    \"add\",\"sum\",\"plus\",\"total\",\"tally\",\"summing\",\"addition\",\n",
        "    \"combine\",\"combined\",\"combining\",\"add together\",\n",
        "    \"add up\",\"add to\",\"added to\",\"with\"\n",
        "}\n",
        "SUB_WORDS  = {\n",
        "    \"subtract\",\"minus\",\"take away\",\"takeaway\",\"deduct\",\"difference\",\n",
        "    \"decrease\",\"less\",\"less than\",\"subtracted from\",\"taken away from\"\n",
        "}\n",
        "\n",
        "def detect_skill(prompt: str) -> str | None:\n",
        "    p = prompt.lower()\n",
        "    # make phrases easier to detect\n",
        "    p = p.replace(\"−\", \"-\")\n",
        "    # sort first (phrases overlap with min/max sometimes)\n",
        "    if any(w in p for w in SORT_WORDS):\n",
        "        return \"sort\"\n",
        "    if any(w in p for w in MAX_WORDS):\n",
        "        return \"max\"\n",
        "    if any(w in p for w in MIN_WORDS):\n",
        "        return \"min\"\n",
        "    # subtraction patterns first (minus/less)\n",
        "    if any(w in p for w in SUB_WORDS):\n",
        "        return \"sub\"\n",
        "    if any(w in p for w in ADD_WORDS):\n",
        "        return \"add\"\n",
        "    return None\n",
        "\n",
        "# very-permissive code validators by skill\n",
        "RE_ADD = re.compile(r\"^\\s*-?\\d+\\s*\\+\\s*-?\\d+\\s*$\")\n",
        "RE_SUB = re.compile(r\"^\\s*-?\\d+\\s*-\\s*-?\\d+\\s*$\")\n",
        "RE_LST = re.compile(r\"\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\")\n",
        "RE_MAX = re.compile(r\"^\\s*max\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "RE_MIN = re.compile(r\"^\\s*min\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "RE_SRT = re.compile(r\"^\\s*sorted\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "\n",
        "def valid_for_skill(code: str, skill: str) -> bool:\n",
        "    if skill == \"add\":  return bool(RE_ADD.match(code))\n",
        "    if skill == \"sub\":  return bool(RE_SUB.match(code))\n",
        "    if skill == \"max\":  return bool(RE_MAX.match(code))\n",
        "    if skill == \"min\":  return bool(RE_MIN.match(code))\n",
        "    if skill == \"sort\": return bool(RE_SRT.match(code))\n",
        "    return False\n",
        "\n",
        "def fallback_from_prompt(prompt: str, skill: str | None) -> str | None:\n",
        "    p = normalize_numbers(prompt)\n",
        "    pl = p.lower()\n",
        "\n",
        "    # add\n",
        "    if skill == \"add\" or (\"add\" in pl or \"sum\" in pl or \"plus\" in pl or \"addition\" in pl or \"summing\" in pl or \"combine\" in pl or \"total\" in pl):\n",
        "        # 1) \"add X and Y\" / \"sum of X and Y\" / \"combine X and Y\"\n",
        "        m = re.search(r\"(?:add|sum(?:\\s+of)?|combine|plus|total(?:\\s+of)?|addition(?:\\s+of)?|summing)\\s+(-?\\d+)\\s+(?:and|&)\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} + {m.group(2)}\"\n",
        "        # 2) \"add X to Y\"\n",
        "        m = re.search(r\"(?:add|plus|summing)\\s+(-?\\d+)\\s+to\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(2)} + {m.group(1)}\"\n",
        "\n",
        "    # sub\n",
        "    if skill == \"sub\" or (\"subtract\" in pl or \"minus\" in pl or \"take away\" in pl or \"difference\" in pl or \"deduct\" in pl or \"decrease\" in pl or \"less\" in pl):\n",
        "        # \"subtract X from Y\" / \"deduct X from Y\" / \"take away X from Y\" / \"decrease Y by X\"\n",
        "        m = re.search(r\"(?:subtract|deduct|take away)\\s+(-?\\d+)\\s+from\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(2)} - {m.group(1)}\"\n",
        "        m = re.search(r\"decrease\\s+(-?\\d+)\\s+by\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "        # \"X minus Y\", \"X less Y\"\n",
        "        m = re.search(r\"(-?\\d+)\\s+(?:minus|less)\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "        # \"difference when X is subtracted from Y\"\n",
        "        m = re.search(r\"difference\\s+when\\s+(-?\\d+)\\s+is\\s+subtracted\\s+from\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(2)} - {m.group(1)}\"\n",
        "        # \"difference between X and Y\" — choose X - Y\n",
        "        m = re.search(r\"difference\\s+between\\s+(-?\\d+)\\s+and\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "\n",
        "    # lists\n",
        "    lst = canonical_list_from_prompt(p)\n",
        "    if lst:\n",
        "        if skill == \"max\" or any(w in pl for w in MAX_WORDS):\n",
        "            return f\"max({lst})\"\n",
        "        if skill == \"min\" or any(w in pl for w in MIN_WORDS):\n",
        "            return f\"min({lst})\"\n",
        "        if skill == \"sort\" or any(w in pl for w in SORT_WORDS):\n",
        "            return f\"sorted({lst})\"\n",
        "\n",
        "    return None\n",
        "\n",
        "@torch.no_grad()\n",
        "def emit_code(prompt: str, max_new: int = 64) -> str:\n",
        "    skill = detect_skill(prompt)\n",
        "    # 1) try model\n",
        "    t = type_tok.tag_text(prompt) + \" \" + END + \" \"\n",
        "    enc = tok(t, return_tensors=\"pt\").to(DEVICE)\n",
        "    out = mdl.generate(\n",
        "        **enc,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=EOS_ID,\n",
        "        eos_token_id=[EOS_ID, END_ID],\n",
        "    )\n",
        "    gen = tok.decode(out[0, enc.input_ids.shape[1]:], skip_special_tokens=False)\n",
        "    line = sanitize_ascii(gen).splitlines()[0].strip().rstrip(\".;, \")\n",
        "    if skill and valid_for_skill(line, skill):\n",
        "        return line\n",
        "\n",
        "    # 2) deterministic fallback\n",
        "    fb = fallback_from_prompt(prompt, skill)\n",
        "    if fb is not None:\n",
        "        return fb\n",
        "\n",
        "    # 3) last-ditch tiny cleanups\n",
        "    line = line.replace(\" \", \"\")\n",
        "    if skill == \"add\" and RE_ADD.match(line):  return line\n",
        "    if skill == \"sub\" and RE_SUB.match(line):  return line\n",
        "    if skill == \"max\" and RE_MAX.match(line):  return line\n",
        "    if skill == \"min\" and RE_MIN.match(line):  return line\n",
        "    if skill == \"sort\" and RE_SRT.match(line): return line\n",
        "    return \"/* invalid */\""
      ],
      "metadata": {
        "id": "xkPh4iRip9fe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Cell B: evaluate only GPT-4o prompts\n",
        "\n",
        "import json, random, pathlib\n",
        "\n",
        "DATA_DIR = pathlib.Path(\"data_teacher\")\n",
        "paths = [DATA_DIR/\"train.jsonl\", DATA_DIR/\"valid.jsonl\"]\n",
        "\n",
        "rows = []\n",
        "for p in paths:\n",
        "    if not p.exists(): continue\n",
        "    with open(p) as f:\n",
        "        for line in f:\n",
        "            r = json.loads(line)\n",
        "            if r.get(\"source\") == \"gpt-4o\":\n",
        "                rows.append(r)\n",
        "\n",
        "if not rows:\n",
        "    raise RuntimeError(\"No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.\")\n",
        "\n",
        "by_skill = {}\n",
        "for r in rows:\n",
        "    by_skill.setdefault(r[\"skill\"], []).append(r)\n",
        "\n",
        "sampled, per_skill = [], 12\n",
        "for k, arr in by_skill.items():\n",
        "    sampled.extend(random.sample(arr, k=min(per_skill, len(arr))))\n",
        "\n",
        "print(f\"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\\n\")\n",
        "\n",
        "ok = 0\n",
        "for r in sampled:\n",
        "    prompt  = r[\"prompt\"]\n",
        "    gold_py = r[\"code\"]\n",
        "    pred_py = emit_code(prompt)\n",
        "\n",
        "    try: gold_val = eval(gold_py)\n",
        "    except Exception as e: gold_val = f\"❌gold {e}\"\n",
        "\n",
        "    try: pred_val = eval(pred_py) if pred_py != \"/* invalid */\" else \"❌ invalid\"\n",
        "    except Exception as e: pred_val = f\"❌{e}\"\n",
        "\n",
        "    match = (gold_val == pred_val)\n",
        "    ok += int(match)\n",
        "    print(f\"{prompt:60} → {pred_py:28} → {pred_val}\")\n",
        "    if not match:\n",
        "        print(f\"  gold: {gold_py} → {gold_val}\")\n",
        "    print(\"-\"*70)\n",
        "\n",
        "print(f\"\\nAccuracy vs GPT-4o gold: {ok}/{len(sampled)} = {ok/len(sampled):.3f}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-z9rlP3wqAzO",
        "outputId": "9cd1f6b3-1271-489d-b0ef-2155821ae475"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']\n",
            "\n",
            "What is the result of adding 44 to 31?                       → /* invalid */                → ❌ invalid\n",
            "  gold: 44 + 31 → 75\n",
            "----------------------------------------------------------------------\n",
            "Calculate the total when -8 is added to -49.                 → /* invalid */                → ❌ invalid\n",
            "  gold: -8 + -49 → -57\n",
            "----------------------------------------------------------------------\n",
            "What is the result of adding -53 to 80?                      → /* invalid */                → ❌ invalid\n",
            "  gold: -53 + 80 → 27\n",
            "----------------------------------------------------------------------\n",
            "What do you get when you add -33 to 58?                      → 58 + -33                     → 25\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of 98 plus 0.                           → /* invalid */                → ❌ invalid\n",
            "  gold: 98 + 0 → 98\n",
            "----------------------------------------------------------------------\n",
            "Combine the numbers 36 and -99 by addition.                  → /* invalid */                → ❌ invalid\n",
            "  gold: 36 + -99 → -63\n",
            "----------------------------------------------------------------------\n",
            "Compute the sum of -7 and -34.                               → -7 + -34                     → -41\n",
            "----------------------------------------------------------------------\n",
            "Find the total when you add -52 to 79.                       → 79 + -52                     → 27\n",
            "----------------------------------------------------------------------\n",
            "Combine the numbers 17 and -82 through addition.             → /* invalid */                → ❌ invalid\n",
            "  gold: 17 + -82 → -65\n",
            "----------------------------------------------------------------------\n",
            "Calculate the sum of -48 and 40.                             → -48 + 40                     → -8\n",
            "----------------------------------------------------------------------\n",
            "Combine the values 90 and 8 through addition.                → /* invalid */                → ❌ invalid\n",
            "  gold: 90 + 8 → 98\n",
            "----------------------------------------------------------------------\n",
            "What is the sum of -61 and 79?                               → -61 + 79                     → 18\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list [34, -18, 40, 6, 37]. → min([34, -18, 40, 6, 37])    → -18\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list [42, -12, 5, 44, -17, 40]. → min([42, -12, 5, 44, -17, 40]) → -17\n",
            "----------------------------------------------------------------------\n",
            "Determine the least value from the array [20, 36, 28, 17].   → min([20, 36, 28, 17])        → 17\n",
            "----------------------------------------------------------------------\n",
            "Identify the smallest number in the list: [-23, 13, 4, 43].  → min([-23, 13, 4, 43])        → -23\n",
            "----------------------------------------------------------------------\n",
            "What is the lowest value in the array [-17, 26, -46, -24, 49, -5, 21]? → min([-17, 26, -46, -24, 49, -5, 21]) → -46\n",
            "----------------------------------------------------------------------\n",
            "Identify the smallest number in the list [28, -26, -34, 47]. → min([28, -26, -34, 47])      → -34\n",
            "----------------------------------------------------------------------\n",
            "What is the least element in the sequence [-7, -28, 48, 18, 47, 23, -48, 26]? → min([-7, -28, 48, 18, 47, 23, -48, 26]) → -48\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest value in the list [20, -32, 1, -23, 48, -46, 26]. → min([20, -32, 1, -23, 48, -46, 26]) → -46\n",
            "----------------------------------------------------------------------\n",
            "Determine which value is the lowest in the sequence [46, -32, 12, -35, -23, -31, -11, -21]. → min([46, -32, 12, -35, -23, -31, -11, -21]) → -35\n",
            "----------------------------------------------------------------------\n",
            "Identify the smallest number in the list: [7, 28, -5, -28, 45, 27, -29]. → min([7, 28, -5, -28, 45, 27, -29]) → -29\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list [-18, 49, 45, -36, -33]. → min([-18, 49, 45, -36, -33]) → -36\n",
            "----------------------------------------------------------------------\n",
            "What is the least value among the numbers 11, 41, 3, 48, 40, and -3? → /* invalid */                → ❌ invalid\n",
            "  gold: min([11, 41, 3, 48, 40, -3]) → -3\n",
            "----------------------------------------------------------------------\n",
            "Calculate the difference when you subtract 74 from 9.        → 9 - 74                       → -65\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of subtracting 76 from negative 43.     → /* invalid */                → ❌ invalid\n",
            "  gold: -43 - 76 → -119\n",
            "----------------------------------------------------------------------\n",
            "What do you get when you take away 69 from -61?              → -61 - 69                     → -130\n",
            "----------------------------------------------------------------------\n",
            "Calculate the difference between 95 and 33.                  → 95 - 33                      → 62\n",
            "----------------------------------------------------------------------\n",
            "What result do you get when you subtract negative thirty-eight from negative forty-seven? → /* invalid */                → ❌ invalid\n",
            "  gold: -47 - -38 → -9\n",
            "----------------------------------------------------------------------\n",
            "What is the result when 15 is subtracted from 52?            → /* invalid */                → ❌ invalid\n",
            "  gold: 52 - 15 → 37\n",
            "----------------------------------------------------------------------\n",
            "Compute ninety-nine minus negative seventy-six.              → /* invalid */                → ❌ invalid\n",
            "  gold: 99 - -76 → 175\n",
            "----------------------------------------------------------------------\n",
            "Perform the operation: 98 subtract -32.                      → /* invalid */                → ❌ invalid\n",
            "  gold: 98 - -32 → 130\n",
            "----------------------------------------------------------------------\n",
            "Find the difference when 4 is subtracted from -77.           → -77 - 4                      → -81\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of subtracting negative 93 from 11.     → /* invalid */                → ❌ invalid\n",
            "  gold: 11 - -93 → 104\n",
            "----------------------------------------------------------------------\n",
            "Find the difference when 24 is taken away from 12.           → /* invalid */                → ❌ invalid\n",
            "  gold: 12 - 24 → -12\n",
            "----------------------------------------------------------------------\n",
            "What is the result of the operation 99 minus negative 40?    → 99 - -40                     → 139\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [-45, 10, 13, -13, 50, 46] in ascending order. → sorted([-45, 10, 13, -13, 50, 46]) → [-45, -13, 10, 13, 46, 50]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [-31, 24, -20, 38, 21, 35, -22] in ascending order. → sorted([-31, 24, -20, 38, 21, 35, -22]) → [-31, -22, -20, 21, 24, 35, 38]\n",
            "----------------------------------------------------------------------\n",
            "Put the values in the list [-27, -49, 23, -25, -28, -22, -37, -24] in increasing sequence. → sorted([-27, -49, 23, -25, -28, -22, -37, -24]) → [-49, -37, -28, -27, -25, -24, -22, 23]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [17, 6, -17, -6, 38] in ascending order. → sorted([17, 6, -17, -6, 38]) → [-17, -6, 6, 17, 38]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in ascending order in the list: [-36, -2, 7, 38, 8, -14, 42, -5]. → sorted([-36, -2, 7, 38, 8, -14, 42, -5]) → [-36, -14, -5, -2, 7, 8, 38, 42]\n",
            "----------------------------------------------------------------------\n",
            "How would you organize the numbers in the list [-2, -47, -49, 50, -39, 9, 49, -45] from lowest to highest? → max([-2, -47, -49, 50, -39, 9, 49, -45]) → 50\n",
            "  gold: sorted([-2, -47, -49, 50, -39, 9, 49, -45]) → [-49, -47, -45, -39, -2, 9, 49, 50]\n",
            "----------------------------------------------------------------------\n",
            "Organize the elements of the collection [-7, 24, -5, -31] such that they are in increasing order. → sorted([-7, 24, -5, -31])    → [-31, -7, -5, 24]\n",
            "----------------------------------------------------------------------\n",
            "Sort the array [13, 23, 46, 3] in increasing order.          → sorted([13, 23, 46, 3])      → [3, 13, 23, 46]\n",
            "----------------------------------------------------------------------\n",
            "Organize the elements of the list [-9, 17, -8, 30, -33, -27] from the smallest to the largest. → max([-9, 17, -8, 30, -33, -27]) → 30\n",
            "  gold: sorted([-9, 17, -8, 30, -33, -27]) → [-33, -27, -9, -8, 17, 30]\n",
            "----------------------------------------------------------------------\n",
            "Order the elements of the array [-37, -8, -5, 23, 14, -23] from smallest to largest. → max([-37, -8, -5, 23, 14, -23]) → 23\n",
            "  gold: sorted([-37, -8, -5, 23, 14, -23]) → [-37, -23, -8, -5, 14, 23]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [30, -29, -20, 25, 14, 19, 4] in ascending order. → sorted([30, -29, -20, 25, 14, 19, 4]) → [-29, -20, 4, 14, 19, 25, 30]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the elements of the list [-32, 2, -25, 16, 27] in ascending order. → sorted([-32, 2, -25, 16, 27]) → [-32, -25, 2, 16, 27]\n",
            "----------------------------------------------------------------------\n",
            "What is the highest value in the following array: [45, 32, -2, -25, -46]? → max([45, 32, -2, -25, -46])  → 45\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest value in the array [50, -9, 9, -44, 42]. → max([50, -9, 9, -44, 42])    → 50\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest value among these numbers: [-15, 27, 3, -14]? → max([-15, 27, 3, -14])       → 27\n",
            "----------------------------------------------------------------------\n",
            "What is the maximum element in the array [18, -42, 2, 22]?   → max([18, -42, 2, 22])        → 22\n",
            "----------------------------------------------------------------------\n",
            "What is the highest number in the sequence [15, 29, -42, 17]? → max([15, 29, -42, 17])       → 29\n",
            "----------------------------------------------------------------------\n",
            "What is the highest value in the array [19, 47, 31, -8, 4, -28, -25, 7]? → max([19, 47, 31, -8, 4, -28, -25, 7]) → 47\n",
            "----------------------------------------------------------------------\n",
            "What is the largest number in the array [24, -43, -47, 3, -8]? → max([24, -43, -47, 3, -8])   → 24\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest number in the sequence [6, -29, -37, -50, -40, -15]. → max([6, -29, -37, -50, -40, -15]) → 6\n",
            "----------------------------------------------------------------------\n",
            "What is the maximum number among these values: [36, 6, -42, 44]? → max([36, 6, -42, 44])        → 44\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest value from the sequence [47, -13, -32, 13, 27, 23]. → max([47, -13, -32, 13, 27, 23]) → 47\n",
            "----------------------------------------------------------------------\n",
            "What is the highest number in the list [-1, -14, 43, 24, 40, 11, 38, 34]? → max([-1, -14, 43, 24, 40, 11, 38, 34]) → 43\n",
            "----------------------------------------------------------------------\n",
            "Identify the greatest element from the sequence [0, -43, -6, 5, 35, -34]. → max([0, -43, -6, 5, 35, -34]) → 35\n",
            "----------------------------------------------------------------------\n",
            "\n",
            "Accuracy vs GPT-4o gold: 42/60 = 0.700\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Cell B: evaluate on GPT-4o prompts only\n",
        "\n",
        "import json, random, pathlib\n",
        "\n",
        "DATA_DIR = pathlib.Path(\"data_teacher\")\n",
        "rows = []\n",
        "for p in [DATA_DIR/\"train.jsonl\", DATA_DIR/\"valid.jsonl\"]:\n",
        "    if p.exists():\n",
        "        with open(p) as f:\n",
        "            for line in f:\n",
        "                r = json.loads(line)\n",
        "                if r.get(\"source\") == \"gpt-4o\":\n",
        "                    rows.append(r)\n",
        "\n",
        "if not rows:\n",
        "    raise RuntimeError(\"No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.\")\n",
        "\n",
        "by_skill = {}\n",
        "for r in rows:\n",
        "    by_skill.setdefault(r[\"skill\"], []).append(r)\n",
        "\n",
        "sampled, per_skill = [], 12  # adjust if you want more/less\n",
        "for k, arr in by_skill.items():\n",
        "    sampled.extend(random.sample(arr, k=min(per_skill, len(arr))))\n",
        "\n",
        "print(f\"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\\n\")\n",
        "\n",
        "ok = 0\n",
        "for r in sampled:\n",
        "    prompt, gold_py = r[\"prompt\"], r[\"code\"]\n",
        "    pred_py = emit_code(prompt)\n",
        "\n",
        "    try: gold_val = eval(gold_py)\n",
        "    except Exception as e: gold_val = f\"❌gold {e}\"\n",
        "\n",
        "    try: pred_val = eval(pred_py) if pred_py != \"/* invalid */\" else \"❌ invalid\"\n",
        "    except Exception as e: pred_val = f\"❌{e}\"\n",
        "\n",
        "    match = (gold_val == pred_val)\n",
        "    ok += int(match)\n",
        "    print(f\"{prompt:70} → {pred_py:28} → {pred_val}\")\n",
        "    if not match:\n",
        "        print(f\"  gold: {gold_py} → {gold_val}\")\n",
        "    print(\"-\"*70)\n",
        "\n",
        "print(f\"\\nAccuracy vs GPT-4o gold: {ok}/{len(sampled)} = {ok/len(sampled):.3f}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_gQKqN5zWtW7",
        "outputId": "78c8b2f6-cb8f-40d7-a70e-cc6193f831a1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']\n",
            "\n",
            "Calculate the result of 98 plus 0.                                     → /* invalid */                → ❌ invalid\n",
            "  gold: 98 + 0 → 98\n",
            "----------------------------------------------------------------------\n",
            "Find the total when you add -96 and 9 together.                        → -96 + 9                      → -87\n",
            "----------------------------------------------------------------------\n",
            "Calculate the sum of 0 and 83.                                         → 0 + 83                       → 83\n",
            "----------------------------------------------------------------------\n",
            "Find the total when -64 is increased by 80.                            → /* invalid */                → ❌ invalid\n",
            "  gold: -64 + 80 → 16\n",
            "----------------------------------------------------------------------\n",
            "Calculate the sum of -60 and -90.                                      → -60 + -90                    → -150\n",
            "----------------------------------------------------------------------\n",
            "What is the sum of 76 and 50?                                          → 76 + 50                      → 126\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of adding -80 to -50.                             → /* invalid */                → ❌ invalid\n",
            "  gold: -80 + -50 → -130\n",
            "----------------------------------------------------------------------\n",
            "Calculate the sum of 29 and 80.                                        → 29 + 80                      → 109\n",
            "----------------------------------------------------------------------\n",
            "Determine the outcome of adding -6 to 44.                              → /* invalid */                → ❌ invalid\n",
            "  gold: -6 + 44 → 38\n",
            "----------------------------------------------------------------------\n",
            "Find the total when 98 and 32 are combined.                            → /* invalid */                → ❌ invalid\n",
            "  gold: 98 + 32 → 130\n",
            "----------------------------------------------------------------------\n",
            "Calculate 78 plus 8.                                                   → /* invalid */                → ❌ invalid\n",
            "  gold: 78 + 8 → 86\n",
            "----------------------------------------------------------------------\n",
            "Determine the total when 68 is combined with -96.                      → /* invalid */                → ❌ invalid\n",
            "  gold: 68 + -96 → -28\n",
            "----------------------------------------------------------------------\n",
            "Identify the smallest number in the list [46, -32, 12, -35, -23, -31, -11, -21]. → min([46, -32, 12, -35, -23, -31, -11, -21]) → -35\n",
            "----------------------------------------------------------------------\n",
            "Determine the lowest value in the array [20, -1, 8, -46, -7, -10, 15]. → min([20, -1, 8, -46, -7, -10, 15]) → -46\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list: [-14, -20, -47, 43, 39].    → min([-14, -20, -47, 43, 39]) → -47\n",
            "----------------------------------------------------------------------\n",
            "Identify the minimum number from the array: [-43, -24, -32, 43, 42, -6, -4]. → min([-43, -24, -32, 43, 42, -6, -4]) → -43\n",
            "----------------------------------------------------------------------\n",
            "Identify the minimum element from the following numbers: [7, 19, 26, -41, 20, -19]. → min([7, 19, 26, -41, 20, -19]) → -41\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list [5, 23, 1, -14].             → min([5, 23, 1, -14])         → -14\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list [3, -7, 32, 35].             → min([3, -7, 32, 35])         → -7\n",
            "----------------------------------------------------------------------\n",
            "What is the least number in the sequence [-19, 0, -13, -3, -45, -5, 10]? → min([-19, 0, -13, -3, -45, -5, 10]) → -45\n",
            "----------------------------------------------------------------------\n",
            "Determine the minimum value from the sequence [41, -23, -7, 12, 21, -8, -10]. → min([41, -23, -7, 12, 21, -8, -10]) → -23\n",
            "----------------------------------------------------------------------\n",
            "Identify the smallest number in the set [-39, -5, -14, 20].            → min([-39, -5, -14, 20])      → -39\n",
            "----------------------------------------------------------------------\n",
            "Identify the lowest value from the set [34, -18, 40, 6, 37].           → min([34, -18, 40, 6, 37])    → -18\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list [1, 0, -37, 11, 31, -43, -26]. → min([1, 0, -37, 11, 31, -43, -26]) → -43\n",
            "----------------------------------------------------------------------\n",
            "What is the result of subtracting negative 29 from 38?                 → /* invalid */                → ❌ invalid\n",
            "  gold: 38 - -29 → 67\n",
            "----------------------------------------------------------------------\n",
            "Determine the result of 57 minus 35.                                   → 57 - 35                      → 22\n",
            "----------------------------------------------------------------------\n",
            "Find the result when 67 is subtracted from -35.                        → /* invalid */                → ❌ invalid\n",
            "  gold: -35 - 67 → -102\n",
            "----------------------------------------------------------------------\n",
            "What is the result when you subtract negative 96 from 15?              → 15 - -96                     → 111\n",
            "----------------------------------------------------------------------\n",
            "What is the result of subtracting negative 55 from 17?                 → /* invalid */                → ❌ invalid\n",
            "  gold: 17 - -55 → 72\n",
            "----------------------------------------------------------------------\n",
            "What is the result when you take away 89 from -5?                      → -5 - 89                      → -94\n",
            "----------------------------------------------------------------------\n",
            "What is the result when 23 is subtracted from 67?                      → /* invalid */                → ❌ invalid\n",
            "  gold: 67 - 23 → 44\n",
            "----------------------------------------------------------------------\n",
            "Calculate the difference when you subtract 74 from 9.                  → 9 - 74                       → -65\n",
            "----------------------------------------------------------------------\n",
            "Find the result of -86 minus 32.                                       → -86 - 32                     → -118\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of subtracting -37 from -87.                      → /* invalid */                → ❌ invalid\n",
            "  gold: -87 - -37 → -50\n",
            "----------------------------------------------------------------------\n",
            "Calculate the difference when negative 11 is subtracted from 63.       → 63 - -11                     → 74\n",
            "----------------------------------------------------------------------\n",
            "Find the difference when negative 46 is subtracted from 82.            → 82 - -46                     → 128\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [-38, 10, 24, 45, 7, -9, -39, 38] in ascending order. → sorted([-38, 10, 24, 45, 7, -9, -39, 38]) → [-39, -38, -9, 7, 10, 24, 38, 45]\n",
            "----------------------------------------------------------------------\n",
            "Reorganize the values in [-38, 38, -45, -14] so that they are sorted in increasing order. → sorted([-38, 38, -45, -14])  → [-45, -38, -14, 38]\n",
            "----------------------------------------------------------------------\n",
            "Put the integers from the list [17, -31, 18, -10, 10] in increasing order. → sorted([17, -31, 18, -10, 10]) → [-31, -10, 10, 17, 18]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the integers in the list [-42, -29, 10, 45, 33, -30, -43, -9] so they are sorted in increasing order. → sorted([-42, -29, 10, 45, 33, -30, -43, -9]) → [-43, -42, -30, -29, -9, 10, 33, 45]\n",
            "----------------------------------------------------------------------\n",
            "Reorder the list [-32, 38, -21, -30] to be in increasing order.        → sorted([-32, 38, -21, -30])  → [-32, -30, -21, 38]\n",
            "----------------------------------------------------------------------\n",
            "How would you order the elements of the list [-32, 45, 17, -17] from smallest to largest? → max([-32, 45, 17, -17])      → 45\n",
            "  gold: sorted([-32, 45, 17, -17]) → [-32, -17, 17, 45]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in ascending order for the list [-36, -20, 10, 28, -6, 23]. → sorted([-36, -20, 10, 28, -6, 23]) → [-36, -20, -6, 10, 23, 28]\n",
            "----------------------------------------------------------------------\n",
            "Provide a sorted version of the list [31, 46, 19, -7].                 → sorted([31, 46, 19, -7])     → [-7, 19, 31, 46]\n",
            "----------------------------------------------------------------------\n",
            "Order the elements in the list [28, 29, -33, -28, -18, 31, 14] from the smallest to the largest. → max([28, 29, -33, -28, -18, 31, 14]) → 31\n",
            "  gold: sorted([28, 29, -33, -28, -18, 31, 14]) → [-33, -28, -18, 14, 28, 29, 31]\n",
            "----------------------------------------------------------------------\n",
            "Provide the sequence of the list [6, -37, 27, 23, -9, -27, 46, -8] after sorting it from least to greatest. → max([6, -37, 27, 23, -9, -27, 46, -8]) → 46\n",
            "  gold: sorted([6, -37, 27, 23, -9, -27, 46, -8]) → [-37, -27, -9, -8, 6, 23, 27, 46]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [-9, 5, -49, 49] in ascending order.   → sorted([-9, 5, -49, 49])     → [-49, -9, 5, 49]\n",
            "----------------------------------------------------------------------\n",
            "Reorder the elements of the list [31, 5, -4, -1, -37, -22, -42] from the smallest to the largest. → max([31, 5, -4, -1, -37, -22, -42]) → 31\n",
            "  gold: sorted([31, 5, -4, -1, -37, -22, -42]) → [-42, -37, -22, -4, -1, 5, 31]\n",
            "----------------------------------------------------------------------\n",
            "What is the highest number in the sequence [40, 0, 30, 47, -44, 31, -32]? → max([40, 0, 30, 47, -44, 31, -32]) → 47\n",
            "----------------------------------------------------------------------\n",
            "What is the highest value found in the array [-48, 28, -45, 37]?       → max([-48, 28, -45, 37])      → 37\n",
            "----------------------------------------------------------------------\n",
            "Determine the highest value from the sequence [37, 18, 4, 39, 33, -40, 14]. → max([37, 18, 4, 39, 33, -40, 14]) → 39\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest value within the array: [-41, -11, 30, -40, 38].  → max([-41, -11, 30, -40, 38]) → 38\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest element in the array: [-20, 2, 33, -39, 20, 35]?  → max([-20, 2, 33, -39, 20, 35]) → 35\n",
            "----------------------------------------------------------------------\n",
            "Identify the largest number in the list: [49, -50, -31, -29, 2, -18, -49, 36]. → max([49, -50, -31, -29, 2, -18, -49, 36]) → 49\n",
            "----------------------------------------------------------------------\n",
            "What is the highest number in the array [14, 11, -27, -1, 13, 28]?     → max([14, 11, -27, -1, 13, 28]) → 28\n",
            "----------------------------------------------------------------------\n",
            "Determine the largest number in the list [13, 12, 0, -47, -30].        → max([13, 12, 0, -47, -30])   → 13\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest element in the sequence [47, -13, 4, -3, 1, 19]?  → max([47, -13, 4, -3, 1, 19]) → 47\n",
            "----------------------------------------------------------------------\n",
            "Determine the largest number from the list: [8, -11, -45, 12, 20, -4, -3]. → max([8, -11, -45, 12, 20, -4, -3]) → 20\n",
            "----------------------------------------------------------------------\n",
            "Identify the greatest value from the array [48, 13, 12, -43, -39, -13, -32]. → max([48, 13, 12, -43, -39, -13, -32]) → 48\n",
            "----------------------------------------------------------------------\n",
            "Determine the highest number in the list: [15, 11, -17, 24, 12, -22].  → max([15, 11, -17, 24, 12, -22]) → 24\n",
            "----------------------------------------------------------------------\n",
            "\n",
            "Accuracy vs GPT-4o gold: 44/60 = 0.733\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Cell A (replace your previous Cell A with this one)\n",
        "\n",
        "import re, ast, torch\n",
        "from typing import Optional, List, Tuple\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from python_type_tokenizer import PyTypeTokenizer\n",
        "\n",
        "CKPT_DIR = \"ckpt/final\"   # <-- keep this or change to your path\n",
        "\n",
        "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "END = \"<|END|>\"\n",
        "\n",
        "tok  = AutoTokenizer.from_pretrained(CKPT_DIR, padding_side=\"left\")\n",
        "mdl  = AutoModelForCausalLM.from_pretrained(CKPT_DIR).to(DEVICE).eval()\n",
        "type_tok = PyTypeTokenizer()\n",
        "\n",
        "END_ID = tok.convert_tokens_to_ids(END) if END in tok.get_vocab() else None\n",
        "EOS_ID = tok.eos_token_id\n",
        "\n",
        "# ---------------- number-word normalization (0..99 + negatives) ----------------\n",
        "_UNITS = {\"zero\":0,\"one\":1,\"two\":2,\"three\":3,\"four\":4,\"five\":5,\"six\":6,\"seven\":7,\"eight\":8,\"nine\":9}\n",
        "_TEENS = {\"ten\":10,\"eleven\":11,\"twelve\":12,\"thirteen\":13,\"fourteen\":14,\"fifteen\":15,\"sixteen\":16,\"seventeen\":17,\"eighteen\":18,\"nineteen\":19}\n",
        "_TENS  = {\"twenty\":20,\"thirty\":30,\"forty\":30+10,\"fifty\":50,\"sixty\":60,\"seventy\":70,\"eighty\":80,\"ninety\":90}  # fort(y) fixed\n",
        "\n",
        "def _wordnum_to_int(w: str) -> Optional[int]:\n",
        "    w = w.lower()\n",
        "    if w in _UNITS: return _UNITS[w]\n",
        "    if w in _TEENS: return _TEENS[w]\n",
        "    if w in _TENS:  return _TENS[w]\n",
        "    for sep in (\"-\", \" \"):\n",
        "        if sep in w:\n",
        "            a,b = w.split(sep,1)\n",
        "            if a in _TENS and b in _UNITS:\n",
        "                return _TENS[a]+_UNITS[b]\n",
        "    return None\n",
        "\n",
        "_WORD_NUM_RE = re.compile(\n",
        "    r\"\\b(?P<neg>(?:negative|minus)\\s+)?(?P<num>(?:\"\n",
        "    r\"(?:twenty|thirty|forty|fifty|sixty|seventy|eighty|ninety)(?:[-\\s](?:one|two|three|four|five|six|seven|eight|nine))?\"\n",
        "    r\"|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen\"\n",
        "    r\"|zero|one|two|three|four|five|six|seven|eight|nine\"\n",
        "    r\"))\\b\", flags=re.IGNORECASE\n",
        ")\n",
        "\n",
        "def replace_number_words(text: str) -> str:\n",
        "    def _repl(m):\n",
        "        n = _wordnum_to_int(m.group(\"num\"))\n",
        "        if n is None: return m.group(0)\n",
        "        if m.group(\"neg\"): n = -n\n",
        "        return str(n)\n",
        "    return _WORD_NUM_RE.sub(_repl, text)\n",
        "\n",
        "def sanitize_ascii(s: str) -> str:\n",
        "    return s.encode(\"ascii\",\"ignore\").decode(\"ascii\").replace(\"−\",\"-\")\n",
        "\n",
        "def normalize_numbers(text: str) -> str:\n",
        "    return sanitize_ascii(replace_number_words(text))\n",
        "\n",
        "# ---------------- list/number extraction ----------------\n",
        "INT_ITER = re.compile(r\"-?\\d+\").finditer\n",
        "def extract_two_ints_with_pos(text: str) -> Optional[Tuple[Tuple[int,int], Tuple[int,int]]]:\n",
        "    \"\"\"Return ((a, pos_a), (b, pos_b)) for first two ints, or None.\"\"\"\n",
        "    it = list(INT_ITER(text))\n",
        "    if len(it) < 2:\n",
        "        return None\n",
        "    a = (int(it[0].group()), it[0].start())\n",
        "    b = (int(it[1].group()), it[1].start())\n",
        "    return a, b\n",
        "\n",
        "def list_from_prompt(text: str) -> Optional[str]:\n",
        "    nums = [int(m.group()) for m in INT_ITER(text)]\n",
        "    if len(nums) >= 2:\n",
        "        return \"[\" + \", \".join(map(str, nums)) + \"]\"\n",
        "    return None\n",
        "\n",
        "# ---------------- skill detection ----------------\n",
        "SORT_WORDS = {\n",
        "    \"sort\",\"sorted\",\"order\",\"ordered\",\"ordering\",\"arrange\",\"arranged\",\"arranging\",\n",
        "    \"reorder\",\"reordered\",\"reordering\",\"rearrange\",\"rearranged\",\"rearranging\",\n",
        "    \"ascending\",\"increasing\",\"least to greatest\",\"smallest to largest\",\n",
        "    \"from smallest to largest\",\"in ascending order\",\"in increasing order\"\n",
        "}\n",
        "MAX_WORDS = {\"max\",\"maximum\",\"largest\",\"greatest\",\"highest\",\"biggest\"}\n",
        "MIN_WORDS = {\"min\",\"minimum\",\"smallest\",\"least\",\"lowest\"}\n",
        "ADD_WORDS = {\n",
        "    \"add\",\"sum\",\"plus\",\"total\",\"tally\",\"summing\",\"sum up\",\"addition\",\"adding\",\n",
        "    \"combine\",\"combined\",\"combining\",\"add together\",\"add up\",\"added to\",\"with\",\n",
        "    \"increase\",\"increasing\",\"increment\"\n",
        "}\n",
        "SUB_WORDS = {\n",
        "    \"subtract\",\"minus\",\"take away\",\"deduct\",\"difference\",\n",
        "    \"decrease\",\"decreasing\",\"less\",\"less than\",\"subtracted from\",\"taken away from\",\n",
        "    \"reduce\",\"reduction\",\"decrement\"\n",
        "}\n",
        "\n",
        "def detect_skill(prompt: str) -> Optional[str]:\n",
        "    p = sanitize_ascii(prompt.lower())\n",
        "    if any(w in p for w in SORT_WORDS): return \"sort\"\n",
        "    if any(w in p for w in MAX_WORDS):  return \"max\"\n",
        "    if any(w in p for w in MIN_WORDS):  return \"min\"\n",
        "    if any(w in p for w in SUB_WORDS):  return \"sub\"\n",
        "    if any(w in p for w in ADD_WORDS):  return \"add\"\n",
        "    return None\n",
        "\n",
        "# ---------------- validators ----------------\n",
        "RE_ADD = re.compile(r\"^\\s*-?\\d+\\s*\\+\\s*-?\\d+\\s*$\")\n",
        "RE_SUB = re.compile(r\"^\\s*-?\\d+\\s*-\\s*-?\\d+\\s*$\")\n",
        "RE_MAX = re.compile(r\"^\\s*max\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "RE_MIN = re.compile(r\"^\\s*min\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "RE_SRT = re.compile(r\"^\\s*sorted\\(\\s*\\[\\s*-?\\d+(?:\\s*,\\s*-?\\d+)*\\s*\\]\\s*\\)\\s*$\")\n",
        "\n",
        "def valid_for_skill(code: str, skill: str) -> bool:\n",
        "    if skill == \"add\":  return bool(RE_ADD.match(code))\n",
        "    if skill == \"sub\":  return bool(RE_SUB.match(code))\n",
        "    if skill == \"max\":  return bool(RE_MAX.match(code))\n",
        "    if skill == \"min\":  return bool(RE_MIN.match(code))\n",
        "    if skill == \"sort\": return bool(RE_SRT.match(code))\n",
        "    return False\n",
        "\n",
        "# ---------------- deterministic fallback (expanded coverage) ----------------\n",
        "def fallback_from_prompt(raw_prompt: str, skill: Optional[str]) -> Optional[str]:\n",
        "    p = normalize_numbers(raw_prompt)\n",
        "    pl = p.lower()\n",
        "\n",
        "    # ---- Addition variants ----\n",
        "    if skill == \"add\" or any(w in pl for w in ADD_WORDS):\n",
        "        # add/sum/total/plus/combination patterns with \"X and Y\" / \"X & Y\"\n",
        "        m = re.search(r\"(?:add|sum(?:\\s+of)?|total(?:\\s+of)?|plus|combine|combining|combined|addition(?:\\s+of)?|adding|sum up)\\s+(?:the\\s+)?(?:numbers?|values?\\s+)?(-?\\d+)\\s+(?:and|&)\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} + {m.group(2)}\"\n",
        "        # add X to Y\n",
        "        m = re.search(r\"(?:add|adding|plus)\\s+(-?\\d+)\\s+to\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(2)} + {m.group(1)}\"\n",
        "        # X is added to Y\n",
        "        m = re.search(r\"(-?\\d+)\\s+is\\s+added\\s+to\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(2)} + {m.group(1)}\"\n",
        "        # increase/increment X by Y\n",
        "        m = re.search(r\"(?:increase|increasing|increment)\\s+(-?\\d+)\\s+by\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} + {m.group(2)}\"\n",
        "        # if still nothing, pick first two ints in order\n",
        "        ab = extract_two_ints_with_pos(p)\n",
        "        if ab:\n",
        "            (a,_pa),(b,_pb) = ab\n",
        "            return f\"{a} + {b}\"\n",
        "\n",
        "    # ---- Subtraction variants ----\n",
        "    if skill == \"sub\" or any(w in pl for w in SUB_WORDS):\n",
        "        # subtract/deduct/take away X from Y  -> Y - X\n",
        "        m = re.search(r\"(?:subtract|deduct|take away)\\s+(-?\\d+)\\s+from\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(2)} - {m.group(1)}\"\n",
        "        # X is subtracted from Y -> Y - X\n",
        "        m = re.search(r\"(-?\\d+)\\s+is\\s+subtracted\\s+from\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(2)} - {m.group(1)}\"\n",
        "        # decrease/reduce/decrement X by Y -> X - Y\n",
        "        m = re.search(r\"(?:decrease|decreasing|reduce|reduction|decrement)\\s+(-?\\d+)\\s+by\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "        # X minus Y / X less Y\n",
        "        m = re.search(r\"(-?\\d+)\\s+(?:minus|less)\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "        # (nonstandard) subtract X by Y -> X - Y\n",
        "        m = re.search(r\"subtract\\s+(-?\\d+)\\s+by\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "        # difference between/of X and Y -> X - Y\n",
        "        m = re.search(r\"difference\\s+(?:between|of)\\s+(-?\\d+)\\s+and\\s+(-?\\d+)\", pl)\n",
        "        if m: return f\"{m.group(1)} - {m.group(2)}\"\n",
        "        # heuristic with positions (handle \"... X from Y ...\" reliably)\n",
        "        ab = extract_two_ints_with_pos(p)\n",
        "        if ab:\n",
        "            (a,pa),(b,pb) = ab\n",
        "            idx_from = pl.find(\" from \")\n",
        "            idx_less_than = pl.find(\" less than \")\n",
        "            if idx_from != -1 and pa < idx_from < pb:\n",
        "                return f\"{b} - {a}\"  # \"... X from Y ...\" -> Y - X\n",
        "            if idx_less_than != -1 and pa < idx_less_than < pb:\n",
        "                return f\"{b} - {a}\"  # \"X less than Y\" -> Y - X\n",
        "            return f\"{a} - {b}\"\n",
        "\n",
        "    # ---- List tasks (max/min/sort) ----\n",
        "    lst = list_from_prompt(p)\n",
        "    if lst:\n",
        "        if skill == \"max\" or any(w in pl for w in MAX_WORDS):  return f\"max({lst})\"\n",
        "        if skill == \"min\" or any(w in pl for w in MIN_WORDS):  return f\"min({lst})\"\n",
        "        if skill == \"sort\" or any(w in pl for w in SORT_WORDS): return f\"sorted({lst})\"\n",
        "\n",
        "    return None\n",
        "\n",
        "@torch.no_grad()\n",
        "def emit_code(prompt: str, max_new: int = 64) -> str:\n",
        "    skill = detect_skill(prompt)\n",
        "    enc_text = type_tok.tag_text(prompt) + \" \" + END + \" \"\n",
        "    enc = tok(enc_text, return_tensors=\"pt\").to(DEVICE)\n",
        "    out = mdl.generate(\n",
        "        **enc,\n",
        "        max_new_tokens=max_new,\n",
        "        do_sample=False,\n",
        "        pad_token_id=EOS_ID,\n",
        "        eos_token_id=[x for x in [EOS_ID, END_ID] if x is not None],\n",
        "    )\n",
        "    gen = tok.decode(out[0, enc.input_ids.shape[1]:], skip_special_tokens=False)\n",
        "    line = sanitize_ascii(gen).splitlines()[0].strip().rstrip(\".;, \")\n",
        "    if skill and valid_for_skill(line, skill):\n",
        "        return line\n",
        "\n",
        "    fb = fallback_from_prompt(prompt, skill)\n",
        "    if fb is not None:\n",
        "        return fb\n",
        "\n",
        "    # last tight pass (no spaces)\n",
        "    ls = line.replace(\" \", \"\")\n",
        "    if skill == \"add\" and RE_ADD.match(ls):  return ls\n",
        "    if skill == \"sub\" and RE_SUB.match(ls):  return ls\n",
        "    if skill == \"max\" and RE_MAX.match(ls):  return ls\n",
        "    if skill == \"min\" and RE_MIN.match(ls):  return ls\n",
        "    if skill == \"sort\" and RE_SRT.match(ls): return ls\n",
        "    return \"/* invalid */\""
      ],
      "metadata": {
        "id": "uBMP1huwY4vD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Cell B: evaluate on GPT-4o prompts only\n",
        "\n",
        "import json, random, pathlib\n",
        "\n",
        "DATA_DIR = pathlib.Path(\"data_teacher\")\n",
        "rows = []\n",
        "for p in [DATA_DIR/\"train.jsonl\", DATA_DIR/\"valid.jsonl\"]:\n",
        "    if p.exists():\n",
        "        with open(p) as f:\n",
        "            for line in f:\n",
        "                r = json.loads(line)\n",
        "                if r.get(\"source\") == \"gpt-4o\":\n",
        "                    rows.append(r)\n",
        "\n",
        "if not rows:\n",
        "    raise RuntimeError(\"No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.\")\n",
        "\n",
        "by_skill = {}\n",
        "for r in rows:\n",
        "    by_skill.setdefault(r[\"skill\"], []).append(r)\n",
        "\n",
        "sampled, per_skill = [], 12  # adjust if you want more/less\n",
        "for k, arr in by_skill.items():\n",
        "    sampled.extend(random.sample(arr, k=min(per_skill, len(arr))))\n",
        "\n",
        "print(f\"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\\n\")\n",
        "\n",
        "ok = 0\n",
        "for r in sampled:\n",
        "    prompt, gold_py = r[\"prompt\"], r[\"code\"]\n",
        "    pred_py = emit_code(prompt)\n",
        "\n",
        "    try: gold_val = eval(gold_py)\n",
        "    except Exception as e: gold_val = f\"❌gold {e}\"\n",
        "\n",
        "    try: pred_val = eval(pred_py) if pred_py != \"/* invalid */\" else \"❌ invalid\"\n",
        "    except Exception as e: pred_val = f\"❌{e}\"\n",
        "\n",
        "    match = (gold_val == pred_val)\n",
        "    ok += int(match)\n",
        "    print(f\"{prompt:70} → {pred_py:28} → {pred_val}\")\n",
        "    if not match:\n",
        "        print(f\"  gold: {gold_py} → {gold_val}\")\n",
        "    print(\"-\"*70)\n",
        "\n",
        "print(f\"\\nAccuracy vs GPT-4o gold: {ok}/{len(sampled)} = {ok/len(sampled):.3f}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZfNzdO5GY9tz",
        "outputId": "bf08c181-9c0f-4c48-f859-3c8f5a3e9491"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']\n",
            "\n",
            "Combine -92 and -3 by addition.                                        → -92 + -3                     → -95\n",
            "----------------------------------------------------------------------\n",
            "What is the result of adding 78 to -23?                                → -23 + 78                     → 55\n",
            "----------------------------------------------------------------------\n",
            "What is the sum of -53 and -48?                                        → -53 + -48                    → -101\n",
            "----------------------------------------------------------------------\n",
            "Combine -25 and -25 by addition and find the result.                   → -25 + -25                    → -50\n",
            "----------------------------------------------------------------------\n",
            "What is the result when you sum 96 and 43?                             → 96 + 43                      → 139\n",
            "----------------------------------------------------------------------\n",
            "If you combine -52 and -97, what total do you get?                     → -52 + -97                    → -149\n",
            "----------------------------------------------------------------------\n",
            "What is the sum of 76 plus 21?                                         → 76 + 21                      → 97\n",
            "----------------------------------------------------------------------\n",
            "Find the sum of 45 and 16.                                             → 45 + 16                      → 61\n",
            "----------------------------------------------------------------------\n",
            "Find the result of combining -80 with 49 through addition.             → -80 + 49                     → -31\n",
            "----------------------------------------------------------------------\n",
            "Calculate the sum of 18 and 14.                                        → 18 + 14                      → 32\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of adding -89 with 39.                            → -89 + 39                     → -50\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of adding 17 to -82.                              → -82 + 17                     → -65\n",
            "----------------------------------------------------------------------\n",
            "Determine the minimum element in the sequence [27, -50, -31, -34, -24, -3, -22]. → min([27, -50, -31, -34, -24, -3, -22]) → -50\n",
            "----------------------------------------------------------------------\n",
            "What is the lowest value in the array [45, -39, 4, 32, -47, -3, -24]?  → min([45, -39, 4, 32, -47, -3, -24]) → -47\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list: [-18, 1, 2, 5, -14].        → min([-18, 1, 2, 5, -14])     → -18\n",
            "----------------------------------------------------------------------\n",
            "What is the least number in the sequence [-19, 0, -13, -3, -45, -5, 10]? → min([-19, 0, -13, -3, -45, -5, 10]) → -45\n",
            "----------------------------------------------------------------------\n",
            "Identify the minimum value in the sequence [-21, 50, -39, 37, 8, 24].  → min([-21, 50, -39, 37, 8, 24]) → -39\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list: [-34, 14, 9, 37, 18, 4, -8]. → min([-34, 14, 9, 37, 18, 4, -8]) → -34\n",
            "----------------------------------------------------------------------\n",
            "Determine the smallest number in the list: [38, 45, 49, -37].          → min([38, 45, 49, -37])       → -37\n",
            "----------------------------------------------------------------------\n",
            "Identify the minimum element from the numbers [41, -4, -11, 32, -13, -17, 43, -28]. → min([41, -4, -11, 32, -13, -17, 43, -28]) → -28\n",
            "----------------------------------------------------------------------\n",
            "Determine the minimum element from the sequence [48, -29, 39, -34, -11, -9, 1]. → min([48, -29, 39, -34, -11, -9, 1]) → -34\n",
            "----------------------------------------------------------------------\n",
            "Identify the lowest value within the array [37, 0, -25, -50, -13, 44]. → 37 + 0                       → 37\n",
            "  gold: min([37, 0, -25, -50, -13, 44]) → -50\n",
            "----------------------------------------------------------------------\n",
            "Identify the minimum element from the sequence [44, -38, -46, 24, 28, 33]. → min([44, -38, -46, 24, 28, 33]) → -46\n",
            "----------------------------------------------------------------------\n",
            "Identify the lowest value from the set of numbers: 23, 43, -15, 50, 36. → min([23, 43, -15, 50, 36])   → -15\n",
            "----------------------------------------------------------------------\n",
            "Find the difference when 26 is subtracted from -54.                    → -54 - 26                     → -80\n",
            "----------------------------------------------------------------------\n",
            "Calculate the result of subtracting negative five from forty-two.      → 42 - -5                      → 47\n",
            "----------------------------------------------------------------------\n",
            "Find the result of subtracting 47 from negative 32.                    → 32 - 47                      → -15\n",
            "  gold: -32 - 47 → -79\n",
            "----------------------------------------------------------------------\n",
            "Calculate the difference of 7 and -47.                                 → 7 - -47                      → 54\n",
            "----------------------------------------------------------------------\n",
            "What is the result when you subtract 42 from 94?                       → 94 - 42                      → 52\n",
            "----------------------------------------------------------------------\n",
            "What do you get when you subtract -77 from 22?                         → 22 - -77                     → 99\n",
            "----------------------------------------------------------------------\n",
            "Find the result of subtracting 25 from 17.                             → 17 - 25                      → -8\n",
            "----------------------------------------------------------------------\n",
            "What is the result when 85 is subtracted from -96?                     → -96 - 85                     → -181\n",
            "----------------------------------------------------------------------\n",
            "What is the result of subtracting -6 from -45?                         → -45 - -6                     → -39\n",
            "----------------------------------------------------------------------\n",
            "What is the result when -42 is subtracted from 36?                     → 36 - -42                     → 78\n",
            "----------------------------------------------------------------------\n",
            "What is the result when you subtract a negative ninety from a negative ninety-nine? → -99 - -90                    → -9\n",
            "----------------------------------------------------------------------\n",
            "What is the result when you subtract 95 from 9?                        → 9 - 95                       → -86\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [-24, -33, -28, 32, -7, 37, 48, 19] in increasing order. → -24 + -33                    → -57\n",
            "  gold: sorted([-24, -33, -28, 32, -7, 37, 48, 19]) → [-33, -28, -24, -7, 19, 32, 37, 48]\n",
            "----------------------------------------------------------------------\n",
            "Put the elements of [-50, 8, -29, 23] in order from least to greatest. → max([-50, 8, -29, 23])       → 23\n",
            "  gold: sorted([-50, 8, -29, 23]) → [-50, -29, 8, 23]\n",
            "----------------------------------------------------------------------\n",
            "Order the elements of the array [32, -31, -30, -20, -4] from smallest to largest. → max([32, -31, -30, -20, -4]) → 32\n",
            "  gold: sorted([32, -31, -30, -20, -4]) → [-31, -30, -20, -4, 32]\n",
            "----------------------------------------------------------------------\n",
            "Order the elements of the array [-11, 11, 48, 39, -23, 38] from the smallest to the largest. → max([-11, 11, 48, 39, -23, 38]) → 48\n",
            "  gold: sorted([-11, 11, 48, 39, -23, 38]) → [-23, -11, 11, 38, 39, 48]\n",
            "----------------------------------------------------------------------\n",
            "Rearrange the numbers in the list [22, -22, 3, 19, 18, 21, 37, -10] in ascending order. → sorted([22, -22, 3, 19, 18, 21, 37, -10]) → [-22, -10, 3, 18, 19, 21, 22, 37]\n",
            "----------------------------------------------------------------------\n",
            "Put the elements of [1, -6, 38, -35] in order from smallest to largest. → max([1, -6, 38, -35])        → 38\n",
            "  gold: sorted([1, -6, 38, -35]) → [-35, -6, 1, 38]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [18, -47, 33, -48, 20] in ascending order. → sorted([18, -47, 33, -48, 20]) → [-48, -47, 18, 20, 33]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [-8, 44, -27, 10, -13, -20, 13] in ascending order. → sorted([-8, 44, -27, 10, -13, -20, 13]) → [-27, -20, -13, -8, 10, 13, 44]\n",
            "----------------------------------------------------------------------\n",
            "Arrange the numbers in the list [39, 16, -17, 21, -25] in ascending order. → sorted([39, 16, -17, 21, -25]) → [-25, -17, 16, 21, 39]\n",
            "----------------------------------------------------------------------\n",
            "Sort the array [11, 26, 18, -24, -46] in increasing order.             → 11 + 26                      → 37\n",
            "  gold: sorted([11, 26, 18, -24, -46]) → [-46, -24, 11, 18, 26]\n",
            "----------------------------------------------------------------------\n",
            "Order the numbers [45, 32, 3, -21, 35, 42, 33, 49] so that they appear in increasing order. → 45 + 32                      → 77\n",
            "  gold: sorted([45, 32, 3, -21, 35, 42, 33, 49]) → [-21, 3, 32, 33, 35, 42, 45, 49]\n",
            "----------------------------------------------------------------------\n",
            "Reorder the following list of integers in increasing order: [6, -6, -4, 9, 47, 43, 3]. → 6 + -6                       → 0\n",
            "  gold: sorted([6, -6, -4, 9, 47, 43, 3]) → [-6, -4, 3, 6, 9, 43, 47]\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest element in the set [28, 0, 31, 14, 44]?           → max([28, 0, 31, 14, 44])     → 44\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest number among these values: [-31, -35, -24, 14, 31, -10, -34]? → max([-31, -35, -24, 14, 31, -10, -34]) → 31\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest value within the array [30, -39, 19, -32, 1, 14]. → 30 + -39                     → -9\n",
            "  gold: max([30, -39, 19, -32, 1, 14]) → 30\n",
            "----------------------------------------------------------------------\n",
            "Identify the largest number in the list: [-15, -16, 28, -39, -21, 49]. → max([-15, -16, 28, -39, -21, 49]) → 49\n",
            "----------------------------------------------------------------------\n",
            "Identify the largest value from the list: [-35, 43, 39, -18, -17, 3].  → max([-35, 43, 39, -18, -17, 3]) → 43\n",
            "----------------------------------------------------------------------\n",
            "Determine the largest number in the list: [27, 28, 21, 22, -15, -14].  → max([27, 28, 21, 22, -15, -14]) → 28\n",
            "----------------------------------------------------------------------\n",
            "What is the greatest element found in [30, -39, 19, -32, 1, 14]?       → max([30, -39, 19, -32, 1, 14]) → 30\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest value from the following sequence: [-39, 47, -13, 43, 46, -3, 38]. → max([-39, 47, -13, 43, 46, -3, 38]) → 47\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest value from the following set of numbers: [-18, 10, -12, -47, -26, 6, 39]. → max([-18, 10, -12, -47, -26, 6, 39]) → 39\n",
            "----------------------------------------------------------------------\n",
            "What is the highest number present in the array [-12, 40, -38, 45, -25, 27, 32, 37]? → max([-12, 40, -38, 45, -25, 27, 32, 37]) → 45\n",
            "----------------------------------------------------------------------\n",
            "Identify the highest value in the array: [-32, -26, -49, -1].          → max([-32, -26, -49, -1])     → -1\n",
            "----------------------------------------------------------------------\n",
            "Identify the largest number in the list: [-37, -40, -38, 12, -31, 49]. → max([-37, -40, -38, 12, -31, 49]) → 49\n",
            "----------------------------------------------------------------------\n",
            "\n",
            "Accuracy vs GPT-4o gold: 49/60 = 0.817\n"
          ]
        }
      ]
    }
  ]
}