{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b4d4526c",
   "metadata": {},
   "source": [
    "# Few-Bit: Fine-tuning RoBERTa on GLUE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cab74b9",
   "metadata": {},
   "source": [
    "Based on HuggingFace's tutorial on [\"Fune-tuning on Classification Tasks\"][1]\n",
    "and [pre-trained RoBERTa][2] model.\n",
    "\n",
    "[1]: https://huggingface.co/docs/transformers/notebooks\n",
    "[2]: https://huggingface.co/roberta-base"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "623986ee",
   "metadata": {
    "id": "kTCFado4IrIc"
   },
   "source": [
    "The GLUE Benchmark is a group of nine classification tasks on sentences or\n",
    "pairs of sentences which are\n",
    "\n",
    "- [CoLA][1] (abbrv. _Corpus of Linguistic Acceptability_) Determine if a\n",
    "  sentence is grammatically correct or not.is a  dataset containing sentences\n",
    "  labeled grammatically correct or not.\n",
    "- [MNLI][2] (abbrv. _Multi-Genre Natural Language Inference_) Determine if a\n",
    "  sentence entails, contradicts or is unrelated to a given hypothesis. (This\n",
    "  dataset has two versions, one with the validation and test set coming from\n",
    "  the same distribution, another called mismatched where the validation and\n",
    "  test use out-of-domain data.)\n",
    "- [MRPC][3] (abbrv. _Microsoft Research Paraphrase Corpus_) Determine if two\n",
    "  sentences are paraphrases from one another or not.\n",
    "- [QNLI][4] (abbrv. _Question-answering Natural Language Inference_)\n",
    "  Determine if the answer to a question is in the second sentence or not.\n",
    "- [QQP][5] (abbrv. _Quora Question Pairs2_) Determine if two questions are\n",
    "  semantically equivalent or not.\n",
    "- [RTE][6] (abbrv. _Recognizing Textual Entailment_) Determine if a sentence\n",
    "  entails a given hypothesis or not.\n",
    "- [SST-2][7] (abbrv. _Stanford Sentiment Treebank_) Determine if the sentence\n",
    "  has a positive or negative sentiment.\n",
    "- [STS-B][8] (abbrv. _Semantic Textual Similarity Benchmark_) Determine the\n",
    "  similarity of two sentences with a score from 1 to 5.\n",
    "- [WNLI][9] (abbrv. _Winograd Natural Language Inference_) Determine if a\n",
    "  sentence with an anonymous pronoun and a sentence with this pronoun\n",
    "  replaced are entailed or not.\n",
    "\n",
    "[1]: https://nyu-mll.github.io/CoLA/\n",
    "[2]: https://arxiv.org/abs/1704.05426\n",
    "[3]: https://www.microsoft.com/en-us/download/details.aspx?id=52398\n",
    "[4]: https://rajpurkar.github.io/SQuAD-explorer/\n",
    "[5]: https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs\n",
    "[6]: https://aclweb.org/aclwiki/Recognizing_Textual_Entailment\n",
    "[7]: https://nlp.stanford.edu/sentiment/index.html\n",
    "[8]: http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark\n",
    "[9]: https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html\n",
    "\n",
    "import builtins\n",
    "\n",
    "from argparse import ArgumentParser\n",
    "from functools import partial\n",
    "from os import environ, makedirs\n",
    "from pathlib import Path\n",
    "from typing import Optional\n",
    "\n",
    "Force using of fixed CUDA device.\n",
    "if 'CUDA_VISIBLE_DEVICES' not in environ:\n",
    "    environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb2ec709",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, load_metric\n",
    "from torch import manual_seed\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from transformers import (RobertaTokenizerFast as Tokenizer,\n",
    "                          RobertaForSequenceClassification as Model,\n",
    "                          Trainer, TrainerCallback, TrainingArguments)\n",
    "from transformers.integrations import TensorBoardCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5611c9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c908fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 0x12c946425095e587\n",
    "TASK = 'cola'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce6d621",
   "metadata": {},
   "outputs": [],
   "source": [
    "CACHE_DIR = Path('~/.cache/fewbit').expanduser()\n",
    "DATA_DIR = Path('../data/huggingface')\n",
    "LOG_DIR = Path('../log')\n",
    "MODEL_DIR = Path('../model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "888ede9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_TO_KEYS = {\n",
    "    'cola': ('sentence', None),\n",
    "    'mnli': ('premise', 'hypothesis'),\n",
    "    'mnli-mm': ('premise', 'hypothesis'),\n",
    "    'mrpc': ('sentence1', 'sentence2'),\n",
    "    'qnli': ('question', 'sentence'),\n",
    "    'qqp': ('question1', 'question2'),\n",
    "    'rte': ('sentence1', 'sentence2'),\n",
    "    'sst2': ('sentence', None),\n",
    "    'stsb': ('sentence1', 'sentence2'),\n",
    "    'wnli': ('sentence1', 'sentence2'),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692d6eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_TO_HYPERPARAMS = {\n",
    "    'cola': (16, 1e-5),\n",
    "    'mnli': (16, 1e-5),\n",
    "    'mnli-mm': (16, 1e-5),\n",
    "    'mrpc': (16, 1e-5),\n",
    "    'qnli': (32, 1e-5),\n",
    "    'qqp': (32, 1e-5),\n",
    "    'rte': (16, 2e-5),\n",
    "    'sst2': (32, 2e-5),\n",
    "    'stsb': (16, 1e-5),\n",
    "    'wnli': (32, 1e-5),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ed7266e",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = ArgumentParser()\n",
    "\n",
    "parser.add_argument('-c', '--cache-dir',\n",
    "                    default=CACHE_DIR,\n",
    "                    type=Path,\n",
    "                    help='Directory to cache or original dataset files.')\n",
    "\n",
    "parser.add_argument('-d', '--data-dir',\n",
    "                    default=DATA_DIR,\n",
    "                    type=Path,\n",
    "                    help='Directory to cache preprocessed dataset files.')\n",
    "\n",
    "parser.add_argument('-l', '--log-dir',\n",
    "                    default=LOG_DIR,\n",
    "                    type=Path,\n",
    "                    help='Directory for TensorBoard logs.')\n",
    "\n",
    "parser.add_argument('-m', '--model-dir',\n",
    "                    default=MODEL_DIR,\n",
    "                    type=Path,\n",
    "                    help='Directory to save checkpoint files.')\n",
    "\n",
    "parser.add_argument('-n', '--num-bits',\n",
    "                    default=None,\n",
    "                    type=int,\n",
    "                    help='Directory to save checkpoint files.')\n",
    "\n",
    "parser.add_argument('-s', '--seed',\n",
    "                    default=SEED,\n",
    "                    type=int,\n",
    "                    help='Random seed for reproducibility.')\n",
    "\n",
    "parser.add_argument('task',\n",
    "                    default=TASK,\n",
    "                    choices=sorted(TASK_TO_HYPERPARAMS),\n",
    "                    nargs='?',\n",
    "                    help='GLUE task to learn.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09e63da5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metric(task, metric, inputs):\n",
    "    predictions, references = inputs\n",
    "    if task != 'stsb':\n",
    "        predictions = predictions.argmax(axis=1)\n",
    "    else:\n",
    "        predictions = predictions[..., 0]\n",
    "    return metric.compute(predictions=predictions, references=references)\n",
    "\n",
    "\n",
    "def preprocess(tokenizer, lhs, rhs, sample):\n",
    "    if rhs is None:\n",
    "        args = (sample[lhs],)\n",
    "    else:\n",
    "        args = (sample[lhs], sample[rhs])\n",
    "    return tokenizer(*args,\n",
    "                     max_length=512,\n",
    "                     padding=True,\n",
    "                     truncation=True,\n",
    "                     return_tensors='np')\n",
    "\n",
    "\n",
    "def setup(task: str,\n",
    "          cache_dir: Path = Path('cache'),\n",
    "          data_dir: Path = Path('data'),\n",
    "          model_dir: Path = Path('model'),\n",
    "          callback: Optional[TrainerCallback] = None):\n",
    "    # Load and configure model output head.\n",
    "    if task in ('mnli', 'mnli-mm'):\n",
    "        num_labels = 3\n",
    "    elif task == 'stsb':\n",
    "        num_labels = 1\n",
    "    else:\n",
    "        num_labels = 2\n",
    "    model_path = 'roberta-base'\n",
    "    model = Model.from_pretrained(model_path, num_labels=num_labels)\n",
    "\n",
    "    # Load tokenizer from checkpoint.\n",
    "    tokenizer = Tokenizer.from_pretrained(model_path)\n",
    "\n",
    "    # Make dataset preprocessor.\n",
    "    keys = TASK_TO_KEYS[task]\n",
    "    func = partial(preprocess, tokenizer, *keys)\n",
    "\n",
    "    # Load and preprocess dataset.\n",
    "    dataset_path = 'glue'\n",
    "    dataset_name = 'mnli' if task == 'mnli-mm' else task\n",
    "    dataset = load_dataset(dataset_path, dataset_name, cache_dir=str(data_dir))\n",
    "    dataset_cache = {key: str(cache_dir / f'glue-{task}-{key}.arrow')\n",
    "                     for key in dataset.keys()}\n",
    "    dataset_encoded = dataset.map(func,\n",
    "                                  batched=True,\n",
    "                                  cache_file_names=dataset_cache)\n",
    "\n",
    "    # Load dataset metric.\n",
    "    metric = load_metric(dataset_path, dataset_name)\n",
    "    metric_compute = partial(compute_metric, task, metric)\n",
    "\n",
    "    # Pick right evaluation metric.\n",
    "    eval_metric_name = 'accuracy'\n",
    "    if task == 'cola':\n",
    "        eval_metric_name = 'matthews_correlation'\n",
    "    elif task == 'stsb':\n",
    "        eval_metric_name = 'pearson'\n",
    "\n",
    "    # Pick right dataset for train/evaluation stage.\n",
    "    dataset_train = dataset_encoded['train']\n",
    "    dataset_eval = dataset_encoded.get('validation')\n",
    "    if task == 'mnli-mm':\n",
    "        dataset_eval = dataset_encoded['validation_mismatched']\n",
    "    elif task == 'mnli':\n",
    "        dataset_eval = dataset_encoded['validation_matched']\n",
    "\n",
    "    # Get hyperparameters from task name.\n",
    "    bs, lr = TASK_TO_HYPERPARAMS[task]\n",
    "\n",
    "    # Make 6% of total steps as warm up steps.\n",
    "    noepoches = 10\n",
    "    warmup_steps = int(0.06 * len(dataset_train) * noepoches / bs)\n",
    "\n",
    "    # Initialize training driver.\n",
    "    args = TrainingArguments(output_dir=str(model_dir / f'glue-{task}'),\n",
    "                             save_strategy='epoch',\n",
    "                             evaluation_strategy='epoch',\n",
    "                             per_device_train_batch_size=bs,\n",
    "                             per_device_eval_batch_size=bs,\n",
    "                             num_train_epochs=noepoches,\n",
    "                             load_best_model_at_end=True,\n",
    "                             metric_for_best_model=eval_metric_name,\n",
    "                             logging_strategy='epoch',\n",
    "                             log_level='warning',\n",
    "                             learning_rate=lr,\n",
    "                             weight_decay=0.1,\n",
    "                             adam_beta1=0.9,\n",
    "                             adam_beta2=0.98,\n",
    "                             adam_epsilon=1e-6,\n",
    "                             lr_scheduler_type='polynomial',\n",
    "                             warmup_steps=warmup_steps,\n",
    "                             push_to_hub=False)\n",
    "\n",
    "    trainer = Trainer(model=model.to(DEVICE),\n",
    "                      args=args,\n",
    "                      train_dataset=dataset_train,\n",
    "                      eval_dataset=dataset_eval,\n",
    "                      tokenizer=tokenizer,\n",
    "                      compute_metrics=metric_compute,\n",
    "                      callbacks=[callback])\n",
    "\n",
    "    return trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ecbee2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(task: str, cache_dir: Path, data_dir: Path, log_dir: Path,\n",
    "          model_dir: Path, num_bits: Optional[int], seed: int):\n",
    "    if num_bits:\n",
    "        # NOTE Monkey patching of HuggingFace's transformers v4.12.5 in order\n",
    "        # to replace standard GeLU with our 3-bits GeLU approximation. Other\n",
    "        # possible solution is subclassing model and configuration via model\n",
    "        # config.\n",
    "        #\n",
    "        # TODO Now, we can use fewbit.util.map_module to transform RoBERTa\n",
    "        # model and replace with our GELU.\n",
    "        import fewbit.functional as F\n",
    "        import transformers.activations\n",
    "        gelu = partial(F.gelu, bits=num_bits)\n",
    "        transformers.activations.ACT2FN['gelu'] = gelu\n",
    "        transformers.activations.gelu = gelu\n",
    "\n",
    "    makedirs(cache_dir, exist_ok=True)\n",
    "    makedirs(log_dir, exist_ok=True)\n",
    "    makedirs(model_dir, exist_ok=True)\n",
    "\n",
    "    manual_seed(seed)\n",
    "\n",
    "    tensorboard_sm = SummaryWriter(log_dir / task)\n",
    "    tensorboard_cb = TensorBoardCallback(tensorboard_sm)\n",
    "\n",
    "    trainer = setup(task, cache_dir, data_dir, model_dir, tensorboard_cb)\n",
    "    trainer.train()\n",
    "\n",
    "    tensorboard_sm.flush()\n",
    "    tensorboard_sm.close()\n",
    "\n",
    "    return trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbc53c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check we are run by IPytton kernel.\n",
    "if getattr(builtins, '__IPYTHON__', False):\n",
    "    args = parser.parse_args(args=[])\n",
    "else:\n",
    "    args = parser.parse_args()\n",
    "\n",
    "# Run training finally!\n",
    "train(**args.__dict__)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
