{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain\n",
    "\n",
    "import nltk\n",
    "import sklearn\n",
    "import scipy.stats\n",
    "from sklearn.metrics import make_scorer\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "\n",
    "import sklearn_crfsuite\n",
    "from sklearn_crfsuite import scorers\n",
    "from sklearn_crfsuite import metrics\n",
    "\n",
    "# imports\n",
    "from vocab_mismatch_utils import *\n",
    "from data_formatter_utils import *\n",
    "from datasets import DatasetDict\n",
    "from datasets import Dataset\n",
    "from datasets import list_datasets\n",
    "from datasets import load_dataset, load_metric\n",
    "import transformers\n",
    "import pandas as pd\n",
    "import operator\n",
    "from collections import OrderedDict\n",
    "from tqdm import tqdm, trange\n",
    "from seqeval.metrics import sequence_labeling\n",
    "\n",
    "import collections\n",
    "import os\n",
    "import unicodedata\n",
    "from typing import List, Optional, Tuple\n",
    "\n",
    "from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace\n",
    "from transformers.utils import logging\n",
    "import torch\n",
    "logger = logging.get_logger(__name__)\n",
    "import numpy as np\n",
    "import copy\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "lemmatizer = WordNetLemmatizer() \n",
    "from word_forms.word_forms import get_word_forms\n",
    "from functools import partial\n",
    "import matplotlib.ticker as mticker\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.dummy import DummyClassifier\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"font.family\"] = \"DejaVu Serif\"\n",
    "font = {'family' : 'DejaVu Serif',\n",
    "        'size'   : 20}\n",
    "plt.rc('font', **font)\n",
    "\n",
    "FILENAME_CONFIG = {\n",
    "    \"sst3\" : \"sst-tenary\",\n",
    "    \"cola\" : \"cola\",\n",
    "    \"mnli\" : \"mnli\",\n",
    "    \"snli\" : \"snli\",\n",
    "    \"mrpc\" : \"mrpc\",\n",
    "    \"qnli\" : \"qnli\",\n",
    "    \"conll2003\" : \"conll2003\",\n",
    "    \"en_ewt\" : \"en_ewt\"\n",
    "}\n",
    "TASK_CONFIG = {\n",
    "    \"wiki-text\": (\"text\", None),\n",
    "    \"sst3\": (\"text\", None),\n",
    "    \"cola\": (\"sentence\", None),\n",
    "    \"mnli\": (\"premise\", \"hypothesis\"),\n",
    "    \"snli\": (\"premise\", \"hypothesis\"),\n",
    "    \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
    "    \"qnli\": (\"question\", \"sentence\"),\n",
    "    \"conll2003\" : (\"tokens\", None),\n",
    "    \"en_ewt\" : (\"tokens\", None)\n",
    "}\n",
    "TAG_CONFIG = {\n",
    "    \"conll2003\" : \"ner_tags\",\n",
    "    \"en_ewt\" : \"upos\"\n",
    "}\n",
    "\n",
    "cache_dir = \"../tmp/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Random Classifer for Token Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# task setups\n",
    "task_name = \"conll2003\"\n",
    "# random seeds\n",
    "# WARNING: this may change your results as well. Try it a few different seeds.\n",
    "seed = 8\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset conll2003 (../tmp/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)\n"
     ]
    }
   ],
   "source": [
    "if task_name == \"conll2003\":\n",
    "    dataset = load_dataset(\"conll2003\", cache_dir=cache_dir)\n",
    "    train_df = dataset[\"train\"]\n",
    "    eval_df = dataset[\"validation\"]\n",
    "    test_df = dataset[\"test\"]\n",
    "elif task_name == \"en_ewt\":\n",
    "    dataset = load_dataset(\"universal_dependencies\", \"en_ewt\", cache_dir=cache_dir)\n",
    "    train_df = dataset[\"train\"]\n",
    "    eval_df = dataset[\"validation\"]\n",
    "    test_df = dataset[\"test\"]\n",
    "else:\n",
    "    # handle token data differently\n",
    "    train_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                        f\"train.tsv\"), \n",
    "                           delimiter=\"\\t\")\n",
    "    eval_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                       f\"dev.tsv\"), \n",
    "                          delimiter=\"\\t\")\n",
    "    test_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                       f\"test.tsv\"), \n",
    "                          delimiter=\"\\t\")\n",
    "\n",
    "    train_df = Dataset.from_pandas(train_df)\n",
    "    eval_df = Dataset.from_pandas(eval_df)\n",
    "    test_df = Dataset.from_pandas(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if task_name == \"conll2003\":\n",
    "    from datasets import ClassLabel\n",
    "    features = train_df.features\n",
    "\n",
    "    def get_label_list(labels):\n",
    "        unique_labels = set()\n",
    "        for label in labels:\n",
    "            unique_labels = unique_labels | set(label)\n",
    "        label_list = list(unique_labels)\n",
    "        label_list.sort()\n",
    "        return label_list\n",
    "    label_column_name = TAG_CONFIG[task_name]\n",
    "    if isinstance(features[label_column_name].feature, ClassLabel):\n",
    "        label_list = features[label_column_name].feature.names\n",
    "        # No need to convert the labels since they are already ints.\n",
    "        label_to_id = {i: i for i in range(len(label_list))}\n",
    "    else:\n",
    "        label_list = get_label_list(inoculation_train_df[label_column_name])\n",
    "        label_to_id = {l: i for i, l in enumerate(label_list)}\n",
    "    num_labels = len(label_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']\n"
     ]
    }
   ],
   "source": [
    "if task_name == \"conll2003\":\n",
    "    print(label_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"**** Dataset Statistics ****\")\n",
    "print(f\"training example = {len(train_df)}\")\n",
    "print(f\"validation example = {len(eval_df)}\")\n",
    "print(f\"testing example = {len(test_df)}\")\n",
    "print(\"****************************\")\n",
    "datasets = {\n",
    "    \"train\" : train_df, \n",
    "    \"validation\" : eval_df,\n",
    "    \"test\" : test_df\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_clf = DummyClassifier(strategy=\"stratified\") # a dummy classifier just with the label prob\n",
    "all_labels_train = []\n",
    "for i in range(0, len(train_df)):\n",
    "    all_labels_train.extend(train_df[i][TAG_CONFIG[task_name]])\n",
    "mock_x = [0] * len(all_labels_train)\n",
    "dummy_clf.fit(mock_x, all_labels_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {},
   "outputs": [],
   "source": [
    "actual_labels = []\n",
    "predicted_labels = []\n",
    "for i in range(0, len(test_df)):\n",
    "    dummy_labels = dummy_clf.predict(test_df[i][TAG_CONFIG[task_name]])\n",
    "    if task_name == \"conll2003\":\n",
    "        # need to do ner special handlings\n",
    "        actual_labels.append([label_list[label_to_id[ele]] for ele in test_df[i][TAG_CONFIG[task_name]]])\n",
    "        predicted_labels.append([label_list[label_to_id[ele]] for ele in dummy_labels.tolist()])\n",
    "    else:\n",
    "        actual_labels.extend([str(ele) for ele in test_df[i][TAG_CONFIG[task_name]]])\n",
    "        predicted_labels.extend([str(ele) for ele in dummy_labels.tolist()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'LOC': {'precision': 0.02319451765946231, 'recall': 0.026378896882494004, 'f1': 0.02468443197755961, 'number': 1668}, 'MISC': {'precision': 0.009930486593843098, 'recall': 0.014245014245014245, 'f1': 0.011702750146284378, 'number': 702}, 'ORG': {'precision': 0.018469656992084433, 'recall': 0.025285972305839857, 'f1': 0.021346886912325287, 'number': 1661}, 'PER': {'precision': 0.015561015561015561, 'recall': 0.02350030921459493, 'f1': 0.018723823601872382, 'number': 1617}, 'overall_precision': 0.01758530183727034, 'overall_recall': 0.023725212464589234, 'overall_f1': 0.02019897497738921, 'overall_accuracy': 0.6927533110800043}\n"
     ]
    }
   ],
   "source": [
    "if task_name == \"conll2003\":\n",
    "    seqeval = load_metric(\"seqeval\")\n",
    "    results = seqeval.compute(predictions=predicted_labels, references=actual_labels)\n",
    "    print(results)\n",
    "else:\n",
    "    from sklearn.metrics import classification_report\n",
    "    print(classification_report(actual_labels, predicted_labels, digits=5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CRF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CONFIG = {\n",
    "    \"sst3\": (\"text\", None),\n",
    "    \"cola\": (\"sentence\", None),\n",
    "    \"mnli\": (\"premise\", \"hypothesis\"),\n",
    "    \"snli\": (\"premise\", \"hypothesis\"),\n",
    "    \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
    "    \"qnli\": (\"question\", \"sentence\"), \n",
    "    \"conll2003\" : (\"tokens\", None),\n",
    "    \"en_ewt\" : (\"tokens\", None)\n",
    "}\n",
    "# WARNING: you dont need BERT tokenizer\n",
    "# original_vocab = load_bert_vocab(\"../data-files/bert_vocab.txt\")\n",
    "# original_tokenizer = transformers.BertTokenizer(\n",
    "#     vocab_file=\"../data-files/bert_vocab.txt\")\n",
    "# Just use some basic white space tokenizor here!\n",
    "modified_basic_tokenizer = ModifiedBasicTokenizer()\n",
    "max_length = 128\n",
    "per_device_train_batch_size = 128\n",
    "per_device_eval_batch_size = 128\n",
    "no_cuda = True\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() and not no_cuda else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count() if not no_cuda else 1 # 1 means just on cpu\n",
    "seed = 42\n",
    "lr = 1e-3\n",
    "num_train_epochs = 10\n",
    "sentence1_key, sentence2_key = TASK_CONFIG[task_name]\n",
    "\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if n_gpu > 0 and not no_cuda:\n",
    "    torch.cuda.manual_seed_all(args.seed)\n",
    "# get the vocab i think?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14041/14041 [00:01<00:00, 10155.58it/s]\n",
      "100%|██████████| 3250/3250 [00:00<00:00, 9136.55it/s]\n",
      "100%|██████████| 3453/3453 [00:00<00:00, 10833.27it/s]\n"
     ]
    }
   ],
   "source": [
    "def sanity_check_non_empty(sentece):\n",
    "    if sentece != None and len(sentece) != 0:\n",
    "        return True\n",
    "    return False\n",
    "\n",
    "# create the vocab file\n",
    "vocab_index = 0\n",
    "original_vocab = OrderedDict()\n",
    "if \"train\" in datasets:\n",
    "    for (ex_index, example) in enumerate(tqdm(datasets[\"train\"])):\n",
    "        if sentence2_key is None:\n",
    "            if sanity_check_non_empty(example[sentence1_key]):\n",
    "                sentence_combined = example[sentence1_key]\n",
    "        else:\n",
    "            pass\n",
    "        sentence_tokens = sentence_combined\n",
    "        for token in sentence_tokens:\n",
    "            if token not in original_vocab.keys():\n",
    "                original_vocab[token] = vocab_index\n",
    "                vocab_index += 1\n",
    "train_data_only = False\n",
    "if not train_data_only:\n",
    "    if \"validation\" in datasets:\n",
    "        for (ex_index, example) in enumerate(tqdm(datasets[\"validation\"])):\n",
    "            if sentence2_key is None:\n",
    "                if sanity_check_non_empty(example[sentence1_key]):\n",
    "                    sentence_combined = example[sentence1_key]\n",
    "            else:\n",
    "                pass\n",
    "            sentence_tokens = sentence_combined\n",
    "            for token in sentence_tokens:\n",
    "                if token not in original_vocab.keys():\n",
    "                    original_vocab[token] = vocab_index\n",
    "                    vocab_index += 1\n",
    "\n",
    "    if \"test\" in datasets:\n",
    "        for (ex_index, example) in enumerate(tqdm(datasets[\"test\"])):\n",
    "            if sentence2_key is None:\n",
    "                if sanity_check_non_empty(example[sentence1_key]):\n",
    "                    sentence_combined = example[sentence1_key]\n",
    "            else:\n",
    "                pass\n",
    "            sentence_tokens = sentence_combined\n",
    "            for token in sentence_tokens:\n",
    "                if token not in original_vocab.keys():\n",
    "                    original_vocab[token] = vocab_index\n",
    "                    vocab_index += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "alpha = list(string.ascii_lowercase) + list(string.ascii_uppercase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14041/14041 [00:03<00:00, 4245.50it/s]\n"
     ]
    }
   ],
   "source": [
    "train_features = []\n",
    "train_labels = []\n",
    "for (ex_index, example) in enumerate(tqdm(datasets[\"train\"])):\n",
    "    sentence_features = []\n",
    "    for t in example[\"tokens\"]:\n",
    "        feature_dict = {}\n",
    "        for a in alpha:\n",
    "            feature_dict[a] = 0\n",
    "        for char in t:\n",
    "            if char in alpha:\n",
    "                feature_dict[char] = feature_dict[char] + 1\n",
    "        sentence_features.append(feature_dict)\n",
    "    train_features.append(sentence_features)\n",
    "    label_str = [str(l) for l in example[TAG_CONFIG[task_name]]]\n",
    "    train_labels.append(label_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/afs/cs.stanford.edu/u/wuzhengx/.local/lib/python3.7/site-packages/sklearn/base.py:213: FutureWarning: From version 0.24, get_params will raise an AttributeError if a parameter cannot be retrieved as an instance attribute. Previously it would return None.\n",
      "  FutureWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CRF(algorithm='lbfgs', all_possible_transitions=True, c1=0.1, c2=0.1,\n",
       "    keep_tempfiles=None, max_iterations=100)"
      ]
     },
     "execution_count": 191,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "crf = sklearn_crfsuite.CRF(\n",
    "    algorithm='lbfgs',\n",
    "    c1=0.1,\n",
    "    c2=0.1,\n",
    "    max_iterations=100,\n",
    "    all_possible_transitions=True\n",
    ")\n",
    "\n",
    "crf.fit(train_features, train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['3', '0', '7', '1', '2', '5', '4', '8', '6']"
      ]
     },
     "execution_count": 192,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = list(crf.classes_)\n",
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3453/3453 [00:00<00:00, 4240.01it/s]\n"
     ]
    }
   ],
   "source": [
    "test_features = []\n",
    "test_labels = []\n",
    "for (ex_index, example) in enumerate(tqdm(datasets[\"test\"])):\n",
    "    sentence_features = []\n",
    "    for t in example[\"tokens\"]:\n",
    "        feature_dict = {}\n",
    "        for a in alpha:\n",
    "            feature_dict[a] = 0\n",
    "        for char in t:\n",
    "            if char in alpha:\n",
    "                feature_dict[char] = feature_dict[char] + 1\n",
    "        sentence_features.append(feature_dict)\n",
    "    test_features.append(sentence_features)\n",
    "    label_str = [str(l) for l in example[TAG_CONFIG[task_name]]]\n",
    "    test_labels.append(label_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'LOC': {'precision': 0.3340471092077088, 'recall': 0.2805755395683453, 'f1': 0.30498533724340177, 'number': 1668}, 'MISC': {'precision': 0.13043478260869565, 'recall': 0.038461538461538464, 'f1': 0.05940594059405941, 'number': 702}, 'ORG': {'precision': 0.28093023255813954, 'recall': 0.18181818181818182, 'f1': 0.22076023391812868, 'number': 1661}, 'PER': {'precision': 0.43007518796992483, 'recall': 0.35374149659863946, 'f1': 0.38819138106549034, 'number': 1617}, 'overall_precision': 0.34114129080488415, 'overall_recall': 0.24238668555240794, 'overall_f1': 0.2834075147500259, 'overall_accuracy': 0.8360288575428018}\n"
     ]
    }
   ],
   "source": [
    "y_pred = crf.predict(test_features)\n",
    "if task_name == \"conll2003\":\n",
    "    actual_labels = []\n",
    "    predicted_labels = []\n",
    "    for i in range(0, len(y_pred)):\n",
    "        actual = [label_list[label_to_id[int(ele)]] for ele in test_labels[i]]\n",
    "        pred = [label_list[label_to_id[int(ele)]] for ele in y_pred[i]]\n",
    "        assert len(actual) == len(pred)\n",
    "        actual_labels.append(actual)\n",
    "        predicted_labels.append(pred)\n",
    "    seqeval = load_metric(\"seqeval\")\n",
    "    results = seqeval.compute(predictions=predicted_labels, references=actual_labels)\n",
    "    print(results)\n",
    "else:\n",
    "    single_acc = metrics.flat_f1_score(test_labels, y_pred,\n",
    "                          average='micro', labels=labels)\n",
    "    print(single_acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
