{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FG7kC-qt4C99"
   },
   "source": [
    "# Load Head with id2label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 835
    },
    "id": "BDbnBJQ64ByT",
    "outputId": "2ecdcdc1-e88e-437b-f47c-99d4d3944317"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/hSterz/adapter-transformers.git@dev/notebooks\n",
      "  Cloning https://github.com/hSterz/adapter-transformers.git (to revision dev/notebooks) to /tmp/pip-req-build-bndohr9x\n",
      "  Running command git clone -q https://github.com/hSterz/adapter-transformers.git /tmp/pip-req-build-bndohr9x\n",
      "  Running command git checkout -b dev/notebooks --track origin/dev/notebooks\n",
      "  Switched to a new branch 'dev/notebooks'\n",
      "  Branch 'dev/notebooks' set up to track remote branch 'dev/notebooks' from 'origin'.\n",
      "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
      "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
      "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied, skipping upgrade: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (0.10.1)\n",
      "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (2.23.0)\n",
      "Requirement already satisfied, skipping upgrade: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (1.19.5)\n",
      "Requirement already satisfied, skipping upgrade: filelock in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (3.0.12)\n",
      "Requirement already satisfied, skipping upgrade: sacremoses in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (0.0.43)\n",
      "Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (20.9)\n",
      "Requirement already satisfied, skipping upgrade: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (4.41.1)\n",
      "Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (3.7.2)\n",
      "Requirement already satisfied, skipping upgrade: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0a1) (2019.12.20)\n",
      "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (1.24.3)\n",
      "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (3.0.4)\n",
      "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (2020.12.5)\n",
      "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0a1) (2.10)\n",
      "Requirement already satisfied, skipping upgrade: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (1.0.1)\n",
      "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (1.15.0)\n",
      "Requirement already satisfied, skipping upgrade: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0a1) (7.1.2)\n",
      "Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->adapter-transformers==2.0.0a1) (2.4.7)\n",
      "Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0a1) (3.7.4.3)\n",
      "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0a1) (3.4.1)\n",
      "Building wheels for collected packages: adapter-transformers\n",
      "  Building wheel for adapter-transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for adapter-transformers: filename=adapter_transformers-2.0.0a1-cp37-none-any.whl size=2009407 sha256=c7167dd920abbd2b8cbf8e86b5035f8528068fa35fef8b8c444d87cdbef8b17f\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-s5fyrokm/wheels/11/f2/6e/5198e56bf8636ec6e948cb84619c6a1c6a7a5c25af561e7a79\n",
      "Successfully built adapter-transformers\n",
      "Installing collected packages: adapter-transformers\n",
      "  Found existing installation: adapter-transformers 2.0.0a1\n",
      "    Uninstalling adapter-transformers-2.0.0a1:\n",
      "      Successfully uninstalled adapter-transformers-2.0.0a1\n",
      "Successfully installed adapter-transformers-2.0.0a1\n"
     ]
    },
    {
     "data": {
      "application/vnd.colab-display-data+json": {
       "pip_warning": {
        "packages": [
         "transformers"
        ]
       }
      }
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "!pip install -U adapter-transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zd8fsYXX4uxm"
   },
   "source": [
    "First, we create the model and load the trained adapter and head. With the head, the labels, and the i2label mapping is automatically loaded. With `model.get_labels()` and `model.get_labels_dict()` you can access them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3uGJQZPz4QNz",
    "outputId": "580de827-4766-4fca-9bbf-8ff2ee6288d3"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModelWithHeads: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModelWithHeads from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModelWithHeads from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']\n",
      "{1: 'B-LOC', 7: 'B-MISC', 5: 'B-ORG', 3: 'B-PER', 2: 'I-LOC', 8: 'I-MISC', 6: 'I-ORG', 4: 'I-PER', 0: 'O'}\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoAdapterModel, AutoTokenizer, AdapterConfig\n",
    "import numpy as np\n",
    "\n",
    "model_name = \"bert-base-uncased\"\n",
    "config = AdapterConfig.load(\"pfeiffer\")\n",
    "model = AutoAdapterModel.from_pretrained(model_name)\n",
    "model.load_adapter(\"ner/conll2003@ukp\", \"text_task\", config=config)\n",
    "\n",
    "model.set_active_adapters([\"ner\"])\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "# How you can acces the labels and the mapping for a pretrained head\n",
    "print(model.get_labels())\n",
    "print(model.get_labels_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1-Ps9GKJ5ahe"
   },
   "source": [
    "This helper function allows us to get the sequence of ids of the predicted output for a specific input sentence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "ylUZiWHs6Oay"
   },
   "outputs": [],
   "source": [
    "def predict(sentence):\n",
    "  tokens = tokenizer.encode(\n",
    "        sentence,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "  model.eval()\n",
    "  preds = model(tokens, adapter_names=['ner'])[0]\n",
    "  preds = preds.detach().numpy()\n",
    "  preds = np.argmax(preds, axis=2)\n",
    "  return tokenizer.tokenize(sentence), preds.squeeze()[1:-1] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WuW4eYx5WJE9"
   },
   "source": [
    "If we want to use the model to predict the labels of a sentence we can use the `model.get_labels_dict()` function to map the predicted label ids to the corresponding label, as for the example text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-q1ErWcF5Z2w",
    "outputId": "baa59a26-2ba5-4ec3-fd6c-6a326a3c0ec6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "germany(B-LOC) '(O) s(O) representative(O) to(O) the(O) european(B-ORG) union(I-ORG) '(O) s(O) veterinary(B-ORG) committee(I-ORG) werner(B-PER) z(I-PER) ##wing(I-PER) ##mann(I-PER) said(O) on(O) wednesday(O) consumers(O) should(O) buy(O) sheep(O) ##me(O) ##at(O) from(O) countries(O) other(O) than(O) britain(B-LOC) until(O) the(O) scientific(O) advice(O) was(O) clearer(O) .(O) "
     ]
    }
   ],
   "source": [
    "example_text=\"Germany's representative to the European Union\\'s veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer.\"\n",
    "# get the mapping of ids to labels\n",
    "label_map = model.get_labels_dict()\n",
    "tokens, preds = predict(example_text)\n",
    "for token, pred in zip(tokens, preds):\n",
    "  print(f\"{token}({label_map[pred]}) \", end=\"\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "load_id2label.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
