{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "< a\n",
    "href = \"https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb\"\n",
    "target = \"_parent\" > < img\n",
    "src = \"https://colab.research.google.com/assets/colab-badge.svg\"\n",
    "alt = \"Open In Colab\" / > </ a >"
   ],
   "id": "c2b23870bed9245e"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Introduction\n",
    "\n",
    "In this notebook, we are going to fine-tune the LayoutLM model by Microsoft Research on the [FUNSD](https://guillaumejaume.github.io/FUNSD/) dataset, which is a collection of annotated form documents. The goal of our model is to learn the annotations of a number of labels (\"question\", \"answer\", \"header\" and \"other\") on those forms, such that it can be used to annotate unseen forms in the future.\n",
    "\n",
    "* Original LayoutLM paper: https://arxiv.org/abs/1912.13318\n",
    "\n",
    "* Original FUNSD paper: https://arxiv.org/abs/1905.13538\n"
   ],
   "id": "8bdadcf9ea4de59f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Install libraries\n",
    "\n",
    "Currently you have to first install the `unilm` package, and then the `transformers` package (which updates the outdated `transformers` package that is included in the `unilm` package). The reason we also install the `unilm` package is because we need its preprocessing files. I've forked it, and removed some statements which introduced some issues."
   ],
   "id": "2d63baefdfbb7bd3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# ! rm -r unilm\n",
    "# ! pip install unilm"
   ],
   "id": "c8f8d275365f4521"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Getting the data\n",
    "\n",
    "Here we download the data of the [FUNSD dataset](https://guillaumejaume.github.io/FUNSD/) from the web. This results in a directory called \"data\" being created, which has 2 subdirectories, one for training and one for testing. Each of those has 2 subdirectories in turn, one containing the images as png files and one containing the annotations in json format."
   ],
   "id": "5cc8a78d82074003"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# ! wget https://guillaumejaume.github.io/FUNSD/dataset.zip\n",
    "# ! unzip dataset.zip && mv dataset data && rm -rf dataset.zip __MACOSX"
   ],
   "id": "11b0b0a63d5b14f0"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Let's take a look at a training example. For this, we are going to use PIL (Python Image Library).",
   "id": "4872c49115baead0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from PIL import Image, ImageDraw, ImageFont\n",
    "import os\n",
    "\n",
    "base_path = \"/home/sourab/temp/data/dataset\"\n",
    "\n",
    "image = Image.open(os.path.join(base_path, \"training_data/images/0000971160.png\"))\n",
    "image = image.convert(\"RGB\")\n",
    "image"
   ],
   "id": "a30b57478c45694b"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Now let's plot its corresponding annotations. Basically, if you type `data['form']`, you get a list of all general annotations. Each general annotation has a label, a bounding box, and one or more words, which in also have their own bounding box. The bounding boxes are in [xleft, ytop, xright, ybottom] format.\n",
    " "
   ],
   "id": "8d7ffbe8076a8bd9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import json\n",
    "\n",
    "with open(os.path.join(base_path, \"training_data/annotations/0000971160.json\")) as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "for annotation in data[\"form\"]:\n",
    "    print(annotation)"
   ],
   "id": "d033488a5f96ee91"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "The PIL library has a handy ImageDraw module, which -you guessed it- allows to draw things (such as rectangles) on an image:",
   "id": "19d855a40766ea0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "draw = ImageDraw.Draw(image, \"RGBA\")\n",
    "\n",
    "font = ImageFont.load_default()\n",
    "\n",
    "label2color = {\"question\": \"blue\", \"answer\": \"green\", \"header\": \"orange\", \"other\": \"violet\"}\n",
    "\n",
    "for annotation in data[\"form\"]:\n",
    "    label = annotation[\"label\"]\n",
    "    general_box = annotation[\"box\"]\n",
    "    draw.rectangle(general_box, outline=label2color[label], width=2)\n",
    "    draw.text((general_box[0] + 10, general_box[1] - 10), label, fill=label2color[label], font=font)\n",
    "    words = annotation[\"words\"]\n",
    "    for word in words:\n",
    "        box = word[\"box\"]\n",
    "        draw.rectangle(box, outline=label2color[label], width=1)\n",
    "\n",
    "image"
   ],
   "id": "89f8d95722121ec9"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Preprocessing the data\n",
    "\n",
    "Next, we need to turn the document images into individual tokens and corresponding labels (BIOES format, see further). We do this both for the training and test datasets. Make sure to run this from the `/content` directory:"
   ],
   "id": "31afa62425bc01b8"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# ! python unilm/layoutlm/examples/seq_labeling/preprocess.py --data_dir data/dataset/training_data/annotations \\\n",
    "#                                                       --data_split train \\\n",
    "#                                                       --output_dir data \\\n",
    "#                                                       --model_name_or_path microsoft/layoutlm-base-uncased \\\n",
    "#                                                       --max_len 510\n",
    "\n",
    "# ! python unilm/layoutlm/examples/seq_labeling/preprocess.py --data_dir data/dataset/testing_data/annotations \\\n",
    "#                                                       --data_split test \\\n",
    "#                                                       --output_dir data \\\n",
    "#                                                       --model_name_or_path microsoft/layoutlm-base-uncased \\\n",
    "#                                                       --max_len 510"
   ],
   "id": "4cdc3cb54b7d1f69"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Next, we create a labels.txt file that contains the unique labels of the FUNSD dataset:",
   "id": "5d2dc30fecd2dd90"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "# ! cat data/train.txt | cut -d$'\\t' -f 2 | grep -v \"^$\"| sort | uniq > data/labels.txt",
   "id": "37e21a6820382423"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Define a PyTorch dataset\n",
    "\n",
    "First, we create a list containing the unique labels based on `data/labels.txt` (run this from the content directory):"
   ],
   "id": "17881c19f3036bbf"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from torch.nn import CrossEntropyLoss\n",
    "\n",
    "\n",
    "def get_labels(path):\n",
    "    with open(path, \"r\") as f:\n",
    "        labels = f.read().splitlines()\n",
    "    if \"O\" not in labels:\n",
    "        labels = [\"O\"] + labels\n",
    "    return labels\n",
    "\n",
    "\n",
    "labels = get_labels(\"data/labels.txt\")\n",
    "num_labels = len(labels)\n",
    "label_map = {i: label for i, label in enumerate(labels)}\n",
    "# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later\n",
    "pad_token_label_id = CrossEntropyLoss().ignore_index"
   ],
   "id": "d15bebe3fbca00a4"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We can see that the dataset uses the so-called BIOES annotation scheme to annotate the tokens. This means that a given token can be either at the beginning (B), inside (I), outside (O), at the end (E) or start (S) of a given entity. Entities include ANSWER, QUESTION, HEADER and OTHER: ",
   "id": "3fe4ba0728125e84"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "print(labels)",
   "id": "f8be9f1958183c7b"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Next, we can create a PyTorch dataset and corresponding dataloader (both for training and evaluation):",
   "id": "4843714f42ee2a07"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import logging\n",
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "\n",
    "class FunsdDataset(Dataset):\n",
    "    def __init__(self, args, tokenizer, labels, pad_token_label_id, mode):\n",
    "        if args.local_rank not in [-1, 0] and mode == \"train\":\n",
    "            torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache\n",
    "\n",
    "        # Load data features from cache or dataset file\n",
    "        cached_features_file = os.path.join(\n",
    "            args.data_dir,\n",
    "            \"cached_{}_{}_{}\".format(\n",
    "                mode,\n",
    "                list(filter(None, args.model_name_or_path.split(\"/\"))).pop(),\n",
    "                str(args.max_seq_length),\n",
    "            ),\n",
    "        )\n",
    "        if os.path.exists(cached_features_file) and not args.overwrite_cache:\n",
    "            logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
    "            features = torch.load(cached_features_file)\n",
    "        else:\n",
    "            logger.info(\"Creating features from dataset file at %s\", args.data_dir)\n",
    "            examples = read_examples_from_file(args.data_dir, mode)\n",
    "            features = convert_examples_to_features(\n",
    "                examples,\n",
    "                labels,\n",
    "                args.max_seq_length,\n",
    "                tokenizer,\n",
    "                cls_token_at_end=bool(args.model_type in [\"xlnet\"]),\n",
    "                # xlnet has a cls token at the end\n",
    "                cls_token=tokenizer.cls_token,\n",
    "                cls_token_segment_id=2 if args.model_type in [\"xlnet\"] else 0,\n",
    "                sep_token=tokenizer.sep_token,\n",
    "                sep_token_extra=bool(args.model_type in [\"roberta\"]),\n",
    "                # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805\n",
    "                pad_on_left=bool(args.model_type in [\"xlnet\"]),\n",
    "                # pad on the left for xlnet\n",
    "                pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],\n",
    "                pad_token_segment_id=4 if args.model_type in [\"xlnet\"] else 0,\n",
    "                pad_token_label_id=pad_token_label_id,\n",
    "            )\n",
    "            # if args.local_rank in [-1, 0]:\n",
    "            # logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
    "            # torch.save(features, cached_features_file)\n",
    "\n",
    "        if args.local_rank == 0 and mode == \"train\":\n",
    "            torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache\n",
    "\n",
    "        self.features = features\n",
    "        # Convert to Tensors and build dataset\n",
    "        self.all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
    "        self.all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n",
    "        self.all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)\n",
    "        self.all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)\n",
    "        self.all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.features)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return (\n",
    "            self.all_input_ids[index],\n",
    "            self.all_input_mask[index],\n",
    "            self.all_segment_ids[index],\n",
    "            self.all_label_ids[index],\n",
    "            self.all_bboxes[index],\n",
    "        )\n",
    "\n",
    "\n",
    "class InputExample(object):\n",
    "    \"\"\"A single training/test example for token classification.\"\"\"\n",
    "\n",
    "    def __init__(self, guid, words, labels, boxes, actual_bboxes, file_name, page_size):\n",
    "        \"\"\"Constructs a InputExample.\n",
    "\n",
    "        Args:\n",
    "            guid: Unique id for the example.\n",
    "            words: list. The words of the sequence.\n",
    "            labels: (Optional) list. The labels for each word of the sequence. This should be\n",
    "            specified for train and dev examples, but not for test examples.\n",
    "        \"\"\"\n",
    "        self.guid = guid\n",
    "        self.words = words\n",
    "        self.labels = labels\n",
    "        self.boxes = boxes\n",
    "        self.actual_bboxes = actual_bboxes\n",
    "        self.file_name = file_name\n",
    "        self.page_size = page_size\n",
    "\n",
    "\n",
    "class InputFeatures(object):\n",
    "    \"\"\"A single set of features of data.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            input_ids,\n",
    "            input_mask,\n",
    "            segment_ids,\n",
    "            label_ids,\n",
    "            boxes,\n",
    "            actual_bboxes,\n",
    "            file_name,\n",
    "            page_size,\n",
    "    ):\n",
    "        assert (\n",
    "                0 <= all(boxes) <= 1000\n",
    "        ), \"Error with input bbox ({}): the coordinate value is not between 0 and 1000\".format(boxes)\n",
    "        self.input_ids = input_ids\n",
    "        self.input_mask = input_mask\n",
    "        self.segment_ids = segment_ids\n",
    "        self.label_ids = label_ids\n",
    "        self.boxes = boxes\n",
    "        self.actual_bboxes = actual_bboxes\n",
    "        self.file_name = file_name\n",
    "        self.page_size = page_size\n",
    "\n",
    "\n",
    "def read_examples_from_file(data_dir, mode):\n",
    "    file_path = os.path.join(data_dir, \"{}.txt\".format(mode))\n",
    "    box_file_path = os.path.join(data_dir, \"{}_box.txt\".format(mode))\n",
    "    image_file_path = os.path.join(data_dir, \"{}_image.txt\".format(mode))\n",
    "    guid_index = 1\n",
    "    examples = []\n",
    "    with open(file_path, encoding=\"utf-8\") as f, open(box_file_path, encoding=\"utf-8\") as fb, open(\n",
    "            image_file_path, encoding=\"utf-8\"\n",
    "    ) as fi:\n",
    "        words = []\n",
    "        boxes = []\n",
    "        actual_bboxes = []\n",
    "        file_name = None\n",
    "        page_size = None\n",
    "        labels = []\n",
    "        for line, bline, iline in zip(f, fb, fi):\n",
    "            if line.startswith(\"-DOCSTART-\") or line == \"\" or line == \"\\n\":\n",
    "                if words:\n",
    "                    examples.append(\n",
    "                        InputExample(\n",
    "                            guid=\"{}-{}\".format(mode, guid_index),\n",
    "                            words=words,\n",
    "                            labels=labels,\n",
    "                            boxes=boxes,\n",
    "                            actual_bboxes=actual_bboxes,\n",
    "                            file_name=file_name,\n",
    "                            page_size=page_size,\n",
    "                        )\n",
    "                    )\n",
    "                    guid_index += 1\n",
    "                    words = []\n",
    "                    boxes = []\n",
    "                    actual_bboxes = []\n",
    "                    file_name = None\n",
    "                    page_size = None\n",
    "                    labels = []\n",
    "            else:\n",
    "                splits = line.split(\"\\t\")\n",
    "                bsplits = bline.split(\"\\t\")\n",
    "                isplits = iline.split(\"\\t\")\n",
    "                assert len(splits) == 2\n",
    "                assert len(bsplits) == 2\n",
    "                assert len(isplits) == 4\n",
    "                assert splits[0] == bsplits[0]\n",
    "                words.append(splits[0])\n",
    "                if len(splits) > 1:\n",
    "                    labels.append(splits[-1].replace(\"\\n\", \"\"))\n",
    "                    box = bsplits[-1].replace(\"\\n\", \"\")\n",
    "                    box = [int(b) for b in box.split()]\n",
    "                    boxes.append(box)\n",
    "                    actual_bbox = [int(b) for b in isplits[1].split()]\n",
    "                    actual_bboxes.append(actual_bbox)\n",
    "                    page_size = [int(i) for i in isplits[2].split()]\n",
    "                    file_name = isplits[3].strip()\n",
    "                else:\n",
    "                    # Examples could have no label for mode = \"test\"\n",
    "                    labels.append(\"O\")\n",
    "        if words:\n",
    "            examples.append(\n",
    "                InputExample(\n",
    "                    guid=\"%s-%d\".format(mode, guid_index),\n",
    "                    words=words,\n",
    "                    labels=labels,\n",
    "                    boxes=boxes,\n",
    "                    actual_bboxes=actual_bboxes,\n",
    "                    file_name=file_name,\n",
    "                    page_size=page_size,\n",
    "                )\n",
    "            )\n",
    "    return examples\n",
    "\n",
    "\n",
    "def convert_examples_to_features(\n",
    "        examples,\n",
    "        label_list,\n",
    "        max_seq_length,\n",
    "        tokenizer,\n",
    "        cls_token_at_end=False,\n",
    "        cls_token=\"[CLS]\",\n",
    "        cls_token_segment_id=1,\n",
    "        sep_token=\"[SEP]\",\n",
    "        sep_token_extra=False,\n",
    "        pad_on_left=False,\n",
    "        pad_token=0,\n",
    "        cls_token_box=[0, 0, 0, 0],\n",
    "        sep_token_box=[1000, 1000, 1000, 1000],\n",
    "        pad_token_box=[0, 0, 0, 0],\n",
    "        pad_token_segment_id=0,\n",
    "        pad_token_label_id=-1,\n",
    "        sequence_a_segment_id=0,\n",
    "        mask_padding_with_zero=True,\n",
    "):\n",
    "    \"\"\"Loads a data file into a list of `InputBatch`s\n",
    "    `cls_token_at_end` define the location of the CLS token:\n",
    "        - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]\n",
    "        - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]\n",
    "    `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)\n",
    "    \"\"\"\n",
    "\n",
    "    label_map = {label: i for i, label in enumerate(label_list)}\n",
    "\n",
    "    features = []\n",
    "    for ex_index, example in enumerate(examples):\n",
    "        file_name = example.file_name\n",
    "        page_size = example.page_size\n",
    "        width, height = page_size\n",
    "        if ex_index % 10000 == 0:\n",
    "            logger.info(\"Writing example %d of %d\", ex_index, len(examples))\n",
    "\n",
    "        tokens = []\n",
    "        token_boxes = []\n",
    "        actual_bboxes = []\n",
    "        label_ids = []\n",
    "        for word, label, box, actual_bbox in zip(example.words, example.labels, example.boxes, example.actual_bboxes):\n",
    "            word_tokens = tokenizer.tokenize(word)\n",
    "            tokens.extend(word_tokens)\n",
    "            token_boxes.extend([box] * len(word_tokens))\n",
    "            actual_bboxes.extend([actual_bbox] * len(word_tokens))\n",
    "            # Use the real label id for the first token of the word, and padding ids for the remaining tokens\n",
    "            label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))\n",
    "\n",
    "        # Account for [CLS] and [SEP] with \"- 2\" and with \"- 3\" for RoBERTa.\n",
    "        special_tokens_count = 3 if sep_token_extra else 2\n",
    "        if len(tokens) > max_seq_length - special_tokens_count:\n",
    "            tokens = tokens[: (max_seq_length - special_tokens_count)]\n",
    "            token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]\n",
    "            actual_bboxes = actual_bboxes[: (max_seq_length - special_tokens_count)]\n",
    "            label_ids = label_ids[: (max_seq_length - special_tokens_count)]\n",
    "\n",
    "        # The convention in BERT is:\n",
    "        # (a) For sequence pairs:\n",
    "        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]\n",
    "        #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1\n",
    "        # (b) For single sequences:\n",
    "        #  tokens:   [CLS] the dog is hairy . [SEP]\n",
    "        #  type_ids:   0   0   0   0  0     0   0\n",
    "        #\n",
    "        # Where \"type_ids\" are used to indicate whether this is the first\n",
    "        # sequence or the second sequence. The embedding vectors for `type=0` and\n",
    "        # `type=1` were learned during pre-training and are added to the wordpiece\n",
    "        # embedding vector (and position vector). This is not *strictly* necessary\n",
    "        # since the [SEP] token unambiguously separates the sequences, but it makes\n",
    "        # it easier for the model to learn the concept of sequences.\n",
    "        #\n",
    "        # For classification tasks, the first vector (corresponding to [CLS]) is\n",
    "        # used as as the \"sentence vector\". Note that this only makes sense because\n",
    "        # the entire model is fine-tuned.\n",
    "        tokens += [sep_token]\n",
    "        token_boxes += [sep_token_box]\n",
    "        actual_bboxes += [[0, 0, width, height]]\n",
    "        label_ids += [pad_token_label_id]\n",
    "        if sep_token_extra:\n",
    "            # roberta uses an extra separator b/w pairs of sentences\n",
    "            tokens += [sep_token]\n",
    "            token_boxes += [sep_token_box]\n",
    "            actual_bboxes += [[0, 0, width, height]]\n",
    "            label_ids += [pad_token_label_id]\n",
    "        segment_ids = [sequence_a_segment_id] * len(tokens)\n",
    "\n",
    "        if cls_token_at_end:\n",
    "            tokens += [cls_token]\n",
    "            token_boxes += [cls_token_box]\n",
    "            actual_bboxes += [[0, 0, width, height]]\n",
    "            label_ids += [pad_token_label_id]\n",
    "            segment_ids += [cls_token_segment_id]\n",
    "        else:\n",
    "            tokens = [cls_token] + tokens\n",
    "            token_boxes = [cls_token_box] + token_boxes\n",
    "            actual_bboxes = [[0, 0, width, height]] + actual_bboxes\n",
    "            label_ids = [pad_token_label_id] + label_ids\n",
    "            segment_ids = [cls_token_segment_id] + segment_ids\n",
    "\n",
    "        input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "\n",
    "        # The mask has 1 for real tokens and 0 for padding tokens. Only real\n",
    "        # tokens are attended to.\n",
    "        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)\n",
    "\n",
    "        # Zero-pad up to the sequence length.\n",
    "        padding_length = max_seq_length - len(input_ids)\n",
    "        if pad_on_left:\n",
    "            input_ids = ([pad_token] * padding_length) + input_ids\n",
    "            input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask\n",
    "            segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids\n",
    "            label_ids = ([pad_token_label_id] * padding_length) + label_ids\n",
    "            token_boxes = ([pad_token_box] * padding_length) + token_boxes\n",
    "        else:\n",
    "            input_ids += [pad_token] * padding_length\n",
    "            input_mask += [0 if mask_padding_with_zero else 1] * padding_length\n",
    "            segment_ids += [pad_token_segment_id] * padding_length\n",
    "            label_ids += [pad_token_label_id] * padding_length\n",
    "            token_boxes += [pad_token_box] * padding_length\n",
    "\n",
    "        assert len(input_ids) == max_seq_length\n",
    "        assert len(input_mask) == max_seq_length\n",
    "        assert len(segment_ids) == max_seq_length\n",
    "        assert len(label_ids) == max_seq_length\n",
    "        assert len(token_boxes) == max_seq_length\n",
    "\n",
    "        if ex_index < 5:\n",
    "            logger.info(\"*** Example ***\")\n",
    "            logger.info(\"guid: %s\", example.guid)\n",
    "            logger.info(\"tokens: %s\", \" \".join([str(x) for x in tokens]))\n",
    "            logger.info(\"input_ids: %s\", \" \".join([str(x) for x in input_ids]))\n",
    "            logger.info(\"input_mask: %s\", \" \".join([str(x) for x in input_mask]))\n",
    "            logger.info(\"segment_ids: %s\", \" \".join([str(x) for x in segment_ids]))\n",
    "            logger.info(\"label_ids: %s\", \" \".join([str(x) for x in label_ids]))\n",
    "            logger.info(\"boxes: %s\", \" \".join([str(x) for x in token_boxes]))\n",
    "            logger.info(\"actual_bboxes: %s\", \" \".join([str(x) for x in actual_bboxes]))\n",
    "\n",
    "        features.append(\n",
    "            InputFeatures(\n",
    "                input_ids=input_ids,\n",
    "                input_mask=input_mask,\n",
    "                segment_ids=segment_ids,\n",
    "                label_ids=label_ids,\n",
    "                boxes=token_boxes,\n",
    "                actual_bboxes=actual_bboxes,\n",
    "                file_name=file_name,\n",
    "                page_size=page_size,\n",
    "            )\n",
    "        )\n",
    "    return features"
   ],
   "id": "1d74364da270388a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import LayoutLMTokenizer\n",
    "\n",
    "# from .unilm.layoutlm.data.funsd import FunsdDataset, InputFeatures\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
    "\n",
    "batch_size = 16\n",
    "args = {\n",
    "    \"local_rank\": -1,\n",
    "    \"overwrite_cache\": True,\n",
    "    \"data_dir\": \"/home/sourab/temp/data/\",\n",
    "    \"model_name_or_path\": \"microsoft/layoutlm-base-uncased\",\n",
    "    \"max_seq_length\": 512,\n",
    "    \"model_type\": \"layoutlm\",\n",
    "}\n",
    "\n",
    "\n",
    "# class to turn the keys of a dict into attributes (thanks Stackoverflow)\n",
    "class AttrDict(dict):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super(AttrDict, self).__init__(*args, **kwargs)\n",
    "        self.__dict__ = self\n",
    "\n",
    "\n",
    "args = AttrDict(args)\n",
    "\n",
    "tokenizer = LayoutLMTokenizer.from_pretrained(\"microsoft/layoutlm-base-uncased\")\n",
    "\n",
    "# the LayoutLM authors already defined a specific FunsdDataset, so we are going to use this here\n",
    "train_dataset = FunsdDataset(args, tokenizer, labels, pad_token_label_id, mode=\"train\")\n",
    "train_sampler = RandomSampler(train_dataset)\n",
    "train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)\n",
    "\n",
    "eval_dataset = FunsdDataset(args, tokenizer, labels, pad_token_label_id, mode=\"test\")\n",
    "eval_sampler = SequentialSampler(eval_dataset)\n",
    "eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size)"
   ],
   "id": "271621212676891"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "len(train_dataloader)",
   "id": "b6f39d1a386cbb98"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "len(eval_dataloader)",
   "id": "5c481f1e2ac44eb9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "batch = next(iter(train_dataloader))\n",
    "input_ids = batch[0][0]\n",
    "tokenizer.decode(input_ids)"
   ],
   "id": "e5c0e9d812c84642"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Define and fine-tune the model\n",
    "\n",
    "As this is a sequence labeling task, we are going to load `LayoutLMForTokenClassification` (the base sized model) from the hub. We are going to fine-tune it on a downstream task, namely FUNSD."
   ],
   "id": "49c53929acc6b6ac"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import get_peft_config, PeftModel, get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.TOKEN_CLS, inference_mode=False, r=16, lora_alpha=16, lora_dropout=0.1, bias=\"all\"\n",
    ")\n",
    "peft_config"
   ],
   "id": "388ad990fc6c972a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import LayoutLMForTokenClassification\n",
    "import torch\n",
    "from transformers import set_seed\n",
    "\n",
    "seed = 100\n",
    "set_seed(seed)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model = LayoutLMForTokenClassification.from_pretrained(\"microsoft/layoutlm-base-uncased\", num_labels=num_labels)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.to(device)"
   ],
   "id": "c64d23039deb0bc4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "print(model.model.layoutlm.encoder.layer[0].attention.self.query.weight)\n",
    "print(model.model.layoutlm.encoder.layer[0].attention.self.query.lora_A.weight)\n",
    "print(model.model.classifier.weight)"
   ],
   "id": "ec5bb9ab8d4be95f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Now we can start training:",
   "id": "1e0a3b113cf2a217"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import AdamW, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "\n",
    "num_train_epochs = 100\n",
    "\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0.06 * (len(train_dataloader) * num_train_epochs),\n",
    "    num_training_steps=(len(train_dataloader) * num_train_epochs),\n",
    ")\n",
    "\n",
    "global_step = 0\n",
    "\n",
    "t_total = len(train_dataloader) * num_train_epochs  # total number of training steps\n",
    "\n",
    "# put the model in training mode\n",
    "model.train()\n",
    "for epoch in range(num_train_epochs):\n",
    "    for batch in tqdm(train_dataloader, desc=\"Training\"):\n",
    "        input_ids = batch[0].to(device)\n",
    "        bbox = batch[4].to(device)\n",
    "        attention_mask = batch[1].to(device)\n",
    "        token_type_ids = batch[2].to(device)\n",
    "        labels = batch[3].to(device)\n",
    "\n",
    "        # forward pass\n",
    "        outputs = model(\n",
    "            input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels\n",
    "        )\n",
    "        loss = outputs.loss\n",
    "\n",
    "        # print loss every 100 steps\n",
    "        if global_step % 10 == 0:\n",
    "            print(f\"Loss after {global_step} steps: {loss.item()}\")\n",
    "\n",
    "        # backward pass to get the gradients\n",
    "        loss.backward()\n",
    "\n",
    "        # print(\"Gradients on classification head:\")\n",
    "        # print(model.classifier.weight.grad[6,:].sum())\n",
    "\n",
    "        # update\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        global_step += 1"
   ],
   "id": "1338eeacdc4348c9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import numpy as np\n",
    "from seqeval.metrics import (\n",
    "    classification_report,\n",
    "    f1_score,\n",
    "    precision_score,\n",
    "    recall_score,\n",
    ")\n",
    "\n",
    "eval_loss = 0.0\n",
    "nb_eval_steps = 0\n",
    "preds = None\n",
    "out_label_ids = None\n",
    "\n",
    "# put model in evaluation mode\n",
    "model.eval()\n",
    "for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n",
    "    with torch.no_grad():\n",
    "        input_ids = batch[0].to(device)\n",
    "        bbox = batch[4].to(device)\n",
    "        attention_mask = batch[1].to(device)\n",
    "        token_type_ids = batch[2].to(device)\n",
    "        labels = batch[3].to(device)\n",
    "\n",
    "        # forward pass\n",
    "        outputs = model(\n",
    "            input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels\n",
    "        )\n",
    "        # get the loss and logits\n",
    "        tmp_eval_loss = outputs.loss\n",
    "        logits = outputs.logits\n",
    "\n",
    "        eval_loss += tmp_eval_loss.item()\n",
    "        nb_eval_steps += 1\n",
    "\n",
    "        # compute the predictions\n",
    "        if preds is None:\n",
    "            preds = logits.detach().cpu().numpy()\n",
    "            out_label_ids = labels.detach().cpu().numpy()\n",
    "        else:\n",
    "            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)\n",
    "            out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)\n",
    "\n",
    "# compute average evaluation loss\n",
    "eval_loss = eval_loss / nb_eval_steps\n",
    "preds = np.argmax(preds, axis=2)\n",
    "\n",
    "out_label_list = [[] for _ in range(out_label_ids.shape[0])]\n",
    "preds_list = [[] for _ in range(out_label_ids.shape[0])]\n",
    "\n",
    "for i in range(out_label_ids.shape[0]):\n",
    "    for j in range(out_label_ids.shape[1]):\n",
    "        if out_label_ids[i, j] != pad_token_label_id:\n",
    "            out_label_list[i].append(label_map[out_label_ids[i][j]])\n",
    "            preds_list[i].append(label_map[preds[i][j]])\n",
    "\n",
    "results = {\n",
    "    \"loss\": eval_loss,\n",
    "    \"precision\": precision_score(out_label_list, preds_list),\n",
    "    \"recall\": recall_score(out_label_list, preds_list),\n",
    "    \"f1\": f1_score(out_label_list, preds_list),\n",
    "}\n",
    "print(results)"
   ],
   "id": "189e2a98e6e3c453"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model.print_trainable_parameters()",
   "id": "5389ff66666b04a1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model.save_pretrained(\"peft_layoutlm\")",
   "id": "280de51988086821"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "!du -h \"peft_layoutlm/adapter_model.bin\"",
   "id": "8821f152f0d416bb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "d95e5c53a8cf9dc"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
