{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "09c55664cbe14403a25a67f54f0eb6b6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_152991c9e59a4b9fa7ce74825990f088",
              "IPY_MODEL_a52276734fa14efebd79d70b73d15663",
              "IPY_MODEL_3aa4286cf0554619ab7f2010c8c18a46"
            ],
            "layout": "IPY_MODEL_2af3f29b14744fdfbbe64ed9acb4c19c"
          }
        },
        "152991c9e59a4b9fa7ce74825990f088": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c75c09818ed54b0bba86c0ab81392b69",
            "placeholder": "​",
            "style": "IPY_MODEL_ded114a3001e4921945bf11b34aaac18",
            "value": "100%"
          }
        },
        "a52276734fa14efebd79d70b73d15663": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_dae4385fddef40c5aeec6df44a716d89",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_d53b6085c30345558a28f32a12381c8f",
            "value": 1
          }
        },
        "3aa4286cf0554619ab7f2010c8c18a46": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_767d92c109664748b652ba628159dc5e",
            "placeholder": "​",
            "style": "IPY_MODEL_e5643fb918064e0cb30292d96ce10b55",
            "value": " 1/1 [00:00&lt;00:00,  6.95it/s]"
          }
        },
        "2af3f29b14744fdfbbe64ed9acb4c19c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c75c09818ed54b0bba86c0ab81392b69": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ded114a3001e4921945bf11b34aaac18": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "dae4385fddef40c5aeec6df44a716d89": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d53b6085c30345558a28f32a12381c8f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "767d92c109664748b652ba628159dc5e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e5643fb918064e0cb30292d96ce10b55": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Empirically Estimating the Augmentation Complexity\n",
        "\n",
        "This notebook demonstrates how to empirically estimate the complexity of different masking augmentations on the NLP dataset wikipedia-simple.\n",
        "Running this notebook requires a CUDA enabled pytorch installation.\n",
        "The easiest way is to run this notebook on Google Colab with a GPU runtime."
      ],
      "metadata": {
        "id": "cpNKTZ-q7vfi"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IPHESKiQ7qQk"
      },
      "outputs": [],
      "source": [
        "# arguments\n",
        "alg = 'block'          # Masking method: rand, block, flip or blockflip\n",
        "alpha = 0.15           # Mask ratio\n",
        "num_samples = 1000     # Number of test samples\n",
        "num_augs = 255         # Number of augmentations for each sample\n",
        "max_length = 64        # Max length of each sample\n",
        "batch_size = 128       # Batch size\n",
        "seed = 0               # Random seed"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install apache_beam\n",
        "!pip install transformers datasets"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "V2qUUABr_Pnf",
        "outputId": "3321546a-4458-446b-ee6a-3ab65ccf48df"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: apache_beam in ./anaconda3/envs/ml/lib/python3.9/site-packages (2.47.0)\n",
            "Requirement already satisfied: dill<0.3.2,>=0.3.1.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (0.3.1.1)\n",
            "Requirement already satisfied: pyarrow<12.0.0,>=3.0.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (9.0.0)\n",
            "Requirement already satisfied: grpcio!=1.48.0,<2,>=1.33.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (1.51.3)\n",
            "Requirement already satisfied: crcmod<2.0,>=1.7 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (1.7)\n",
            "Requirement already satisfied: requests<3.0.0,>=2.24.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (2.27.1)\n",
            "Requirement already satisfied: python-dateutil<3,>=2.8.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (2.8.2)\n",
            "Requirement already satisfied: regex>=2020.6.8 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (2022.3.15)\n",
            "Requirement already satisfied: httplib2<0.22.0,>=0.8 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (0.20.4)\n",
            "Requirement already satisfied: fastavro<2,>=0.23.6 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (1.7.1)\n",
            "Requirement already satisfied: pydot<2,>=1.2.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (1.4.2)\n",
            "Requirement already satisfied: fasteners<1.0,>=0.3 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (0.18)\n",
            "Requirement already satisfied: orjson<4.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (3.8.6)\n",
            "Requirement already satisfied: typing-extensions>=3.7.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (4.5.0)\n",
            "Requirement already satisfied: hdfs<3.0.0,>=2.1.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (2.7.0)\n",
            "Requirement already satisfied: protobuf<4.23.0,>=3.20.3 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (4.22.0)\n",
            "Requirement already satisfied: numpy<1.25.0,>=1.14.3 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (1.21.5)\n",
            "Requirement already satisfied: proto-plus<2,>=1.7.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (1.22.2)\n",
            "Requirement already satisfied: zstandard<1,>=0.18.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (0.20.0)\n",
            "Requirement already satisfied: cloudpickle~=2.2.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (2.2.1)\n",
            "Requirement already satisfied: objsize<0.7.0,>=0.6.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (0.6.1)\n",
            "Requirement already satisfied: pymongo<5.0.0,>=3.8.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (3.13.0)\n",
            "Requirement already satisfied: pytz>=2018.3 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from apache_beam) (2021.3)\n",
            "Requirement already satisfied: docopt in ./anaconda3/envs/ml/lib/python3.9/site-packages (from hdfs<3.0.0,>=2.1.0->apache_beam) (0.6.2)\n",
            "Requirement already satisfied: six>=1.9.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from hdfs<3.0.0,>=2.1.0->apache_beam) (1.16.0)\n",
            "Requirement already satisfied: pyparsing!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from httplib2<0.22.0,>=0.8->apache_beam) (3.0.4)\n",
            "Requirement already satisfied: charset-normalizer~=2.0.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests<3.0.0,>=2.24.0->apache_beam) (2.0.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests<3.0.0,>=2.24.0->apache_beam) (2022.12.7)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests<3.0.0,>=2.24.0->apache_beam) (1.26.9)\n",
            "Requirement already satisfied: idna<4,>=2.5 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests<3.0.0,>=2.24.0->apache_beam) (3.3)\n",
            "Requirement already satisfied: transformers in ./anaconda3/envs/ml/lib/python3.9/site-packages (4.29.0.dev0)\n",
            "Requirement already satisfied: datasets in ./anaconda3/envs/ml/lib/python3.9/site-packages (2.5.2)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (0.12.1)\n",
            "Requirement already satisfied: requests in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (2.27.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (2022.3.15)\n",
            "Requirement already satisfied: pyyaml>=5.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (6.0)\n",
            "Requirement already satisfied: tqdm>=4.27 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (4.64.0)\n",
            "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (0.13.2)\n",
            "Requirement already satisfied: filelock in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (3.6.0)\n",
            "Requirement already satisfied: packaging>=20.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (21.3)\n",
            "Requirement already satisfied: numpy>=1.17 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from transformers) (1.21.5)\n",
            "Requirement already satisfied: fsspec[http]>=2021.11.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (2022.2.0)\n",
            "Requirement already satisfied: multiprocess in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (0.70.13)\n",
            "Requirement already satisfied: xxhash in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (3.0.0)\n",
            "Requirement already satisfied: pyarrow>=6.0.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (9.0.0)\n",
            "Requirement already satisfied: aiohttp in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (3.8.1)\n",
            "Requirement already satisfied: responses<0.19 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (0.18.0)\n",
            "Requirement already satisfied: dill<0.3.6 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (0.3.1.1)\n",
            "Requirement already satisfied: pandas in ./anaconda3/envs/ml/lib/python3.9/site-packages (from datasets) (1.4.2)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n",
            "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from packaging>=20.0->transformers) (3.0.4)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests->transformers) (1.26.9)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests->transformers) (2022.12.7)\n",
            "Requirement already satisfied: charset-normalizer~=2.0.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests->transformers) (2.0.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from requests->transformers) (3.3)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from aiohttp->datasets) (1.2.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from aiohttp->datasets) (1.2.0)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from aiohttp->datasets) (1.6.3)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from aiohttp->datasets) (5.2.0)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from aiohttp->datasets) (4.0.1)\n",
            "Requirement already satisfied: attrs>=17.3.0 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from aiohttp->datasets) (21.4.0)\n",
            "Collecting dill<0.3.6\n",
            "  Using cached dill-0.3.5.1-py2.py3-none-any.whl (95 kB)\n",
            "Requirement already satisfied: python-dateutil>=2.8.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from pandas->datasets) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from pandas->datasets) (2021.3)\n",
            "Requirement already satisfied: six>=1.5 in ./anaconda3/envs/ml/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
            "Installing collected packages: dill\n",
            "  Attempting uninstall: dill\n",
            "    Found existing installation: dill 0.3.1.1\n",
            "    Uninstalling dill-0.3.1.1:\n",
            "      Successfully uninstalled dill-0.3.1.1\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "apache-beam 2.47.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.5.1 which is incompatible.\u001b[0m\n",
            "Successfully installed dill-0.3.5.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import os\n",
        "os.environ[\"HF_ENDPOINT\"] = \"https://huggingface.co\"\n",
        "from datasets import load_dataset\n",
        "from tqdm import tqdm\n",
        "from transformers import AutoTokenizer, BertForMaskedLM\n",
        "\n",
        "np.random.seed(seed)\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
        "model = BertForMaskedLM.from_pretrained(\"bert-base-uncased\")\n",
        "model = model.to('cuda')\n",
        "dataset = load_dataset(\"wikipedia\", \"20220301.simple\")"
      ],
      "metadata": {
        "id": "PTHiMfHT8Vrd",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118,
          "referenced_widgets": [
            "09c55664cbe14403a25a67f54f0eb6b6",
            "152991c9e59a4b9fa7ce74825990f088",
            "a52276734fa14efebd79d70b73d15663",
            "3aa4286cf0554619ab7f2010c8c18a46",
            "2af3f29b14744fdfbbe64ed9acb4c19c",
            "c75c09818ed54b0bba86c0ab81392b69",
            "ded114a3001e4921945bf11b34aaac18",
            "dae4385fddef40c5aeec6df44a716d89",
            "d53b6085c30345558a28f32a12381c8f",
            "767d92c109664748b652ba628159dc5e",
            "e5643fb918064e0cb30292d96ce10b55"
          ]
        },
        "outputId": "bc8cd689-2b72-48a6-f3bb-1950df2125d9"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
            "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Found cached dataset wikipedia (/usr1/rzhai/huggingface/datasets/wikipedia/20220301.simple/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/1 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "09c55664cbe14403a25a67f54f0eb6b6"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "sample_ids = np.random.choice(len(dataset['train']), num_samples, replace=False)\n",
        "mask_token = 103\n",
        "vocab_size = 30522\n",
        "s_arr = np.zeros(num_samples)\n",
        "\n",
        "for i in tqdm(range(num_samples)):\n",
        "    tokens = tokenizer(dataset['train'][int(sample_ids[i])]['text'], return_tensors=\"pt\", \n",
        "                       padding=False, max_length=max_length, truncation=True)['input_ids'][0]\n",
        "    l_text = len(tokens)\n",
        "    r = int(alpha * l_text)\n",
        "    r1 = r // 2\n",
        "    r2 = r - r1\n",
        "\n",
        "    n_augs = num_augs + 1\n",
        "    r0 = r if alg == 'block' else r1\n",
        "    if alg == 'block':\n",
        "        if n_augs > l_text - r0:\n",
        "            # Text is too short. Not so many possible augmentations\n",
        "            n_augs = l_text - r0\n",
        "            block_ids = np.arange(n_augs)\n",
        "        else:\n",
        "            block_ids = np.random.choice(l_text - 1 - r0, n_augs, replace=False) + 1\n",
        "    elif alg == 'blockflip':\n",
        "        block_ids = np.random.choice(l_text - 1 - r0, n_augs, replace=True) + 1\n",
        "    \n",
        "    log_probs_arr = torch.zeros(n_augs).to('cuda')\n",
        "    n_batches = n_augs // batch_size\n",
        "\n",
        "    for j in range(n_batches + 1):\n",
        "        lid = j * batch_size\n",
        "        rid = (j + 1) * batch_size\n",
        "        bsz = batch_size\n",
        "        if j == n_batches:\n",
        "            bsz = n_augs - lid\n",
        "            if bsz == 0:\n",
        "                break\n",
        "            rid = n_augs\n",
        "        \n",
        "        # g_tokens - Repeated original text\n",
        "        g_tokens = tokens.repeat(bsz, 1)\n",
        "        g_tokens = g_tokens.to('cuda')\n",
        "\n",
        "        # r_tokens - Augmented text\n",
        "        r_tokens = tokens.repeat(bsz)\n",
        "        if alg == 'rand':\n",
        "            r_ids = torch.vstack([torch.randperm(l_text - 2) for _ in range(bsz)]) + 1\n",
        "            r_ids = r_ids[:, :r]                     # Get masked positions\n",
        "            r_ids += torch.arange(bsz).view(-1, 1) * l_text \n",
        "            r_ids = r_ids.flatten()\n",
        "            r_tokens[r_ids] = mask_token\n",
        "        elif alg == 'block':\n",
        "            b_ids = block_ids[lid:rid]\n",
        "            r_ids = torch.vstack([torch.arange(r) + b_ids[k] for k in range(len(b_ids))])\n",
        "            r_ids += torch.arange(bsz).view(-1, 1) * l_text \n",
        "            r_ids = r_ids.flatten()\n",
        "            r_tokens[r_ids] = mask_token\n",
        "        elif alg == 'flip':\n",
        "            r_ids = torch.vstack([torch.randperm(l_text - 2) for _ in range(bsz)]) + 1\n",
        "            r_ids = r_ids[:, :r]\n",
        "            r_ids += torch.arange(bsz).view(-1, 1) * l_text\n",
        "            r1_ids = r_ids[:, :r1]\n",
        "            r2_ids = r_ids[:, r1:]\n",
        "            r1_ids = r1_ids.flatten()\n",
        "            r_tokens[r1_ids] = mask_token \n",
        "            r2_ids = r2_ids.flatten()\n",
        "            rand_tokens = torch.randint_like(r2_ids, vocab_size)\n",
        "            r_tokens[r2_ids] = rand_tokens\n",
        "        elif alg == 'blockflip':\n",
        "            r_ids = torch.vstack([torch.randperm(l_text - 2 - r1) for _ in range(bsz)]) + 1\n",
        "            r2_ids = r_ids[:, :r2]\n",
        "            b_ids = block_ids[lid:rid]\n",
        "            r1_ids = torch.vstack([torch.arange(r1) + b_ids[k] for k in range(len(b_ids))])\n",
        "            b_ids = torch.tensor(b_ids.reshape(bsz, 1))\n",
        "            r2_ids[r2_ids > b_ids] += r1\n",
        "            r2_ids += torch.arange(bsz).view(-1, 1) * l_text\n",
        "            r2_ids = r2_ids.flatten()\n",
        "            rand_tokens = torch.randint_like(r2_ids, vocab_size)\n",
        "            r_tokens[r2_ids] = rand_tokens\n",
        "            r1_ids += torch.arange(bsz).view(-1, 1) * l_text \n",
        "            r1_ids = r1_ids.flatten()\n",
        "            r_tokens[r1_ids] = mask_token        \n",
        "\n",
        "        r_tokens = r_tokens.view(bsz, -1)\n",
        "        r_tokens = r_tokens.to('cuda')\n",
        "\n",
        "        if j == 0:\n",
        "            # The first augmented text is all masked, for computing p(x)\n",
        "            r_tokens[0, 1:-1] = mask_token\n",
        "\n",
        "        # Compute log p(x|a)\n",
        "        with torch.no_grad():\n",
        "            # Compute backwards\n",
        "            for k in range(l_text - 2, 0, -1):\n",
        "                label = int(tokens[k])               # Real token\n",
        "                g_tokens[:, k] = r_tokens[:, k]      # Mask the real token\n",
        "                logits = model(input_ids=g_tokens).logits\n",
        "                logits = logits[:, k]\n",
        "                log_probs = logits = logits[:, label] - torch.logsumexp(logits, dim=1)   # Log prob of the real token\n",
        "                log_probs_arr[lid:rid] += log_probs       # Add to the log prob         \n",
        "\n",
        "\n",
        "    # At this point, log_probs_arr[0] = log p(x), so we subtract it\n",
        "    log_probs_arr = log_probs_arr[1:] - log_probs_arr[0]\n",
        "    m = log_probs_arr.max().item()    # For numerical\n",
        "    log_probs_arr -= m \n",
        "    s = torch.logsumexp(log_probs_arr, dim=0).item() + m - np.log(n_augs - 1)\n",
        "    s_arr[i] = s\n",
        "\n"
      ],
      "metadata": {
        "id": "dpLPGCRr8hQu",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "52d37236-2156-4ab2-e7cf-d201f728f76f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [44:02<00:00,  2.64s/it]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# The results are now in s_arr. Print out the 99-th percentile\n",
        "s_arr.sort()\n",
        "r = int(0.99 * num_samples) - 1\n",
        "print('99-th percentile = {}'.format(s_arr[r]))"
      ],
      "metadata": {
        "id": "p6d8tlNRBim2",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5837d0c1-14aa-4605-9bdf-1046218fc3b3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "99-th percentile = 330.97245924871606\n"
          ]
        }
      ]
    }
  ]
}