{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Pipeline for analyzing adverserial examples via fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load modules, mainly huggingface basic model handlers.\n",
    "# Make sure you install huggingface and other packages properly.\n",
    "from collections import Counter\n",
    "import json\n",
    "\n",
    "from nltk.tokenize import TweetTokenizer\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.metrics import matthews_corrcoef\n",
    "\n",
    "import logging\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "import os\n",
    "os.environ[\"TRANSFORMERS_CACHE\"] = \"../huggingface_cache/\" # Not overload common dir \n",
    "                                                           # if run in shared resources.\n",
    "\n",
    "import random\n",
    "import sys\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Optional\n",
    "import torch\n",
    "import argparse\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from datasets import load_dataset, load_metric\n",
    "from datasets import Dataset\n",
    "from datasets import DatasetDict\n",
    "\n",
    "import transformers\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    EvalPrediction,\n",
    "    HfArgumentParser,\n",
    "    PretrainedConfig,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    set_seed,\n",
    "    EarlyStoppingCallback\n",
    ")\n",
    "from transformers.trainer_utils import is_main_process, EvaluationStrategy\n",
    "from functools import partial\n",
    "\n",
    "from vocab_mismatch_utils import *\n",
    "basic_tokenizer = ModifiedBasicTokenizer()\n",
    "from models.modeling_bert import CustomerizedBertForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_training_args(args, inoculation_step):\n",
    "    training_args = TrainingArguments(\"tmp_trainer\")\n",
    "    training_args.no_cuda = args.no_cuda\n",
    "    training_args.seed = args.seed\n",
    "    training_args.do_train = args.do_train\n",
    "    training_args.do_eval = args.do_eval\n",
    "    training_args.output_dir = os.path.join(args.output_dir, str(inoculation_step)+\"-sample\")\n",
    "    training_args.evaluation_strategy = args.evaluation_strategy # evaluation is done after each epoch\n",
    "    training_args.metric_for_best_model = args.metric_for_best_model\n",
    "    training_args.greater_is_better = args.greater_is_better\n",
    "    training_args.logging_dir = args.logging_dir\n",
    "    training_args.task_name = args.task_name\n",
    "    training_args.learning_rate = args.learning_rate\n",
    "    training_args.per_device_train_batch_size = args.per_device_train_batch_size\n",
    "    training_args.per_device_eval_batch_size = args.per_device_eval_batch_size\n",
    "    training_args.num_train_epochs = args.num_train_epochs # this is the maximum num_train_epochs, we set this to be 100.\n",
    "    training_args.eval_steps = args.eval_steps\n",
    "    training_args.logging_steps = args.logging_steps\n",
    "    training_args.load_best_model_at_end = args.load_best_model_at_end\n",
    "    if args.save_total_limit != -1:\n",
    "        # only set if it is specified\n",
    "        training_args.save_total_limit = args.save_total_limit\n",
    "    import datetime\n",
    "    date_time = \"{}-{}\".format(datetime.datetime.now().month, datetime.datetime.now().day)\n",
    "    run_name = \"{0}_{1}_{2}_{3}_mlen_{4}_lr_{5}_seed_{6}_metrics_{7}\".format(\n",
    "        args.run_name,\n",
    "        args.task_name,\n",
    "        args.model_type,\n",
    "        date_time,\n",
    "        args.max_seq_length,\n",
    "        args.learning_rate,\n",
    "        args.seed,\n",
    "        args.metric_for_best_model\n",
    "    )\n",
    "    training_args.run_name = run_name\n",
    "    training_args_dict = training_args.to_dict()\n",
    "    # for PR\n",
    "    _n_gpu = training_args_dict[\"_n_gpu\"]\n",
    "    del training_args_dict[\"_n_gpu\"]\n",
    "    training_args_dict[\"n_gpu\"] = _n_gpu\n",
    "    HfParser = HfArgumentParser((TrainingArguments))\n",
    "    training_args = HfParser.parse_dict(training_args_dict)[0]\n",
    "\n",
    "    if args.model_path == \"\":\n",
    "        args.model_path = args.model_type\n",
    "        if args.model_type == \"\":\n",
    "            assert False # you have to provide one of them.\n",
    "    # Set seed before initializing model.\n",
    "    set_seed(training_args.seed)\n",
    "\n",
    "    # Setup logging\n",
    "    logging.basicConfig(\n",
    "        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
    "        datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,\n",
    "    )\n",
    "\n",
    "    # Log on each process the small summary:\n",
    "    logger.warning(\n",
    "        f\"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\"\n",
    "        + f\"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}\"\n",
    "    )\n",
    "\n",
    "    # Set the verbosity to info of the Transformers logger (on main process only):\n",
    "    if is_main_process(training_args.local_rank):\n",
    "        transformers.utils.logging.set_verbosity_info()\n",
    "        transformers.utils.logging.enable_default_handler()\n",
    "        transformers.utils.logging.enable_explicit_format()\n",
    "    logger.info(f\"Training/evaluation parameters {training_args}\")\n",
    "    return training_args\n",
    "\n",
    "class HuggingFaceRoBERTaBase:\n",
    "    \"\"\"\n",
    "    An extension for evaluation based off the huggingface module.\n",
    "    \"\"\"\n",
    "    def __init__(self, tokenizer, model, task_name, task_config):\n",
    "        self.task_name = task_name\n",
    "        self.task_config = task_config\n",
    "        self.tokenizer = tokenizer\n",
    "        self.model = model\n",
    "        \n",
    "    def train(self, inoculation_train_df, eval_df, model_path, training_args, max_length=128,\n",
    "              inoculation_patience_count=5, pd_format=True, \n",
    "              scramble_proportion=0.0, eval_with_scramble=False):\n",
    "\n",
    "        if pd_format:\n",
    "            datasets = {}\n",
    "            datasets[\"train\"] = Dataset.from_pandas(inoculation_train_df)\n",
    "            datasets[\"validation\"] = Dataset.from_pandas(eval_df)\n",
    "        else:\n",
    "            datasets = {}\n",
    "            datasets[\"train\"] = inoculation_train_df\n",
    "            datasets[\"validation\"] = eval_df\n",
    "        logger.info(f\"***** Train Sample Count (Verify): %s *****\"%(len(datasets[\"train\"])))\n",
    "        logger.info(f\"***** Valid Sample Count (Verify): %s *****\"%(len(datasets[\"validation\"])))\n",
    "    \n",
    "        label_list = datasets[\"validation\"].unique(\"label\")\n",
    "        label_list.sort()  # Let's sort it for determinism\n",
    "\n",
    "        sentence1_key, sentence2_key = self.task_config\n",
    "        \n",
    "        # we will scramble out input sentence here\n",
    "        # TODO: we scramble both train and eval sets\n",
    "        if self.task_name == \"sst3\" or self.task_name == \"cola\":\n",
    "            def scramble_inputs(proportion, example):\n",
    "                original_text = example[sentence1_key]\n",
    "                original_sentence = basic_tokenizer.tokenize(original_text)\n",
    "                max_length = len(original_sentence)\n",
    "                scramble_length = int(max_length*proportion)\n",
    "                scramble_start = random.randint(0, len(original_sentence)-scramble_length)\n",
    "                scramble_end = scramble_start + scramble_length\n",
    "                scramble_sentence = original_sentence[scramble_start:scramble_end]\n",
    "                random.shuffle(scramble_sentence)\n",
    "                scramble_text = original_sentence[:scramble_start] + scramble_sentence + original_sentence[scramble_end:]\n",
    "\n",
    "                out_string = \" \".join(scramble_text).replace(\" ##\", \"\").strip()\n",
    "                example[sentence1_key] = out_string\n",
    "                return example\n",
    "        elif self.task_name == \"snli\" or \\\n",
    "            self.task_name == \"mrpc\" or \\\n",
    "            self.task_name == \"qnli\":\n",
    "            def scramble_inputs(proportion, example):\n",
    "                original_premise = example[sentence1_key]\n",
    "                original_hypothesis = example[sentence2_key]\n",
    "                if original_hypothesis == None:\n",
    "                    original_hypothesis = \"\"\n",
    "                try:\n",
    "                    original_premise_tokens = basic_tokenizer.tokenize(original_premise)\n",
    "                    original_hypothesis_tokens = basic_tokenizer.tokenize(original_hypothesis)\n",
    "                except:\n",
    "                    print(\"Please debug these sequence...\")\n",
    "                    print(original_premise)\n",
    "                    print(original_hypothesis)\n",
    "\n",
    "                max_length = len(original_premise_tokens)\n",
    "                scramble_length = int(max_length*proportion)\n",
    "                scramble_start = random.randint(0, max_length-scramble_length)\n",
    "                scramble_end = scramble_start + scramble_length\n",
    "                scramble_sentence = original_premise_tokens[scramble_start:scramble_end]\n",
    "                random.shuffle(scramble_sentence)\n",
    "                scramble_text_premise = original_premise_tokens[:scramble_start] + scramble_sentence + original_premise_tokens[scramble_end:]\n",
    "\n",
    "                max_length = len(original_hypothesis_tokens)\n",
    "                scramble_length = int(max_length*proportion)\n",
    "                scramble_start = random.randint(0, max_length-scramble_length)\n",
    "                scramble_end = scramble_start + scramble_length\n",
    "                scramble_sentence = original_hypothesis_tokens[scramble_start:scramble_end]\n",
    "                random.shuffle(scramble_sentence)\n",
    "                scramble_text_hypothesis = original_hypothesis_tokens[:scramble_start] + scramble_sentence + original_hypothesis_tokens[scramble_end:]\n",
    "\n",
    "                out_string_premise = \" \".join(scramble_text_premise).replace(\" ##\", \"\").strip()\n",
    "                out_string_hypothesis = \" \".join(scramble_text_hypothesis).replace(\" ##\", \"\").strip()\n",
    "                example[sentence1_key] = out_string_premise\n",
    "                example[sentence2_key] = out_string_hypothesis\n",
    "                return example\n",
    "        \n",
    "        if scramble_proportion > 0.0:\n",
    "            logger.info(f\"You are scrambling the inputs to test syntactic feature importance!\")\n",
    "            datasets[\"train\"] = datasets[\"train\"].map(partial(scramble_inputs, scramble_proportion))\n",
    "            if eval_with_scramble:\n",
    "                logger.info(f\"You are scrambling the evaluation data as well!\")\n",
    "                datasets[\"validation\"] = datasets[\"validation\"].map(partial(scramble_inputs, scramble_proportion))\n",
    "        \n",
    "        padding = \"max_length\"\n",
    "        sentence1_key, sentence2_key = self.task_config\n",
    "        label_to_id = None\n",
    "        def preprocess_function(examples):\n",
    "            # Tokenize the texts\n",
    "            args = (\n",
    "                (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n",
    "            )\n",
    "            result = self.tokenizer(*args, padding=padding, max_length=max_length, truncation=True)\n",
    "            # Map labels to IDs (not necessary for GLUE tasks)\n",
    "            if label_to_id is not None and \"label\" in examples:\n",
    "                result[\"label\"] = [label_to_id[l] for l in examples[\"label\"]]\n",
    "            return result\n",
    "        datasets[\"train\"] = datasets[\"train\"].map(preprocess_function, batched=True)\n",
    "        datasets[\"validation\"] = datasets[\"validation\"].map(preprocess_function, batched=True)\n",
    "        \n",
    "        train_dataset = datasets[\"train\"]\n",
    "        eval_dataset = datasets[\"validation\"]\n",
    "        \n",
    "        # Log a few random samples from the training set:\n",
    "        for index in random.sample(range(len(train_dataset)), 3):\n",
    "            logger.info(f\"Sample {index} of the training set: {train_dataset[index]}.\")\n",
    "            \n",
    "        metric = load_metric(\"glue\", \"sst2\") # any glue task will do the job, just for eval loss\n",
    "        \n",
    "        def asenti_compute_metrics(p: EvalPrediction):\n",
    "            preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions\n",
    "            preds = np.argmax(preds, axis=1)\n",
    "            result_to_print = classification_report(p.label_ids, preds, digits=5, output_dict=True)\n",
    "            print(classification_report(p.label_ids, preds, digits=5))\n",
    "            mcc_scores = matthews_corrcoef(p.label_ids, preds)\n",
    "            logger.info(f\"MCC scores: {mcc_scores}.\")\n",
    "            result_to_return = metric.compute(predictions=preds, references=p.label_ids)\n",
    "            result_to_return[\"Macro-F1\"] = result_to_print[\"macro avg\"][\"f1-score\"]\n",
    "            result_to_return[\"MCC\"] = mcc_scores\n",
    "            return result_to_return\n",
    "\n",
    "        # Initialize our Trainer. We are only intersted in evaluations\n",
    "        trainer = Trainer(\n",
    "            model=model,\n",
    "            args=training_args,\n",
    "            train_dataset=train_dataset,\n",
    "            eval_dataset=eval_dataset,\n",
    "            compute_metrics=asenti_compute_metrics,\n",
    "            tokenizer=self.tokenizer,\n",
    "            # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.\n",
    "            data_collator=default_data_collator\n",
    "        )\n",
    "        # Early stop\n",
    "        if inoculation_patience_count != -1:\n",
    "            trainer.add_callback(EarlyStoppingCallback(inoculation_patience_count))\n",
    "        \n",
    "        # Training\n",
    "        if training_args.do_train:\n",
    "            logger.info(\"*** Training our model ***\")\n",
    "            trainer.train(\n",
    "                model_path=model_path\n",
    "            )\n",
    "            trainer.save_model()  # Saves the tokenizer too for easy upload\n",
    "        \n",
    "        # Evaluation\n",
    "        eval_results = {}\n",
    "        if training_args.do_eval:\n",
    "            logger.info(\"*** Evaluate ***\")\n",
    "            tasks = [self.task_name]\n",
    "            eval_datasets = [eval_dataset]\n",
    "            for eval_dataset, task in zip(eval_datasets, tasks):\n",
    "                eval_result = trainer.evaluate(eval_dataset=eval_dataset)\n",
    "                output_eval_file = os.path.join(training_args.output_dir, f\"eval_results_{task}.txt\")\n",
    "                if trainer.is_world_process_zero():\n",
    "                    with open(output_eval_file, \"w\") as writer:\n",
    "                        logger.info(f\"***** Eval results {task} *****\")\n",
    "                        for key, value in eval_result.items():\n",
    "                            logger.info(f\"  {key} = {value}\")\n",
    "                            writer.write(f\"{key} = {value}\\n\")\n",
    "                eval_results.update(eval_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    parser = argparse.ArgumentParser()\n",
    "\n",
    "    ## Required parameters\n",
    "    parser.add_argument(\"--wandb_proj_name\",\n",
    "                        default=\"\",\n",
    "                        type=str)\n",
    "    parser.add_argument(\"--task_name\",\n",
    "                        default=\"sst3\",\n",
    "                        type=str)\n",
    "    parser.add_argument(\"--run_name\",\n",
    "                        default=\"inoculation-run\",\n",
    "                        type=str)\n",
    "    parser.add_argument(\"--inoculation_data_path\",\n",
    "                        default=\"../data-files/sst-tenary/sst-tenary-train.tsv\",\n",
    "                        type=str,\n",
    "                        help=\"The input data dir. Should contain the .tsv files (or other data files) for the task.\")\n",
    "    parser.add_argument(\"--eval_data_path\",\n",
    "                        default=\"../data-files/sst-tenary/sst-tenary-dev.tsv\",\n",
    "                        type=str,\n",
    "                        help=\"The input data dir. Should contain the .tsv files (or other data files) for the task.\")\n",
    "    parser.add_argument(\"--model_path\",\n",
    "                        default=\"\",\n",
    "                        type=str,\n",
    "                        help=\"The pretrained model binary file.\")\n",
    "    parser.add_argument(\"--model_type\",\n",
    "                        default=\"bert-base-uncased\",\n",
    "                        type=str,\n",
    "                        help=\"The pretrained model binary file.\")\n",
    "    parser.add_argument(\"--do_train\",\n",
    "                        default=True,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether not to use CUDA when available\")\n",
    "    parser.add_argument(\"--do_eval\",\n",
    "                        default=True,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether not to use CUDA when available\")\n",
    "    parser.add_argument(\"--no_cuda\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether not to use CUDA when available\")\n",
    "    parser.add_argument(\"--evaluation_strategy\",\n",
    "                        default=\"steps\",\n",
    "                        type=str,\n",
    "                        help=\"When you evaluate your training model on eval set.\")\n",
    "    parser.add_argument(\"--cache_dir\",\n",
    "                        default=\"../tmp/\",\n",
    "                        type=str,\n",
    "                        help=\"Cache directory for the evaluation pipeline (not HF cache).\")\n",
    "    parser.add_argument(\"--logging_dir\",\n",
    "                        default=\"../tmp/\",\n",
    "                        type=str,\n",
    "                        help=\"Logging directory.\")\n",
    "    parser.add_argument(\"--output_dir\",\n",
    "                        default=\"../results/\",\n",
    "                        type=str,\n",
    "                        help=\"Output directory of this training process.\")\n",
    "    parser.add_argument(\"--max_seq_length\",\n",
    "                        default=128,\n",
    "                        type=int,\n",
    "                        help=\"The maximum total input sequence length after WordPiece tokenization. \\n\"\n",
    "                             \"Sequences longer than this will be truncated, and sequences shorter \\n\"\n",
    "                             \"than this will be padded.\")\n",
    "    parser.add_argument(\"--learning_rate\",\n",
    "                        default=2e-5,\n",
    "                        type=float,\n",
    "                        help=\"The initial learning rate for Adam.\")\n",
    "    parser.add_argument(\"--seed\",\n",
    "                        default=42,\n",
    "                        type=int,\n",
    "                        help=\"Random seed\")\n",
    "    parser.add_argument(\"--metric_for_best_model\",\n",
    "                        default=\"Macro-F1\",\n",
    "                        type=str,\n",
    "                        help=\"The metric to use to compare two different models.\")\n",
    "    parser.add_argument(\"--greater_is_better\",\n",
    "                        default=True,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether the `metric_for_best_model` should be maximized or not.\")\n",
    "    parser.add_argument(\"--is_tensorboard\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If tensorboard is connected.\")\n",
    "    parser.add_argument(\"--load_best_model_at_end\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether load best model and evaluate at the end.\")\n",
    "    parser.add_argument(\"--eval_steps\",\n",
    "                        default=10,\n",
    "                        type=float,\n",
    "                        help=\"The total steps to flush logs to wandb specifically.\")\n",
    "    parser.add_argument(\"--logging_steps\",\n",
    "                        default=10,\n",
    "                        type=float,\n",
    "                        help=\"The total steps to flush logs to wandb specifically.\")\n",
    "    parser.add_argument(\"--save_total_limit\",\n",
    "                        default=-1,\n",
    "                        type=int,\n",
    "                        help=\"If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output dir.\")\n",
    "    # these are arguments for inoculations\n",
    "    parser.add_argument(\"--inoculation_patience_count\",\n",
    "                        default=5,\n",
    "                        type=int,\n",
    "                        help=\"If the evaluation metrics is not increasing with maximum this step number, the training will be stopped.\")\n",
    "    parser.add_argument(\"--inoculation_step_sample_size\",\n",
    "                        default=0.05,\n",
    "                        type=float,\n",
    "                        help=\"For each step, how many more adverserial samples you want to add in.\")\n",
    "    parser.add_argument(\"--per_device_train_batch_size\",\n",
    "                        default=8,\n",
    "                        type=int,\n",
    "                        help=\"\")\n",
    "    parser.add_argument(\"--per_device_eval_batch_size\",\n",
    "                        default=8,\n",
    "                        type=int,\n",
    "                        help=\"\")\n",
    "    parser.add_argument(\"--eval_sample_limit\",\n",
    "                        default=-1,\n",
    "                        type=int,\n",
    "                        help=\"\")\n",
    "    parser.add_argument(\"--num_train_epochs\",\n",
    "                        default=3.0,\n",
    "                        type=float,\n",
    "                        help=\"The total number of epochs for training.\")\n",
    "    parser.add_argument(\"--no_pretrain\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"Whether to use pretrained model if provided.\")\n",
    "    parser.add_argument(\"--train_embeddings_only\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If only train embeddings not the whole model.\")\n",
    "    parser.add_argument(\"--train_linear_layer_only\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If only train embeddings not the whole model.\")\n",
    "    # these are arguments for scrambling texts\n",
    "    parser.add_argument(\"--scramble_proportion\",\n",
    "                        default=0.0,\n",
    "                        type=float,\n",
    "                        help=\"What is the percentage of text you want to scramble.\")\n",
    "    parser.add_argument(\"--eval_with_scramble\",\n",
    "                        default=False,\n",
    "                        action='store_true',\n",
    "                        help=\"If you are also evaluating with scrambled texts.\")\n",
    "    parser.add_argument(\"--n_layer_to_finetune\",\n",
    "                        default=-1,\n",
    "                        type=int,\n",
    "                        help=\"Indicate a number that is less than original layer if you only want to finetune with earlier layers only.\")\n",
    "    try:\n",
    "        get_ipython().run_line_magic('matplotlib', 'inline')\n",
    "        args = parser.parse_args([])\n",
    "    except:\n",
    "        args = parser.parse_args()\n",
    "    # os.environ[\"WANDB_DISABLED\"] = \"NO\" if args.is_tensorboard else \"YES\" # BUG\n",
    "    os.environ[\"TRANSFORMERS_CACHE\"] = \"../huggingface_inoculation_cache/\"\n",
    "    os.environ[\"WANDB_PROJECT\"] = f\"{args.task_name}_bertonomy\"\n",
    "    # if cache does not exist, create one\n",
    "    if not os.path.exists(os.environ[\"TRANSFORMERS_CACHE\"]): \n",
    "        os.makedirs(os.environ[\"TRANSFORMERS_CACHE\"])\n",
    "    TASK_CONFIG = {\n",
    "        \"sst3\": (\"text\", None),\n",
    "        \"cola\": (\"sentence\", None),\n",
    "        \"mnli\": (\"premise\", \"hypothesis\"),\n",
    "        \"snli\": (\"premise\", \"hypothesis\"),\n",
    "        \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
    "        \"qnli\": (\"question\", \"sentence\")\n",
    "    }\n",
    "    # Load pretrained model and tokenizer\n",
    "    NUM_LABELS = 2 if args.task_name == \"cola\" or args.task_name == \"mrpc\" or args.task_name == \"qnli\" else 3\n",
    "    MAX_SEQ_LEN = args.max_seq_length\n",
    "    training_args = generate_training_args(args, inoculation_step=0)\n",
    "    config = AutoConfig.from_pretrained(\n",
    "        args.model_type,\n",
    "        num_labels=NUM_LABELS,\n",
    "        finetuning_task=args.task_name,\n",
    "        cache_dir=args.cache_dir\n",
    "    )\n",
    "    if args.n_layer_to_finetune != -1:\n",
    "        # then we are only finetuning n-th layer, not all the layers\n",
    "        if args.n_layer_to_finetune > config.num_hidden_layers:\n",
    "            logger.info(f\"***** WARNING: You are trying to train with first {args.n_layer_to_finetune} layers only *****\")\n",
    "            logger.info(f\"***** WARNING: But the model has only {config.num_hidden_layers} layers *****\")\n",
    "            logger.info(f\"***** WARNING: Training with all layers instead! *****\")\n",
    "            pass # just to let it happen, just train it with all layers\n",
    "        else:\n",
    "            # overwrite\n",
    "            logger.info(f\"***** WARNING: You are trying to train with first {args.n_layer_to_finetune} layers only *****\")\n",
    "            logger.info(f\"***** WARNING: But the model has only {config.num_hidden_layers} layers *****\")\n",
    "            config.num_hidden_layers = args.n_layer_to_finetune\n",
    "    \n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        args.model_type,\n",
    "        use_fast=False,\n",
    "        cache_dir=args.cache_dir\n",
    "    )\n",
    "    if args.no_pretrain:\n",
    "        logger.info(\"***** Training new model from scratch *****\")\n",
    "        model = AutoModelForSequenceClassification.from_config(config)\n",
    "    else:\n",
    "        model = AutoModelForSequenceClassification.from_pretrained(\n",
    "            args.model_path,\n",
    "            from_tf=False,\n",
    "            config=config,\n",
    "            cache_dir=args.cache_dir\n",
    "        )\n",
    "        \n",
    "    if args.train_embeddings_only:\n",
    "        logger.info(\"***** We only train embeddings, not other layers *****\")\n",
    "        for name, param in model.named_parameters():\n",
    "            if 'word_embeddings' not in name: # only word embeddings\n",
    "                param.requires_grad = False\n",
    "    \n",
    "    if args.train_linear_layer_only:\n",
    "        logger.info(\"***** We only train classifier head, not other layers *****\")\n",
    "        for name, param in model.named_parameters():\n",
    "            if 'classifier' not in name: # only word embeddings\n",
    "                param.requires_grad = False\n",
    "        \n",
    "    train_pipeline = HuggingFaceRoBERTaBase(tokenizer, \n",
    "                                            model, args.task_name, \n",
    "                                            TASK_CONFIG[args.task_name])\n",
    "    logger.info(f\"***** TASK NAME: {args.task_name} *****\")\n",
    "    # we use panda loader now, to make sure it is backward compatible\n",
    "    # with our file writer.\n",
    "    pd_format = True\n",
    "    if args.inoculation_data_path.split(\".\")[-1] != \"tsv\":\n",
    "        if len(args.inoculation_data_path.split(\".\")) > 1:\n",
    "            logger.info(f\"***** Loading pre-loaded datasets from the disk directly! *****\")\n",
    "            pd_format = False\n",
    "            datasets = DatasetDict.load_from_disk(args.inoculation_data_path)\n",
    "            inoculation_step_sample_size = int(len(datasets[\"train\"]) * args.inoculation_step_sample_size)\n",
    "            logger.info(f\"***** Inoculation Sample Count: %s *****\"%(inoculation_step_sample_size))\n",
    "            # this may not always start for zero inoculation\n",
    "            training_args = generate_training_args(args, inoculation_step=inoculation_step_sample_size)\n",
    "            datasets[\"train\"] = datasets[\"train\"].shuffle(seed=args.seed)\n",
    "            inoculation_train_df = datasets[\"train\"].select(range(inoculation_step_sample_size))\n",
    "            eval_df = datasets[\"validation\"]\n",
    "            datasets[\"validation\"] = datasets[\"validation\"].shuffle(seed=args.seed)\n",
    "            if args.eval_sample_limit != -1:\n",
    "                datasets[\"validation\"] = datasets[\"validation\"].select(range(args.eval_sample_limit))\n",
    "        else:\n",
    "            logger.info(f\"***** Loading downloaded huggingface datasets: {args.inoculation_data_path}! *****\")\n",
    "            pd_format = False\n",
    "            if args.inoculation_data_path in [\"sst3\", \"cola\", \"mnli\", \"snli\", \"mrps\", \"qnli\"]:\n",
    "                pass\n",
    "            raise NotImplementedError()\n",
    "    else:\n",
    "        train_df = pd.read_csv(args.inoculation_data_path, delimiter=\"\\t\")\n",
    "        eval_df = pd.read_csv(args.eval_data_path, delimiter=\"\\t\")\n",
    "        inoculation_step_sample_size = int(len(train_df) * args.inoculation_step_sample_size)\n",
    "        logger.info(f\"***** Inoculation Sample Count: %s *****\"%(inoculation_step_sample_size))\n",
    "        # this may not always start for zero inoculation\n",
    "        training_args = generate_training_args(args, inoculation_step=inoculation_step_sample_size)\n",
    "        inoculation_train_df = train_df.sample(n=inoculation_step_sample_size, \n",
    "                                               replace=False, \n",
    "                                               random_state=args.seed) # seed here could not a little annoying.\n",
    "\n",
    "\n",
    "    train_pipeline.train(inoculation_train_df, eval_df, \n",
    "                         args.model_path,\n",
    "                         training_args, max_length=args.max_seq_length,\n",
    "                         inoculation_patience_count=args.inoculation_patience_count, pd_format=pd_format, \n",
    "                         scramble_proportion=args.scramble_proportion, eval_with_scramble=args.eval_with_scramble)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
