{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mI6aVBQRy6LC"
   },
   "source": [
    "# 3️⃣ Combining Pre-Trained Adapters using AdapterFusion\n",
    "\n",
    "In [the previous notebook](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/02_Adapter_Inference.ipynb), we loaded a single pre-trained adapter from _AdapterHub_. Now we will explore how to take advantage of multiple pre-trained adapters to combine their knowledge on a new task. This setup is called **AdapterFusion** ([Pfeiffer et al., 2020](https://arxiv.org/pdf/2005.00247.pdf)).\n",
    "\n",
    "For this guide, we select **CommitmentBank** ([De Marneffe et al., 2019](https://github.com/mcdm/CommitmentBank)), a three-class textual entailment dataset, as our target task. We will fuse [adapters from AdapterHub](https://adapterhub.ml/explore/) which were pre-trained on different tasks. During training, their represantions are kept fix while a newly introduced fusion layer is trained. As our base model, we will use BERT (`bert-base-uncased`). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3L9gYpCV28OA"
   },
   "source": [
    "## Installation\n",
    "\n",
    "Again, we install `adapter-transformers` and HuggingFace's `datasets` library first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "qL3Sq1HQynCq",
    "outputId": "0ea7eca1-e986-4a92-f18d-94b97f4ae790"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/calpt/adapter-transformers.git@sync/v4.4.x\n",
      "  Cloning https://github.com/calpt/adapter-transformers.git (to revision sync/v4.4.x) to /tmp/pip-req-build-4ftwv687\n",
      "  Running command git clone -q https://github.com/calpt/adapter-transformers.git /tmp/pip-req-build-4ftwv687\n",
      "  Running command git checkout -b sync/v4.4.x --track origin/sync/v4.4.x\n",
      "  Switched to a new branch 'sync/v4.4.x'\n",
      "  Branch 'sync/v4.4.x' set up to track remote branch 'sync/v4.4.x' from 'origin'.\n",
      "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
      "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
      "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (20.9)\n",
      "Requirement already satisfied, skipping upgrade: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (4.41.1)\n",
      "Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (3.7.2)\n",
      "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (2.23.0)\n",
      "Requirement already satisfied, skipping upgrade: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (2019.12.20)\n",
      "Requirement already satisfied, skipping upgrade: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (1.19.5)\n",
      "Requirement already satisfied, skipping upgrade: sacremoses in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (0.0.43)\n",
      "Requirement already satisfied, skipping upgrade: filelock in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (3.0.12)\n",
      "Requirement already satisfied, skipping upgrade: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (0.10.1)\n",
      "Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->adapter-transformers==2.0.0a1) (2.4.7)\n",
      "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0a1) (3.4.1)\n",
      "Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0a1) (3.7.4.3)\n",
      "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (2.10)\n",
      "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (2020.12.5)\n",
      "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (1.24.3)\n",
      "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (3.0.4)\n",
      "Requirement already satisfied, skipping upgrade: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (7.1.2)\n",
      "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (1.15.0)\n",
      "Requirement already satisfied, skipping upgrade: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (1.0.1)\n",
      "Building wheels for collected packages: adapter-transformers\n",
      "  Building wheel for adapter-transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for adapter-transformers: filename=adapter_transformers-2.0.0a1-cp37-none-any.whl size=2009312 sha256=d46bab14e9170519f32ec04fee520b0f7a97cc8f8d4a1784682cd7c962e88fe1\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-7padf2op/wheels/71/c6/6a/e39f2a61cb162d4d61971ff214d433e84984b7e8ff45710298\n",
      "Successfully built adapter-transformers\n",
      "Installing collected packages: adapter-transformers\n",
      "  Found existing installation: adapter-transformers 2.0.0a1\n",
      "    Uninstalling adapter-transformers-2.0.0a1:\n",
      "      Successfully uninstalled adapter-transformers-2.0.0a1\n",
      "Successfully installed adapter-transformers-2.0.0a1\n"
     ]
    },
    {
     "data": {
      "application/vnd.colab-display-data+json": {
       "pip_warning": {
        "packages": [
         "transformers"
        ]
       }
      }
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (1.5.0)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.7/dist-packages (from datasets) (0.8.7)\n",
      "Requirement already satisfied: pyarrow>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n",
      "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.41.1)\n",
      "Requirement already satisfied: huggingface-hub<0.1.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.0.7)\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.1.5)\n",
      "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (2.0.0)\n",
      "Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.3)\n",
      "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from datasets) (3.7.2)\n",
      "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.11.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.19.5)\n",
      "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<0.1.0->datasets) (3.0.12)\n",
      "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2018.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.1)\n",
      "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->datasets) (3.7.4.3)\n",
      "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->datasets) (3.4.1)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2020.12.5)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install -U adapter-transformers\n",
    "!pip install datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fP9OMSUT-FtL"
   },
   "source": [
    "## Dataset Preprocessing\n",
    "\n",
    "Before setting up training, we first prepare the training data. CommimentBank is part of the SuperGLUE benchmark and can be loaded via HuggingFace `datasets` using one line of code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "INW7UEhC-I6b",
    "outputId": "811a71c5-8e14-49ee-a4cf-577f8675ca49"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset super_glue (/root/.cache/huggingface/datasets/super_glue/cb/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'test': 250, 'train': 250, 'validation': 56}"
      ]
     },
     "execution_count": 14,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"super_glue\", \"cb\")\n",
    "dataset.num_rows"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "epiKaEz5dDVe"
   },
   "source": [
    "Every dataset sample has a premise, a hypothesis and a three-class class label:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ifu4q5IJ-hYI",
    "outputId": "429a2e8c-f3e3-4740-ecd4-fa48189f4812"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'hypothesis': Value(dtype='string', id=None),\n",
       " 'idx': Value(dtype='int32', id=None),\n",
       " 'label': ClassLabel(num_classes=3, names=['entailment', 'contradiction', 'neutral'], names_file=None, id=None),\n",
       " 'premise': Value(dtype='string', id=None)}"
      ]
     },
     "execution_count": 15,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'].features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uVa3Vk0QdNYI"
   },
   "source": [
    "Now, we need to encode all dataset samples to valid inputs for our `bert-base-uncased` model. Using `dataset.map()`, we can pass the full dataset through the tokenizer in batches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "hEnRCQfE_Oi3",
    "outputId": "34d18f5f-baac-4d6a-8900-111748433209"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/super_glue/cb/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/cache-e30f5a8850c231d1.arrow\n",
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/super_glue/cb/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/cache-185b7ced0a635e54.arrow\n",
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/super_glue/cb/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/cache-3efe51a43ae4e2a3.arrow\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertTokenizer\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "def encode_batch(batch):\n",
    "  \"\"\"Encodes a batch of input data using the model tokenizer.\"\"\"\n",
    "  return tokenizer(\n",
    "      batch[\"premise\"],\n",
    "      batch[\"hypothesis\"],\n",
    "      max_length=180,\n",
    "      truncation=True,\n",
    "      padding=\"max_length\"\n",
    "  )\n",
    "\n",
    "# Encode the input data\n",
    "dataset = dataset.map(encode_batch, batched=True)\n",
    "# The transformers model expects the target class column to be named \"labels\"\n",
    "dataset.rename_column_(\"label\", \"labels\")\n",
    "# Transform to pytorch tensors and only output the required columns\n",
    "dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TsXEy1TNeAnI"
   },
   "source": [
    "New we're ready to setup AdapterFusion..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xs21MzEQ_0v4"
   },
   "source": [
    "## Fusion Training\n",
    "\n",
    "We use a pre-trained BERT model from HuggingFace and instantiate our model using `BertModelWithHeads`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "fnq8n_KP_3aX",
    "outputId": "04b34ee7-6faf-4226-9b31-9fc39fcf15c4"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModelWithHeads: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModelWithHeads from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModelWithHeads from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertConfig, BertModelWithHeads\n",
    "\n",
    "id2label = {id: label for (id, label) in enumerate(dataset[\"train\"].features[\"labels\"].names)}\n",
    "\n",
    "config = BertConfig.from_pretrained(\n",
    "    \"bert-base-uncased\",\n",
    "    id2label=id2label,\n",
    ")\n",
    "model = BertModelWithHeads.from_pretrained(\n",
    "    \"bert-base-uncased\",\n",
    "    config=config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BDCNJzTXezcn"
   },
   "source": [
    "Now we have everything set up to load our _AdapterFusion_ setup. First, we load three adapters pre-trained on different tasks from the Hub: MultiNLI, QQP and QNLI. As we don't need their prediction heads, we pass `with_head=False` to the loading method. Next, we add a new fusion layer that combines all the adapters we've just loaded. Finally, we add a new classification head for our target task on top."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jRqbBgS0BoHJ"
   },
   "outputs": [],
   "source": [
    "from transformers.adapters.composition import Fuse\n",
    "\n",
    "# Load the pre-trained adapters we want to fuse\n",
    "model.load_adapter(\"nli/multinli@ukp\", load_as=\"multinli\", with_head=False)\n",
    "model.load_adapter(\"sts/qqp@ukp\", with_head=False)\n",
    "model.load_adapter(\"nli/qnli@ukp\", with_head=False)\n",
    "# Add a fusion layer for all loaded adapters\n",
    "model.add_adapter_fusion(Fuse(\"multinli\", \"qqp\", \"qnli\"))\n",
    "model.set_active_adapters(Fuse(\"multinli\", \"qqp\", \"qnli\"))\n",
    "\n",
    "# Add a classification head for our target task\n",
    "model.add_classification_head(\"cb\", num_labels=len(id2label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "60qIas8-il92"
   },
   "source": [
    "The last preparation step is to define and activate our adapter setup. Similar to `train_adapter()`, `train_adapter_fusion()` does two things: It freezes all weights of the model (including adapters!) except for the fusion layer and classification head. It also activates the given adapter setup to be used in very forward pass.\n",
    "\n",
    "The syntax for the adapter setup (which is also applied to other methods such as `set_active_adapters()`) works as follows:\n",
    "\n",
    "- a single string is interpreted as a single adapter\n",
    "- a list of strings is interpreted as a __stack__ of adapters\n",
    "- a _nested_ list of strings is interpreted as a __fusion__ of adapters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zgGqHJQbijgg"
   },
   "outputs": [],
   "source": [
    "# Unfreeze and activate fusion setup\n",
    "adapter_setup = Fuse(\"multinli\", \"qqp\", \"qnli\")\n",
    "model.train_adapter_fusion(adapter_setup)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Fb8BY5RAmzkd"
   },
   "source": [
    "For training, we make use of the `Trainer` class built-in into `transformers`. We configure the training process using a `TrainingArguments` object and define a method that will calculate the evaluation accuracy in the end. We pass both, together with the training and validation split of our dataset, to the trainer instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "j0gFxQRdDkQ6"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from transformers import TrainingArguments, AdapterTrainer, EvalPrediction\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    learning_rate=5e-5,\n",
    "    num_train_epochs=5,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    logging_steps=200,\n",
    "    output_dir=\"./training_output\",\n",
    "    overwrite_output_dir=True,\n",
    "    # The next line is important to ensure the dataset labels are properly passed to the model\n",
    "    remove_unused_columns=False,\n",
    ")\n",
    "\n",
    "def compute_accuracy(p: EvalPrediction):\n",
    "  preds = np.argmax(p.predictions, axis=1)\n",
    "  return {\"acc\": (preds == p.label_ids).mean()}\n",
    "\n",
    "trainer = AdapterTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"validation\"],\n",
    "    compute_metrics=compute_accuracy,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iKlYjA9rm2Kp"
   },
   "source": [
    "Start the training 🚀 (this will take a while)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 114
    },
    "id": "vSUs2FjXDmsx",
    "outputId": "934e087a-8b31-4861-c9d1-854290dd3c4c"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='40' max='40' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [40/40 00:50, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=40, training_loss=0.8146671295166016, metrics={'train_runtime': 52.2937, 'train_samples_per_second': 0.765, 'total_flos': 180914605650000.0, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': 53542, 'init_mem_gpu_alloc_delta': 536864256, 'init_mem_cpu_peaked_delta': 18306, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 172358, 'train_mem_gpu_alloc_delta': 265284608, 'train_mem_cpu_peaked_delta': 338785, 'train_mem_gpu_peaked_delta': 7706785792})"
      ]
     },
     "execution_count": 21,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "axvDsmnJnGUG"
   },
   "source": [
    "After completed training, let's check how well our setup performs on the validation set of our target dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "id": "PqcApZ_-DpwK",
    "outputId": "8a67eafc-3815-42ad-ab23-dfaba42d34a8"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='2' max='2' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2/2 00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'epoch': 5.0,\n",
       " 'eval_acc': 0.6964285714285714,\n",
       " 'eval_loss': 0.8152657747268677,\n",
       " 'eval_mem_cpu_alloc_delta': 68149,\n",
       " 'eval_mem_cpu_peaked_delta': 19846,\n",
       " 'eval_mem_gpu_alloc_delta': 0,\n",
       " 'eval_mem_gpu_peaked_delta': 409437696,\n",
       " 'eval_runtime': 1.0951,\n",
       " 'eval_samples_per_second': 51.139}"
      ]
     },
     "execution_count": 22,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "op8M7AfYnWhs"
   },
   "source": [
    "We can also use our setup to make some predictions (the example is from the test set of CB):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 36
    },
    "id": "e9wQcaNLNPAT",
    "outputId": "63841f17-590e-48f0-9677-2c590aa6bdcb"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'entailment'"
      ]
     },
     "execution_count": 23,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "def predict(premise, hypothesis):\n",
    "  encoded = tokenizer(premise, hypothesis, return_tensors=\"pt\")\n",
    "  if torch.cuda.is_available():\n",
    "    encoded.to(\"cuda\")\n",
    "  logits = model(**encoded)[0]\n",
    "  pred_class = torch.argmax(logits).item()\n",
    "  return id2label[pred_class]\n",
    "\n",
    "predict(\"\"\"\n",
    "``It doesn't happen very often.'' Karen went home\n",
    "happy at the end of the day. She didn't think that\n",
    "the work was difficult.\n",
    "\"\"\",\n",
    "\"the work was difficult\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H-fJK8OtnrAu"
   },
   "source": [
    "Finally, we can extract and save our fusion layer as well as all the adapters we used for training. Both can later be reloaded into the pre-trained model again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "YpHIRIHxJd5d",
    "outputId": "d92cc522-f412-4a69-c6a9-9deed1bee98b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 83056\n",
      "-rw-r--r-- 1 root root      407 Mar 28 10:12 adapter_fusion_config.json\n",
      "drwxr-xr-x 2 root root     4096 Mar 28 10:07 multinli\n",
      "-rw-r--r-- 1 root root 85031691 Mar 28 10:12 pytorch_model_adapter_fusion.bin\n",
      "drwxr-xr-x 2 root root     4096 Mar 28 10:07 qnli\n",
      "drwxr-xr-x 2 root root     4096 Mar 28 10:07 qqp\n"
     ]
    }
   ],
   "source": [
    "model.save_adapter_fusion(\"./saved\", \"multinli,qqp,qnli\")\n",
    "model.save_all_adapters(\"./saved\")\n",
    "\n",
    "!ls -l saved"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4PHWpVLApDWt"
   },
   "source": [
    "That's it. Do check out [the paper on AdapterFusion](https://arxiv.org/pdf/2005.00247.pdf) for a more theoretical view on what we've just seen.\n",
    "\n",
    "➡️ `adapter-transformers` also enables other composition methods beyond AdapterFusion. For example, check out [the next notebook in this series](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/04_Cross_Lingual_Transfer.ipynb) on cross-lingual transfer."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "03_Adapter_Fusion.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}