{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2iYLQO5Evvqy"
   },
   "source": [
    "# Training a `Robust' Adapter with AdapterDrop\n",
    "\n",
    "This notebook extends our quickstart adapter training notebook to illustrate how we can use AdapterDrop\n",
    "to robustly train an adapter, i.e. adapters that allow us to dynmically dropp layers for faster multi-task inference.\n",
    "Please have a look at the original adapter training notebook for more details on the setup."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a-XTIOLv0isn"
   },
   "source": [
    "## Installation\n",
    "\n",
    "First, let's install the required libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "ju-alwbHmKYA",
    "outputId": "44bbe24a-0925-46c2-aaf8-d9b5cdd7d0e3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/hSterz/adapter-transformers.git@notebooks\n",
      "  Cloning https://github.com/hSterz/adapter-transformers.git (to revision notebooks) to /tmp/pip-req-build-9ffnhq9w\n",
      "  Running command git clone -q https://github.com/hSterz/adapter-transformers.git /tmp/pip-req-build-9ffnhq9w\n",
      "  Running command git checkout -b notebooks --track origin/notebooks\n",
      "  Switched to a new branch 'notebooks'\n",
      "  Branch 'notebooks' set up to track remote branch 'notebooks' from 'origin'.\n",
      "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
      "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
      "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied, skipping upgrade: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (0.10.1)\n",
      "Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (20.9)\n",
      "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (2.23.0)\n",
      "Requirement already satisfied, skipping upgrade: filelock in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (3.0.12)\n",
      "Requirement already satisfied, skipping upgrade: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (1.19.5)\n",
      "Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (3.7.2)\n",
      "Requirement already satisfied, skipping upgrade: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (2019.12.20)\n",
      "Requirement already satisfied, skipping upgrade: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (4.41.1)\n",
      "Requirement already satisfied, skipping upgrade: sacremoses in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (0.0.43)\n",
      "Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->adapter-transformers==2.0.0a1) (2.4.7)\n",
      "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (1.24.3)\n",
      "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (2.10)\n",
      "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (3.0.4)\n",
      "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (2020.12.5)\n",
      "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0a1) (3.4.1)\n",
      "Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0a1) (3.7.4.3)\n",
      "Requirement already satisfied, skipping upgrade: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (1.0.1)\n",
      "Requirement already satisfied, skipping upgrade: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (7.1.2)\n",
      "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (1.15.0)\n",
      "Building wheels for collected packages: adapter-transformers\n",
      "  Building wheel for adapter-transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for adapter-transformers: filename=adapter_transformers-2.0.0a1-cp37-none-any.whl size=2009344 sha256=e83ea5be2754c747bd098be28320e515535b097366290d44d0aeb16302f7d1b1\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-5g7ew8pa/wheels/e2/39/27/3dcf47f4b1bf14923684680dc1f44c2e935944359f16487c86\n",
      "Successfully built adapter-transformers\n",
      "Installing collected packages: adapter-transformers\n",
      "  Found existing installation: adapter-transformers 2.0.0a1\n",
      "    Uninstalling adapter-transformers-2.0.0a1:\n",
      "      Successfully uninstalled adapter-transformers-2.0.0a1\n",
      "Successfully installed adapter-transformers-2.0.0a1\n"
     ]
    },
    {
     "data": {
      "application/vnd.colab-display-data+json": {
       "pip_warning": {
        "packages": [
         "transformers"
        ]
       }
      }
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (1.5.0)\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.1.5)\n",
      "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n",
      "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.41.1)\n",
      "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.11.1)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.7/dist-packages (from datasets) (0.8.7)\n",
      "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from datasets) (3.7.2)\n",
      "Requirement already satisfied: pyarrow>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n",
      "Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.3)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.19.5)\n",
      "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (2.0.0)\n",
      "Requirement already satisfied: huggingface-hub<0.1.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.0.7)\n",
      "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2018.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.1)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2020.12.5)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n",
      "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->datasets) (3.4.1)\n",
      "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->datasets) (3.7.4.3)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<0.1.0->datasets) (3.0.12)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install -U adapter-transformers\n",
    "!pip install datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7Mx916lBCfoL",
    "outputId": "bec39206-a7d9-415a-a3fa-90b948d51489",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default\n",
      "Reusing dataset rotten_tomatoes_movie_review (/root/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/9198dbc50858df8bdb0d5f18ccaf33125800af96ad8434bc8b829918c987ee8a)\n",
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/9198dbc50858df8bdb0d5f18ccaf33125800af96ad8434bc8b829918c987ee8a/cache-eea56258438b8fa9.arrow\n",
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/9198dbc50858df8bdb0d5f18ccaf33125800af96ad8434bc8b829918c987ee8a/cache-f71bd7769a0d3d1b.arrow\n",
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/9198dbc50858df8bdb0d5f18ccaf33125800af96ad8434bc8b829918c987ee8a/cache-99fbbc1783d63db9.arrow\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import RobertaTokenizer\n",
    "\n",
    "dataset = load_dataset(\"rotten_tomatoes\")\n",
    "tokenizer = RobertaTokenizer.from_pretrained(\"roberta-base\")\n",
    "\n",
    "def encode_batch(batch):\n",
    "  \"\"\"Encodes a batch of input data using the model tokenizer.\"\"\"\n",
    "  return tokenizer(batch[\"text\"], max_length=80, truncation=True, padding=\"max_length\")\n",
    "\n",
    "# Encode the input data\n",
    "dataset = dataset.map(encode_batch, batched=True)\n",
    "# The transformers model expects the target class column to be named \"labels\"\n",
    "dataset.rename_column_(\"label\", \"labels\")\n",
    "# Transform to pytorch tensors and only output the required columns\n",
    "dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S2-2CbfPGYvi"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Tp9uG-pT-qgv",
    "outputId": "c6a19f25-26b9-44c7-887c-1d1a71e957fa"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModelWithHeads: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n",
      "- This IS expected if you are initializing RobertaModelWithHeads from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaModelWithHeads from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaModelWithHeads were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import RobertaConfig, RobertaModelWithHeads\n",
    "\n",
    "config = RobertaConfig.from_pretrained(\n",
    "    \"roberta-base\",\n",
    "    num_labels=2,\n",
    "    id2label={ 0: \"👎\", 1: \"👍\"},\n",
    ")\n",
    "model = RobertaModelWithHeads.from_pretrained(\n",
    "    \"roberta-base\",\n",
    "    config=config,\n",
    ")\n",
    "\n",
    "# Add a new adapter\n",
    "model.add_adapter(\"rotten_tomatoes\")\n",
    "# Add a matching classification head\n",
    "model.add_classification_head(\"rotten_tomatoes\", num_labels=2)\n",
    "# Activate the adapter\n",
    "model.train_adapter(\"rotten_tomatoes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ev5t_8i8HzJB"
   },
   "source": [
    "To dynamically drop adapter layers during training, we make use of HuggingFace's `TrainerCallback'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "5FRft_5AAlQd"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from transformers import TrainingArguments, AdapterTrainer, EvalPrediction, TrainerCallback\n",
    "\n",
    "class AdapterDropTrainerCallback(TrainerCallback):\n",
    "  def on_step_begin(self, args, state, control, **kwargs):\n",
    "    skip_layers = list(range(np.random.randint(0, 11)))\n",
    "    kwargs['model'].set_active_adapters(\"rotten_tomatoes\", skip_layers=skip_layers)\n",
    "\n",
    "  def on_evaluate(self, args, state, control, **kwargs):\n",
    "    # Deactivate skipping layers during evaluation (otherwise it would use the\n",
    "    # previous randomly chosen skip_layers and thus yield results not comparable\n",
    "    # across different epochs)\n",
    "    kwargs['model'].set_active_adapters(\"rotten_tomatoes\", skip_layers=None)\n",
    "\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    learning_rate=1e-4,\n",
    "    num_train_epochs=6,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    logging_steps=200,\n",
    "    output_dir=\"./training_output\",\n",
    "    overwrite_output_dir=True,\n",
    "    remove_unused_columns=False\n",
    ")\n",
    "\n",
    "def compute_accuracy(p: EvalPrediction):\n",
    "  preds = np.argmax(p.predictions, axis=1)\n",
    "  return {\"acc\": (preds == p.label_ids).mean()}\n",
    "\n",
    "trainer = AdapterTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"validation\"],\n",
    "    compute_metrics=compute_accuracy,\n",
    ")\n",
    "\n",
    "trainer.add_callback(AdapterDropTrainerCallback())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9iHhoYuLIdX3",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "We can now train and evaluate our robustly trained adapter!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 514
    },
    "id": "zZxaujENntNR",
    "outputId": "6700e4ac-1258-4bd8-ac30-a578c8d6c4ba",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='1602' max='1602' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1602/1602 07:38, Epoch 6/6]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.573800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.362600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.326900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.318700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.303100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.293200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.282300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.284900</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='34' max='34' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [34/34 00:04]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'epoch': 6.0,\n",
       " 'eval_acc': 0.8799249530956847,\n",
       " 'eval_loss': 0.29594820737838745,\n",
       " 'eval_mem_cpu_alloc_delta': 112111,\n",
       " 'eval_mem_cpu_peaked_delta': 214624,\n",
       " 'eval_mem_gpu_alloc_delta': 0,\n",
       " 'eval_mem_gpu_peaked_delta': 94487040,\n",
       " 'eval_runtime': 5.0862,\n",
       " 'eval_samples_per_second': 209.586}"
      ]
     },
     "execution_count": 20,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()\n",
    "trainer.evaluate()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Adapter_Drop_Training.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}