{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load modules, mainly huggingface basic model handlers.\n",
    "# Make sure you install huggingface and other packages properly.\n",
    "from collections import Counter\n",
    "import json\n",
    "\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.metrics import matthews_corrcoef\n",
    "\n",
    "import logging\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "import os\n",
    "os.environ[\"TRANSFORMERS_CACHE\"] = \"../huggingface_cache/\" # Not overload common dir \n",
    "                                                           # if run in shared resources.\n",
    "\n",
    "import random\n",
    "import sys\n",
    "from typing import Optional\n",
    "import torch\n",
    "import argparse\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from datasets import load_dataset, load_metric\n",
    "from datasets import Dataset\n",
    "from datasets import DatasetDict\n",
    "from datasets import ClassLabel\n",
    "\n",
    "import transformers\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoModelForTokenClassification,\n",
    "    AutoTokenizer,\n",
    "    EvalPrediction,\n",
    "    HfArgumentParser,\n",
    "    PretrainedConfig,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    set_seed,\n",
    "    EarlyStoppingCallback,\n",
    "    DataCollatorForTokenClassification,\n",
    "    PreTrainedTokenizerFast\n",
    ")\n",
    "from transformers.trainer_utils import is_main_process, EvaluationStrategy\n",
    "from functools import partial\n",
    "\n",
    "from vocab_mismatch_utils import *\n",
    "basic_tokenizer = ModifiedBasicTokenizer()\n",
    "from models.modeling_bert import CustomerizedBertForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_training_args(args, inoculation_step):\n",
    "    training_args = TrainingArguments(\"tmp_trainer\")\n",
    "    training_args.no_cuda = args.no_cuda\n",
    "    training_args.seed = args.seed\n",
    "    training_args.do_train = args.do_train\n",
    "    training_args.do_eval = args.do_eval\n",
    "    training_args.output_dir = os.path.join(args.output_dir, str(inoculation_step)+\"-sample\")\n",
    "    training_args.evaluation_strategy = args.evaluation_strategy # evaluation is done after each epoch\n",
    "    training_args.metric_for_best_model = args.metric_for_best_model\n",
    "    training_args.greater_is_better = args.greater_is_better\n",
    "    training_args.logging_dir = args.logging_dir\n",
    "    training_args.task_name = args.task_name\n",
    "    training_args.learning_rate = args.learning_rate\n",
    "    training_args.per_device_train_batch_size = args.per_device_train_batch_size\n",
    "    training_args.per_device_eval_batch_size = args.per_device_eval_batch_size\n",
    "    training_args.num_train_epochs = args.num_train_epochs # this is the maximum num_train_epochs, we set this to be 100.\n",
    "    training_args.eval_steps = args.eval_steps\n",
    "    training_args.logging_steps = args.logging_steps\n",
    "    training_args.load_best_model_at_end = args.load_best_model_at_end\n",
    "    if args.save_total_limit != -1:\n",
    "        # only set if it is specified\n",
    "        training_args.save_total_limit = args.save_total_limit\n",
    "    import datetime\n",
    "    date_time = \"{}-{}\".format(datetime.datetime.now().month, datetime.datetime.now().day)\n",
    "    run_name = \"{0}_{1}_{2}_{3}_mlen_{4}_lr_{5}_seed_{6}_metrics_{7}\".format(\n",
    "        args.run_name,\n",
    "        args.task_name,\n",
    "        args.model_type,\n",
    "        date_time,\n",
    "        args.max_seq_length,\n",
    "        args.learning_rate,\n",
    "        args.seed,\n",
    "        args.metric_for_best_model\n",
    "    )\n",
    "    training_args.run_name = run_name\n",
    "    training_args_dict = training_args.to_dict()\n",
    "    # for PR\n",
    "    _n_gpu = training_args_dict[\"_n_gpu\"]\n",
    "    del training_args_dict[\"_n_gpu\"]\n",
    "    training_args_dict[\"n_gpu\"] = _n_gpu\n",
    "    HfParser = HfArgumentParser((TrainingArguments))\n",
    "    training_args = HfParser.parse_dict(training_args_dict)[0]\n",
    "\n",
    "    if args.model_path == \"\":\n",
    "        args.model_path = args.model_type\n",
    "        if args.model_type == \"\":\n",
    "            assert False # you have to provide one of them.\n",
    "    # Set seed before initializing model.\n",
    "    set_seed(training_args.seed)\n",
    "\n",
    "    # Setup logging\n",
    "    logging.basicConfig(\n",
    "        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
    "        datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,\n",
    "    )\n",
    "\n",
    "    # Log on each process the small summary:\n",
    "    logger.warning(\n",
    "        f\"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\"\n",
    "        + f\"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}\"\n",
    "    )\n",
    "\n",
    "    # Set the verbosity to info of the Transformers logger (on main process only):\n",
    "    if is_main_process(training_args.local_rank):\n",
    "        transformers.utils.logging.set_verbosity_info()\n",
    "        transformers.utils.logging.enable_default_handler()\n",
    "        transformers.utils.logging.enable_explicit_format()\n",
    "    logger.info(f\"Training/evaluation parameters {training_args}\")\n",
    "    return training_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    parser = argparse.ArgumentParser()\n",
    "\n",
    "    ## Required parameters\n",
    "    parser.add_argument(\"--wandb_proj_name\",\n",
    "                        default=\"\",\n",
    "                        type=str)\n",
    "    parser.add_argument(\"--task_name\",\n",
    "                        default=\"conll2003\",\n",
    "                        type=str)\n",
    "    parser.add_argument(\"--token_type\",\n",
    "                        default=\"pos\",\n",
    "                        type=str)\n",
    "    parser.add_argument(\"--run_name\",\n",
    "                        default=\"inoculation-run\",\n",
    "                        type=str)\n",
    "    parser.add_argument(\"--inoculation_data_path\",\n",
    "                        default=\"../data-files/sst-tenary/sst-tenary-train.tsv\",\n",
    "                        type=str,\n",
    "                        help=\"The input data dir. Should contain the .tsv files (or other data files) for the task.\")\n",
    "    parser.add_argument(\"--eval_data_path\",\n",
    "                        default=\"../data-files/sst-tenary/sst-tenary-dev.tsv\",\n",
    "                        type=str,\n",
    "                        help=\"The input data dir. Should contain the .tsv files (or other data files) for the task.\")\n",
    "    parser.add_argument(\"--model_path\",\n",
    "                        default=\"\",\n",
    "                        type=str,\n",
    "                        help=\"The pretrained model binary file.\")\n",
    "    parser.add_argument(\"--model_type\",\n",
    "                        default=\"bert-base-uncased\",\n",
    "                        type=str,\n",
    "                        help=\"The pretrained model binary file.\")\n",
    "    parser.add_argument(\"--do_train\",\n",
    "                        default=True,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether not to use CUDA when available\")\n",
    "    parser.add_argument(\"--do_eval\",\n",
    "                        default=True,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether not to use CUDA when available\")\n",
    "    parser.add_argument(\"--no_cuda\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether not to use CUDA when available\")\n",
    "    parser.add_argument(\"--evaluation_strategy\",\n",
    "                        default=\"steps\",\n",
    "                        type=str,\n",
    "                        help=\"When you evaluate your training model on eval set.\")\n",
    "    parser.add_argument(\"--cache_dir\",\n",
    "                        default=\"../tmp/\",\n",
    "                        type=str,\n",
    "                        help=\"Cache directory for the evaluation pipeline (not HF cache).\")\n",
    "    parser.add_argument(\"--logging_dir\",\n",
    "                        default=\"../tmp/\",\n",
    "                        type=str,\n",
    "                        help=\"Logging directory.\")\n",
    "    parser.add_argument(\"--output_dir\",\n",
    "                        default=\"../results/\",\n",
    "                        type=str,\n",
    "                        help=\"Output directory of this training process.\")\n",
    "    parser.add_argument(\"--max_seq_length\",\n",
    "                        default=128,\n",
    "                        type=int,\n",
    "                        help=\"The maximum total input sequence length after WordPiece tokenization. \\n\"\n",
    "                             \"Sequences longer than this will be truncated, and sequences shorter \\n\"\n",
    "                             \"than this will be padded.\")\n",
    "    parser.add_argument(\"--learning_rate\",\n",
    "                        default=2e-5,\n",
    "                        type=float,\n",
    "                        help=\"The initial learning rate for Adam.\")\n",
    "    parser.add_argument(\"--seed\",\n",
    "                        default=42,\n",
    "                        type=int,\n",
    "                        help=\"Random seed\")\n",
    "    parser.add_argument(\"--metric_for_best_model\",\n",
    "                        default=\"f1\",\n",
    "                        type=str,\n",
    "                        help=\"The metric to use to compare two different models.\")\n",
    "    parser.add_argument(\"--greater_is_better\",\n",
    "                        default=True,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether the `metric_for_best_model` should be maximized or not.\")\n",
    "    parser.add_argument(\"--is_tensorboard\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If tensorboard is connected.\")\n",
    "    parser.add_argument(\"--load_best_model_at_end\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether load best model and evaluate at the end.\")\n",
    "    parser.add_argument(\"--label_all_tokens\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether to put the label for one word on all tokens of generated by that word or just on the \"\n",
    "                             \"one (in which case the other tokens will have a padding index).\")\n",
    "    parser.add_argument(\"--eval_steps\",\n",
    "                        default=10,\n",
    "                        type=float,\n",
    "                        help=\"The total steps to flush logs to wandb specifically.\")\n",
    "    parser.add_argument(\"--logging_steps\",\n",
    "                        default=10,\n",
    "                        type=float,\n",
    "                        help=\"The total steps to flush logs to wandb specifically.\")\n",
    "    parser.add_argument(\"--save_total_limit\",\n",
    "                        default=-1,\n",
    "                        type=int,\n",
    "                        help=\"If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output dir.\")\n",
    "    # these are arguments for inoculations\n",
    "    parser.add_argument(\"--inoculation_patience_count\",\n",
    "                        default=5,\n",
    "                        type=int,\n",
    "                        help=\"If the evaluation metrics is not increasing with maximum this step number, the training will be stopped.\")\n",
    "    parser.add_argument(\"--inoculation_step_sample_size\",\n",
    "                        default=0.05,\n",
    "                        type=float,\n",
    "                        help=\"For each step, how many more adverserial samples you want to add in.\")\n",
    "    parser.add_argument(\"--per_device_train_batch_size\",\n",
    "                        default=8,\n",
    "                        type=int,\n",
    "                        help=\"\")\n",
    "    parser.add_argument(\"--per_device_eval_batch_size\",\n",
    "                        default=8,\n",
    "                        type=int,\n",
    "                        help=\"\")\n",
    "    parser.add_argument(\"--preprocessing_num_workers\",\n",
    "                        default=8,\n",
    "                        type=int,\n",
    "                        help=\"\")\n",
    "    parser.add_argument(\"--eval_sample_limit\",\n",
    "                        default=-1,\n",
    "                        type=int,\n",
    "                        help=\"\")\n",
    "    parser.add_argument(\"--num_train_epochs\",\n",
    "                        default=3.0,\n",
    "                        type=float,\n",
    "                        help=\"The total number of epochs for training.\")\n",
    "    parser.add_argument(\"--no_pretrain\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether to use pretrained model if provided.\")\n",
    "    parser.add_argument(\"--train_embeddings_only\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If only train embeddings not the whole model.\")\n",
    "    parser.add_argument(\"--train_linear_layer_only\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If only train embeddings not the whole model.\")\n",
    "    # these are arguments for scrambling texts\n",
    "    parser.add_argument(\"--scramble_proportion\",\n",
    "                        default=0.0,\n",
    "                        type=float,\n",
    "                        help=\"What is the percentage of text you want to scramble.\")\n",
    "    parser.add_argument(\"--eval_with_scramble\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If you are also evaluating with scrambled texts.\")\n",
    "    parser.add_argument(\"--overwrite_cache\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Overwrite the cached training and evaluation sets.\")\n",
    "    parser.add_argument(\"--n_layer_to_finetune\",\n",
    "                        default=-1,\n",
    "                        type=int,\n",
    "                        help=\"Indicate a number that is less than original layer if you only want to finetune with earlier layers only.\")\n",
    "    try:\n",
    "        get_ipython().run_line_magic('matplotlib', 'inline')\n",
    "        args = parser.parse_args([])\n",
    "    except:\n",
    "        args = parser.parse_args()\n",
    "    # os.environ[\"WANDB_DISABLED\"] = \"NO\" if args.is_tensorboard else \"YES\" # BUG\n",
    "    os.environ[\"TRANSFORMERS_CACHE\"] = \"../huggingface_inoculation_cache/\"\n",
    "    os.environ[\"WANDB_PROJECT\"] = f\"{args.task_name}_bertonomy\"\n",
    "    # if cache does not exist, create one\n",
    "    if not os.path.exists(os.environ[\"TRANSFORMERS_CACHE\"]): \n",
    "        os.makedirs(os.environ[\"TRANSFORMERS_CACHE\"])\n",
    "    TASK_CONFIG = {\n",
    "        \"conll2003\" : (\"tokens\", None),\n",
    "        \"en_ewt\" : (\"tokens\", None)\n",
    "    }\n",
    "    # we need some prehandling here for token classifications!\n",
    "    # we use panda loader now, to make sure it is backward compatible\n",
    "    # with our file writer.\n",
    "    pd_format = True\n",
    "    if \"data-files\" not in args.inoculation_data_path:\n",
    "        # here, we download from the huggingface server probably!\n",
    "        if args.inoculation_data_path == \"en_ewt\":\n",
    "            dataset = load_dataset(\"universal_dependencies\", args.inoculation_data_path, cache_dir=args.cache_dir) # make sure cache dir is self-contained!\n",
    "        else:\n",
    "            dataset = load_dataset(args.inoculation_data_path, cache_dir=args.cache_dir) # make sure cache dir is self-contained!\n",
    "        inoculation_train_df = dataset[\"train\"]\n",
    "        eval_df = dataset[\"validation\"]\n",
    "    else:  \n",
    "        pd_format = True\n",
    "        if args.inoculation_data_path.split(\".\")[-1] != \"tsv\":\n",
    "            if len(args.inoculation_data_path.split(\".\")) > 1:\n",
    "                logger.info(f\"***** Loading pre-loaded datasets from the disk directly! *****\")\n",
    "                pd_format = False\n",
    "                datasets = DatasetDict.load_from_disk(args.inoculation_data_path)\n",
    "                inoculation_step_sample_size = int(len(datasets[\"train\"]) * args.inoculation_step_sample_size)\n",
    "                logger.info(f\"***** Inoculation Sample Count: %s *****\"%(inoculation_step_sample_size))\n",
    "                # this may not always start for zero inoculation\n",
    "                training_args = generate_training_args(args, inoculation_step=inoculation_step_sample_size)\n",
    "                datasets[\"train\"] = datasets[\"train\"].shuffle(seed=args.seed)\n",
    "                inoculation_train_df = datasets[\"train\"].select(range(inoculation_step_sample_size))\n",
    "                eval_df = datasets[\"validation\"]\n",
    "                datasets[\"validation\"] = datasets[\"validation\"].shuffle(seed=args.seed)\n",
    "                if args.eval_sample_limit != -1:\n",
    "                    datasets[\"validation\"] = datasets[\"validation\"].select(range(args.eval_sample_limit))\n",
    "            else:\n",
    "                logger.info(f\"***** Loading downloaded huggingface datasets: {args.inoculation_data_path}! *****\")\n",
    "                pd_format = False\n",
    "                if args.inoculation_data_path in [\"sst3\", \"cola\", \"mnli\", \"snli\", \"mrps\", \"qnli\", \"conll2003\", \"en_ewt\"]:\n",
    "                    pass\n",
    "                raise NotImplementedError()\n",
    "    \n",
    "    text_column_name = TASK_CONFIG[args.task_name][0]\n",
    "    label_column_name = args.token_type\n",
    "    features = inoculation_train_df.features\n",
    "    \n",
    "    def get_label_list(labels):\n",
    "        unique_labels = set()\n",
    "        for label in labels:\n",
    "            unique_labels = unique_labels | set(label)\n",
    "        label_list = list(unique_labels)\n",
    "        label_list.sort()\n",
    "        return label_list\n",
    "\n",
    "    if isinstance(features[label_column_name].feature, ClassLabel):\n",
    "        label_list = features[label_column_name].feature.names\n",
    "        # No need to convert the labels since they are already ints.\n",
    "        label_to_id = {i: i for i in range(len(label_list))}\n",
    "    else:\n",
    "        label_list = get_label_list(inoculation_train_df[label_column_name])\n",
    "        label_to_id = {l: i for i, l in enumerate(label_list)}\n",
    "    num_labels = len(label_list)\n",
    "    \n",
    "    # Load pretrained model and tokenizer\n",
    "    MAX_SEQ_LEN = args.max_seq_length\n",
    "    training_args = generate_training_args(args, inoculation_step=0)\n",
    "    config = AutoConfig.from_pretrained(\n",
    "        args.model_type,\n",
    "        num_labels=num_labels,\n",
    "        finetuning_task=args.task_name,\n",
    "        cache_dir=args.cache_dir\n",
    "    )\n",
    "    if args.n_layer_to_finetune != -1:\n",
    "        # then we are only finetuning n-th layer, not all the layers\n",
    "        if args.n_layer_to_finetune > config.num_hidden_layers:\n",
    "            logger.info(f\"***** WARNING: You are trying to train with first {args.n_layer_to_finetune} layers only *****\")\n",
    "            logger.info(f\"***** WARNING: But the model has only {config.num_hidden_layers} layers *****\")\n",
    "            logger.info(f\"***** WARNING: Training with all layers instead! *****\")\n",
    "            pass # just to let it happen, just train it with all layers\n",
    "        else:\n",
    "            # overwrite\n",
    "            logger.info(f\"***** WARNING: You are trying to train with first {args.n_layer_to_finetune} layers only *****\")\n",
    "            logger.info(f\"***** WARNING: But the model has only {config.num_hidden_layers} layers *****\")\n",
    "            config.num_hidden_layers = args.n_layer_to_finetune\n",
    "    \n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        args.model_type,\n",
    "        use_fast=True,\n",
    "        cache_dir=args.cache_dir\n",
    "    )\n",
    "    # Tokenizer check: this script requires a fast tokenizer.\n",
    "    if not isinstance(tokenizer, PreTrainedTokenizerFast):\n",
    "        raise ValueError(\n",
    "            \"This example script only works for models that have a fast tokenizer. Checkout the big table of models \"\n",
    "            \"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this \"\n",
    "            \"requirement\"\n",
    "        )\n",
    "    \n",
    "    if args.no_pretrain:\n",
    "        logger.info(\"***** Training new model from scratch *****\")\n",
    "        model = AutoModelForTokenClassification.from_config(config)\n",
    "    else:\n",
    "        model = AutoModelForTokenClassification.from_pretrained(\n",
    "            args.model_path,\n",
    "            from_tf=False,\n",
    "            config=config,\n",
    "            cache_dir=args.cache_dir\n",
    "        )\n",
    "        \n",
    "    if args.train_embeddings_only:\n",
    "        logger.info(\"***** We only train embeddings, not other layers *****\")\n",
    "        for name, param in model.named_parameters():\n",
    "            if 'word_embeddings' not in name: # only word embeddings\n",
    "                param.requires_grad = False\n",
    "    \n",
    "    if args.train_linear_layer_only:\n",
    "        logger.info(\"***** We only train classifier head, not other layers *****\")\n",
    "        for name, param in model.named_parameters():\n",
    "            if 'classifier' not in name: # only word embeddings\n",
    "                param.requires_grad = False\n",
    "        \n",
    "    logger.info(f\"***** TASK NAME: {args.task_name} *****\")\n",
    "    datasets = {}\n",
    "    datasets[\"train\"] = inoculation_train_df\n",
    "    datasets[\"validation\"] = eval_df\n",
    "    logger.info(f\"***** Train Sample Count (Verify): %s *****\"%(len(datasets[\"train\"])))\n",
    "    logger.info(f\"***** Valid Sample Count (Verify): %s *****\"%(len(datasets[\"validation\"])))\n",
    "    \n",
    "    padding = \"max_length\"\n",
    "    \n",
    "    # Tokenize all texts and align the labels with them.\n",
    "    def tokenize_and_align_labels(examples):\n",
    "        tokenized_inputs = tokenizer(\n",
    "            examples[text_column_name],\n",
    "            padding=padding,\n",
    "            truncation=True,\n",
    "            # We use this argument because the texts in our dataset are lists of words (with a label for each word).\n",
    "            is_split_into_words=True,\n",
    "        )\n",
    "        labels = []\n",
    "        for i, label in enumerate(examples[label_column_name]):\n",
    "            word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
    "            previous_word_idx = None\n",
    "            label_ids = []\n",
    "            for word_idx in word_ids:\n",
    "                # Special tokens have a word id that is None. We set the label to -100 so they are automatically\n",
    "                # ignored in the loss function.\n",
    "                if word_idx is None:\n",
    "                    label_ids.append(-100)\n",
    "                # We set the label for the first token of each word.\n",
    "                elif word_idx != previous_word_idx:\n",
    "                    label_ids.append(label_to_id[label[word_idx]])\n",
    "                # For the other tokens in a word, we set the label to either the current label or -100, depending on\n",
    "                # the label_all_tokens flag.\n",
    "                else:\n",
    "                    label_ids.append(label_to_id[label[word_idx]] if args.label_all_tokens else -100)\n",
    "                previous_word_idx = word_idx\n",
    "\n",
    "            labels.append(label_ids)\n",
    "        tokenized_inputs[\"labels\"] = labels\n",
    "        return tokenized_inputs\n",
    "    \n",
    "    # preparing datasets\n",
    "    datasets[\"train\"] = datasets[\"train\"].map(\n",
    "        tokenize_and_align_labels,\n",
    "        batched=True,\n",
    "        num_proc=args.preprocessing_num_workers,\n",
    "        load_from_cache_file=not args.overwrite_cache,\n",
    "    )\n",
    "    \n",
    "    datasets[\"validation\"] = datasets[\"validation\"].map(\n",
    "        tokenize_and_align_labels,\n",
    "        batched=True,\n",
    "        num_proc=args.preprocessing_num_workers,\n",
    "        load_from_cache_file=not args.overwrite_cache,\n",
    "    )\n",
    "    \n",
    "    data_collator = DataCollatorForTokenClassification(tokenizer, padding=padding)\n",
    "    \n",
    "    # Metrics\n",
    "    ner_metric = load_metric(\"seqeval\")\n",
    "\n",
    "    def compute_metrics_ner(p):\n",
    "        predictions, labels = p\n",
    "        predictions = np.argmax(predictions, axis=2)\n",
    "\n",
    "        # Remove ignored index (special tokens)\n",
    "        true_predictions = [\n",
    "            [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
    "            for prediction, label in zip(predictions, labels)\n",
    "        ]\n",
    "        true_labels = [\n",
    "            [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
    "            for prediction, label in zip(predictions, labels)\n",
    "        ]\n",
    "\n",
    "        results = ner_metric.compute(predictions=true_predictions, references=true_labels)\n",
    "        return {\n",
    "            \"precision\": results[\"overall_precision\"],\n",
    "            \"recall\": results[\"overall_recall\"],\n",
    "            \"f1\": results[\"overall_f1\"],\n",
    "            \"accuracy\": results[\"overall_accuracy\"],\n",
    "        }\n",
    "        \n",
    "    def compute_metrics_pos(p):\n",
    "        predictions, labels = p\n",
    "        predictions = np.argmax(predictions, axis=2)\n",
    "\n",
    "        # Remove ignored index (special tokens)\n",
    "        true_predictions = [\n",
    "            [p for (p, l) in zip(prediction, label) if l != -100]\n",
    "            for prediction, label in zip(predictions, labels)\n",
    "        ]\n",
    "        true_labels = [\n",
    "            [l for (p, l) in zip(prediction, label) if l != -100]\n",
    "            for prediction, label in zip(predictions, labels)\n",
    "        ]\n",
    "        \n",
    "        true_predictions = [item for sublist in true_predictions for item in sublist]\n",
    "        true_labels = [item for sublist in true_labels for item in sublist]\n",
    "        \n",
    "        result_to_print = classification_report(true_labels, true_predictions, digits=5, output_dict=True)\n",
    "        print(classification_report(true_labels, true_predictions, digits=5))\n",
    "        \n",
    "        return {\n",
    "            \"f1\": result_to_print[\"macro avg\"][\"f1-score\"],\n",
    "            \"accuracy\": result_to_print[\"accuracy\"],\n",
    "        }\n",
    "        \n",
    "    train_dataset = datasets[\"train\"]\n",
    "    eval_dataset = datasets[\"validation\"]\n",
    "\n",
    "    # Log a few random samples from the training set:\n",
    "    for index in random.sample(range(len(train_dataset)), 3):\n",
    "        logger.info(f\"Sample {index} of the training set: {train_dataset[index]}.\")\n",
    "        \n",
    "    # Initialize our Trainer\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=training_args,\n",
    "        train_dataset=train_dataset,\n",
    "        eval_dataset=eval_dataset,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=data_collator,\n",
    "        compute_metrics=compute_metrics_pos if args.task_name == \"en_ewt\" else compute_metrics_ner,\n",
    "    )\n",
    "    \n",
    "    # Early stop\n",
    "    if args.inoculation_patience_count != -1:\n",
    "        trainer.add_callback(EarlyStoppingCallback(args.inoculation_patience_count))\n",
    "\n",
    "    # Training\n",
    "    if training_args.do_train:\n",
    "        checkpoint = None\n",
    "        train_result = trainer.train(resume_from_checkpoint=checkpoint)\n",
    "        metrics = train_result.metrics\n",
    "        trainer.save_model()  # Saves the tokenizer too for easy upload\n",
    "\n",
    "        metrics[\"train_samples\"] = len(train_dataset)\n",
    "\n",
    "        # trainer.log_metrics(\"train\", metrics)\n",
    "        # trainer.save_metrics(\"train\", metrics)\n",
    "        # trainer.save_state()\n",
    "        \n",
    "    # Evaluation\n",
    "    if training_args.do_eval:\n",
    "        logger.info(\"*** Evaluate ***\")\n",
    "\n",
    "        metrics = trainer.evaluate()\n",
    "\n",
    "        max_val_samples = len(eval_dataset)\n",
    "        metrics[\"eval_samples\"] = min(max_val_samples, len(eval_dataset))\n",
    "\n",
    "        # trainer.log_metrics(\"eval\", metrics)\n",
    "        # trainer.save_metrics(\"eval\", metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
