{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "82b36dff-b312-4e44-9281-1f1a4450c9bb",
   "metadata": {},
   "source": [
    "# Covid Tweet Legitmacy Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f07f762-cd3f-4005-b234-455cefeef18a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# the model we gonna train, Covid-Twitter-Bert V2\n",
    "# check text classification models here: https://huggingface.co/models?filter=text-classification\n",
    "# https://huggingface.co/digitalepidemiologylab/covid-twitter-bert-v2\n",
    "model_name = \"digitalepidemiologylab/covid-twitter-bert-v2\"\n",
    "tokenizer_name = \"digitalepidemiologylab/covid-twitter-bert-v2\"\n",
    "max_length = 96"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec301661-26ff-45db-81d0-2d481772d85d",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "81cd42d2-a4fd-48d3-831a-88f3747424db",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import regex as re\n",
    "import string\n",
    "\n",
    "import torch\n",
    "from transformers.file_utils import is_tf_available, is_torch_available, is_torch_tpu_available\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AdamW,\n",
    "    Trainer,\n",
    "    TrainingArguments\n",
    ")\n",
    "\n",
    "import random\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.stem import WordNetLemmatizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f1ff6c8-4104-4921-8b0f-7833c9131d3a",
   "metadata": {},
   "source": [
    "### Set seed for consistency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a0455553-4aba-43ea-9087-b14747aa32e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed: int):\n",
    "    \"\"\"\n",
    "    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if\n",
    "    installed).\n",
    "\n",
    "    Args:\n",
    "        seed (:obj:`int`): The seed to set.\n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    if is_torch_available():\n",
    "        torch.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "        # ^^ safe to call this function even if cuda is not available\n",
    "    if is_tf_available():\n",
    "        import tensorflow as tf\n",
    "\n",
    "        tf.random.set_seed(seed)\n",
    "\n",
    "set_seed(5)     # For 'first-augmented-miscov19-covid-twitter-bert-v2' model\n",
    "# set_seed(314)   # For 'second-augmented-miscov19-covid-twitter-bert-v2' model\n",
    "# set_seed(2718)  # For 'third-augmented-miscov19-covid-twitter-bert-v2' model\n",
    "# set_seed(4669)  # For 'fourth-augmented-miscov19-covid-twitter-bert-v2' model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6723eac7-a78f-4e0e-8207-72eabbcc7c19",
   "metadata": {},
   "source": [
    "### Load data and create the datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "91b20d0d-dbd4-4b69-8d25-5940b5ce5005",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_12006/3992126773.py:4: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df['text'] = df['text'].astype(str)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>6572</th>\n",
       "      <td>nan</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6573</th>\n",
       "      <td>Why is everyone pretending like theyre immune ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6574</th>\n",
       "      <td>CDC said if you snort enough cocaine it would ...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6575</th>\n",
       "      <td>A message hailing the powers of 'boiled garlic...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6576</th>\n",
       "      <td>5G Networks? Bleach? Ultraviolet light? The co...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                   text  label\n",
       "6572                                                nan      1\n",
       "6573  Why is everyone pretending like theyre immune ...      0\n",
       "6574  CDC said if you snort enough cocaine it would ...      2\n",
       "6575  A message hailing the powers of 'boiled garlic...      0\n",
       "6576  5G Networks? Bleach? Ultraviolet light? The co...      0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_data = pd.read_csv('miscov19_augmented_final_hydrated.csv')\n",
    "df = raw_data[['text','label']]\n",
    "df.dropna()\n",
    "df['text'] = df['text'].astype(str)\n",
    "df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7f7b1f5c-e8e8-4344-8197-8d9ffa438725",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_names = ['legitimate','misinformation','irrelevant']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "36027e63-4772-40ed-aa5d-0b65682b1cb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "stop = stopwords.words('english')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a9403363-b5fb-4950-9fb2-abdf10babe38",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_12006/3244438452.py:34: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df['text'] = df['text'].apply(clean_text)\n"
     ]
    }
   ],
   "source": [
    "def clean_text(row):\n",
    "    # Lower case\n",
    "    row = row.lower()\n",
    "    \n",
    "    # Remove URLs\n",
    "    row = re.sub('http\\S+|www.\\S+', '', row)\n",
    "    \n",
    "    # Remove @mentions\n",
    "    row = re.sub('@[A-Za-z0-9]+', '', row)\n",
    "    \n",
    "    # Remove non-standard characters\n",
    "    row = row.encode(\"ascii\", \"ignore\").decode()\n",
    "    \n",
    "    # Remove punctuation\n",
    "    row = row.translate(str.maketrans('', '', string.punctuation))\n",
    "    \n",
    "    # Remove stop words\n",
    "    pat = r'\\b(?:{})\\b'.format('|'.join(stop))\n",
    "    row = row.replace(pat, '')\n",
    "    row = row.replace(r'\\s+', ' ')\n",
    "    \n",
    "    # Remove extraneous whitespace\n",
    "    row = row.strip()\n",
    "    \n",
    "    # Lemmatization\n",
    "    wordnet_lemmatizer = WordNetLemmatizer()\n",
    "    w_tokenization = nltk.word_tokenize(row)\n",
    "    final = \"\"\n",
    "    for w in w_tokenization:\n",
    "        final = final + \" \" + wordnet_lemmatizer.lemmatize(w)\n",
    "    \n",
    "    return final\n",
    "\n",
    "df['text'] = df['text'].apply(clean_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b1bca589-1c11-4e24-afc1-1197711b7485",
   "metadata": {},
   "outputs": [],
   "source": [
    "documents = df['text'].tolist()\n",
    "labels = df['label'].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9c2af05f-2a95-4b4a-a113-77b98036547f",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_size = 0.2 # Percentage of dataset used for validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0996c948-7963-4d56-afde-1b9eebf25e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_texts, valid_texts, train_labels, valid_labels = train_test_split(documents, labels, test_size=test_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "87073b2c-75e0-4d8f-a283-f437d2047e62",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use only if training notebook from scratch\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ef01c177-0032-444e-b2b2-015c00723244",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenize the dataset, truncate when passed `max_length`, \n",
    "# and pad with 0's when less than `max_length`\n",
    "train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=max_length)\n",
    "valid_encodings = tokenizer(valid_texts, truncation=True, padding=True, max_length=max_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "42f205a3-aa10-443c-9aab-cfecda1fa281",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MiscovDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}\n",
    "        item[\"labels\"] = torch.tensor([self.labels[idx]])\n",
    "        return item\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d80f9181-2978-4ae5-abb3-7b7b0980bc2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert our tokenized data into a torch Dataset\n",
    "train_dataset = MiscovDataset(train_encodings, train_labels)\n",
    "valid_dataset = MiscovDataset(valid_encodings, valid_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ebc20513-8047-4f63-bc96-a9246ce951f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at digitalepidemiologylab/covid-twitter-bert-v2 were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at digitalepidemiologylab/covid-twitter-bert-v2 and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(target_names)).to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1bb4bc6a-ac91-4f08-9f40-04e6df29e359",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "def compute_metrics(pred):\n",
    "    labels = pred.label_ids\n",
    "    preds = pred.predictions.argmax(-1)\n",
    "    # calculate accuracy using sklearn's function\n",
    "    acc = accuracy_score(labels, preds)\n",
    "    return {\n",
    "      'accuracy': acc,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ebfbc6d5-035d-4568-9d10-88c190620d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',          # output directory\n",
    "    num_train_epochs=3,              # total number of training epochs\n",
    "    per_device_train_batch_size=4,  # batch size per device during training\n",
    "    per_device_eval_batch_size=20,   # batch size for evaluation\n",
    "    warmup_steps=500,                # number of warmup steps for learning rate scheduler\n",
    "    weight_decay=0.01,               # strength of weight decay\n",
    "    learning_rate = 1e-5,            # learning rate\n",
    "    logging_dir='./logs',            # directory for storing logs\n",
    "    load_best_model_at_end=True,     # load the best model when finished training (default metric is loss)\n",
    "    # but you can specify `metric_for_best_model` argument to change to accuracy or other metric\n",
    "    metric_for_best_model='accuracy',\n",
    "    logging_steps=400,               # log & save weights each logging_steps\n",
    "    save_steps=400,\n",
    "    evaluation_strategy=\"steps\",     # evaluate each `logging_steps`\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,                         # the instantiated Transformers model to be trained\n",
    "    args=training_args,                  # training arguments, defined above\n",
    "    train_dataset=train_dataset,         # training dataset\n",
    "    eval_dataset=valid_dataset,          # evaluation dataset\n",
    "    compute_metrics=compute_metrics,     # the callback that computes metrics of interest\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "81ac665f-33a6-4feb-aa7f-59a4a28e0ae2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num examples = 5261\n",
      "  Num Epochs = 3\n",
      "  Instantaneous batch size per device = 4\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 4\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 3948\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3948' max='3948' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3948/3948 26:17, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>1.065000</td>\n",
       "      <td>0.968088</td>\n",
       "      <td>0.531915</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.842900</td>\n",
       "      <td>0.701968</td>\n",
       "      <td>0.718845</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.693700</td>\n",
       "      <td>0.620640</td>\n",
       "      <td>0.771277</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.621300</td>\n",
       "      <td>0.669740</td>\n",
       "      <td>0.775076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.576000</td>\n",
       "      <td>0.781552</td>\n",
       "      <td>0.778116</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.661700</td>\n",
       "      <td>0.674528</td>\n",
       "      <td>0.784195</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.503100</td>\n",
       "      <td>0.731292</td>\n",
       "      <td>0.781155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.382100</td>\n",
       "      <td>0.845506</td>\n",
       "      <td>0.787234</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.498600</td>\n",
       "      <td>0.844073</td>\n",
       "      <td>0.775836</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-400\n",
      "Configuration saved in ./results/checkpoint-400/config.json\n",
      "Model weights saved in ./results/checkpoint-400/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-800\n",
      "Configuration saved in ./results/checkpoint-800/config.json\n",
      "Model weights saved in ./results/checkpoint-800/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-1200\n",
      "Configuration saved in ./results/checkpoint-1200/config.json\n",
      "Model weights saved in ./results/checkpoint-1200/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-1600\n",
      "Configuration saved in ./results/checkpoint-1600/config.json\n",
      "Model weights saved in ./results/checkpoint-1600/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-2000\n",
      "Configuration saved in ./results/checkpoint-2000/config.json\n",
      "Model weights saved in ./results/checkpoint-2000/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-2400\n",
      "Configuration saved in ./results/checkpoint-2400/config.json\n",
      "Model weights saved in ./results/checkpoint-2400/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-2800\n",
      "Configuration saved in ./results/checkpoint-2800/config.json\n",
      "Model weights saved in ./results/checkpoint-2800/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-3200\n",
      "Configuration saved in ./results/checkpoint-3200/config.json\n",
      "Model weights saved in ./results/checkpoint-3200/pytorch_model.bin\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n",
      "Saving model checkpoint to ./results/checkpoint-3600\n",
      "Configuration saved in ./results/checkpoint-3600/config.json\n",
      "Model weights saved in ./results/checkpoint-3600/pytorch_model.bin\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n",
      "Loading best model from ./results/checkpoint-3200 (score: 0.7872340425531915).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=3948, training_loss=0.6327137739341914, metrics={'train_runtime': 1577.7288, 'train_samples_per_second': 10.004, 'train_steps_per_second': 2.502, 'total_flos': 2269509902321058.0, 'train_loss': 0.6327137739341914, 'epoch': 3.0})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "434e5c7d-83d2-426d-af6e-9bfe773ced30",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 1316\n",
      "  Batch size = 20\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='66' max='66' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [66/66 00:21]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.8455063104629517,\n",
       " 'eval_accuracy': 0.7872340425531915,\n",
       " 'eval_runtime': 21.7769,\n",
       " 'eval_samples_per_second': 60.431,\n",
       " 'eval_steps_per_second': 3.031,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate the current model after training\n",
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5388717b-f491-45b4-a469-5f9438472bb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Configuration saved in sixth-augmented-miscov19-covid-twitter-bert-v2/config.json\n",
      "Model weights saved in sixth-augmented-miscov19-covid-twitter-bert-v2/pytorch_model.bin\n",
      "tokenizer config file saved in sixth-augmented-miscov19-covid-twitter-bert-v2/tokenizer_config.json\n",
      "Special tokens file saved in sixth-augmented-miscov19-covid-twitter-bert-v2/special_tokens_map.json\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('sixth-augmented-miscov19-covid-twitter-bert-v2/tokenizer_config.json',\n",
       " 'sixth-augmented-miscov19-covid-twitter-bert-v2/special_tokens_map.json',\n",
       " 'sixth-augmented-miscov19-covid-twitter-bert-v2/vocab.txt',\n",
       " 'sixth-augmented-miscov19-covid-twitter-bert-v2/added_tokens.json',\n",
       " 'sixth-augmented-miscov19-covid-twitter-bert-v2/tokenizer.json')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# saving the fine tuned model & tokenizer\n",
    "model_path = \"first-augmented-miscov19-covid-twitter-bert-v2\"\n",
    "model.save_pretrained(model_path)\n",
    "tokenizer.save_pretrained(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "78fbd67c-e876-4339-8461-e8f6268bfdc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prediction(text):\n",
    "    # apply preprocessing to text\n",
    "    inputs = text # clean_text(text)\n",
    "    # prepare our text into tokenized sequence\n",
    "    inputs = tokenizer(inputs, padding=True, truncation=True, max_length=max_length, return_tensors=\"pt\").to(\"cuda\")\n",
    "    # perform inference to our model\n",
    "    outputs = model(**inputs)\n",
    "    # get output probabilities by doing softmax\n",
    "    probs = outputs[0].softmax(1)\n",
    "    # executing argmax function to get the candidate label\n",
    "    return target_names[probs.argmax()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "445660a1-e733-451e-85c2-c8502cbff1c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "irrelevant\n"
     ]
    }
   ],
   "source": [
    "# Example #1\n",
    "text = \"DP Dough is the best restaurant in New York\"\n",
    "print(get_prediction(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "f2e1066c-8b17-4f42-b433-ddb7973bd307",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "misinformation\n"
     ]
    }
   ],
   "source": [
    "# Example #2\n",
    "text2 = \"Vaccines cause autism\"\n",
    "print(get_prediction(text2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e841cfcd-2b3f-4961-925e-a357b7bd789d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "legitimate\n"
     ]
    }
   ],
   "source": [
    "# Example #3\n",
    "text3 = \"Vaccinations prevent over 90% of Covid infections! #Science\"\n",
    "print(get_prediction(text3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "0be6a33c-6eb9-4fa9-aa7f-fe951d5fecf2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "legitimate\n"
     ]
    }
   ],
   "source": [
    "# Example #4\n",
    "text4 = \"Vaccines will end the pandemic\"\n",
    "print(get_prediction(text4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "548e8f26-eaaf-40e1-9a37-a85e9559dfea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "misinformation\n"
     ]
    }
   ],
   "source": [
    "# Example #5\n",
    "text5 = \"scientists say bleach will prevent covid\"\n",
    "print(get_prediction(text5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "47d5c740-13dd-44d4-bb1f-8e02e10c6327",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "misinformation\n"
     ]
    }
   ],
   "source": [
    "# Example #6\n",
    "text6 = \"Biden says vaccines prevent over 90% of Covid infections!\"\n",
    "print(get_prediction(text6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "90e2c91a-42e9-466a-9217-943ee1b0f5cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "misinformation\n"
     ]
    }
   ],
   "source": [
    "# Example #7\n",
    "text7 = \"Biden says vaccines cause autism!\"\n",
    "print(get_prediction(text7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "0a405ca8-f932-48fa-a048-4dab794e1b32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "legitimate\n"
     ]
    }
   ],
   "source": [
    "# Example #8\n",
    "text8 = \"In Portugal, with 89% of the total population fully vaccinated, almost 90% of UCI Covid patients are unvaccinated\"\n",
    "print(get_prediction(text8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "31626f6f-954b-4eb8-b0af-70bf373a9acf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "irrelevant\n"
     ]
    }
   ],
   "source": [
    "# Example #9\n",
    "text9 = \"President Trump has covid\"\n",
    "print(get_prediction(text9))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "78c4a372-4ea9-4ea6-bb5f-ad57b2161d62",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "legitimate\n"
     ]
    }
   ],
   "source": [
    "# Example #10\n",
    "text10 = \"Vaccines don't stop you from getting covid.\"\n",
    "print(get_prediction(text10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "3072412f-f4fc-4756-b67d-cd5bc4e92680",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "misinformation\n"
     ]
    }
   ],
   "source": [
    "# Example #11\n",
    "text11 = \"Vaccinations stop you from getting covid.\"\n",
    "print(get_prediction(text11))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353bc78e-29a6-4f1e-917f-a59bb7a31982",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca1b18d0-83d2-4ad6-9288-9292cdd61e36",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": ".venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
