{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from vocab_mismatch_utils import *\n",
    "from data_formatter_utils import *\n",
    "from datasets import DatasetDict\n",
    "from datasets import Dataset\n",
    "from datasets import load_dataset\n",
    "import transformers\n",
    "import pandas as pd\n",
    "from collections import OrderedDict\n",
    "import operator\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from torch.utils.data.sampler import RandomSampler, SequentialSampler\n",
    "from torch.nn import CrossEntropyLoss\n",
    "\n",
    "# Load modules, mainly huggingface basic model handlers.\n",
    "# Make sure you install huggingface and other packages properly.\n",
    "from collections import Counter\n",
    "import json\n",
    "\n",
    "from nltk.tokenize import TweetTokenizer\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.metrics import matthews_corrcoef\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import math\n",
    "import statistics\n",
    "\n",
    "import logging\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "import os\n",
    "os.environ[\"TRANSFORMERS_CACHE\"] = \"../huggingface_cache/\" # Not overload common dir \n",
    "                                                           # if run in shared resources.\n",
    "\n",
    "import random\n",
    "import sys\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Optional\n",
    "import torch\n",
    "import argparse\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from datasets import load_dataset, load_metric\n",
    "from datasets import Dataset\n",
    "from datasets import DatasetDict\n",
    "from tqdm import tqdm, trange\n",
    "\n",
    "import transformers\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    EvalPrediction,\n",
    "    HfArgumentParser,\n",
    "    PretrainedConfig,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    set_seed,\n",
    "    EarlyStoppingCallback\n",
    ")\n",
    "from transformers.trainer_utils import is_main_process, EvaluationStrategy\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "font = {'family' : 'Times New Roman',\n",
    "        'size'   : 30}\n",
    "plt.rc('font', **font)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name = \"cola\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(inoculation_data_path, eval_data_path=None, test_data_path=None,\n",
    "                inoculation_step_sample_size=1.0, \n",
    "                eval_sample_limit=-1, seed=42):\n",
    "    \"\"\"\n",
    "    eval_data_path is not needed if it is a saved_to_disk \n",
    "    huggingface dataset.\n",
    "    \n",
    "    return type is already a huggingface dataset.\n",
    "    \"\"\"\n",
    "    pd_format = True\n",
    "    if inoculation_data_path.split(\".\")[-1] != \"tsv\":\n",
    "        if len(inoculation_data_path.split(\".\")) > 1:\n",
    "            logger.info(f\"***** Loading pre-loaded datasets from the disk directly! *****\")\n",
    "            pd_format = False\n",
    "            datasets = DatasetDict.load_from_disk(inoculation_data_path)\n",
    "            inoculation_step_sample_size = int(len(datasets[\"train\"]) * inoculation_step_sample_size)\n",
    "            logger.info(f\"***** Inoculation Sample Count: %s *****\"%(inoculation_step_sample_size))\n",
    "            # this may not always start for zero inoculation\n",
    "            datasets[\"train\"] = datasets[\"train\"].shuffle(seed=seed)\n",
    "            inoculation_train_df = datasets[\"train\"].select(range(inoculation_step_sample_size))\n",
    "            eval_df = datasets[\"validation\"]\n",
    "            datasets[\"validation\"] = datasets[\"validation\"].shuffle(seed=seed)\n",
    "            if eval_sample_limit != -1:\n",
    "                datasets[\"validation\"] = datasets[\"validation\"].select(range(eval_sample_limit))\n",
    "        else:\n",
    "            logger.info(f\"***** Loading downloaded huggingface datasets: {inoculation_data_path}! *****\")\n",
    "            pd_format = False\n",
    "            if inoculation_data_path in [\"sst3\", \"cola\", \"mnli\", \"snli\", \"mrps\", \"qnli\"]:\n",
    "                pass\n",
    "            raise NotImplementedError()\n",
    "    else:\n",
    "        train_df = pd.read_csv(inoculation_data_path, delimiter=\"\\t\")\n",
    "        eval_df = pd.read_csv(eval_data_path, delimiter=\"\\t\")\n",
    "        test_df = pd.read_csv(test_data_path, delimiter=\"\\t\")\n",
    "        inoculation_step_sample_size = int(len(train_df) * inoculation_step_sample_size)\n",
    "        logger.info(f\"***** Inoculation Sample Count: %s *****\"%(inoculation_step_sample_size))\n",
    "        # this may not always start for zero inoculation\n",
    "        inoculation_train_df = train_df.sample(n=inoculation_step_sample_size, \n",
    "                                               replace=False, \n",
    "                                               random_state=seed) # seed here could not a little annoying.\n",
    "    if pd_format:\n",
    "        datasets = {}\n",
    "        datasets[\"train\"] = Dataset.from_pandas(inoculation_train_df)\n",
    "        datasets[\"validation\"] = Dataset.from_pandas(eval_df)\n",
    "        datasets[\"test\"] = Dataset.from_pandas(test_df)\n",
    "    else:\n",
    "        datasets = {}\n",
    "        datasets[\"train\"] = inoculation_train_df\n",
    "        datasets[\"validation\"] = eval_df\n",
    "    return datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CONFIG = {\n",
    "    \"sst3\": (\"text\", None),\n",
    "    \"cola\": (\"sentence\", None),\n",
    "    \"mnli\": (\"premise\", \"hypothesis\"),\n",
    "    \"snli\": (\"premise\", \"hypothesis\"),\n",
    "    \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
    "    \"qnli\": (\"question\", \"sentence\")\n",
    "}\n",
    "# WARNING: you dont need BERT tokenizer\n",
    "# original_vocab = load_bert_vocab(\"../data-files/bert_vocab.txt\")\n",
    "# original_tokenizer = transformers.BertTokenizer(\n",
    "#     vocab_file=\"../data-files/bert_vocab.txt\")\n",
    "# Just use some basic white space tokenizor here!\n",
    "modified_basic_tokenizer = ModifiedBasicTokenizer()\n",
    "max_length = 128\n",
    "per_device_train_batch_size = 128\n",
    "per_device_eval_batch_size = 128\n",
    "no_cuda = True\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() and not no_cuda else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count() if not no_cuda else 1 # 1 means just on cpu\n",
    "seed = 42\n",
    "lr = 1e-3\n",
    "num_train_epochs = 10\n",
    "sentence1_key, sentence2_key = TASK_CONFIG[task_name]\n",
    "\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if n_gpu > 0 and not no_cuda:\n",
    "    torch.cuda.manual_seed_all(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:15 - INFO - __main__ - ***** Inoculation Sample Count: 8551 *****\n",
      "03/29/2021 14:17:15 - INFO - __main__ - ***** Train Sample Count (Verify): 8551 *****\n",
      "03/29/2021 14:17:15 - INFO - __main__ - ***** Valid Sample Count (Verify): 1043 *****\n",
      "03/29/2021 14:17:15 - INFO - __main__ - ***** Test Sample Count (Verify): 1063 *****\n"
     ]
    }
   ],
   "source": [
    "# Setup logging\n",
    "logging.basicConfig(\n",
    "    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
    "    datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "    level=logging.INFO,\n",
    ")\n",
    "data_file_name = task_name if task_name != \"sst3\" else \"sst-tenary\"\n",
    "datasets = get_dataset(f\"../data-files/{data_file_name}/{data_file_name}-train.tsv\", \n",
    "                       f\"../data-files/{data_file_name}/{data_file_name}-dev.tsv\", \n",
    "                       f\"../data-files/{data_file_name}/{data_file_name}-test.tsv\")\n",
    "logger.info(f\"***** Train Sample Count (Verify): %s *****\"%(len(datasets[\"train\"])))\n",
    "logger.info(f\"***** Valid Sample Count (Verify): %s *****\"%(len(datasets[\"validation\"])))\n",
    "logger.info(f\"***** Test Sample Count (Verify): %s *****\"%(len(datasets[\"test\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### BoW preprocessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8551/8551 [00:01<00:00, 8031.65it/s]\n",
      "100%|██████████| 1043/1043 [00:00<00:00, 8366.12it/s]\n",
      "100%|██████████| 1063/1063 [00:00<00:00, 8201.12it/s]\n"
     ]
    }
   ],
   "source": [
    "def sanity_check_non_empty(sentece):\n",
    "    if sentece != None and sentece.strip() != \"\" and sentece.strip() != \"None\":\n",
    "        return True\n",
    "    return False\n",
    "\n",
    "# create the vocab file\n",
    "vocab_index = 0\n",
    "original_vocab = OrderedDict()\n",
    "if \"train\" in datasets:\n",
    "    for (ex_index, example) in enumerate(tqdm(datasets[\"train\"])):\n",
    "        if sentence2_key is None:\n",
    "            if sanity_check_non_empty(example[sentence1_key]):\n",
    "                sentence_combined = example[sentence1_key]\n",
    "        else:\n",
    "            s1 = \"\"\n",
    "            s2 = \"\"\n",
    "            if sanity_check_non_empty(example[sentence1_key]):\n",
    "                s1 = example[sentence1_key]\n",
    "            if sanity_check_non_empty(example[sentence2_key]):\n",
    "                s2 = example[sentence2_key]\n",
    "            sentence_combined = s1 + \" [SEP] \" + s2\n",
    "        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "        for token in sentence_tokens:\n",
    "            if token not in original_vocab.keys():\n",
    "                original_vocab[token] = vocab_index\n",
    "                vocab_index += 1\n",
    "train_data_only = False\n",
    "if not train_data_only:\n",
    "    if \"validation\" in datasets:\n",
    "        for (ex_index, example) in enumerate(tqdm(datasets[\"validation\"])):\n",
    "            if sentence2_key is None:\n",
    "                if sanity_check_non_empty(example[sentence1_key]):\n",
    "                    sentence_combined = example[sentence1_key]\n",
    "            else:\n",
    "                s1 = \"\"\n",
    "                s2 = \"\"\n",
    "                if sanity_check_non_empty(example[sentence1_key]):\n",
    "                    s1 = example[sentence1_key]\n",
    "                if sanity_check_non_empty(example[sentence2_key]):\n",
    "                    s2 = example[sentence2_key]\n",
    "                sentence_combined = s1 + \" [SEP] \" + s2\n",
    "            sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "            for token in sentence_tokens:\n",
    "                if token not in original_vocab.keys():\n",
    "                    original_vocab[token] = vocab_index\n",
    "                    vocab_index += 1\n",
    "\n",
    "    if \"test\" in datasets:\n",
    "        for (ex_index, example) in enumerate(tqdm(datasets[\"test\"])):\n",
    "            if sentence2_key is None:\n",
    "                if sanity_check_non_empty(example[sentence1_key]):\n",
    "                    sentence_combined = example[sentence1_key]\n",
    "            else:\n",
    "                s1 = \"\"\n",
    "                s2 = \"\"\n",
    "                if sanity_check_non_empty(example[sentence1_key]):\n",
    "                    s1 = example[sentence1_key]\n",
    "                if sanity_check_non_empty(example[sentence2_key]):\n",
    "                    s2 = example[sentence2_key]\n",
    "                sentence_combined = s1 + \" [SEP] \" + s2\n",
    "            sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "            for token in sentence_tokens:\n",
    "                if token not in original_vocab.keys():\n",
    "                    original_vocab[token] = vocab_index\n",
    "                    vocab_index += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|▉         | 759/8551 [00:00<00:02, 3775.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example sentence: Where all did they go for their holidays?\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8551/8551 [00:02<00:00, 3760.38it/s]\n",
      "/afs/cs.stanford.edu/u/wuzhengx/.local/lib/python3.7/site-packages/ipykernel_launcher.py:39: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n"
     ]
    }
   ],
   "source": [
    "# BoW feature vectors for train split\n",
    "train_input_features = []\n",
    "train_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(datasets[\"train\"])):\n",
    "    if sentence2_key is None:\n",
    "        bow_feature = torch.zeros(len(original_vocab))\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            sentence_combined = example[sentence1_key]\n",
    "        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence: \" + sentence_combined)\n",
    "        for t in sentence_tokens:\n",
    "            bow_feature[original_vocab[t]] += 1\n",
    "        train_input_features.append(bow_feature)\n",
    "        train_label_ids.append(example[\"label\"])\n",
    "    else:\n",
    "        bow_feature_1 = torch.zeros(len(original_vocab))\n",
    "        bow_feature_2 = torch.zeros(len(original_vocab))\n",
    "        s1 = \"\"\n",
    "        s2 = \"\"\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            s1 = example[sentence1_key]\n",
    "        if sanity_check_non_empty(example[sentence2_key]):\n",
    "            s2 = example[sentence2_key]\n",
    "        s1_tokens = modified_basic_tokenizer.tokenize(s1)\n",
    "        s2_tokens = modified_basic_tokenizer.tokenize(s2)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence 1: \" + s1)\n",
    "            print(\"Example sentence 2: \" + s2)\n",
    "        for t in s1_tokens:\n",
    "            bow_feature_1[original_vocab[t]] += 1\n",
    "        for t in s2_tokens:\n",
    "            bow_feature_2[original_vocab[t]] += 1\n",
    "        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)\n",
    "        train_input_features.append(bow_feature)\n",
    "        train_label_ids.append(example[\"label\"])\n",
    "    \n",
    "train_input_features = torch.stack(train_input_features, dim=0)\n",
    "train_input_features = torch.tensor(train_input_features, dtype=torch.float)\n",
    "train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)\n",
    "train_data = TensorDataset(train_input_features, train_label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 651/1043 [00:00<00:00, 3389.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example sentence: All who lost money in the scam are eligible for the program.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1043/1043 [00:00<00:00, 3073.06it/s]\n",
      "/afs/cs.stanford.edu/u/wuzhengx/.local/lib/python3.7/site-packages/ipykernel_launcher.py:39: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n"
     ]
    }
   ],
   "source": [
    "# BoW feature vectors for validation split\n",
    "validation_input_features = []\n",
    "validation_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(datasets[\"validation\"])):\n",
    "    if sentence2_key is None:\n",
    "        bow_feature = torch.zeros(len(original_vocab))\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            sentence_combined = example[sentence1_key]\n",
    "        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence: \" + sentence_combined)\n",
    "        for t in sentence_tokens:\n",
    "            bow_feature[original_vocab[t]] += 1\n",
    "    else:\n",
    "        bow_feature_1 = torch.zeros(len(original_vocab))\n",
    "        bow_feature_2 = torch.zeros(len(original_vocab))\n",
    "        s1 = \"\"\n",
    "        s2 = \"\"\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            s1 = example[sentence1_key]\n",
    "        if sanity_check_non_empty(example[sentence2_key]):\n",
    "            s2 = example[sentence2_key]\n",
    "        s1_tokens = modified_basic_tokenizer.tokenize(s1)\n",
    "        s2_tokens = modified_basic_tokenizer.tokenize(s2)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence 1: \" + s1)\n",
    "            print(\"Example sentence 2: \" + s2)\n",
    "        for t in s1_tokens:\n",
    "            bow_feature_1[original_vocab[t]] += 1\n",
    "        for t in s2_tokens:\n",
    "            bow_feature_2[original_vocab[t]] += 1\n",
    "        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)\n",
    "    validation_input_features.append(bow_feature)\n",
    "    validation_label_ids.append(example[\"label\"])\n",
    "\n",
    "    \n",
    "    \n",
    "validation_input_features = torch.stack(validation_input_features, dim=0)\n",
    "validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)\n",
    "validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)\n",
    "validation_data = TensorDataset(validation_input_features, validation_label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data loader\n",
    "train_sampler = RandomSampler(train_data)\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)\n",
    "validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### BoW Classifer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BOWClassifier(nn.Module):\n",
    "    def __init__(self, num_labels, vocab_size):\n",
    "        super(BOWClassifier, self).__init__()\n",
    "        self.classifier = nn.Linear(vocab_size, num_labels, bias=True)\n",
    "    def forward(self, x, labels=None):\n",
    "        logits = self.classifier(x)\n",
    "\n",
    "        if labels is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            loss = loss_fct(logits, labels)\n",
    "            return loss, logits\n",
    "        else:\n",
    "            return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MockBERTBOWClassifier(nn.Module):\n",
    "    def __init__(self, num_labels, vocab_size):\n",
    "        super(MockBERTBOWClassifier, self).__init__()\n",
    "        hidden_dim = 32\n",
    "        self.mock_bert = nn.Linear(vocab_size, hidden_dim, bias=False)\n",
    "        self.mock_activation = nn.Tanh()\n",
    "        self.classifier = nn.Linear(hidden_dim, num_labels, bias=False)\n",
    "    def forward(self, x, labels=None):\n",
    "        cls = self.mock_activation(self.mock_bert(x))\n",
    "        logits = self.classifier(cls)\n",
    "\n",
    "        if labels is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            loss = loss_fct(logits, labels)\n",
    "            return loss, logits\n",
    "        else:\n",
    "            return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some overriding fun stuffs!\n",
    "lr = 1e-3\n",
    "num_train_epochs = 50\n",
    "if sentence2_key is None:\n",
    "    in_dim = len(original_vocab)\n",
    "else:\n",
    "    in_dim = len(original_vocab) * 2\n",
    "model = BOWClassifier(len(validation_label_ids.unique()), in_dim)\n",
    "optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "if n_gpu > 0 and not no_cuda:\n",
    "    model = torch.nn.DataParallel(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Main training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:28 - INFO - __main__ - ***** Evaluation Interval Hit *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.32227   0.72360   0.44593       322\n",
      "           1    0.72188   0.32039   0.44380       721\n",
      "\n",
      "    accuracy                        0.44487      1043\n",
      "   macro avg    0.52207   0.52200   0.44487      1043\n",
      "weighted avg    0.59851   0.44487   0.44446      1043\n",
      "\n",
      "Macro-F1:  0.44486852446809977\n",
      "MCC:  0.044067014240337134\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:30 - INFO - __main__ - ***** Evaluation Interval Hit *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.54545   0.01863   0.03604       322\n",
      "           1    0.69380   0.99307   0.81689       721\n",
      "\n",
      "    accuracy                        0.69223      1043\n",
      "   macro avg    0.61963   0.50585   0.42646      1043\n",
      "weighted avg    0.64800   0.69223   0.57582      1043\n",
      "\n",
      "Macro-F1:  0.42646068772708823\n",
      "MCC:  0.0529051568355168\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:31 - INFO - __main__ - ***** Evaluation Interval Hit *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.39024   0.04969   0.08815       322\n",
      "           1    0.69461   0.96533   0.80789       721\n",
      "\n",
      "    accuracy                        0.68265      1043\n",
      "   macro avg    0.54243   0.50751   0.44802      1043\n",
      "weighted avg    0.60065   0.68265   0.58569      1043\n",
      "\n",
      "Macro-F1:  0.4480237397453669\n",
      "MCC:  0.03569488815055102\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:33 - INFO - __main__ - ***** Evaluation Interval Hit *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.44444   0.09938   0.16244       322\n",
      "           1    0.70134   0.94452   0.80496       721\n",
      "\n",
      "    accuracy                        0.68360      1043\n",
      "   macro avg    0.57289   0.52195   0.48370      1043\n",
      "weighted avg    0.62203   0.68360   0.60660      1043\n",
      "\n",
      "Macro-F1:  0.4837005436152212\n",
      "MCC:  0.07999963096499767\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:34 - INFO - __main__ - ***** Evaluation Interval Hit *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.44706   0.11801   0.18673       322\n",
      "           1    0.70355   0.93481   0.80286       721\n",
      "\n",
      "    accuracy                        0.68265      1043\n",
      "   macro avg    0.57530   0.52641   0.49480      1043\n",
      "weighted avg    0.62436   0.68265   0.61265      1043\n",
      "\n",
      "Macro-F1:  0.4947955156412572\n",
      "MCC:  0.08919578997554263\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:36 - INFO - __main__ - ***** Evaluation Interval Hit *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.42574   0.13354   0.20331       322\n",
      "           1    0.70382   0.91956   0.79735       721\n",
      "\n",
      "    accuracy                        0.67689      1043\n",
      "   macro avg    0.56478   0.52655   0.50033      1043\n",
      "weighted avg    0.61797   0.67689   0.61396      1043\n",
      "\n",
      "Macro-F1:  0.5003319359328111\n",
      "MCC:  0.08294222652033967\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:17:37 - INFO - __main__ - ***** Evaluation Interval Hit *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.42342   0.14596   0.21709       322\n",
      "           1    0.70494   0.91123   0.79492       721\n",
      "\n",
      "    accuracy                        0.67498      1043\n",
      "   macro avg    0.56418   0.52860   0.50600      1043\n",
      "weighted avg    0.61803   0.67498   0.61653      1043\n",
      "\n",
      "Macro-F1:  0.5060041997962973\n",
      "MCC:  0.08568412322809724\n",
      "Best Macro-F1:  0.5060041997962973\n",
      "Best MCC:  0.08919578997554263\n"
     ]
    }
   ],
   "source": [
    "global_step = 0\n",
    "best_f1 = -1\n",
    "best_mcc = -1\n",
    "for _ in range(int(num_train_epochs)):\n",
    "    \n",
    "    model.train()\n",
    "    # pbar = tqdm(train_dataloader, desc=\"Iteration\")\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        input_features, label_ids = batch\n",
    "\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            input_features = input_features.to(device)\n",
    "            label_ids = label_ids.to(device)\n",
    "\n",
    "        loss, _ = model(input_features, labels=label_ids)\n",
    "\n",
    "        if n_gpu > 1:\n",
    "            loss = loss.mean() # mean() to average on multi-gpu.\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        model.zero_grad()\n",
    "        # pbar.set_postfix({'train_loss': loss.tolist()})\n",
    "\n",
    "        if global_step % 500 == 0:\n",
    "            logger.info(\"***** Evaluation Interval Hit *****\")\n",
    "            model.eval()\n",
    "            all_logits = []\n",
    "            all_label_ids = []\n",
    "            with torch.no_grad():\n",
    "                # pbar = tqdm(validation_dataloader, desc=\"Iteration\")\n",
    "                for step, batch in enumerate(validation_dataloader):\n",
    "                    if torch.cuda.is_available() and not no_cuda:\n",
    "                        torch.cuda.empty_cache()\n",
    "                        \n",
    "                    input_features, label_ids = batch\n",
    "                    \n",
    "                    if torch.cuda.is_available() and not no_cuda:\n",
    "                        input_features = input_features.to(device)\n",
    "                        label_ids = label_ids.to(device)\n",
    "                    \n",
    "                    loss, logits = model(input_features, labels=label_ids)\n",
    "                    logits = F.softmax(logits, dim=-1)\n",
    "                    logits = logits.detach().cpu().numpy()\n",
    "                    label_ids = label_ids.to('cpu').numpy()\n",
    "                    outputs = np.argmax(logits, axis=1)\n",
    "                    all_logits.append(outputs)\n",
    "                    all_label_ids.append(label_ids)\n",
    "                    \n",
    "            all_logits = np.concatenate(all_logits, axis=0)\n",
    "            all_label_ids = np.concatenate(all_label_ids, axis=0)\n",
    "            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)\n",
    "            print(classification_report(all_label_ids, all_logits, digits=5))\n",
    "            f1 = result_to_save[\"macro avg\"][\"f1-score\"]\n",
    "            print(\"Macro-F1: \", f1)\n",
    "            best_f1 = f1 if f1 > best_f1 else best_f1\n",
    "            mcc = matthews_corrcoef(all_label_ids, all_logits)\n",
    "            best_mcc = mcc if mcc > best_mcc else best_mcc\n",
    "            print(\"MCC: \", mcc)\n",
    "                    \n",
    "        global_step += 1\n",
    "print(\"Best Macro-F1: \", best_f1)\n",
    "print(\"Best MCC: \", best_mcc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluations with frequency-matched scrambling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:18:31 - INFO - __main__ - ***** Loading pre-loaded datasets from the disk directly! *****\n",
      "03/29/2021 14:18:31 - INFO - __main__ - ***** Inoculation Sample Count: 8551 *****\n",
      "Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-matched/train/cache-f4202775805e0baf.arrow\n",
      "Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-matched/validation/cache-411cd289c16c3d03.arrow\n",
      "03/29/2021 14:18:31 - INFO - __main__ - ***** Train Sample Count (Verify): 8551 *****\n",
      "03/29/2021 14:18:31 - INFO - __main__ - ***** Valid Sample Count (Verify): 1043 *****\n",
      " 26%|██▌       | 267/1043 [00:00<00:00, 2663.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example sentence: can bill lifted whether and . sushi but includes on . critics the\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1043/1043 [00:00<00:00, 3237.66it/s]\n",
      "/afs/cs.stanford.edu/u/wuzhengx/.local/lib/python3.7/site-packages/ipykernel_launcher.py:48: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "03/29/2021 14:18:31 - INFO - __main__ - ***** Evaluation With Corrupt Data *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.29327   0.18944   0.23019       322\n",
      "           1    0.68743   0.79612   0.73779       721\n",
      "\n",
      "    accuracy                        0.60882      1043\n",
      "   macro avg    0.49035   0.49278   0.48399      1043\n",
      "weighted avg    0.56574   0.60882   0.58108      1043\n",
      "\n",
      "Macro-F1:  0.483988941165058\n",
      "MCC:  -0.016697947067186646\n"
     ]
    }
   ],
   "source": [
    "# Setup logging\n",
    "logging.basicConfig(\n",
    "    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
    "    datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "    level=logging.INFO,\n",
    ")\n",
    "corrupt_method = \"matched\"\n",
    "data_file_name = task_name if task_name != \"sst3\" else \"sst-tenary\"\n",
    "corrupt_datasets = get_dataset(f\"../data-files/{data_file_name}-corrupted-{corrupt_method}\")\n",
    "logger.info(f\"***** Train Sample Count (Verify): %s *****\"%(len(datasets[\"train\"])))\n",
    "logger.info(f\"***** Valid Sample Count (Verify): %s *****\"%(len(datasets[\"validation\"])))\n",
    "\n",
    "corrupt_validation_input_features = []\n",
    "corrupt_validation_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(corrupt_datasets[\"validation\"])):\n",
    "    if sentence2_key is None:\n",
    "        bow_feature = torch.zeros(len(original_vocab))\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            sentence_combined = example[sentence1_key]\n",
    "        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence: \" + sentence_combined)\n",
    "        for t in sentence_tokens:\n",
    "            bow_feature[original_vocab[t]] += 1\n",
    "    else:\n",
    "        bow_feature_1 = torch.zeros(len(original_vocab))\n",
    "        bow_feature_2 = torch.zeros(len(original_vocab))\n",
    "        s1 = \"\"\n",
    "        s2 = \"\"\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            s1 = example[sentence1_key]\n",
    "        if sanity_check_non_empty(example[sentence2_key]):\n",
    "            s2 = example[sentence2_key]\n",
    "        s1_tokens = modified_basic_tokenizer.tokenize(s1)\n",
    "        s2_tokens = modified_basic_tokenizer.tokenize(s2)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence 1: \" + s1)\n",
    "            print(\"Example sentence 2: \" + s2)\n",
    "        for t in s1_tokens:\n",
    "            bow_feature_1[original_vocab[t]] += 1\n",
    "        for t in s2_tokens:\n",
    "            bow_feature_2[original_vocab[t]] += 1\n",
    "        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)\n",
    "    corrupt_validation_input_features.append(bow_feature)\n",
    "    corrupt_validation_label_ids.append(example[\"label\"])\n",
    "    \n",
    "corrupt_validation_input_features = torch.stack(corrupt_validation_input_features, dim=0)\n",
    "corrupt_validation_input_features = torch.tensor(corrupt_validation_input_features, dtype=torch.float)\n",
    "corrupt_validation_label_ids = torch.tensor(corrupt_validation_label_ids, dtype=torch.long)\n",
    "corrupt_validation_data = TensorDataset(corrupt_validation_input_features, corrupt_validation_label_ids)\n",
    "corrupt_validation_dataloader = DataLoader(corrupt_validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)\n",
    "\n",
    "logger.info(\"***** Evaluation With Corrupt Data *****\")\n",
    "model.eval()\n",
    "all_logits = []\n",
    "all_label_ids = []\n",
    "with torch.no_grad():\n",
    "    # pbar = tqdm(validation_dataloader, desc=\"Iteration\")\n",
    "    for step, batch in enumerate(corrupt_validation_dataloader):\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        input_features, label_ids = batch\n",
    "\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            input_features = input_features.to(device)\n",
    "            label_ids = label_ids.to(device)\n",
    "\n",
    "        loss, logits = model(input_features, labels=label_ids)\n",
    "        logits = F.softmax(logits, dim=-1)\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = label_ids.to('cpu').numpy()\n",
    "        outputs = np.argmax(logits, axis=1)\n",
    "        all_logits.append(outputs)\n",
    "        all_label_ids.append(label_ids)\n",
    "\n",
    "all_logits = np.concatenate(all_logits, axis=0)\n",
    "all_label_ids = np.concatenate(all_label_ids, axis=0)\n",
    "result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)\n",
    "print(classification_report(all_label_ids, all_logits, digits=5))\n",
    "print(\"Macro-F1: \", result_to_save[\"macro avg\"][\"f1-score\"])\n",
    "mcc = matthews_corrcoef(all_label_ids, all_logits)\n",
    "print(\"MCC: \", mcc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluations with frequency-unmatched scrambling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/29/2021 14:18:43 - INFO - __main__ - ***** Loading pre-loaded datasets from the disk directly! *****\n",
      "03/29/2021 14:18:43 - INFO - __main__ - ***** Inoculation Sample Count: 8551 *****\n",
      "Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-mismatched/train/cache-452a380612a1acc4.arrow\n",
      "Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-mismatched/validation/cache-83a4d14ad2bf06f0.arrow\n",
      "03/29/2021 14:18:43 - INFO - __main__ - ***** Train Sample Count (Verify): 8551 *****\n",
      "03/29/2021 14:18:43 - INFO - __main__ - ***** Valid Sample Count (Verify): 1043 *****\n",
      " 59%|█████▊    | 611/1043 [00:00<00:00, 3043.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example sentence: bartlett infinite ? delivered respect outdone 1492 penny nina majestic outdone unpopular ugliest\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1043/1043 [00:00<00:00, 3019.72it/s]\n",
      "/afs/cs.stanford.edu/u/wuzhengx/.local/lib/python3.7/site-packages/ipykernel_launcher.py:48: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "03/29/2021 14:18:43 - INFO - __main__ - ***** Evaluation With Corrupt Data *****\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.38462   0.04658   0.08310       322\n",
      "           1    0.69422   0.96671   0.80812       721\n",
      "\n",
      "    accuracy                        0.68265      1043\n",
      "   macro avg    0.53942   0.50665   0.44561      1043\n",
      "weighted avg    0.59864   0.68265   0.58429      1043\n",
      "\n",
      "Macro-F1:  0.4456092175518889\n",
      "MCC:  0.03237739483038874\n"
     ]
    }
   ],
   "source": [
    "# Setup logging\n",
    "logging.basicConfig(\n",
    "    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
    "    datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "    level=logging.INFO,\n",
    ")\n",
    "corrupt_method = \"mismatched\"\n",
    "data_file_name = task_name if task_name != \"sst3\" else \"sst-tenary\"\n",
    "corrupt_datasets = get_dataset(f\"../data-files/{data_file_name}-corrupted-{corrupt_method}\")\n",
    "logger.info(f\"***** Train Sample Count (Verify): %s *****\"%(len(datasets[\"train\"])))\n",
    "logger.info(f\"***** Valid Sample Count (Verify): %s *****\"%(len(datasets[\"validation\"])))\n",
    "\n",
    "corrupt_validation_input_features = []\n",
    "corrupt_validation_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(corrupt_datasets[\"validation\"])):\n",
    "    if sentence2_key is None:\n",
    "        bow_feature = torch.zeros(len(original_vocab))\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            sentence_combined = example[sentence1_key]\n",
    "        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence: \" + sentence_combined)\n",
    "        for t in sentence_tokens:\n",
    "            bow_feature[original_vocab[t]] += 1\n",
    "    else:\n",
    "        bow_feature_1 = torch.zeros(len(original_vocab))\n",
    "        bow_feature_2 = torch.zeros(len(original_vocab))\n",
    "        s1 = \"\"\n",
    "        s2 = \"\"\n",
    "        if sanity_check_non_empty(example[sentence1_key]):\n",
    "            s1 = example[sentence1_key]\n",
    "        if sanity_check_non_empty(example[sentence2_key]):\n",
    "            s2 = example[sentence2_key]\n",
    "        s1_tokens = modified_basic_tokenizer.tokenize(s1)\n",
    "        s2_tokens = modified_basic_tokenizer.tokenize(s2)\n",
    "        if ex_index % 50000 == 0:\n",
    "            print(\"Example sentence 1: \" + s1)\n",
    "            print(\"Example sentence 2: \" + s2)\n",
    "        for t in s1_tokens:\n",
    "            bow_feature_1[original_vocab[t]] += 1\n",
    "        for t in s2_tokens:\n",
    "            bow_feature_2[original_vocab[t]] += 1\n",
    "        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)\n",
    "    corrupt_validation_input_features.append(bow_feature)\n",
    "    corrupt_validation_label_ids.append(example[\"label\"])\n",
    "    \n",
    "corrupt_validation_input_features = torch.stack(corrupt_validation_input_features, dim=0)\n",
    "corrupt_validation_input_features = torch.tensor(corrupt_validation_input_features, dtype=torch.float)\n",
    "corrupt_validation_label_ids = torch.tensor(corrupt_validation_label_ids, dtype=torch.long)\n",
    "corrupt_validation_data = TensorDataset(corrupt_validation_input_features, corrupt_validation_label_ids)\n",
    "corrupt_validation_dataloader = DataLoader(corrupt_validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)\n",
    "\n",
    "logger.info(\"***** Evaluation With Corrupt Data *****\")\n",
    "model.eval()\n",
    "all_logits = []\n",
    "all_label_ids = []\n",
    "with torch.no_grad():\n",
    "    # pbar = tqdm(validation_dataloader, desc=\"Iteration\")\n",
    "    for step, batch in enumerate(corrupt_validation_dataloader):\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        input_features, label_ids = batch\n",
    "\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            input_features = input_features.to(device)\n",
    "            label_ids = label_ids.to(device)\n",
    "\n",
    "        loss, logits = model(input_features, labels=label_ids)\n",
    "        logits = F.softmax(logits, dim=-1)\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = label_ids.to('cpu').numpy()\n",
    "        outputs = np.argmax(logits, axis=1)\n",
    "        all_logits.append(outputs)\n",
    "        all_label_ids.append(label_ids)\n",
    "\n",
    "all_logits = np.concatenate(all_logits, axis=0)\n",
    "all_label_ids = np.concatenate(all_label_ids, axis=0)\n",
    "result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)\n",
    "print(classification_report(all_label_ids, all_logits, digits=5))\n",
    "print(\"Macro-F1: \", result_to_save[\"macro avg\"][\"f1-score\"])\n",
    "mcc = matthews_corrcoef(all_label_ids, all_logits)\n",
    "print(\"MCC: \", mcc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Random guessing baseline\n",
    "If we randomly guess the lables, what is the performance now?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.32298   0.32298   0.32298       322\n",
      "           1    0.69764   0.69764   0.69764       721\n",
      "\n",
      "    accuracy                        0.58198      1043\n",
      "   macro avg    0.51031   0.51031   0.51031      1043\n",
      "weighted avg    0.58198   0.58198   0.58198      1043\n",
      "\n",
      "AVG over 100 runs mF1: 0.500884.\n",
      "Standard Deviation of sample is 0.015327163760148876 \n",
      "AVG over 100 runs MCC: 0.002038.\n"
     ]
    }
   ],
   "source": [
    "# getting avg mF1 on the dataset with a dummy classifier\n",
    "import numpy as np\n",
    "from sklearn.dummy import DummyClassifier\n",
    "\n",
    "mf1s = []\n",
    "mccs = []\n",
    "runs = 100\n",
    "for i in range(runs):\n",
    "    dummy_clf = DummyClassifier(strategy=\"stratified\")\n",
    "    dummy_clf.fit(validation_input_features, validation_label_ids)\n",
    "    dummy_labels = dummy_clf.predict(validation_input_features)\n",
    "\n",
    "    # dummy performance\n",
    "    # print(classification_report(validation_label_ids, dummy_labels, digits=5))\n",
    "    result_to_save = classification_report(validation_label_ids, dummy_labels, digits=5, output_dict=True)\n",
    "    mf1s += [result_to_save[\"macro avg\"][\"f1-score\"]]\n",
    "    mcc = matthews_corrcoef(validation_label_ids, dummy_labels)\n",
    "    mccs += [mcc]\n",
    "\n",
    "print(classification_report(validation_label_ids, dummy_labels, digits=5))\n",
    "print(f\"AVG over {runs} runs mF1: {round(sum(mf1s)/len(mf1s), 6)}.\")\n",
    "print(\"Standard Deviation of sample is % s \" % (statistics.stdev(mf1s)))\n",
    "print(f\"AVG over {runs} runs MCC: {round(sum(mccs)/len(mccs), 6)}.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### FrequencyBoW classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# task setups\n",
    "task_name = \"sst3\"\n",
    "num_labels = 3\n",
    "FILENAME_CONFIG = {\n",
    "    \"sst3\" : \"sst-tenary\"\n",
    "}\n",
    "\n",
    "# let us corrupt SST3 in the same way as before\n",
    "train_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                    f\"{FILENAME_CONFIG[task_name]}-train.tsv\"), \n",
    "                       delimiter=\"\\t\")\n",
    "eval_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                   f\"{FILENAME_CONFIG[task_name]}-dev.tsv\"), \n",
    "                      delimiter=\"\\t\")\n",
    "test_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                   f\"{FILENAME_CONFIG[task_name]}-test.tsv\"), \n",
    "                      delimiter=\"\\t\")\n",
    "\n",
    "train_df = Dataset.from_pandas(train_df)\n",
    "eval_df = Dataset.from_pandas(eval_df)\n",
    "test_df = Dataset.from_pandas(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_basic_tokenizer = ModifiedBasicTokenizer()\n",
    "label_vocab_map = {}\n",
    "token_frequency_map = {} # overwrite this everytime for a new dataset\n",
    "for i, example in enumerate(train_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        if label not in label_vocab_map.keys():\n",
    "            label_vocab_map[label] = tokens\n",
    "        else:\n",
    "            for t in tokens:\n",
    "                label_vocab_map[label].append(t)\n",
    "        for t in tokens:\n",
    "            if t in token_frequency_map.keys():\n",
    "                token_frequency_map[t] = token_frequency_map[t] + 1\n",
    "            else:\n",
    "                token_frequency_map[t] = 1\n",
    "for i, example in enumerate(eval_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        if label not in label_vocab_map.keys():\n",
    "            label_vocab_map[label] = tokens\n",
    "        else:\n",
    "            for t in tokens:\n",
    "                label_vocab_map[label].append(t)\n",
    "        for t in tokens:\n",
    "            if t in token_frequency_map.keys():\n",
    "                token_frequency_map[t] = token_frequency_map[t] + 1\n",
    "            else:\n",
    "                token_frequency_map[t] = 1\n",
    "for i, example in enumerate(test_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        if label not in label_vocab_map.keys():\n",
    "            label_vocab_map[label] = tokens\n",
    "        else:\n",
    "            for t in tokens:\n",
    "                label_vocab_map[label].append(t)\n",
    "        for t in tokens:\n",
    "            if t in token_frequency_map.keys():\n",
    "                token_frequency_map[t] = token_frequency_map[t] + 1\n",
    "            else:\n",
    "                token_frequency_map[t] = 1\n",
    "task_token_frequency_map = sorted(token_frequency_map.items(), key=operator.itemgetter(1), reverse=True)\n",
    "task_token_frequency_map = OrderedDict(task_token_frequency_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "training BoW with 1st order frequency bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# freq and bucket mappings\n",
    "freq_set = set([])\n",
    "for k, v in task_token_frequency_map.items():\n",
    "    freq_set.add(v)\n",
    "freq_set = list(freq_set)\n",
    "freq_set.sort()\n",
    "bucket_count = 256\n",
    "freq_bucket = np.logspace(math.log(freq_set[0], 10), math.log(freq_set[-1], 10), bucket_count, endpoint=True)\n",
    "freq_bucket = freq_bucket[:-1]\n",
    "freq_bucket = [math.ceil(n) for n in freq_bucket]\n",
    "# finally the bucket is a map between freq and bucket number\n",
    "def find_bucket_number(freq, freq_bucket):\n",
    "    for i in range(len(freq_bucket)):\n",
    "        if freq > freq_bucket[i]:\n",
    "            continue\n",
    "        else:\n",
    "            return i+1\n",
    "    return len(freq_bucket)\n",
    "\n",
    "new_bucket_idx = 0\n",
    "freq_bucket_map = {}\n",
    "for freq in freq_set:\n",
    "    # bucket_num = find_bucket_number(freq, freq_bucket)\n",
    "    freq_bucket_map[freq] = new_bucket_idx\n",
    "    new_bucket_idx += 1\n",
    "\n",
    "bucket_length = new_bucket_idx # len(freq_bucket)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# these lines of code make random buckets and assign words to them.\n",
    "freq_count = {}\n",
    "vocab = []\n",
    "for k, v in task_token_frequency_map.items():\n",
    "    vocab.append(k)\n",
    "    if v in freq_count.keys():\n",
    "        freq_count[v] += 1\n",
    "    else:\n",
    "        freq_count[v] = 1\n",
    "random.shuffle(vocab)\n",
    "bucket_length = 600\n",
    "def split(a, n):\n",
    "    k, m = divmod(len(a), n)\n",
    "    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))\n",
    "bucket_vocab_random = split(vocab, bucket_length)\n",
    "random_bucket_vocab_map = {}\n",
    "bucket_id = 0\n",
    "for bucket in bucket_vocab_random:\n",
    "    for word in bucket:\n",
    "        random_bucket_vocab_map[word] = bucket_id\n",
    "    bucket_id += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FBoW feature vectors for train split\n",
    "train_input_features = []\n",
    "train_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(train_df)):\n",
    "    bow_feature = torch.zeros(bucket_length)\n",
    "    if sentence2_key is None:\n",
    "        sentence_combined = example[sentence1_key]\n",
    "    else:\n",
    "        sentence_combined = example[sentence1_key] + \" [SEP] \" + example[sentence2_key]\n",
    "    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "    sentence_tokens = sentence_tokens[:max_length]\n",
    "    for t in sentence_tokens:\n",
    "        # bow_feature[freq_bucket_map[token_frequency_map[t]]] = 1 # not bucket count, aggregated info contains word identity!\n",
    "        bow_feature[random_bucket_vocab_map[t]] = +1\n",
    "    if ex_index % 50000 == 0:\n",
    "        print(\"Example sentence: \" + sentence_combined)\n",
    "        print(bow_feature)\n",
    "    train_input_features.append(bow_feature)\n",
    "    train_label_ids.append(example[\"label\"])\n",
    "    \n",
    "train_input_features = torch.stack(train_input_features, dim=0)\n",
    "train_input_features = torch.tensor(train_input_features, dtype=torch.float)\n",
    "train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)\n",
    "train_data = TensorDataset(train_input_features, train_label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FBoW feature vectors for validation split\n",
    "validation_input_features = []\n",
    "validation_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(eval_df)):\n",
    "    bow_feature = torch.zeros(bucket_length)\n",
    "    if sentence2_key is None:\n",
    "        sentence_combined = example[sentence1_key]\n",
    "    else:\n",
    "        sentence_combined = example[sentence1_key] + \" [SEP] \" + example[sentence2_key]\n",
    "    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "    sentence_tokens = sentence_tokens[:max_length]\n",
    "    for t in sentence_tokens:\n",
    "        # bow_feature[freq_bucket_map[token_frequency_map[t]]] = 1 # bucket count\n",
    "        bow_feature[random_bucket_vocab_map[t]] = +1\n",
    "    validation_input_features.append(bow_feature)\n",
    "    validation_label_ids.append(example[\"label\"])\n",
    "\n",
    "validation_input_features = torch.stack(validation_input_features, dim=0)\n",
    "validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)\n",
    "validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)\n",
    "validation_data = TensorDataset(validation_input_features, validation_label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data loader\n",
    "train_sampler = RandomSampler(train_data)\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)\n",
    "validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some overriding fun stuffs!\n",
    "lr = 1e-3\n",
    "num_train_epochs = 20\n",
    "model = BOWClassifier(len(validation_label_ids.unique()), bucket_length)\n",
    "optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "if n_gpu > 0 and not no_cuda:\n",
    "    model = torch.nn.DataParallel(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_step = 0\n",
    "max_score = -1\n",
    "for _ in range(int(num_train_epochs)):\n",
    "    \n",
    "    model.train()\n",
    "    # pbar = tqdm(train_dataloader, desc=\"Iteration\")\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        input_features, label_ids = batch\n",
    "\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            input_features = input_features.to(device)\n",
    "            label_ids = label_ids.to(device)\n",
    "\n",
    "        loss, _ = model(input_features, labels=label_ids)\n",
    "\n",
    "        if n_gpu > 1:\n",
    "            loss = loss.mean() # mean() to average on multi-gpu.\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        model.zero_grad()\n",
    "        # pbar.set_postfix({'train_loss': loss.tolist()})\n",
    "\n",
    "        if global_step % 500 == 0:\n",
    "            # logger.info(\"***** Evaluation Interval Hit *****\")\n",
    "            model.eval()\n",
    "            all_logits = []\n",
    "            all_label_ids = []\n",
    "            with torch.no_grad():\n",
    "                # pbar = tqdm(validation_dataloader, desc=\"Iteration\")\n",
    "                for step, batch in enumerate(validation_dataloader):\n",
    "                    if torch.cuda.is_available() and not no_cuda:\n",
    "                        torch.cuda.empty_cache()\n",
    "                        \n",
    "                    input_features, label_ids = batch\n",
    "                    \n",
    "                    if torch.cuda.is_available() and not no_cuda:\n",
    "                        input_features = input_features.to(device)\n",
    "                        label_ids = label_ids.to(device)\n",
    "                    \n",
    "                    loss, logits = model(input_features, labels=label_ids)\n",
    "                    logits = F.softmax(logits, dim=-1)\n",
    "                    logits = logits.detach().cpu().numpy()\n",
    "                    label_ids = label_ids.to('cpu').numpy()\n",
    "                    outputs = np.argmax(logits, axis=1)\n",
    "                    all_logits.append(outputs)\n",
    "                    all_label_ids.append(label_ids)\n",
    "                    \n",
    "            all_logits = np.concatenate(all_logits, axis=0)\n",
    "            all_label_ids = np.concatenate(all_label_ids, axis=0)\n",
    "            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)\n",
    "            # print(classification_report(all_label_ids, all_logits, digits=5))\n",
    "            print(\"Macro-F1: \", result_to_save[\"macro avg\"][\"f1-score\"])\n",
    "            if result_to_save[\"macro avg\"][\"f1-score\"] > max_score:\n",
    "                max_score = result_to_save[\"macro avg\"][\"f1-score\"]\n",
    "                    \n",
    "        global_step += 1\n",
    "print(\"Best Macro-F1: \", max_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "training BoW with 1st and 2nd order frequency bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# repartition the first order information\n",
    "second_order_freq_set = set([])\n",
    "for k, v in task_token_frequency_map.items():\n",
    "    second_order_freq_set.add(v)\n",
    "second_order_freq_set = list(second_order_freq_set)\n",
    "second_order_freq_set.sort()\n",
    "temp_bucket_count = 24\n",
    "second_order_freq_bucket = np.logspace(math.log(second_order_freq_set[0], 10), \n",
    "                          math.log(second_order_freq_set[-1], 10), temp_bucket_count+1, \n",
    "                          endpoint=True)\n",
    "second_order_freq_bucket = second_order_freq_bucket[:-1]\n",
    "second_order_freq_bucket = [math.ceil(n) for n in second_order_freq_bucket]\n",
    "# finally the bucket is a map between freq and bucket number\n",
    "def find_bucket_number(freq, freq_bucket):\n",
    "    for i in range(len(freq_bucket)):\n",
    "        if freq > freq_bucket[i]:\n",
    "            continue\n",
    "        else:\n",
    "            return i+1\n",
    "    return len(freq_bucket)\n",
    "\n",
    "second_order_freq_bucket_map = {}\n",
    "for freq in second_order_freq_set:\n",
    "    bucket_num = find_bucket_number(freq, second_order_freq_bucket)\n",
    "    second_order_freq_bucket_map[freq] = bucket_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_basic_tokenizer = ModifiedBasicTokenizer()\n",
    "token_freq_freq_map = {} # overwrite this everytime for a new dataset\n",
    "for i, example in enumerate(train_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        for i in range(len(tokens)-1):\n",
    "            for j in range(i+1, len(tokens)):\n",
    "                t1 = tokens[i]\n",
    "                t2 = tokens[j]\n",
    "                index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], \n",
    "                               second_order_freq_bucket_map[token_frequency_map[t2]]]\n",
    "                index_tuple.sort()\n",
    "                index_tuple = tuple(index_tuple)\n",
    "                if index_tuple in token_freq_freq_map.keys():\n",
    "                    token_freq_freq_map[index_tuple] += 1\n",
    "                else:\n",
    "                    token_freq_freq_map[index_tuple] = 1\n",
    "                    \n",
    "for i, example in enumerate(eval_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        for i in range(len(tokens)-1):\n",
    "            for j in range(i+1, len(tokens)):\n",
    "                t1 = tokens[i]\n",
    "                t2 = tokens[j]\n",
    "                index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], \n",
    "                               second_order_freq_bucket_map[token_frequency_map[t2]]]\n",
    "                index_tuple.sort()\n",
    "                index_tuple = tuple(index_tuple)\n",
    "                if index_tuple in token_freq_freq_map.keys():\n",
    "                    token_freq_freq_map[index_tuple] += 1\n",
    "                else:\n",
    "                    token_freq_freq_map[index_tuple] = 1\n",
    "                    \n",
    "for i, example in enumerate(test_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        for i in range(len(tokens)-1):\n",
    "            for j in range(i+1, len(tokens)):\n",
    "                t1 = tokens[i]\n",
    "                t2 = tokens[j]\n",
    "                index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], \n",
    "                               second_order_freq_bucket_map[token_frequency_map[t2]]]\n",
    "                index_tuple.sort()\n",
    "                index_tuple = tuple(index_tuple)\n",
    "                if index_tuple in token_freq_freq_map.keys():\n",
    "                    token_freq_freq_map[index_tuple] += 1\n",
    "                else:\n",
    "                    token_freq_freq_map[index_tuple] = 1\n",
    "                    \n",
    "task_token_freq_freq_map = sorted(token_freq_freq_map.items(), key=operator.itemgetter(1), reverse=True)\n",
    "task_token_freq_freq_map = OrderedDict(task_token_freq_freq_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# repartition the first order information\n",
    "second_order_freq_freq_set = set([])\n",
    "for k, v in task_token_freq_freq_map.items():\n",
    "    second_order_freq_freq_set.add(v)\n",
    "second_order_freq_freq_set = list(second_order_freq_freq_set)\n",
    "second_order_freq_freq_set.sort()\n",
    "# second_order_freq_freq_set = second_order_freq_freq_set[::-1]\n",
    "# bucket_count = 48\n",
    "# second_order_freq_freq_bucket = np.logspace(0, \n",
    "#                           math.log(len(second_order_freq_freq_set), 10), bucket_count, \n",
    "#                           endpoint=True)\n",
    "# second_order_freq_freq_bucket = second_order_freq_freq_bucket[:-1]\n",
    "# second_order_freq_freq_bucket = [math.ceil(n) for n in second_order_freq_freq_bucket]\n",
    "# for i in range(1, len(second_order_freq_freq_bucket)):\n",
    "#     if second_order_freq_freq_bucket[i] == second_order_freq_freq_bucket[i-1]:\n",
    "#         second_order_freq_freq_bucket[i] += 1\n",
    "# second_order_freq_freq_bucket += [len(second_order_freq_freq_set)]\n",
    "# start = 0\n",
    "# bucket_count = 0\n",
    "# second_order_freq_freq_bucket_map = {}\n",
    "# for i in range(len(second_order_freq_freq_bucket)):\n",
    "#     end = second_order_freq_freq_bucket[i]\n",
    "#     bucket_freqs = second_order_freq_freq_set[start:second_order_freq_freq_bucket[i]]\n",
    "#     for freq in bucket_freqs:\n",
    "#         second_order_freq_freq_bucket_map[freq] = bucket_count+1\n",
    "#     bucket_count += 1\n",
    "#     start = second_order_freq_freq_bucket[i]\n",
    "second_order_freq_freq_bucket_map = {}\n",
    "new_bucket_idx = 0\n",
    "freq_bucket_map = {}\n",
    "for freq in second_order_freq_freq_set:\n",
    "    # bucket_num = find_bucket_number(freq, freq_bucket)\n",
    "    second_order_freq_freq_bucket_map[freq] = new_bucket_idx\n",
    "    new_bucket_idx += 1\n",
    "\n",
    "bucket_length = new_bucket_idx # len(freq_bucket)\n",
    "# the code above create second order buckets, now we can create second order BoW vectors!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FBoW feature vectors for train split (2nd order = 1st order concat with 2nd order)\n",
    "train_input_features = []\n",
    "train_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(train_df)):\n",
    "    bow_feature = torch.zeros(bucket_length) # up-to 2nd feature map\n",
    "    if sentence2_key is None:\n",
    "        sentence_combined = example[sentence1_key]\n",
    "    else:\n",
    "        sentence_combined = example[sentence1_key] + \" [SEP] \" + example[sentence2_key]\n",
    "    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "    sentence_tokens = sentence_tokens[:max_length]\n",
    "    # first order here!\n",
    "#     for t in sentence_tokens:\n",
    "#         bow_feature[freq_bucket_map[token_frequency_map[t]]-1] += 1 # bucket count\n",
    "    # awesome :) second order here!\n",
    "    for i in range(len(sentence_tokens)-1):\n",
    "        for j in range(i+1, len(sentence_tokens)):\n",
    "            t1 = sentence_tokens[i]\n",
    "            t2 = sentence_tokens[j]\n",
    "            index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], second_order_freq_bucket_map[token_frequency_map[t2]]]\n",
    "            index_tuple.sort()\n",
    "            index_tuple = tuple(index_tuple)\n",
    "            second_order_bucket = second_order_freq_freq_bucket_map[task_token_freq_freq_map[index_tuple]]\n",
    "            bow_feature[second_order_bucket] += 1 # bucket count\n",
    "\n",
    "    if ex_index % 50000 == 0:\n",
    "        print(\"Example sentence: \" + sentence_combined)\n",
    "        print(bow_feature)\n",
    "    train_input_features.append(bow_feature)\n",
    "    train_label_ids.append(example[\"label\"])\n",
    "    \n",
    "train_input_features = torch.stack(train_input_features, dim=0)\n",
    "train_input_features = torch.tensor(train_input_features, dtype=torch.float)\n",
    "train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)\n",
    "train_data = TensorDataset(train_input_features, train_label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FBoW feature vectors for validation split\n",
    "validation_input_features = []\n",
    "validation_label_ids = []\n",
    "for (ex_index, example) in enumerate(tqdm(eval_df)):\n",
    "    bow_feature = torch.zeros(bucket_length) # up-to 2nd feature map\n",
    "    if sentence2_key is None:\n",
    "        sentence_combined = example[sentence1_key]\n",
    "    else:\n",
    "        sentence_combined = example[sentence1_key] + \" [SEP] \" + example[sentence2_key]\n",
    "    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)\n",
    "    sentence_tokens = sentence_tokens[:max_length]\n",
    "    # first order here!\n",
    "#     for t in sentence_tokens:\n",
    "#         bow_feature[freq_bucket_map[token_frequency_map[t]]] += 1 # bucket count\n",
    "    # awesome :) second order here!\n",
    "    for i in range(len(sentence_tokens)-1):\n",
    "        for j in range(i+1, len(sentence_tokens)):\n",
    "            t1 = sentence_tokens[i]\n",
    "            t2 = sentence_tokens[j]\n",
    "            index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], second_order_freq_bucket_map[token_frequency_map[t2]]]\n",
    "            index_tuple.sort()\n",
    "            index_tuple = tuple(index_tuple)\n",
    "            second_order_bucket = second_order_freq_freq_bucket_map[task_token_freq_freq_map[index_tuple]]\n",
    "            bow_feature[second_order_bucket] += 1 # bucket count\n",
    "\n",
    "    validation_input_features.append(bow_feature)\n",
    "    validation_label_ids.append(example[\"label\"])\n",
    "\n",
    "validation_input_features = torch.stack(validation_input_features, dim=0)\n",
    "validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)\n",
    "validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)\n",
    "validation_data = TensorDataset(validation_input_features, validation_label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data loader\n",
    "train_sampler = RandomSampler(train_data)\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)\n",
    "validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# restart the model\n",
    "model = BOWClassifier(len(validation_label_ids.unique()), \n",
    "                      bucket_length)\n",
    "lr = 1e-3\n",
    "optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "if n_gpu > 0 and not no_cuda:\n",
    "    model = torch.nn.DataParallel(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_step = 0\n",
    "num_train_epochs = 20\n",
    "max_score = -1\n",
    "for _ in range(int(num_train_epochs)):\n",
    "    \n",
    "    model.train()\n",
    "    # pbar = tqdm(train_dataloader, desc=\"Iteration\")\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        input_features, label_ids = batch\n",
    "\n",
    "        if torch.cuda.is_available() and not no_cuda:\n",
    "            input_features = input_features.to(device)\n",
    "            label_ids = label_ids.to(device)\n",
    "\n",
    "        loss, _ = model(input_features, labels=label_ids)\n",
    "\n",
    "        if n_gpu > 1:\n",
    "            loss = loss.mean() # mean() to average on multi-gpu.\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        model.zero_grad()\n",
    "        # pbar.set_postfix({'train_loss': loss.tolist()})\n",
    "\n",
    "        if global_step % 500 == 0:\n",
    "            # logger.info(\"***** Evaluation Interval Hit *****\")\n",
    "            model.eval()\n",
    "            all_logits = []\n",
    "            all_label_ids = []\n",
    "            with torch.no_grad():\n",
    "                # pbar = tqdm(validation_dataloader, desc=\"Iteration\")\n",
    "                for step, batch in enumerate(validation_dataloader):\n",
    "                    if torch.cuda.is_available() and not no_cuda:\n",
    "                        torch.cuda.empty_cache()\n",
    "                        \n",
    "                    input_features, label_ids = batch\n",
    "                    \n",
    "                    if torch.cuda.is_available() and not no_cuda:\n",
    "                        input_features = input_features.to(device)\n",
    "                        label_ids = label_ids.to(device)\n",
    "                    \n",
    "                    loss, logits = model(input_features, labels=label_ids)\n",
    "                    logits = F.softmax(logits, dim=-1)\n",
    "                    logits = logits.detach().cpu().numpy()\n",
    "                    label_ids = label_ids.to('cpu').numpy()\n",
    "                    outputs = np.argmax(logits, axis=1)\n",
    "                    all_logits.append(outputs)\n",
    "                    all_label_ids.append(label_ids)\n",
    "                    \n",
    "            all_logits = np.concatenate(all_logits, axis=0)\n",
    "            all_label_ids = np.concatenate(all_label_ids, axis=0)\n",
    "            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)\n",
    "            # print(classification_report(all_label_ids, all_logits, digits=5))\n",
    "            print(\"Macro-F1: \", result_to_save[\"macro avg\"][\"f1-score\"])\n",
    "            if result_to_save[\"macro avg\"][\"f1-score\"] > max_score:\n",
    "                max_score = result_to_save[\"macro avg\"][\"f1-score\"]\n",
    "                    \n",
    "        global_step += 1\n",
    "print(\"Best Macro-F1: \", max_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
