{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "LongBART",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm",
      "authorship_tag": "ABX9TyMBu/tl3uAemtoSjaCYca2U",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/patil-suraj/longbart/blob/master/LongBART.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PkmDekoURprl",
        "colab_type": "text"
      },
      "source": [
        "# LongBART"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FNfNMSjTK61d",
        "colab_type": "code",
        "outputId": "f2278a45-26ee-48dd-ab62-4a34f8c56662",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "!git clone https://github.com/patil-suraj/longbart.git\n",
        "%cd longbart"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/longbart\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yOSSddShCaz1",
        "colab_type": "code",
        "outputId": "8fc06a89-e421-4f67-df2d-004400b60fbc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 658
        }
      },
      "source": [
        "!pip install git+https://github.com/huggingface/transformers.git"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting git+https://github.com/huggingface/transformers.git\n",
            "  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-69qg24g6\n",
            "  Running command git clone -q https://github.com/huggingface/transformers.git /tmp/pip-req-build-69qg24g6\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (1.18.4)\n",
            "Collecting tokenizers==0.7.0\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)\n",
            "\u001b[K     |████████████████████████████████| 3.8MB 3.4MB/s \n",
            "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (20.4)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (3.0.12)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (2.23.0)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (4.41.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (2019.12.20)\n",
            "Collecting sentencepiece\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n",
            "\u001b[K     |████████████████████████████████| 1.1MB 59.3MB/s \n",
            "\u001b[?25hCollecting sacremoses\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
            "\u001b[K     |████████████████████████████████| 890kB 60.2MB/s \n",
            "\u001b[?25hRequirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from transformers==2.10.0) (0.7)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==2.10.0) (1.12.0)\n",
            "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==2.10.0) (2.4.7)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (3.0.4)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (2020.4.5.1)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.10.0) (2.9)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.10.0) (7.1.2)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.10.0) (0.15.1)\n",
            "Building wheels for collected packages: transformers, sacremoses\n",
            "  Building wheel for transformers (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for transformers: filename=transformers-2.10.0-cp36-none-any.whl size=667026 sha256=1fbfcea1f14b529238dfb962701daa0b4df4df60b6927f0793ca24f52b161af8\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-gv3xom6x/wheels/33/eb/3b/4bf5dd835e865e472d4fc0754f35ac0edb08fe852e8f21655f\n",
            "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=87f114a7eda7007f2e152396969d392ab29f68a5812c6a57e3538111bdf7c32e\n",
            "  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
            "Successfully built transformers sacremoses\n",
            "Installing collected packages: tokenizers, sentencepiece, sacremoses, transformers\n",
            "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.7.0 transformers-2.10.0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3-EQ0F6qCm_d",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import logging\n",
        "import os\n",
        "import math\n",
        "from dataclasses import dataclass, field\n",
        "from transformers import RobertaForMaskedLM, RobertaTokenizerFast, TextDataset, DataCollatorForLanguageModeling, Trainer\n",
        "from transformers import BartTokenizer\n",
        "from transformers import TrainingArguments, HfArgumentParser\n",
        "from transformers.modeling_longformer import LongformerSelfAttention\n",
        "\n",
        "from modeling_bart import BartForConditionalGeneration\n",
        "\n",
        "logger = logging.getLogger(__name__)\n",
        "logging.basicConfig(level=logging.INFO)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oO4RNqIODK9z",
        "colab_type": "code",
        "outputId": "bc3a13ca-03ed-45e9-f593-490fe0b36d61",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "# lets use a tiny version of bart for initial experiment \n",
        "tokenizer = BartTokenizer.from_pretrained('sshleifer/bart-tiny-random')\n",
        "bart = BartForConditionalGeneration.from_pretrained('sshleifer/bart-tiny-random')\n",
        "\n",
        "# load ROBERta model to see the difference between bart encoder layer and roberta encoder layer \n",
        "roberta = RobertaForMaskedLM.from_pretrained('roberta-base')"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.tokenization_utils:Model name 'sshleifer/bart-tiny-random' not found in model shortcut name list (bart-large, bart-large-mnli, bart-large-cnn, bart-large-xsum). Assuming 'sshleifer/bart-tiny-random' is a path, a model identifier, or url to a directory containing tokenizer files.\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/vocab.json from cache at /root/.cache/torch/transformers/70b9426bcc7c2cd96de53c16f7e13eabbc8373cecf5c38d68ced2fcc25e3382a.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/merges.txt from cache at /root/.cache/torch/transformers/dc37af6307b1a17037d2d066cb55af9cc1cf55d38d3b1f862221fc8d87b9a672.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/added_tokens.json from cache at None\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/special_tokens_map.json from cache at None\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/tokenizer_config.json from cache at None\n",
            "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/config.json from cache at /root/.cache/torch/transformers/ce13c5b4dd7e5d8a0d2417a7842224d1535d0cd14dd928809bdb6029e1fa7af3.0a5a7d7a4a1c79b5dce5d054a64dd329deefdcbe16b8cf8a4e825bbed4186047\n",
            "INFO:transformers.configuration_utils:Model config BartConfig {\n",
            "  \"_num_labels\": 3,\n",
            "  \"activation_dropout\": 0.0,\n",
            "  \"activation_function\": \"gelu\",\n",
            "  \"add_bias_logits\": false,\n",
            "  \"add_final_layer_norm\": false,\n",
            "  \"architectures\": [\n",
            "    \"BartForConditionalGeneration\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classif_dropout\": 0.0,\n",
            "  \"d_model\": 24,\n",
            "  \"decoder_attention_heads\": 2,\n",
            "  \"decoder_ffn_dim\": 16,\n",
            "  \"decoder_layerdrop\": 0.0,\n",
            "  \"decoder_layers\": 2,\n",
            "  \"decoder_max_position_embeddings\": 1024,\n",
            "  \"decoder_start_token_id\": 2,\n",
            "  \"dropout\": 0.1,\n",
            "  \"encoder_attention_heads\": 2,\n",
            "  \"encoder_ffn_dim\": 16,\n",
            "  \"encoder_layerdrop\": 0.0,\n",
            "  \"encoder_layers\": 2,\n",
            "  \"encoder_max_position_embeddings\": 1024,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\"\n",
            "  },\n",
            "  \"init_std\": 0.02,\n",
            "  \"is_encoder_decoder\": true,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_2\": 2\n",
            "  },\n",
            "  \"max_position_embeddings\": 1024,\n",
            "  \"model_type\": \"bart\",\n",
            "  \"normalize_before\": false,\n",
            "  \"normalize_embedding\": true,\n",
            "  \"num_hidden_layers\": 2,\n",
            "  \"output_past\": true,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"prefix\": \" \",\n",
            "  \"scale_embedding\": false,\n",
            "  \"static_position_embeddings\": false,\n",
            "  \"task_specific_params\": {\n",
            "    \"summarization\": {\n",
            "      \"early_stopping\": true,\n",
            "      \"length_penalty\": 2.0,\n",
            "      \"max_length\": 142,\n",
            "      \"min_length\": 56,\n",
            "      \"no_repeat_ngram_size\": 3,\n",
            "      \"num_beams\": 4\n",
            "    }\n",
            "  },\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/sshleifer/bart-tiny-random/pytorch_model.bin from cache at /root/.cache/torch/transformers/002911b8e4cea0a107864f5b17f20c10f613d256e92e3c1247d6d174fbf56fe5.bf6ebaf6162cfbfbad2ce1909278a9ea1fbfe9284d318bff8bccddfdaa104205\n",
            "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n",
            "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /root/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.117c81977c5979de8c088352e74ec6e70f5c66096c28b61d3c50101609b39690\n",
            "INFO:transformers.configuration_utils:Model config RobertaConfig {\n",
            "  \"architectures\": [\n",
            "    \"RobertaForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"layer_norm_eps\": 1e-05,\n",
            "  \"max_position_embeddings\": 514,\n",
            "  \"model_type\": \"roberta\",\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"type_vocab_size\": 1,\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/roberta-base-pytorch_model.bin from cache at /root/.cache/torch/transformers/80b4a484eddeb259bec2f06a6f2f05d90934111628e0e1c09a33bd4a121358e1.49b88ba7ec2c26a7558dda98ca3884c3b80fa31cf43a1b1f23aef3ff81ba344e\n",
            "INFO:transformers.modeling_utils:Weights of RobertaForMaskedLM not initialized from pretrained model: ['lm_head.decoder.bias']\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kZcws3kaMFrp",
        "colab_type": "code",
        "outputId": "1ef71a3a-5f3c-4ed6-d033-a4b55a259325",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 370
        }
      },
      "source": [
        "roberta.config"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "RobertaConfig {\n",
              "  \"architectures\": [\n",
              "    \"RobertaForMaskedLM\"\n",
              "  ],\n",
              "  \"attention_probs_dropout_prob\": 0.1,\n",
              "  \"bos_token_id\": 0,\n",
              "  \"eos_token_id\": 2,\n",
              "  \"hidden_act\": \"gelu\",\n",
              "  \"hidden_dropout_prob\": 0.1,\n",
              "  \"hidden_size\": 768,\n",
              "  \"initializer_range\": 0.02,\n",
              "  \"intermediate_size\": 3072,\n",
              "  \"layer_norm_eps\": 1e-05,\n",
              "  \"max_position_embeddings\": 514,\n",
              "  \"model_type\": \"roberta\",\n",
              "  \"num_attention_heads\": 12,\n",
              "  \"num_hidden_layers\": 12,\n",
              "  \"pad_token_id\": 1,\n",
              "  \"type_vocab_size\": 1,\n",
              "  \"vocab_size\": 50265\n",
              "}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2ucGw5qrEtp8",
        "colab_type": "code",
        "outputId": "d844579b-a54a-4799-e783-4945930d6eb7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "bart.config"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "BartConfig {\n",
              "  \"_num_labels\": 3,\n",
              "  \"activation_dropout\": 0.0,\n",
              "  \"activation_function\": \"gelu\",\n",
              "  \"add_bias_logits\": false,\n",
              "  \"add_final_layer_norm\": false,\n",
              "  \"architectures\": [\n",
              "    \"BartForConditionalGeneration\"\n",
              "  ],\n",
              "  \"attention_dropout\": 0.0,\n",
              "  \"bos_token_id\": 0,\n",
              "  \"classif_dropout\": 0.0,\n",
              "  \"d_model\": 24,\n",
              "  \"decoder_attention_heads\": 2,\n",
              "  \"decoder_ffn_dim\": 16,\n",
              "  \"decoder_layerdrop\": 0.0,\n",
              "  \"decoder_layers\": 2,\n",
              "  \"decoder_max_position_embeddings\": 1024,\n",
              "  \"decoder_start_token_id\": 2,\n",
              "  \"dropout\": 0.1,\n",
              "  \"encoder_attention_heads\": 2,\n",
              "  \"encoder_ffn_dim\": 16,\n",
              "  \"encoder_layerdrop\": 0.0,\n",
              "  \"encoder_layers\": 2,\n",
              "  \"encoder_max_position_embeddings\": 1024,\n",
              "  \"eos_token_id\": 2,\n",
              "  \"id2label\": {\n",
              "    \"0\": \"LABEL_0\",\n",
              "    \"1\": \"LABEL_1\",\n",
              "    \"2\": \"LABEL_2\"\n",
              "  },\n",
              "  \"init_std\": 0.02,\n",
              "  \"is_encoder_decoder\": true,\n",
              "  \"label2id\": {\n",
              "    \"LABEL_0\": 0,\n",
              "    \"LABEL_1\": 1,\n",
              "    \"LABEL_2\": 2\n",
              "  },\n",
              "  \"max_position_embeddings\": 1024,\n",
              "  \"model_type\": \"bart\",\n",
              "  \"normalize_before\": false,\n",
              "  \"normalize_embedding\": true,\n",
              "  \"num_hidden_layers\": 2,\n",
              "  \"output_past\": true,\n",
              "  \"pad_token_id\": 1,\n",
              "  \"prefix\": \" \",\n",
              "  \"scale_embedding\": false,\n",
              "  \"static_position_embeddings\": false,\n",
              "  \"task_specific_params\": {\n",
              "    \"summarization\": {\n",
              "      \"early_stopping\": true,\n",
              "      \"length_penalty\": 2.0,\n",
              "      \"max_length\": 142,\n",
              "      \"min_length\": 56,\n",
              "      \"no_repeat_ngram_size\": 3,\n",
              "      \"num_beams\": 4\n",
              "    }\n",
              "  },\n",
              "  \"vocab_size\": 50265\n",
              "}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RNA4Z21mEvGt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "bart_layer = bart.model.encoder.layers[0]\n",
        "roberta_layer = roberta.roberta.encoder.layer[0]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2KAOBHPdFGn_",
        "colab_type": "code",
        "outputId": "9121f646-3762-4016-b62d-62523686f8af",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 403
        }
      },
      "source": [
        "roberta_layer"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "BertLayer(\n",
              "  (attention): BertAttention(\n",
              "    (self): BertSelfAttention(\n",
              "      (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "      (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "      (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "      (dropout): Dropout(p=0.1, inplace=False)\n",
              "    )\n",
              "    (output): BertSelfOutput(\n",
              "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
              "      (dropout): Dropout(p=0.1, inplace=False)\n",
              "    )\n",
              "  )\n",
              "  (intermediate): BertIntermediate(\n",
              "    (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "  )\n",
              "  (output): BertOutput(\n",
              "    (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
              "    (dropout): Dropout(p=0.1, inplace=False)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "okn8MAVcFDS9",
        "colab_type": "code",
        "outputId": "2f6a72a3-d7e3-46d6-c81f-a1fc1af85258",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        }
      },
      "source": [
        "bart_layer"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "EncoderLayer(\n",
              "  (self_attn): SelfAttention(\n",
              "    (k_proj): Linear(in_features=24, out_features=24, bias=True)\n",
              "    (v_proj): Linear(in_features=24, out_features=24, bias=True)\n",
              "    (q_proj): Linear(in_features=24, out_features=24, bias=True)\n",
              "    (out_proj): Linear(in_features=24, out_features=24, bias=True)\n",
              "  )\n",
              "  (self_attn_layer_norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)\n",
              "  (fc1): Linear(in_features=24, out_features=16, bias=True)\n",
              "  (fc2): Linear(in_features=16, out_features=24, bias=True)\n",
              "  (final_layer_norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pURj--xZL7Ef",
        "colab_type": "text"
      },
      "source": [
        "BART calculates the output projection in the attention layer itself, also the `forward` paramter names of `SelfAttention` layer used in BART are different than that of `BertSelfAttention`. So we'll need to wrap `LongformerSelfAttention` to use it for BART"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jfPgsJ8YQR4A",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import math\n",
        "from typing import Dict, List, Optional, Tuple\n",
        "\n",
        "import torch\n",
        "from torch import Tensor, nn\n",
        "\n",
        "class LongformerSelfAttentionForBart(nn.Module):\n",
        "  def __init__(self, config, layer_id):\n",
        "    super().__init__()\n",
        "    self.embed_dim = config.d_model\n",
        "    self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)\n",
        "    self.output = nn.Linear(self.embed_dim, self.embed_dim)\n",
        "  \n",
        "  def forward(\n",
        "    self,\n",
        "    query,\n",
        "    key: Optional[Tensor],\n",
        "    key_padding_mask: Optional[Tensor] = None,\n",
        "    layer_state: Optional[Dict[str, Optional[Tensor]]] = None,\n",
        "    attn_mask: Optional[Tensor] = None,\n",
        "    need_weights=False,\n",
        "  ) -> Tuple[Tensor, Optional[Tensor]]:\n",
        "    \n",
        "    tgt_len, bsz, embed_dim = query.size()\n",
        "    assert embed_dim == self.embed_dim\n",
        "    assert list(query.size()) == [tgt_len, bsz, embed_dim]\n",
        "\n",
        "    # LongformerSelfAttention expects this shape\n",
        "    query = query.view(bsz, tgt_len, embed_dim)\n",
        "\n",
        "    outputs = self.longformer_self_attn(\n",
        "        query,\n",
        "        attention_mask=attn_mask,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "    )\n",
        "\n",
        "    attn_output = outputs[0] \n",
        "    attn_output = attn_output.contiguous().view(tgt_len, bsz, embed_dim)\n",
        "    attn_output = self.output(attn_output)\n",
        "\n",
        "    return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6VIx_TmOELqF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class LongBartForConditionalGeneration(BartForConditionalGeneration):\n",
        "  def __init__(self, config):\n",
        "    super().__init__(config)\n",
        "    for i, layer in enumerate(self.model.encoder.layers):\n",
        "      # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`\n",
        "      layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lx28-eLNEou5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def create_long_model(save_model_to, base_model='bart-large', attention_window=512, max_pos=4096):\n",
        "    model = BartForConditionalGeneration.from_pretrained(base_model)\n",
        "    tokenizer = BartTokenizer.from_pretrained('bart-large', model_max_length=max_pos)\n",
        "    config = model.config\n",
        "\n",
        "    # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention\n",
        "    # expects attention_probs_dropout_prob, so set it here  \n",
        "    config.attention_probs_dropout_prob = config.attention_dropout\n",
        "\n",
        "    # extend position embeddings\n",
        "    tokenizer.model_max_length = max_pos\n",
        "    tokenizer.init_kwargs['model_max_length'] = max_pos\n",
        "    current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape\n",
        "    # config.max_position_embeddings = max_pos\n",
        "    config.encoder_max_position_embeddings = max_pos\n",
        "    max_pos += 2  # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2\n",
        "    assert max_pos > current_max_pos\n",
        "    # allocate a larger position embedding matrix\n",
        "    new_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)\n",
        "    # copy position embeddings over and over to initialize the new position embeddings\n",
        "    k = 2\n",
        "    step = current_max_pos - 2\n",
        "    while k < max_pos - 1:\n",
        "        new_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]\n",
        "        k += step\n",
        "    model.model.encoder.embed_positions.weight.data = new_pos_embed\n",
        "\n",
        "    # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`\n",
        "    config.attention_window = [attention_window] * config.num_hidden_layers\n",
        "    for i, layer in enumerate(model.model.encoder.layers):\n",
        "        longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i)\n",
        "        \n",
        "        longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj\n",
        "        longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj\n",
        "        longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj\n",
        "\n",
        "        longformer_self_attn_for_bart.longformer_self_attn.query_global = layer.self_attn.q_proj\n",
        "        longformer_self_attn_for_bart.longformer_self_attn.key_global = layer.self_attn.k_proj\n",
        "        longformer_self_attn_for_bart.longformer_self_attn.value_global = layer.self_attn.v_proj\n",
        "\n",
        "        longformer_self_attn_for_bart.output = layer.self_attn.out_proj\n",
        "\n",
        "        layer.self_attn = longformer_self_attn_for_bart\n",
        "\n",
        "    logger.info(f'saving model to {save_model_to}')\n",
        "    model.save_pretrained(save_model_to)\n",
        "    tokenizer.save_pretrained(save_model_to)\n",
        "    return model, tokenizer"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XuYkE_kGO-U-",
        "colab_type": "code",
        "outputId": "ee5176ee-677c-406b-feb2-1d1aa5ea2b23",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "# model_path = f'{training_args.output_dir}/roberta-base-{model_args.max_pos}'\n",
        "base_model = \"sshleifer/bart-tiny-random\"\n",
        "model_path = \"bart-tiny-random-4096\"\n",
        "attention_window = 512\n",
        "max_pos = 4096\n",
        "\n",
        "if not os.path.exists(model_path):\n",
        "    os.makedirs(model_path)\n",
        "\n",
        "# logger.info(f'Converting roberta-base into roberta-base-{model_args.max_pos}')\n",
        "model, tokenizer = create_long_model(\n",
        "    save_model_to=model_path,\n",
        "    base_model=base_model,\n",
        "    attention_window=attention_window,\n",
        "    max_pos=max_pos\n",
        ")"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/sshleifer/bart-tiny-random/config.json from cache at /root/.cache/torch/transformers/ce13c5b4dd7e5d8a0d2417a7842224d1535d0cd14dd928809bdb6029e1fa7af3.0a5a7d7a4a1c79b5dce5d054a64dd329deefdcbe16b8cf8a4e825bbed4186047\n",
            "INFO:transformers.configuration_utils:Model config BartConfig {\n",
            "  \"_num_labels\": 3,\n",
            "  \"activation_dropout\": 0.0,\n",
            "  \"activation_function\": \"gelu\",\n",
            "  \"add_bias_logits\": false,\n",
            "  \"add_final_layer_norm\": false,\n",
            "  \"architectures\": [\n",
            "    \"BartForConditionalGeneration\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classif_dropout\": 0.0,\n",
            "  \"d_model\": 24,\n",
            "  \"decoder_attention_heads\": 2,\n",
            "  \"decoder_ffn_dim\": 16,\n",
            "  \"decoder_layerdrop\": 0.0,\n",
            "  \"decoder_layers\": 2,\n",
            "  \"decoder_max_position_embeddings\": 1024,\n",
            "  \"decoder_start_token_id\": 2,\n",
            "  \"dropout\": 0.1,\n",
            "  \"encoder_attention_heads\": 2,\n",
            "  \"encoder_ffn_dim\": 16,\n",
            "  \"encoder_layerdrop\": 0.0,\n",
            "  \"encoder_layers\": 2,\n",
            "  \"encoder_max_position_embeddings\": 1024,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\"\n",
            "  },\n",
            "  \"init_std\": 0.02,\n",
            "  \"is_encoder_decoder\": true,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_2\": 2\n",
            "  },\n",
            "  \"max_position_embeddings\": 1024,\n",
            "  \"model_type\": \"bart\",\n",
            "  \"normalize_before\": false,\n",
            "  \"normalize_embedding\": true,\n",
            "  \"num_hidden_layers\": 2,\n",
            "  \"output_past\": true,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"prefix\": \" \",\n",
            "  \"scale_embedding\": false,\n",
            "  \"static_position_embeddings\": false,\n",
            "  \"task_specific_params\": {\n",
            "    \"summarization\": {\n",
            "      \"early_stopping\": true,\n",
            "      \"length_penalty\": 2.0,\n",
            "      \"max_length\": 142,\n",
            "      \"min_length\": 56,\n",
            "      \"no_repeat_ngram_size\": 3,\n",
            "      \"num_beams\": 4\n",
            "    }\n",
            "  },\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/sshleifer/bart-tiny-random/pytorch_model.bin from cache at /root/.cache/torch/transformers/002911b8e4cea0a107864f5b17f20c10f613d256e92e3c1247d6d174fbf56fe5.bf6ebaf6162cfbfbad2ce1909278a9ea1fbfe9284d318bff8bccddfdaa104205\n",
            "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n",
            "INFO:__main__:saving model to bart-tiny-random-4096\n",
            "INFO:transformers.configuration_utils:Configuration saved in bart-tiny-random-4096/config.json\n",
            "INFO:transformers.modeling_utils:Model weights saved in bart-tiny-random-4096/pytorch_model.bin\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XsnW-waPU3Ua",
        "colab_type": "code",
        "outputId": "5f6d97f5-2863-4535-969a-1fa884b5635e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "long_model_tiny = LongBartForConditionalGeneration.from_pretrained('bart-tiny-random-4096')"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.configuration_utils:loading configuration file bart-tiny-random-4096/config.json\n",
            "INFO:transformers.configuration_utils:Model config BartConfig {\n",
            "  \"_num_labels\": 3,\n",
            "  \"activation_dropout\": 0.0,\n",
            "  \"activation_function\": \"gelu\",\n",
            "  \"add_bias_logits\": false,\n",
            "  \"add_final_layer_norm\": false,\n",
            "  \"architectures\": [\n",
            "    \"BartForConditionalGeneration\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"attention_probs_dropout_prob\": 0.0,\n",
            "  \"attention_window\": [\n",
            "    512,\n",
            "    512\n",
            "  ],\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classif_dropout\": 0.0,\n",
            "  \"d_model\": 24,\n",
            "  \"decoder_attention_heads\": 2,\n",
            "  \"decoder_ffn_dim\": 16,\n",
            "  \"decoder_layerdrop\": 0.0,\n",
            "  \"decoder_layers\": 2,\n",
            "  \"decoder_max_position_embeddings\": 1024,\n",
            "  \"decoder_start_token_id\": 2,\n",
            "  \"dropout\": 0.1,\n",
            "  \"encoder_attention_heads\": 2,\n",
            "  \"encoder_ffn_dim\": 16,\n",
            "  \"encoder_layerdrop\": 0.0,\n",
            "  \"encoder_layers\": 2,\n",
            "  \"encoder_max_position_embeddings\": 4096,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\"\n",
            "  },\n",
            "  \"init_std\": 0.02,\n",
            "  \"is_encoder_decoder\": true,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_2\": 2\n",
            "  },\n",
            "  \"max_position_embeddings\": 1024,\n",
            "  \"model_type\": \"bart\",\n",
            "  \"normalize_before\": false,\n",
            "  \"normalize_embedding\": true,\n",
            "  \"num_hidden_layers\": 2,\n",
            "  \"output_past\": true,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"prefix\": \" \",\n",
            "  \"scale_embedding\": false,\n",
            "  \"static_position_embeddings\": false,\n",
            "  \"task_specific_params\": {\n",
            "    \"summarization\": {\n",
            "      \"early_stopping\": true,\n",
            "      \"length_penalty\": 2.0,\n",
            "      \"max_length\": 142,\n",
            "      \"min_length\": 56,\n",
            "      \"no_repeat_ngram_size\": 3,\n",
            "      \"num_beams\": 4\n",
            "    }\n",
            "  },\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "INFO:transformers.modeling_utils:loading weights file bart-tiny-random-4096/pytorch_model.bin\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z5QKIIdeYRDL",
        "colab_type": "code",
        "outputId": "dd326b65-6bc7-4a91-e670-e41b6f64784f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "TXT = \"My friends are <mask> but they eat too many carbs.\"\n",
        "\n",
        "input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt', max_length=4096, pad_to_max_length=True)['input_ids']\n",
        "\n",
        "logits = long_model_tiny(input_ids)[0]\n",
        "\n",
        "masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()\n",
        "probs = logits[0, masked_index].softmax(dim=0)\n",
        "values, predictions = probs.topk(5)\n",
        "tokenizer.decode(predictions).split()"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['<unk>.<pad><s></s>']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FNmzwNHAN1AI",
        "colab_type": "text"
      },
      "source": [
        "Now lets try with bart-large"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vAzZdj-1N3Is",
        "colab_type": "code",
        "outputId": "85cf0a32-fc89-4137-f61e-d2710ea85bc4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "# model_path = f'{training_args.output_dir}/roberta-base-{model_args.max_pos}'\n",
        "base_model = \"bart-large\"\n",
        "model_path = \"bart-large-4096\"\n",
        "attention_window = 512\n",
        "max_pos = 4096\n",
        "\n",
        "if not os.path.exists(model_path):\n",
        "    os.makedirs(model_path)\n",
        "\n",
        "# logger.info(f'Converting roberta-base into roberta-base-{model_args.max_pos}')\n",
        "model, tokenizer = create_long_model(\n",
        "    save_model_to=model_path,\n",
        "    base_model=base_model,\n",
        "    attention_window=attention_window,\n",
        "    max_pos=max_pos\n",
        ")"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json from cache at /root/.cache/torch/transformers/7f6632e580b7d9fd4f611dd96dab877cccfc319867b53b8b72fddca7fd64de5c.40bd49bcec9d93d8b0bfbd020088e2e1b6e6bb03e8e80aa5144638f90ca6bd61\n",
            "INFO:transformers.configuration_utils:Model config BartConfig {\n",
            "  \"_num_labels\": 3,\n",
            "  \"activation_dropout\": 0.0,\n",
            "  \"activation_function\": \"gelu\",\n",
            "  \"add_bias_logits\": false,\n",
            "  \"add_final_layer_norm\": false,\n",
            "  \"architectures\": [\n",
            "    \"BartModel\",\n",
            "    \"BartForMaskedLM\",\n",
            "    \"BartForSequenceClassification\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classif_dropout\": 0.0,\n",
            "  \"d_model\": 1024,\n",
            "  \"decoder_attention_heads\": 16,\n",
            "  \"decoder_ffn_dim\": 4096,\n",
            "  \"decoder_layerdrop\": 0.0,\n",
            "  \"decoder_layers\": 12,\n",
            "  \"decoder_max_position_embeddings\": 1024,\n",
            "  \"decoder_start_token_id\": 2,\n",
            "  \"dropout\": 0.1,\n",
            "  \"encoder_attention_heads\": 16,\n",
            "  \"encoder_ffn_dim\": 4096,\n",
            "  \"encoder_layerdrop\": 0.0,\n",
            "  \"encoder_layers\": 12,\n",
            "  \"encoder_max_position_embeddings\": 1024,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\"\n",
            "  },\n",
            "  \"init_std\": 0.02,\n",
            "  \"is_encoder_decoder\": true,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_2\": 2\n",
            "  },\n",
            "  \"max_position_embeddings\": 1024,\n",
            "  \"model_type\": \"bart\",\n",
            "  \"normalize_before\": false,\n",
            "  \"normalize_embedding\": true,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"output_past\": false,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"prefix\": \" \",\n",
            "  \"scale_embedding\": false,\n",
            "  \"static_position_embeddings\": false,\n",
            "  \"task_specific_params\": {\n",
            "    \"summarization\": {\n",
            "      \"early_stopping\": true,\n",
            "      \"length_penalty\": 2.0,\n",
            "      \"max_length\": 142,\n",
            "      \"min_length\": 56,\n",
            "      \"no_repeat_ngram_size\": 3,\n",
            "      \"num_beams\": 4\n",
            "    }\n",
            "  },\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/facebook/bart-large/pytorch_model.bin from cache at /root/.cache/torch/transformers/2e7cae41bb1dd1f18e498ff4ff0ea85f7e9bc2b637439e2d95c485c5d5bdd579.5be2a88ec29f5969270f98902db392beab8be8a6a7ecc588d410ada3e32c4263\n",
            "INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']\n",
            "INFO:transformers.modeling_utils:Weights from pretrained model not used in BartForConditionalGeneration: ['encoder.version', 'decoder.version']\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /root/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /root/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n",
            "INFO:__main__:saving model to bart-large-4096\n",
            "INFO:transformers.configuration_utils:Configuration saved in bart-large-4096/config.json\n",
            "INFO:transformers.modeling_utils:Model weights saved in bart-large-4096/pytorch_model.bin\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9-lgWprfN8QW",
        "colab_type": "code",
        "outputId": "c1a817ea-3a05-4676-acea-81b27b3f6591",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "long_model = LongBartForConditionalGeneration.from_pretrained('bart-large-4096')\n",
        "tokenizer = BartTokenizer.from_pretrained('bart-large-4096')"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.configuration_utils:loading configuration file bart-large-4096/config.json\n",
            "INFO:transformers.configuration_utils:Model config BartConfig {\n",
            "  \"_num_labels\": 3,\n",
            "  \"activation_dropout\": 0.0,\n",
            "  \"activation_function\": \"gelu\",\n",
            "  \"add_bias_logits\": false,\n",
            "  \"add_final_layer_norm\": false,\n",
            "  \"architectures\": [\n",
            "    \"BartForConditionalGeneration\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"attention_probs_dropout_prob\": 0.0,\n",
            "  \"attention_window\": [\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512,\n",
            "    512\n",
            "  ],\n",
            "  \"bos_token_id\": 0,\n",
            "  \"classif_dropout\": 0.0,\n",
            "  \"d_model\": 1024,\n",
            "  \"decoder_attention_heads\": 16,\n",
            "  \"decoder_ffn_dim\": 4096,\n",
            "  \"decoder_layerdrop\": 0.0,\n",
            "  \"decoder_layers\": 12,\n",
            "  \"decoder_max_position_embeddings\": 1024,\n",
            "  \"decoder_start_token_id\": 2,\n",
            "  \"dropout\": 0.1,\n",
            "  \"encoder_attention_heads\": 16,\n",
            "  \"encoder_ffn_dim\": 4096,\n",
            "  \"encoder_layerdrop\": 0.0,\n",
            "  \"encoder_layers\": 12,\n",
            "  \"encoder_max_position_embeddings\": 4096,\n",
            "  \"eos_token_id\": 2,\n",
            "  \"id2label\": {\n",
            "    \"0\": \"LABEL_0\",\n",
            "    \"1\": \"LABEL_1\",\n",
            "    \"2\": \"LABEL_2\"\n",
            "  },\n",
            "  \"init_std\": 0.02,\n",
            "  \"is_encoder_decoder\": true,\n",
            "  \"label2id\": {\n",
            "    \"LABEL_0\": 0,\n",
            "    \"LABEL_1\": 1,\n",
            "    \"LABEL_2\": 2\n",
            "  },\n",
            "  \"max_position_embeddings\": 1024,\n",
            "  \"model_type\": \"bart\",\n",
            "  \"normalize_before\": false,\n",
            "  \"normalize_embedding\": true,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"output_past\": false,\n",
            "  \"pad_token_id\": 1,\n",
            "  \"prefix\": \" \",\n",
            "  \"scale_embedding\": false,\n",
            "  \"static_position_embeddings\": false,\n",
            "  \"task_specific_params\": {\n",
            "    \"summarization\": {\n",
            "      \"early_stopping\": true,\n",
            "      \"length_penalty\": 2.0,\n",
            "      \"max_length\": 142,\n",
            "      \"min_length\": 56,\n",
            "      \"no_repeat_ngram_size\": 3,\n",
            "      \"num_beams\": 4\n",
            "    }\n",
            "  },\n",
            "  \"vocab_size\": 50265\n",
            "}\n",
            "\n",
            "INFO:transformers.modeling_utils:loading weights file bart-large-4096/pytorch_model.bin\n",
            "INFO:transformers.tokenization_utils:Model name 'bart-large-4096' not found in model shortcut name list (bart-large, bart-large-mnli, bart-large-cnn, bart-large-xsum). Assuming 'bart-large-4096' is a path, a model identifier, or url to a directory containing tokenizer files.\n",
            "INFO:transformers.tokenization_utils:Didn't find file bart-large-4096/added_tokens.json. We won't load it.\n",
            "INFO:transformers.tokenization_utils:loading file bart-large-4096/vocab.json\n",
            "INFO:transformers.tokenization_utils:loading file bart-large-4096/merges.txt\n",
            "INFO:transformers.tokenization_utils:loading file None\n",
            "INFO:transformers.tokenization_utils:loading file bart-large-4096/special_tokens_map.json\n",
            "INFO:transformers.tokenization_utils:loading file bart-large-4096/tokenizer_config.json\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cLhZFQMYONPb",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "f4eda1b3-5333-4144-bdd4-9804046dd30a"
      },
      "source": [
        "TXT = \"My friends are <mask> but they eat too many carbs.\"\n",
        "\n",
        "# 4096 seq len crashes even with 35 GB memory\n",
        "# so we also probably need sliding-window attention in decoder as well\n",
        "input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt', max_length=2560, pad_to_max_length=True)['input_ids']\n",
        "\n",
        "logits = long_model(input_ids)[0]\n",
        "\n",
        "masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()\n",
        "probs = logits[0, masked_index].softmax(dim=0)\n",
        "values, predictions = probs.topk(5)\n",
        "tokenizer.decode(predictions).split()"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['having', 'still', 'going', 'getting', 'not']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    }
  ]
}