{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6VFCTA5Z7sB6"
   },
   "source": [
    "# Training on Wikiann"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HV3IYEdoHnNF"
   },
   "source": [
    "In this notebook, we use [MAD-X 2.0](https://arxiv.org/pdf/2012.15562.pdf) with a stacked language and task adapter setup to zero-shot cross-lingual transfer for NER.\n",
    "We use a NER adapter from [AdapterHub.ml](https://adapterhub.ml/explore) pre-trained on the **English** portion of the [WikiAnn](https://www.aclweb.org/anthology/P17-1178.pdf) dataset and transfer to **Guarani** with a pre-trained language adapter.\n",
    "This notebook is similar to the 'run_ner.py' example script in 'examples/pytorch/token-classification/'.\n",
    "\n",
    "First, let's install 'adapter-transformers' and other required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3AJLS4197g-i",
    "outputId": "1c58522a-5353-487b-ede2-842d7ecdd2ef"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/hSterz/adapter-transformers.git@dev/batch_split\n",
      "  Cloning https://github.com/hSterz/adapter-transformers.git (to revision dev/batch_split) to /tmp/pip-req-build-70gxfnbd\n",
      "  Running command git clone -q https://github.com/hSterz/adapter-transformers.git /tmp/pip-req-build-70gxfnbd\n",
      "  Running command git checkout -b dev/batch_split --track origin/dev/batch_split\n",
      "  Switched to a new branch 'dev/batch_split'\n",
      "  Branch 'dev/batch_split' set up to track remote branch 'dev/batch_split' from 'origin'.\n",
      "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
      "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
      "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied, skipping upgrade: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (4.41.1)\n",
      "Requirement already satisfied, skipping upgrade: filelock in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (3.0.12)\n",
      "Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (4.0.1)\n",
      "Requirement already satisfied, skipping upgrade: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (2019.12.20)\n",
      "Requirement already satisfied, skipping upgrade: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (0.10.3)\n",
      "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (2.23.0)\n",
      "Requirement already satisfied, skipping upgrade: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (1.19.5)\n",
      "Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (20.9)\n",
      "Requirement already satisfied, skipping upgrade: sacremoses in /usr/local/lib/python3.7/dist-packages (from adapter-transformers==2.0.0) (0.0.45)\n",
      "Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0) (3.7.4.3)\n",
      "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->adapter-transformers==2.0.0) (3.4.1)\n",
      "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0) (2020.12.5)\n",
      "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0) (3.0.4)\n",
      "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0) (2.10)\n",
      "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->adapter-transformers==2.0.0) (1.24.3)\n",
      "Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->adapter-transformers==2.0.0) (2.4.7)\n",
      "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0) (1.15.0)\n",
      "Requirement already satisfied, skipping upgrade: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0) (7.1.2)\n",
      "Requirement already satisfied, skipping upgrade: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->adapter-transformers==2.0.0) (1.0.1)\n",
      "Building wheels for collected packages: adapter-transformers\n",
      "  Building wheel for adapter-transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for adapter-transformers: filename=adapter_transformers-2.0.0-cp37-none-any.whl size=2099599 sha256=3feb17dca0940df00288ead1da3264cf964d2029f5524f92859b76a0c06627c8\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-qwpi72pg/wheels/00/83/9b/5a3af174c4f3ad9b708896e050432422002e934b835dc9dd72\n",
      "Successfully built adapter-transformers\n",
      "Installing collected packages: adapter-transformers\n",
      "  Found existing installation: adapter-transformers 2.0.0\n",
      "    Uninstalling adapter-transformers-2.0.0:\n",
      "      Successfully uninstalled adapter-transformers-2.0.0\n",
      "Successfully installed adapter-transformers-2.0.0\n",
      "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (1.8.0)\n",
      "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.11.1)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.7/dist-packages (from datasets) (2021.6.0)\n",
      "Requirement already satisfied: huggingface-hub<0.1.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.0.10)\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.1.5)\n",
      "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n",
      "Requirement already satisfied: pyarrow<4.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n",
      "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from datasets) (4.0.1)\n",
      "Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.3)\n",
      "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.41.1)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (20.9)\n",
      "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (2.0.2)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.19.5)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<0.1.0->datasets) (3.0.12)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<0.1.0->datasets) (3.7.4.3)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.1)\n",
      "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2018.9)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2020.12.5)\n",
      "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->datasets) (3.4.1)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (2.4.7)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
      "Requirement already satisfied: seqeval in /usr/local/lib/python3.7/dist-packages (1.2.2)\n",
      "Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.7/dist-packages (from seqeval) (1.19.5)\n",
      "Requirement already satisfied: scikit-learn>=0.21.3 in /usr/local/lib/python3.7/dist-packages (from seqeval) (0.22.2.post1)\n",
      "Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.21.3->seqeval) (1.4.1)\n",
      "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.21.3->seqeval) (1.0.1)\n"
     ]
    }
   ],
   "source": [
    "!pip install -U adapter-transformers\n",
    "!pip install datasets\n",
    "!pip install seqeval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "knPViYpFJzup"
   },
   "source": [
    "Next, we initialize the tokenizer and the model with the correct labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "vgQ5DC_48_r_",
    "outputId": "6cdd6d98-9308-4395-8d0f-71b6588e06d5"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoConfig\n",
    "from datasets import load_dataset\n",
    "from torch.utils.data import Dataset\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm.notebook import tqdm\n",
    "from torch import nn\n",
    "import copy\n",
    "#The labels for the NER task and the dictionaries to map the to ids or \n",
    "#the other way around\n",
    "labels = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\"]\n",
    "id_2_label = {id_: label for id_, label in enumerate(labels)}\n",
    "label_2_id = {label: id_ for id_, label in enumerate(labels)}\n",
    "\n",
    "model_name = \"bert-base-multilingual-cased\"\n",
    "config = AutoConfig.from_pretrained(model_name, num_labels=len(labels), label2id=label_2_id, id2label=id_2_label)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForTokenClassification.from_pretrained(model_name, config=config)\n",
    "\n",
    "print(model.get_labels())\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LYRG9o0HKv7_"
   },
   "source": [
    "Now, we load the task and the language adapter. For both adapters, we drop the adapter in the last layer following MAD-X 2.0. We then set both adapters as active adapters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "gZj6Jt-P7Ysd"
   },
   "outputs": [],
   "source": [
    "from transformers import AdapterConfig\n",
    "target_language = \"gn\" # choose any language that a bert-base-multilingual-cased language adapter is available for\n",
    "source_language = \"en\" # We support  \"en\", \"ja\", \"zh\", and \"ar\"\n",
    "\n",
    "adapter_config = AdapterConfig.load(\n",
    "    None,\n",
    "    leave_out=[11]\n",
    ")\n",
    "\n",
    "model.load_adapter(\n",
    "    \"wikiann/\" + source_language + \"@ukp\",\n",
    "    config=adapter_config,\n",
    "    load_as=\"wikiann\",\n",
    ")\n",
    "    \n",
    "lang_adapter_name = model.load_adapter(\n",
    "    target_language + \"/wiki@ukp\",\n",
    "    load_as=target_language,\n",
    "    leave_out=[11],\n",
    ")\n",
    "# Set the adapters to be used in every forward pass\n",
    "model.set_active_adapters([lang_adapter_name, \"wikiann\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UyOE0neNKDmr"
   },
   "source": [
    "Next, we can download the dataset and initialize the trainings arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "o-SUUa367TBr",
    "outputId": "97840fcc-2f68-42e0-cb81-a20843f84584"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset wikiann (/root/.cache/huggingface/datasets/wikiann/gn/1.1.0/c0a0280cc1c835e2bb7db29f43ef89c8ea30e145b78c1ba2746d709fae3da112)\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import TrainingArguments\n",
    "\n",
    "datasets = load_dataset('wikiann', target_language)\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    per_device_eval_batch_size=64,\n",
    "    do_predict=True,\n",
    "    output_dir=\"ner_models/madx/\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DQ4-v6BuKJzl"
   },
   "source": [
    "This method is taken from the example script 'run_ner.py'. It prepares the input tokens such that they are tokenized by the correct tokenizer and the labels are adapted to the new tokenization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "bbfMD2EM8OAs"
   },
   "outputs": [],
   "source": [
    "# This method is adapted from the huggingface transformers run_ner.py example script\n",
    "# Tokenize all texts and align the labels with them.\n",
    "def tokenize_and_align_labels(examples):\n",
    "    text_column_name = \"tokens\"\n",
    "    label_column_name = \"ner_tags\"\n",
    "    tokenized_inputs = tokenizer(\n",
    "        examples[text_column_name],\n",
    "        padding=False,\n",
    "        truncation=True,\n",
    "        # We use this argument because the texts in our dataset are lists of words (with a label for each word).\n",
    "        is_split_into_words=True,\n",
    "    )\n",
    "    labels = []\n",
    "    for i, label in enumerate(examples[label_column_name]):\n",
    "        word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
    "        previous_word_idx = None\n",
    "        label_ids = []\n",
    "        for word_idx in word_ids:\n",
    "            # Special tokens have a word id that is None. We set the label to -100 so they are automatically\n",
    "            # ignored in the loss function.\n",
    "            if word_idx is None:\n",
    "                label_ids.append(-100)\n",
    "            # We set the label for the first token of each word.\n",
    "            elif word_idx != previous_word_idx:\n",
    "                label_ids.append(label[word_idx])\n",
    "            # For the other tokens in a word, we set the label to either the current label or -100, depending on\n",
    "            # the label_all_tokens flag.\n",
    "            else:\n",
    "                label_ids.append(-100)\n",
    "            previous_word_idx = word_idx\n",
    "\n",
    "        labels.append(label_ids)\n",
    "    tokenized_inputs[\"labels\"] = labels\n",
    "    return tokenized_inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lBiIMcWOKjih"
   },
   "source": [
    "We apply the previous method to the test dataset to prepare it for prediction. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xrnTO-hV8pRc",
    "outputId": "274359ac-9304-4701-cf92-9780dd861688"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/wikiann/gn/1.1.0/c0a0280cc1c835e2bb7db29f43ef89c8ea30e145b78c1ba2746d709fae3da112/cache-2d9847f420f61d2c.arrow\n"
     ]
    }
   ],
   "source": [
    "from transformers import DataCollatorForTokenClassification\n",
    "test_dataset = datasets[\"test\"]\n",
    "test_dataset = test_dataset.map(\n",
    "    tokenize_and_align_labels,\n",
    "    batched=True,\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForTokenClassification(tokenizer,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rR8_a-MILFLZ"
   },
   "source": [
    "We use HuggingFace's `Trainer` class to evaluate zero-shot transfer on the WikiAnn test dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 67,
     "referenced_widgets": [
      "7bf73b5dd83b4fcab81abaa7af4b795b",
      "57894087d9434abfa908e0db6ff9a004",
      "28d8969a53c0442f9f3692da6eda74bc",
      "e29f6202600948c1af646921940261d6",
      "c2a748c71bad475393a4786c63f1a977",
      "0df8aeb8bbdc4a81b7f1197ec1d5a436",
      "783b944e92374bbbb755f943ed6c2f1a",
      "29522662989f4b2fbb24c07337d68717"
     ]
    },
    "id": "H07g_-if9xb0",
    "outputId": "4b08082c-76bb-4dd5-a738-422288fabb0d"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7bf73b5dd83b4fcab81abaa7af4b795b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2482.0, style=ProgressStyle(description…"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from transformers import TrainingArguments, AdapterTrainer, EvalPrediction\n",
    "from datasets import load_metric\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# Metrics\n",
    "metric = load_metric(\"seqeval\")\n",
    "\n",
    "def compute_metrics(p):\n",
    "    predictions, labels = p\n",
    "    predictions = np.argmax(predictions, axis=2)\n",
    "    label_list = id_2_label\n",
    "\n",
    "    # Remove ignored index (special tokens)\n",
    "    true_predictions = [\n",
    "        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "    true_labels = [\n",
    "        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "\n",
    "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
    "    return {\n",
    "        \"precision\": results[\"overall_precision\"],\n",
    "        \"recall\": results[\"overall_recall\"],\n",
    "        \"f1\": results[\"overall_f1\"],\n",
    "        \"accuracy\": results[\"overall_accuracy\"],\n",
    "    }\n",
    "\n",
    "trainer = AdapterTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=None,\n",
    "    eval_dataset=test_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F5lKBoZ9LPnb"
   },
   "source": [
    "Finally we can predict the labels for the test set and evaluate he predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 313
    },
    "id": "HAxswwN3hmyw",
    "outputId": "a5914764-a126-4667-a336-117b876d4891"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='2' max='2' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2/2 00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_accuracy': 0.7981220657276995,\n",
       " 'eval_f1': 0.5714285714285714,\n",
       " 'eval_loss': 0.9542072415351868,\n",
       " 'eval_mem_cpu_alloc_delta': 10735616,\n",
       " 'eval_mem_cpu_peaked_delta': 0,\n",
       " 'eval_mem_gpu_alloc_delta': 0,\n",
       " 'eval_mem_gpu_peaked_delta': 352350208,\n",
       " 'eval_precision': 0.4642857142857143,\n",
       " 'eval_recall': 0.7428571428571429,\n",
       " 'eval_runtime': 2.0225,\n",
       " 'eval_samples_per_second': 49.444,\n",
       " 'init_mem_cpu_alloc_delta': 842862592,\n",
       " 'init_mem_cpu_peaked_delta': 0,\n",
       " 'init_mem_gpu_alloc_delta': 754506240,\n",
       " 'init_mem_gpu_peaked_delta': 0}"
      ]
     },
     "execution_count": 8,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Wikiann.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "0df8aeb8bbdc4a81b7f1197ec1d5a436": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "28d8969a53c0442f9f3692da6eda74bc": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Downloading: ",
      "description_tooltip": null,
      "layout": "IPY_MODEL_0df8aeb8bbdc4a81b7f1197ec1d5a436",
      "max": 2482,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_c2a748c71bad475393a4786c63f1a977",
      "value": 2482
     }
    },
    "29522662989f4b2fbb24c07337d68717": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "57894087d9434abfa908e0db6ff9a004": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "783b944e92374bbbb755f943ed6c2f1a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "7bf73b5dd83b4fcab81abaa7af4b795b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_28d8969a53c0442f9f3692da6eda74bc",
       "IPY_MODEL_e29f6202600948c1af646921940261d6"
      ],
      "layout": "IPY_MODEL_57894087d9434abfa908e0db6ff9a004"
     }
    },
    "c2a748c71bad475393a4786c63f1a977": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "e29f6202600948c1af646921940261d6": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_29522662989f4b2fbb24c07337d68717",
      "placeholder": "​",
      "style": "IPY_MODEL_783b944e92374bbbb755f943ed6c2f1a",
      "value": " 6.34k/? [00:00&lt;00:00, 7.43kB/s]"
     }
    }
   }
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}